mambular.base_models

Contents

mambular.base_models#

class mambular.base_models.Mambular(feature_information, num_classes=1, config=DefaultMambularConfig(lr=0.0001, lr_patience=10, weight_decay=1e-06, lr_factor=0.1, use_embeddings=False, embedding_activation=torch.nn.Identity, embedding_type='linear', embedding_bias=False, layer_norm_after_embedding=False, d_model=64, plr_lite=False, n_frequencies=48, frequencies_init_scale=0.01, embedding_projection=True, batch_norm=False, layer_norm=False, layer_norm_eps=1e-05, activation=torch.nn.SiLU, cat_encoding='int', n_layers=4, d_conv=4, dilation=1, expand_factor=2, bias=False, dropout=0.0, dt_rank='auto', d_state=128, dt_scale=1.0, dt_init='random', dt_max=0.1, dt_min=0.0001, dt_init_floor=0.0001, norm='RMSNorm', conv_bias=False, AD_weight_decay=True, BC_layer_norm=False, shuffle_embeddings=False, head_layer_sizes=[], head_dropout=0.5, head_skip_layers=False, head_activation=torch.nn.SELU, head_use_batch_norm=False, pooling_method='avg', bidirectional=False, use_learnable_interaction=False, use_cls=False, use_pscan=False, mamba_version='mamba-torch'), **kwargs)[source]#

A Mambular model for tabular data, integrating feature embeddings, Mamba transformations, and a configurable architecture for processing categorical and numerical features with pooling and normalization.

Parameters:
  • cat_feature_info (dict) – Dictionary containing information about categorical features, including their names and dimensions.

  • num_feature_info (dict) – Dictionary containing information about numerical features, including their names and dimensions.

  • num_classes (int, optional) – The number of output classes or target dimensions for regression, by default 1.

  • config (DefaultMambularConfig, optional) – Configuration object with model hyperparameters such as dropout rates, head layer sizes, Mamba version, and other architectural configurations, by default DefaultMambularConfig().

  • **kwargs (dict) – Additional keyword arguments for the BaseModel class.

pooling_method#

Pooling method to aggregate features after the Mamba layer.

Type:

str

shuffle_embeddings#

Flag indicating if embeddings should be shuffled, as specified in the configuration.

Type:

bool

embedding_layer#

Layer for embedding categorical and numerical features.

Type:

EmbeddingLayer

mamba#

Mamba-based transformation layer based on the version specified in config.

Type:

Mamba or MambaOriginal

norm_f#

Normalization layer for the processed features.

Type:

nn.Module

tabular_head#

MLP layer to produce the final prediction based on the output of the Mamba layer.

Type:

MLP

perm#

Permutation tensor used for shuffling embeddings, if enabled.

Type:

torch.Tensor, optional

forward(num_features, cat_features)[source]#

Perform a forward pass through the model, including embedding, Mamba transformation, pooling, and prediction steps.

forward(*data)[source]#

Defines the forward pass of the model.

Parameters:

data (tuple) – Input tuple of tensors of num_features, cat_features, embeddings.

Returns:

The output predictions of the model.

Return type:

Tensor

class mambular.base_models.MLP(feature_information, num_classes=1, config=DefaultMLPConfig(lr=0.0001, lr_patience=10, weight_decay=1e-06, lr_factor=0.1, use_embeddings=False, embedding_activation=torch.nn.Identity, embedding_type='linear', embedding_bias=False, layer_norm_after_embedding=False, d_model=32, plr_lite=False, n_frequencies=48, frequencies_init_scale=0.01, embedding_projection=True, batch_norm=False, layer_norm=False, layer_norm_eps=1e-05, activation=torch.nn.ReLU, cat_encoding='int', layer_sizes=[256, 128, 32], skip_layers=False, dropout=0.2, use_glu=False, skip_connections=False), **kwargs)[source]#

A multi-layer perceptron (MLP) model for tabular data processing, with options for embedding, normalization, skip connections, and customizable activation functions.

