Source code for mambular.base_models.tabtransformer

import torch
import torch.nn as nn
import numpy as np
from ..arch_utils.get_norm_fn import get_normalization_layer
from ..arch_utils.layer_utils.embedding_layer import EmbeddingLayer
from ..arch_utils.mlp_utils import MLPhead
from ..arch_utils.transformer_utils import CustomTransformerEncoderLayer
from ..configs.tabtransformer_config import DefaultTabTransformerConfig
from .utils.basemodel import BaseModel


[docs]class TabTransformer(BaseModel): """A PyTorch model for tasks utilizing the Transformer architecture and various normalization techniques. Parameters ---------- cat_feature_info : dict Dictionary containing information about categorical features. num_feature_info : dict Dictionary containing information about numerical features. num_classes : int, optional Number of output classes (default is 1). config : DefaultFTTransformerConfig, optional Configuration object containing default hyperparameters for the model (default is DefaultMambularConfig()). **kwargs : dict Additional keyword arguments. Attributes ---------- lr : float Learning rate. lr_patience : int Patience for learning rate scheduler. weight_decay : float Weight decay for optimizer. lr_factor : float Factor by which the learning rate will be reduced. pooling_method : str Method to pool the features. cat_feature_info : dict Dictionary containing information about categorical features. num_feature_info : dict Dictionary containing information about numerical features. embedding_activation : callable Activation function for embeddings. encoder: callable stack of N encoder layers norm_f : nn.Module Normalization layer. num_embeddings : nn.ModuleList Module list for numerical feature embeddings. cat_embeddings : nn.ModuleList Module list for categorical feature embeddings. tabular_head : MLPhead Multi-layer perceptron head for tabular data. cls_token : nn.Parameter Class token parameter. embedding_norm : nn.Module, optional Layer normalization applied after embedding if specified. """ def __init__( self, feature_information: tuple, # Expecting (num_feature_info, cat_feature_info, embedding_feature_info) num_classes=1, config: DefaultTabTransformerConfig = DefaultTabTransformerConfig(), # noqa: B008 **kwargs, ): super().__init__(config=config, **kwargs) self.save_hyperparameters(ignore=["feature_information"]) num_feature_info, cat_feature_info, emb_feature_info = feature_information if cat_feature_info == {}: raise ValueError( "You are trying to fit a TabTransformer with no categorical features. \ Try using a different model that is better suited for tasks without categorical features." ) self.returns_ensemble = False # embedding layer self.embedding_layer = EmbeddingLayer( *({}, cat_feature_info, emb_feature_info), config=config, ) # transformer encoder self.norm_f = get_normalization_layer(config) encoder_layer = CustomTransformerEncoderLayer(config=config) self.encoder = nn.TransformerEncoder( encoder_layer, num_layers=self.hparams.n_layers, norm=self.norm_f, ) mlp_input_dim = 0 for feature_name, info in num_feature_info.items(): mlp_input_dim += info["dimension"] mlp_input_dim += self.hparams.d_model self.tabular_head = MLPhead( input_dim=mlp_input_dim, config=config, output_dim=num_classes, ) # pooling n_inputs = n_inputs = [len(info) for info in feature_information] self.initialize_pooling_layers(config=config, n_inputs=n_inputs)
[docs] def forward(self, *data): """Defines the forward pass of the model. Parameters ---------- ata : tuple Input tuple of tensors of num_features, cat_features, embeddings. Returns ------- Tensor The output predictions of the model. """ num_features, cat_features, emb_features = data cat_embeddings = self.embedding_layer(*(None, cat_features, emb_features)) num_features = torch.cat(num_features, dim=1) num_embeddings = self.norm_f(num_features) # type: ignore x = self.encoder(cat_embeddings) x = self.pool_sequence(x) x = torch.cat((x, num_embeddings), axis=1) # type: ignore preds = self.tabular_head(x) return preds