Source code for deeptab.base_models.mambatab

import torch
import torch.nn as nn

from ..arch_utils.layer_utils.normalization_layers import LayerNorm
from ..arch_utils.mamba_utils.mamba_arch import Mamba
from ..arch_utils.mamba_utils.mamba_original import MambaOriginal
from ..arch_utils.mlp_utils import MLPhead
from ..configs.mambatab_config import DefaultMambaTabConfig
from ..utils.get_feature_dimensions import get_feature_dimensions
from .utils.basemodel import BaseModel


[docs]class MambaTab(BaseModel): """A MambaTab model for tabular data processing, integrating feature embeddings, normalization, and a configurable architecture for flexible deployment of Mamba-based feature transformation layers. Parameters ---------- cat_feature_info : dict Dictionary containing information about categorical features, including their names and dimensions. num_feature_info : dict Dictionary containing information about numerical features, including their names and dimensions. num_classes : int, optional The number of output classes or target dimensions for regression, by default 1. config : DefaultMambaTabConfig, optional Configuration object with model hyperparameters such as dropout rates, hidden layer sizes, Mamba version, and other architectural configurations, by default DefaultMambaTabConfig(). **kwargs : dict Additional keyword arguments for the BaseModel class. Attributes ---------- cat_feature_info : dict Stores categorical feature information. num_feature_info : dict Stores numerical feature information. initial_layer : nn.Linear Linear layer for the initial transformation of concatenated feature embeddings. norm_f : LayerNorm Layer normalization applied after the initial transformation. embedding_activation : callable Activation function applied to the embedded features. axis : int Axis used to adjust the shape of features during transformation. tabular_head : MLPhead MLPhead layer to produce the final prediction based on transformed features. mamba : Mamba or MambaOriginal Mamba-based feature transformation layer based on the version specified in config. Methods ------- forward(num_features, cat_features) Perform a forward pass through the model, including feature concatenation, initial transformation, Mamba processing, and prediction steps. """ def __init__( self, feature_information: tuple, # Expecting (num_feature_info, cat_feature_info, embedding_feature_info) num_classes=1, config: DefaultMambaTabConfig = DefaultMambaTabConfig(), # noqa: B008 **kwargs, ): super().__init__(config=config, **kwargs) self.save_hyperparameters(ignore=["feature_information"]) input_dim = get_feature_dimensions(*feature_information) self.returns_ensemble = False self.initial_layer = nn.Linear(input_dim, config.d_model) self.norm_f = LayerNorm(config.d_model) self.embedding_activation = self.hparams.embedding_activation self.axis = config.axis self.tabular_head = MLPhead( input_dim=self.hparams.d_model, config=config, output_dim=num_classes, ) if config.mamba_version == "mamba-torch": self.mamba = Mamba(config) else: self.mamba = MambaOriginal(config)
[docs] def forward(self, *data): """Forward pass of the Mambatab model Parameters ---------- data : tuple Input tuple of tensors of num_features, cat_features, embeddings. Returns ------- torch.Tensor Output tensor. """ x = torch.cat([t for tensors in data for t in tensors], dim=1) x = self.initial_layer(x) if self.axis == 1: x = x.unsqueeze(1) else: x = x.unsqueeze(0) x = self.norm_f(x) x = self.embedding_activation(x) if self.axis == 1: x = x.squeeze(1) else: x = x.squeeze(0) preds = self.tabular_head(x) return preds