Parameters:
  • cat_feature_info (dict) – Dictionary containing information about categorical features, including their names and dimensions.

  • num_feature_info (dict) – Dictionary containing information about numerical features, including their names and dimensions.

  • num_classes (int, optional) – The number of output classes or target dimensions for regression, by default 1.

  • config (DefaultMLPConfig, optional) – Configuration object with model hyperparameters such as layer sizes, dropout rates, activation functions, embedding settings, and normalization options, by default DefaultMLPConfig().

  • **kwargs (dict) – Additional keyword arguments for the BaseModel class.

layer_sizes#

List specifying the number of units in each layer of the MLP.

Type:

list of int

cat_feature_info#

Stores categorical feature information.

Type:

dict

num_feature_info#

Stores numerical feature information.

Type:

dict

layers#

List containing the layers of the MLP, including linear layers, normalization layers, and activations.

Type:

nn.ModuleList

skip_connections#

Flag indicating whether skip connections are enabled between layers.

Type:

bool

use_glu#

Flag indicating if gated linear units (GLU) should be used as the activation function.

Type:

bool

activation#

Activation function applied between layers.

Type:

callable

use_embeddings#

Flag indicating if embeddings should be used for categorical and numerical features.

Type:

bool

embedding_layer#

Embedding layer for features, used if use_embeddings is enabled.

Type:

EmbeddingLayer, optional

norm_f#

Normalization layer applied to the output of the first layer, if specified in the configuration.

Type:

nn.Module, optional

forward(num_features, cat_features)[source]#

Perform a forward pass through the model, including embedding (if enabled), linear transformations, activation, normalization, and prediction steps.

forward(*data)[source]#

Forward pass of the MLP model.

Parameters:

data (tuple) – Input tuple of tensors of num_features, cat_features, embeddings.

Returns:

Output tensor.

Return type:

torch.Tensor

class mambular.base_models.ResNet(feature_information, num_classes=1, config=DefaultResNetConfig(lr=0.0001, lr_patience=10, weight_decay=1e-06, lr_factor=0.1, use_embeddings=False, embedding_activation=torch.nn.Identity, embedding_type='linear', embedding_bias=False, layer_norm_after_embedding=False, d_model=32, plr_lite=False, n_frequencies=48, frequencies_init_scale=0.01, embedding_projection=True, batch_norm=False, layer_norm=False, layer_norm_eps=1e-05, activation=torch.nn.SELU, cat_encoding='int', layer_sizes=[256, 128, 32], skip_layers=False, dropout=0.5, norm=False, use_glu=False, skip_connections=True, num_blocks=3, average_embeddings=True), **kwargs)[source]#

A ResNet model for tabular data, combining feature embeddings, residual blocks, and customizable architecture for processing categorical and numerical features.

Parameters:
  • cat_feature_info (dict) – Dictionary containing information about categorical features, including their names and dimensions.

  • num_feature_info (dict) – Dictionary containing information about numerical features, including their names and dimensions.

  • num_classes (int, optional) – The number of output classes or target dimensions for regression, by default 1.

  • config (DefaultResNetConfig, optional) – Configuration object containing model hyperparameters such as layer sizes, number of residual blocks, dropout rates, activation functions, and normalization settings, by default DefaultResNetConfig().

  • **kwargs (dict) – Additional keyword arguments for the BaseModel class.

layer_sizes#

List specifying the number of units in each layer of the ResNet.

Type:

list of int

cat_feature_info#

Stores categorical feature information.

Type:

dict

num_feature_info#

Stores numerical feature information.

Type:

dict

activation#

Activation function used in the residual blocks.

Type:

callable

use_embeddings#

Flag indicating if embeddings should be used for categorical and numerical features.

Type:

bool

embedding_layer#

Embedding layer for features, used if use_embeddings is enabled.

Type:

EmbeddingLayer, optional

initial_layer#

Initial linear layer to project input features into the model’s hidden dimension.

Type:

nn.Linear

blocks#

