import numpy as np
import torch
from torch.utils.data import Dataset as TorchDataset
[docs]class TorchTabularTextDataset(TorchDataset):
"""
:obj:`TorchDataset` wrapper for text dataset with categorical features
and numerical features
Parameters:
encodings (:class:`transformers.BatchEncoding`):
The output from encode_plus() and batch_encode() methods (tokens, attention_masks, etc) of
a transformers.PreTrainedTokenizer
categorical_feats (:class:`numpy.ndarray`, of shape :obj:`(n_examples, categorical feat dim)`, `optional`, defaults to :obj:`None`):
An array containing the preprocessed categorical features
numerical_feats (:class:`numpy.ndarray`, of shape :obj:`(n_examples, numerical feat dim)`, `optional`, defaults to :obj:`None`):
An array containing the preprocessed numerical features
labels (:class: list` or `numpy.ndarray`, `optional`, defaults to :obj:`None`):
The labels of the training examples
class_weights (:class:`numpy.ndarray`, of shape (n_classes), `optional`, defaults to :obj:`None`):
Class weights used for cross entropy loss for classification
df (:class:`pandas.DataFrame`, `optional`, defaults to :obj:`None`):
Model configuration class with all the parameters of the model.
This object must also have a tabular_config member variable that is a
TabularConfig instance specifying the configs for TabularFeatCombiner
"""
def __init__(self,
encodings,
categorical_feats,
numerical_feats,
labels=None,
df=None,
label_list=None,
class_weights=None
):
self.df = df
self.encodings = encodings
self.cat_feats = categorical_feats
self.numerical_feats = numerical_feats
self.labels = labels
self.class_weights = class_weights
self.label_list = label_list if label_list is not None else [i for i in range(len(np.unique(labels)))]
def __getitem__(self, idx):
item = {key: torch.tensor(val[idx])
for key, val in self.encodings.items()}
item['labels'] = torch.tensor(self.labels[idx]) if self.labels is not None else None
item['cat_feats'] = torch.tensor(self.cat_feats[idx]).float() \
if self.cat_feats is not None else torch.zeros(0)
item['numerical_feats'] = torch.tensor(self.numerical_feats[idx]).float()\
if self.numerical_feats is not None else torch.zeros(0)
return item
def __len__(self):
return len(self.labels)
[docs] def get_labels(self):
"""returns the label names for classification"""
return self.label_list