Source code for mambular.base_models.mambattn

import torch
import numpy as np
from ..arch_utils.get_norm_fn import get_normalization_layer
from ..arch_utils.layer_utils.embedding_layer import EmbeddingLayer
from ..arch_utils.mamba_utils.mambattn_arch import MambAttn
from ..arch_utils.mlp_utils import MLPhead
from ..configs.mambattention_config import DefaultMambAttentionConfig
from .utils.basemodel import BaseModel


[docs]class MambAttention(BaseModel): """A MambAttention model for tabular data, integrating feature embeddings, attention-based Mamba transformations, and a customizable architecture for handling categorical and numerical features. Parameters ---------- cat_feature_info : dict Dictionary containing information about categorical features, including their names and dimensions. num_feature_info : dict Dictionary containing information about numerical features, including their names and dimensions. num_classes : int, optional The number of output classes or target dimensions for regression, by default 1. config : DefaultMambAttentionConfig, optional Configuration object with model hyperparameters such as dropout rates, head layer sizes, attention settings, and other architectural configurations, by default DefaultMambAttentionConfig(). **kwargs : dict Additional keyword arguments for the BaseModel class. Attributes ---------- pooling_method : str Pooling method to aggregate features after the Mamba attention layer. shuffle_embeddings : bool Flag indicating if embeddings should be shuffled, as specified in the configuration. mamba : MambAttn Mamba attention layer to process embedded features. norm_f : nn.Module Normalization layer for the processed features. embedding_layer : EmbeddingLayer Layer for embedding categorical and numerical features. tabular_head : MLPhead MLPhead layer to produce the final prediction based on the output of the Mamba attention layer. perm : torch.Tensor, optional Permutation tensor used for shuffling embeddings, if enabled. Methods ------- forward(num_features, cat_features) Perform a forward pass through the model, including embedding, Mamba attention transformation, pooling, and prediction steps. """ def __init__( self, feature_information: tuple, # Expecting (num_feature_info, cat_feature_info, embedding_feature_info) num_classes=1, config: DefaultMambAttentionConfig = DefaultMambAttentionConfig(), # noqa: B008 **kwargs, ): super().__init__(config=config, **kwargs) self.save_hyperparameters(ignore=["feature_information"]) self.returns_ensemble = False try: self.pooling_method = self.hparams.pooling_method except AttributeError: self.pooling_method = config.pooling_method try: self.shuffle_embeddings = self.hparams.shuffle_embeddings except AttributeError: self.shuffle_embeddings = config.shuffle_embeddings self.mamba = MambAttn(config) self.norm_f = get_normalization_layer(config) # embedding layer self.embedding_layer = EmbeddingLayer( *feature_information, config=config, ) try: head_activation = self.hparams.head_activation except AttributeError: head_activation = config.head_activation try: input_dim = self.hparams.d_model except AttributeError: input_dim = config.d_model self.tabular_head = MLPhead( input_dim=input_dim, config=config, output_dim=num_classes, ) if self.shuffle_embeddings: self.perm = torch.randperm(self.embedding_layer.seq_len) # pooling n_inputs = np.sum([len(info) for info in feature_information]) self.initialize_pooling_layers(config=config, n_inputs=n_inputs)
[docs] def forward(self, *data): """Defines the forward pass of the model. Parameters ---------- data : tuple Input tuple of tensors of num_features, cat_features, embeddings. Returns ------- torch.Tensor Output tensor. """ x = self.embedding_layer(*data) if self.shuffle_embeddings: x = x[:, self.perm, :] x = self.mamba(x) x = self.pool_sequence(x) x = self.norm_f(x) # type: ignore preds = self.tabular_head(x) return preds