List of residual blocks to process the hidden representations.

Type:

nn.ModuleList

output_layer#

Output layer that produces the final prediction.

Type:

nn.Linear

forward(num_features, cat_features)[source]#

Perform a forward pass through the model, including embedding (if enabled), residual blocks, and prediction steps.

forward(*data)[source]#

Forward pass of the ResNet model.

Parameters:

data (tuple) – Input tuple of tensors of num_features, cat_features, embeddings.

Returns:

Output tensor.

Return type:

torch.Tensor

class mambular.base_models.FTTransformer(feature_information, num_classes=1, config=DefaultFTTransformerConfig(lr=0.0001, lr_patience=10, weight_decay=1e-06, lr_factor=0.1, use_embeddings=False, embedding_activation=torch.nn.Identity, embedding_type='linear', embedding_bias=False, layer_norm_after_embedding=False, d_model=128, plr_lite=False, n_frequencies=48, frequencies_init_scale=0.01, embedding_projection=True, batch_norm=False, layer_norm=False, layer_norm_eps=1e-05, activation=torch.nn.SELU, cat_encoding='int', n_layers=4, n_heads=8, attn_dropout=0.2, ff_dropout=0.1, norm='LayerNorm', transformer_activation=torch.nn.Module, transformer_dim_feedforward=256, norm_first=False, bias=True, head_layer_sizes=[], head_dropout=0.5, head_skip_layers=False, head_activation=torch.nn.SELU, head_use_batch_norm=False, pooling_method='avg', use_cls=False), **kwargs)[source]#

A Feature Transformer model for tabular data with categorical and numerical features, using embedding, transformer encoding, and pooling to produce final predictions.

Parameters:
  • cat_feature_info (dict) – Dictionary containing information about categorical features, including their names and dimensions.

  • num_feature_info (dict) – Dictionary containing information about numerical features, including their names and dimensions.

  • num_classes (int, optional) – The number of output classes or target dimensions for regression, by default 1.

  • config (DefaultFTTransformerConfig, optional) – Configuration object containing model hyperparameters such as dropout rates, hidden layer sizes, transformer settings, and other architectural configurations, by default DefaultFTTransformerConfig().

  • **kwargs (dict) – Additional keyword arguments for the BaseModel class.

pooling_method#

The pooling method to aggregate features after transformer encoding.

Type:

str

cat_feature_info#

Stores categorical feature information.

Type:

dict

num_feature_info#

Stores numerical feature information.

Type:

dict

embedding_layer#

Layer for embedding categorical and numerical features.

Type:

EmbeddingLayer

norm_f#

Normalization layer for the transformer output.

Type:

nn.Module

encoder#

Transformer encoder for sequential processing of embedded features.

Type:

nn.TransformerEncoder

tabular_head#

MLPhead layer to produce the final prediction based on the output of the transformer encoder.

Type:

MLPhead

forward(num_features, cat_features)[source]#

Perform a forward pass through the model, including embedding, transformer encoding, pooling, and prediction steps.

forward(*data)[source]#

Defines the forward pass of the model.

Parameters:

data (tuple) – Input tuple of tensors of num_features, cat_features, embeddings.

Returns:

The output predictions of the model.

Return type:

Tensor

class mambular.base_models.TabTransformer(feature_information, num_classes=1, config=DefaultTabTransformerConfig(lr=0.0001, lr_patience=10, weight_decay=1e-06, lr_factor=0.1, use_embeddings=False, embedding_activation=torch.nn.Identity, embedding_type='linear', embedding_bias=False, layer_norm_after_embedding=False, d_model=128, plr_lite=False, n_frequencies=48, frequencies_init_scale=0.01, embedding_projection=True, batch_norm=False, layer_norm=False, layer_norm_eps=1e-05, activation=torch.nn.SELU, cat_encoding='int', n_layers=4, n_heads=8, attn_dropout=0.2, ff_dropout=0.1, norm='LayerNorm', transformer_activation=torch.nn.Module, transformer_dim_feedforward=512, norm_first=True, bias=True, head_layer_sizes=[], head_dropout=0.5, head_skip_layers=False, head_activation=torch.nn.SELU, head_use_batch_norm=False, pooling_method='avg'), **kwargs)[source]#

