Source code for mambular.base_models.mambular
import torch
from ..arch_utils.layer_utils.embedding_layer import EmbeddingLayer
from ..arch_utils.mamba_utils.mamba_arch import Mamba
from ..arch_utils.mamba_utils.mamba_original import MambaOriginal
from ..arch_utils.mlp_utils import MLPhead
from ..configs.mambular_config import DefaultMambularConfig
from .utils.basemodel import BaseModel
import numpy as np
[docs]class Mambular(BaseModel):
"""A Mambular model for tabular data, integrating feature embeddings, Mamba transformations, and a configurable
architecture for processing categorical and numerical features with pooling and normalization.
Parameters
----------
cat_feature_info : dict
Dictionary containing information about categorical features, including their names and dimensions.
num_feature_info : dict
Dictionary containing information about numerical features, including their names and dimensions.
num_classes : int, optional
The number of output classes or target dimensions for regression, by default 1.
config : DefaultMambularConfig, optional
Configuration object with model hyperparameters such as dropout rates, head layer sizes, Mamba version, and
other architectural configurations, by default DefaultMambularConfig().
**kwargs : dict
Additional keyword arguments for the BaseModel class.
Attributes
----------
pooling_method : str
Pooling method to aggregate features after the Mamba layer.
shuffle_embeddings : bool
Flag indicating if embeddings should be shuffled, as specified in the configuration.
embedding_layer : EmbeddingLayer
Layer for embedding categorical and numerical features.
mamba : Mamba or MambaOriginal
Mamba-based transformation layer based on the version specified in config.
norm_f : nn.Module
Normalization layer for the processed features.
tabular_head : MLP
MLP layer to produce the final prediction based on the output of the Mamba layer.
perm : torch.Tensor, optional
Permutation tensor used for shuffling embeddings, if enabled.
Methods
-------
forward(num_features, cat_features)
Perform a forward pass through the model, including embedding, Mamba transformation, pooling,
and prediction steps.
"""
def __init__(
self,
feature_information: tuple, # Expecting (cat_feature_info, num_feature_info, embedding_feature_info)
num_classes=1,
config: DefaultMambularConfig = DefaultMambularConfig(), # noqa: B008
**kwargs,
):
super().__init__(config=config, **kwargs)
self.save_hyperparameters(ignore=["feature_information"])
self.returns_ensemble = False
# embedding layer
self.embedding_layer = EmbeddingLayer(
*feature_information,
config=config,
)
if config.mamba_version == "mamba-torch":
self.mamba = Mamba(config)
else:
self.mamba = MambaOriginal(config)
self.tabular_head = MLPhead(
input_dim=self.hparams.d_model,
config=config,
output_dim=num_classes,
)
if self.hparams.shuffle_embeddings:
self.perm = torch.randperm(self.embedding_layer.seq_len)
# pooling
n_inputs = np.sum([len(info) for info in feature_information])
self.initialize_pooling_layers(config=config, n_inputs=n_inputs)
[docs] def forward(self, *data):
"""Defines the forward pass of the model.
Parameters
----------
data : tuple
Input tuple of tensors of num_features, cat_features, embeddings.
Returns
-------
Tensor
The output predictions of the model.
"""
x = self.embedding_layer(*data)
if self.hparams.shuffle_embeddings:
x = x[:, self.perm, :]
x = self.mamba(x)
x = self.pool_sequence(x)
preds = self.tabular_head(x)
return preds