Source code for mambular.base_models.saint

from ..arch_utils.get_norm_fn import get_normalization_layer
from ..arch_utils.layer_utils.embedding_layer import EmbeddingLayer
from ..arch_utils.mlp_utils import MLPhead
from ..arch_utils.transformer_utils import RowColTransformer
from ..configs.saint_config import DefaultSAINTConfig
from .utils.basemodel import BaseModel
import numpy as np


[docs]class SAINT(BaseModel): """A Feature Transformer model for tabular data with categorical and numerical features, using embedding, transformer encoding, and pooling to produce final predictions. Parameters ---------- cat_feature_info : dict Dictionary containing information about categorical features, including their names and dimensions. num_feature_info : dict Dictionary containing information about numerical features, including their names and dimensions. num_classes : int, optional The number of output classes or target dimensions for regression, by default 1. config : DefaultSAINTConfig, optional Configuration object containing model hyperparameters such as dropout rates, hidden layer sizes, transformer settings, and other architectural configurations, by default DefaultSAINTConfig(). **kwargs : dict Additional keyword arguments for the BaseModel class. Attributes ---------- pooling_method : str The pooling method to aggregate features after transformer encoding. cat_feature_info : dict Stores categorical feature information. num_feature_info : dict Stores numerical feature information. embedding_layer : EmbeddingLayer Layer for embedding categorical and numerical features. norm_f : nn.Module Normalization layer for the transformer output. encoder : nn.TransformerEncoder Transformer encoder for sequential processing of embedded features. tabular_head : MLPhead MLPhead layer to produce the final prediction based on the output of the transformer encoder. Methods ------- forward(num_features, cat_features) Perform a forward pass through the model, including embedding, transformer encoding, pooling, and prediction steps. """ def __init__( self, feature_information: tuple, # Expecting (num_feature_info, cat_feature_info, embedding_feature_info) num_classes=1, config: DefaultSAINTConfig = DefaultSAINTConfig(), # noqa: B008 **kwargs, ): super().__init__(config=config, **kwargs) self.save_hyperparameters(ignore=["feature_information"]) self.returns_ensemble = False n_inputs = np.sum([len(info) for info in feature_information]) if getattr(config, "use_cls", True): n_inputs += 1 # embedding layer self.embedding_layer = EmbeddingLayer( *feature_information, config=config, ) # transformer encoder self.norm_f = get_normalization_layer(config) self.encoder = RowColTransformer( config=config, n_features=n_inputs, ) self.tabular_head = MLPhead( input_dim=self.hparams.d_model, config=config, output_dim=num_classes, ) # pooling self.initialize_pooling_layers(config=config, n_inputs=n_inputs)
[docs] def forward(self, *data): """Defines the forward pass of the model. Parameters ---------- data : tuple Input tuple of tensors of num_features, cat_features, embeddings. Returns ------- torch.Tensor Output tensor. """ x = self.embedding_layer(*data) x = self.encoder(x) x = self.pool_sequence(x) if self.norm_f is not None: x = self.norm_f(x) preds = self.tabular_head(x) return preds