A PyTorch model for tasks utilizing the Transformer architecture and various normalization techniques.

Parameters:
  • cat_feature_info (dict) – Dictionary containing information about categorical features.

  • num_feature_info (dict) – Dictionary containing information about numerical features.

  • num_classes (int, optional) – Number of output classes (default is 1).

  • config (DefaultFTTransformerConfig, optional) – Configuration object containing default hyperparameters for the model (default is DefaultMambularConfig()).

  • **kwargs (dict) – Additional keyword arguments.

lr#

Learning rate.

Type:

float

lr_patience#

Patience for learning rate scheduler.

Type:

int

weight_decay#

Weight decay for optimizer.

Type:

float

lr_factor#

Factor by which the learning rate will be reduced.

Type:

float

pooling_method#

Method to pool the features.

Type:

str

cat_feature_info#

Dictionary containing information about categorical features.

Type:

dict

num_feature_info#

Dictionary containing information about numerical features.

Type:

dict

embedding_activation#

Activation function for embeddings.

Type:

callable

encoder#

stack of N encoder layers

Type:

callable

norm_f#

Normalization layer.

Type:

nn.Module

num_embeddings#

Module list for numerical feature embeddings.

Type:

nn.ModuleList

cat_embeddings#

Module list for categorical feature embeddings.

Type:

nn.ModuleList

tabular_head#

Multi-layer perceptron head for tabular data.

Type:

MLPhead

cls_token#

Class token parameter.

Type:

nn.Parameter

embedding_norm#

Layer normalization applied after embedding if specified.

Type:

nn.Module, optional

forward(*data)[source]#

Defines the forward pass of the model.

Parameters:

ata (tuple) – Input tuple of tensors of num_features, cat_features, embeddings.

Returns:

The output predictions of the model.

Return type:

Tensor

class mambular.base_models.TabulaRNN(feature_information, num_classes=1, config=DefaultTabulaRNNConfig(lr=0.0001, lr_patience=10, weight_decay=1e-06, lr_factor=0.1, use_embeddings=False, embedding_activation=torch.nn.Identity, embedding_type='linear', embedding_bias=False, layer_norm_after_embedding=False, d_model=128, plr_lite=False, n_frequencies=48, frequencies_init_scale=0.01, embedding_projection=True, batch_norm=False, layer_norm=False, layer_norm_eps=1e-05, activation=torch.nn.SELU, cat_encoding='int', model_type='RNN', n_layers=4, rnn_dropout=0.2, norm='RMSNorm', residuals=False, head_layer_sizes=[], head_dropout=0.5, head_skip_layers=False, head_activation=torch.nn.SELU, head_use_batch_norm=False, pooling_method='avg', norm_first=False, bias=True, rnn_activation='relu', dim_feedforward=256, d_conv=4, dilation=1, conv_bias=True), **kwargs)[source]#
forward(*data)[source]#

Defines the forward pass of the model.

Parameters:
  • num_features (Tensor) – Tensor containing the numerical features.

  • cat_features (Tensor) – Tensor containing the categorical features.

Returns:

The output predictions of the model.

Return type:

Tensor

