Source code for mambular.configs.mambattention_config
from collections.abc import Callable
from dataclasses import dataclass, field
import torch.nn as nn
from .base_config import BaseConfig
[docs]@dataclass
class DefaultMambAttentionConfig(BaseConfig):
"""Configuration class for the Default Mambular Attention model with predefined hyperparameters.
Parameters
----------
d_model : int, default=64
Dimensionality of the model.
n_layers : int, default=4
Number of layers in the model.
expand_factor : int, default=2
Expansion factor for the feed-forward layers.
n_heads : int, default=8
Number of attention heads in the model.
last_layer : str, default="attn"
Type of the last layer (e.g., 'attn').
n_mamba_per_attention : int, default=1
Number of Mamba blocks per attention layer.
bias : bool, default=False
Whether to use bias in the linear layers.
d_conv : int, default=4
Dimensionality of the convolutional layers.
conv_bias : bool, default=True
Whether to use bias in the convolutional layers.
dropout : float, default=0.0
Dropout rate for regularization.
attn_dropout : float, default=0.2
Dropout rate for the attention mechanism.
dt_rank : str, default="auto"
Rank of the decision tree.
d_state : int, default=128
Dimensionality of the state in recurrent layers.
dt_scale : float, default=1.0
Scaling factor for the decision tree.
dt_init : str, default="random"
Initialization method for the decision tree.
dt_max : float, default=0.1
Maximum value for decision tree initialization.
dt_min : float, default=1e-04
Minimum value for decision tree initialization.
dt_init_floor : float, default=1e-04
Floor value for decision tree initialization.
norm : str, default="LayerNorm"
Type of normalization used in the model.
activation : callable, default=nn.SiLU()
Activation function for the model.
head_layer_sizes : list, default=()
Sizes of the fully connected layers in the model's head.
head_dropout : float, default=0.5
Dropout rate for the head layers.
head_skip_layers : bool, default=False
Whether to use skip connections in the head layers.
head_activation : callable, default=nn.SELU()
Activation function for the head layers.
head_use_batch_norm : bool, default=False
Whether to use batch normalization in the head layers.
pooling_method : str, default="avg"
Pooling method to be used ('avg', 'max', etc.).
bidirectional : bool, default=False
Whether to process input sequences bidirectionally.
use_learnable_interaction : bool, default=False
Whether to use learnable feature interactions before passing through Mamba blocks.
use_cls : bool, default=False
Whether to append a CLS token for sequence pooling.
shuffle_embeddings : bool, default=False
Whether to shuffle embeddings before passing to Mamba layers.
cat_encoding : str, default="int"
Encoding method for categorical features ('int', 'one-hot', etc.).
AD_weight_decay : bool, default=True
Whether weight decay is applied to A-D matrices.
BC_layer_norm : bool, default=False
Whether to apply layer normalization to B-C matrices.
use_pscan : bool, default=False
Whether to use PSCAN for the state-space model.
n_attention_layers : int, default=1
Number of attention layers in the model.
"""
# Architecture Parameters
d_model: int = 64
n_layers: int = 4
expand_factor: int = 2
n_heads: int = 8
last_layer: str = "attn"
n_mamba_per_attention: int = 1
bias: bool = False
d_conv: int = 4
conv_bias: bool = True
dropout: float = 0.0
attn_dropout: float = 0.2
dt_rank: str = "auto"
d_state: int = 128
dt_scale: float = 1.0
dt_init: str = "random"
dt_max: float = 0.1
dt_min: float = 1e-04
dt_init_floor: float = 1e-04
norm: str = "LayerNorm"
activation: Callable = nn.SiLU() # noqa: RUF009
# Head Parameters
head_layer_sizes: list = field(default_factory=list)
head_dropout: float = 0.5
head_skip_layers: bool = False
head_activation: Callable = nn.SELU() # noqa: RUF009
head_use_batch_norm: bool = False
# Pooling and Categorical Encoding
pooling_method: str = "avg"
bidirectional: bool = False
use_learnable_interaction: bool = False
use_cls: bool = False
shuffle_embeddings: bool = False
cat_encoding: str = "int"
# Additional Features
AD_weight_decay: bool = True
BC_layer_norm: bool = False
use_pscan: bool = False
n_attention_layers: int = 1