Source code for mambular.configs.tabularnn_config

from collections.abc import Callable
from dataclasses import dataclass, field
import torch.nn as nn
from .base_config import BaseConfig


[docs]@dataclass class DefaultTabulaRNNConfig(BaseConfig): """Configuration class for the TabulaRNN model with predefined hyperparameters. Parameters ---------- model_type : str, default="RNN" Type of model, one of "RNN", "LSTM", "GRU", "mLSTM", "sLSTM". n_layers : int, default=4 Number of layers in the RNN. rnn_dropout : float, default=0.2 Dropout rate for the RNN layers. d_model : int, default=128 Dimensionality of embeddings or model representations. norm : str, default="RMSNorm" Normalization method to be used. activation : callable, default=nn.SELU() Activation function for the RNN layers. residuals : bool, default=False Whether to include residual connections in the RNN. head_layer_sizes : list, default=() Sizes of the layers in the head of the model. head_dropout : float, default=0.5 Dropout rate for the head layers. head_skip_layers : bool, default=False Whether to skip layers in the head. head_activation : callable, default=nn.SELU() Activation function for the head layers. head_use_batch_norm : bool, default=False Whether to use batch normalization in the head layers. pooling_method : str, default="avg" Pooling method to be used ('avg', 'cls', etc.). norm_first : bool, default=False Whether to apply normalization before other operations in each block. layer_norm_eps : float, default=1e-05 Epsilon value for layer normalization. bias : bool, default=True Whether to use bias in the linear layers. rnn_activation : str, default="relu" Activation function for the RNN layers. dim_feedforward : int, default=256 Size of the feedforward network. d_conv : int, default=4 Size of the convolutional layer for embedding features. dilation : int, default=1 Dilation factor for the convolution. conv_bias : bool, default=True Whether to use bias in the convolutional layers. """ # Architecture params model_type: str = "RNN" d_model: int = 128 n_layers: int = 4 rnn_dropout: float = 0.2 norm: str = "RMSNorm" activation: Callable = nn.SELU() # noqa: RUF009 residuals: bool = False # Head params head_layer_sizes: list = field(default_factory=list) head_dropout: float = 0.5 head_skip_layers: bool = False head_activation: Callable = nn.SELU() # noqa: RUF009 head_use_batch_norm: bool = False # Pooling and normalization pooling_method: str = "avg" norm_first: bool = False layer_norm_eps: float = 1e-05 # Additional params bias: bool = True rnn_activation: str = "relu" dim_feedforward: int = 256 d_conv: int = 4 dilation: int = 1 conv_bias: bool = True