class mambular.base_models.MambAttention(feature_information, num_classes=1, config=DefaultMambAttentionConfig(lr=0.0001, lr_patience=10, weight_decay=1e-06, lr_factor=0.1, use_embeddings=False, embedding_activation=torch.nn.Identity, embedding_type='linear', embedding_bias=False, layer_norm_after_embedding=False, d_model=64, plr_lite=False, n_frequencies=48, frequencies_init_scale=0.01, embedding_projection=True, batch_norm=False, layer_norm=False, layer_norm_eps=1e-05, activation=torch.nn.SiLU, cat_encoding='int', n_layers=4, expand_factor=2, n_heads=8, last_layer='attn', n_mamba_per_attention=1, bias=False, d_conv=4, conv_bias=True, dropout=0.0, attn_dropout=0.2, dt_rank='auto', d_state=128, dt_scale=1.0, dt_init='random', dt_max=0.1, dt_min=0.0001, dt_init_floor=0.0001, norm='LayerNorm', head_layer_sizes=[], head_dropout=0.5, head_skip_layers=False, head_activation=torch.nn.SELU, head_use_batch_norm=False, pooling_method='avg', bidirectional=False, use_learnable_interaction=False, use_cls=False, shuffle_embeddings=False, AD_weight_decay=True, BC_layer_norm=False, use_pscan=False, n_attention_layers=1), **kwargs)[source]#

A MambAttention model for tabular data, integrating feature embeddings, attention-based Mamba transformations, and a customizable architecture for handling categorical and numerical features.

Parameters:
  • cat_feature_info (dict) – Dictionary containing information about categorical features, including their names and dimensions.

  • num_feature_info (dict) – Dictionary containing information about numerical features, including their names and dimensions.

  • num_classes (int, optional) – The number of output classes or target dimensions for regression, by default 1.

  • config (DefaultMambAttentionConfig, optional) – Configuration object with model hyperparameters such as dropout rates, head layer sizes, attention settings, and other architectural configurations, by default DefaultMambAttentionConfig().

  • **kwargs (dict) – Additional keyword arguments for the BaseModel class.

pooling_method#

Pooling method to aggregate features after the Mamba attention layer.

Type:

str

shuffle_embeddings#

Flag indicating if embeddings should be shuffled, as specified in the configuration.

Type:

bool

mamba#

Mamba attention layer to process embedded features.

Type:

MambAttn

norm_f#

Normalization layer for the processed features.

Type:

nn.Module

embedding_layer#

Layer for embedding categorical and numerical features.

Type:

EmbeddingLayer

tabular_head#

MLPhead layer to produce the final prediction based on the output of the Mamba attention layer.

Type:

MLPhead

perm#

Permutation tensor used for shuffling embeddings, if enabled.

Type:

torch.Tensor, optional

forward(num_features, cat_features)[source]#

Perform a forward pass through the model, including embedding, Mamba attention transformation, pooling, and prediction steps.

forward(*data)[source]#

Defines the forward pass of the model.

Parameters:

data (tuple) – Input tuple of tensors of num_features, cat_features, embeddings.

Returns:

Output tensor.

Return type:

torch.Tensor

class mambular.base_models.TabM(feature_information, num_classes=1, config=DefaultTabMConfig(lr=0.0001, lr_patience=10, weight_decay=1e-06, lr_factor=0.1, use_embeddings=False, embedding_activation=torch.nn.Identity, embedding_type='linear', embedding_bias=False, layer_norm_after_embedding=False, d_model=32, plr_lite=False, n_frequencies=48, frequencies_init_scale=0.01, embedding_projection=True, batch_norm=False, layer_norm=False, layer_norm_eps=1e-05, activation=torch.nn.ReLU, cat_encoding='int', layer_sizes=[256, 256, 128], dropout=0.5, norm=None, use_glu=False, ensemble_size=32, ensemble_scaling_in=True, ensemble_scaling_out=True, ensemble_bias=True, scaling_init='ones', average_ensembles=False, model_type='mini', average_embeddings=True), **kwargs)[source]#
forward(*data)[source]#

Forward pass of the TabM model with batch ensembling.

Parameters:

data (tuple) – Input tuple of tensors of num_features, cat_features, embeddings.

Returns:

Output tensor.

Return type:

torch.Tensor

