Source code for mambular.data_utils.dataset

import numpy as np
import torch
from torch.utils.data import Dataset


[docs]class MambularDataset(Dataset): """Custom dataset for handling structured data with separate categorical and numerical features, tailored for both regression and classification tasks. Parameters ---------- cat_features_list (list of Tensors): A list of tensors representing the categorical features. num_features_list (list of Tensors): A list of tensors representing the numerical features. embeddings_list (list of Tensors, optional): A list of tensors representing the embeddings. labels (Tensor, optional): A tensor of labels. If None, the dataset is used for prediction. regression (bool, optional): A flag indicating if the dataset is for a regression task. Defaults to True. """ def __init__( self, cat_features_list, num_features_list, embeddings_list=None, labels=None, regression=True, ): self.cat_features_list = cat_features_list # Categorical features tensors self.num_features_list = num_features_list # Numerical features tensors self.embeddings_list = embeddings_list # Embeddings tensors (optional) self.regression = regression if labels is not None: if not self.regression: self.num_classes = len(np.unique(labels)) if self.num_classes > 2: self.labels = labels.view(-1) else: self.num_classes = 1 self.labels = labels else: self.labels = labels self.num_classes = 1 else: self.labels = None # No labels in prediction mode def __len__(self): return len(self.num_features_list[0]) # Use numerical features length def __getitem__(self, idx): """Retrieves the features and label for a given index. Parameters ---------- idx (int): The index of the data point. Returns ------- tuple: A tuple containing lists of tensors for numerical features, categorical features, embeddings (if available), and a label (if available). """ cat_features = [ feature_tensor[idx] for feature_tensor in self.cat_features_list ] num_features = [ torch.as_tensor(feature_tensor[idx]).clone().detach().to(torch.float32) for feature_tensor in self.num_features_list ] if self.embeddings_list is not None: embeddings = [ torch.as_tensor(embed_tensor[idx]).clone().detach().to(torch.float32) for embed_tensor in self.embeddings_list ] else: embeddings = None if self.labels is not None: label = self.labels[idx] if self.regression: label = label.clone().detach().to(torch.float32) elif self.num_classes == 1: label = label.clone().detach().to(torch.float32) else: label = label.clone().detach().to(torch.long) return (num_features, cat_features, embeddings), label else: return (num_features, cat_features, embeddings)