class mambular.base_models.NODE(feature_information, num_classes=1, config=DefaultNODEConfig(lr=0.0001, lr_patience=10, weight_decay=1e-06, lr_factor=0.1, use_embeddings=False, embedding_activation=torch.nn.Identity, embedding_type='linear', embedding_bias=False, layer_norm_after_embedding=False, d_model=32, plr_lite=False, n_frequencies=48, frequencies_init_scale=0.01, embedding_projection=True, batch_norm=False, layer_norm=False, layer_norm_eps=1e-05, activation=torch.nn.ReLU, cat_encoding='int', num_layers=4, layer_dim=128, tree_dim=1, depth=6, norm=None, head_layer_sizes=[], head_dropout=0.3, head_skip_layers=False, head_activation=torch.nn.ReLU, head_use_batch_norm=False), **kwargs)[source]#

A Neural Oblivious Decision Ensemble (NODE) model for tabular data, integrating feature embeddings, dense blocks, and customizable heads for predictions.

Parameters:
  • cat_feature_info (dict) – Dictionary containing information about categorical features, including their names and dimensions.

  • num_feature_info (dict) – Dictionary containing information about numerical features, including their names and dimensions.

  • num_classes (int, optional) – The number of output classes or target dimensions for regression, by default 1.

  • config (DefaultNODEConfig, optional) – Configuration object containing model hyperparameters such as the number of dense layers, layer dimensions, tree depth, embedding settings, and head layer configurations, by default DefaultNODEConfig().

  • **kwargs (dict) – Additional keyword arguments for the BaseModel class.

cat_feature_info#

Stores categorical feature information.

Type:

dict

num_feature_info#

Stores numerical feature information.

Type:

dict

use_embeddings#

Flag indicating if embeddings should be used for categorical and numerical features.

Type:

bool

embedding_layer#

Embedding layer for features, used if use_embeddings is enabled.

Type:

EmbeddingLayer, optional

d_out#

The output dimension, usually set to num_classes.

Type:

int

block#

Dense block layer for feature transformations based on the NODE approach.

Type:

DenseBlock

tabular_head#

MLPhead layer to produce the final prediction based on the output of the dense block.

Type:

MLPhead

forward(num_features, cat_features)[source]#

Perform a forward pass through the model, including embedding (if enabled), dense transformations, and prediction steps.

forward(*data)[source]#

Forward pass through the NODE model.

Parameters:
  • num_features (torch.Tensor) – Numerical features tensor of shape [batch_size, num_numerical_features].

  • cat_features (torch.Tensor) – Categorical features tensor of shape [batch_size, num_categorical_features].

Returns:

Model output of shape [batch_size, num_classes].

Return type:

torch.Tensor

class mambular.base_models.NDTF(feature_information, num_classes=1, config=DefaultNDTFConfig(lr=0.0001, lr_patience=10, weight_decay=1e-06, lr_factor=0.1, use_embeddings=False, embedding_activation=torch.nn.Identity, embedding_type='linear', embedding_bias=False, layer_norm_after_embedding=False, d_model=32, plr_lite=False, n_frequencies=48, frequencies_init_scale=0.01, embedding_projection=True, batch_norm=False, layer_norm=False, layer_norm_eps=1e-05, activation=torch.nn.ReLU, cat_encoding='int', min_depth=4, max_depth=16, temperature=0.1, node_sampling=0.3, lamda=0.3, n_ensembles=12, penalty_factor=1e-08), **kwargs)[source]#

A Neural Decision Tree Forest (NDTF) model for tabular data, composed of an ensemble of neural decision trees with convolutional feature interactions, capable of producing predictions and penalty-based regularization.

Parameters:
  • cat_feature_info (dict) – Dictionary containing information about categorical features, including their names and dimensions.

  • num_feature_info (dict) – Dictionary containing information about numerical features, including their names and dimensions.

  • num_classes (int, optional) – The number of output classes or target dimensions for regression, by default 1.

  • config (DefaultNDTFConfig, optional) – Configuration object containing model hyperparameters such as the number of ensembles, tree depth, penalty factor, sampling settings, and temperature, by default DefaultNDTFConfig().

  • **kwargs (dict) – Additional keyword arguments for the BaseModel class.

cat_feature_info#

Stores categorical feature information.

Type:

dict

num_feature_info#

Stores numerical feature information.

Type:

dict

penalty_factor#

Scaling factor for the penalty applied during training, specified in the self.hparams.

Type:

float

input_dimensions#

List of input dimensions for each tree in the ensemble, with random sampling.

Type:

list of int

trees#

List of neural decision trees used in the ensemble.

Type:

nn.ModuleList

conv_layer#

Convolutional layer for feature interactions before passing inputs to trees.

Type:

nn.Conv1d

tree_weights#

Learnable parameter to weight each tree’s output in the ensemble.

Type:

nn.Parameter

forward(num_features, cat_features) torch.Tensor[source]#

Perform a forward pass through the model, producing predictions based on an ensemble of neural decision trees.

penalty_forward(num_features, cat_features) tuple of torch.Tensor[source]#

Perform a forward pass with penalty regularization, returning predictions and the calculated penalty term.

forward(*data)[source]#

Forward pass of the NDTF model.

Parameters:

data (tuple) – Input tuple of tensors of num_features, cat_features, embeddings.

Returns:

Output tensor.

Return type:

torch.Tensor

penalty_forward(*data)[source]#

Forward pass of the NDTF model.

Parameters:

data (tuple) – Input tuple of tensors of num_features, cat_features, embeddings.

Returns:

Output tensor.

Return type:

torch.Tensor

class mambular.base_models.SAINT(feature_information, num_classes=1, config=DefaultSAINTConfig(lr=0.0001, lr_patience=10, weight_decay=1e-06, lr_factor=0.1, use_embeddings=False, embedding_activation=torch.nn.Identity, embedding_type='linear', embedding_bias=False, layer_norm_after_embedding=False, d_model=128, plr_lite=False, n_frequencies=48, frequencies_init_scale=0.01, embedding_projection=True, batch_norm=False, layer_norm=False, layer_norm_eps=1e-05, activation=torch.nn.GELU, cat_encoding='int', n_layers=1, n_heads=2, attn_dropout=0.2, ff_dropout=0.1, norm='LayerNorm', norm_first=False, bias=True, head_layer_sizes=[], head_dropout=0.5, head_skip_layers=False, head_activation=torch.nn.SELU, head_use_batch_norm=False, pooling_method='cls', use_cls=True), **kwargs)[source]#

A Feature Transformer model for tabular data with categorical and numerical features, using embedding, transformer encoding, and pooling to produce final predictions.

Parameters:
  • cat_feature_info (dict) – Dictionary containing information about categorical features, including their names and dimensions.

  • num_feature_info (dict) – Dictionary containing information about numerical features, including their names and dimensions.

  • num_classes (int, optional) – The number of output classes or target dimensions for regression, by default 1.

  • config (DefaultSAINTConfig, optional) – Configuration object containing model hyperparameters such as dropout rates, hidden layer sizes, transformer settings, and other architectural configurations, by default DefaultSAINTConfig().

  • **kwargs (dict) – Additional keyword arguments for the BaseModel class.

pooling_method#

The pooling method to aggregate features after transformer encoding.

Type:

str

cat_feature_info#

Stores categorical feature information.

Type:

dict

num_feature_info#

Stores numerical feature information.

Type:

dict

embedding_layer#

Layer for embedding categorical and numerical features.

Type:

EmbeddingLayer

norm_f#

Normalization layer for the transformer output.

Type:

nn.Module

encoder#

Transformer encoder for sequential processing of embedded features.

Type:

nn.TransformerEncoder

tabular_head#

MLPhead layer to produce the final prediction based on the output of the transformer encoder.

Type:

MLPhead

forward(num_features, cat_features)[source]#

Perform a forward pass through the model, including embedding, transformer encoding, pooling, and prediction steps.

forward(*data)[source]#

Defines the forward pass of the model.

Parameters:

data (tuple) – Input tuple of tensors of num_features, cat_features, embeddings.

Returns:

Output tensor.

Return type:

torch.Tensor