Source code for deeptab.configs.tabm_config
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Literal
import torch.nn as nn
from .base_config import BaseConfig
[docs]@dataclass
class DefaultTabMConfig(BaseConfig):
"""Configuration class for the TabM model with batch ensembling and predefined hyperparameters.
Parameters
----------
layer_sizes : list, default=(512, 512, 128)
Sizes of the layers in the model.
activation : callable, default=nn.ReLU()
Activation function for the model layers.
dropout : float, default=0.3
Dropout rate for regularization.
norm : str, default=None
Normalization method to be used, if any.
use_glu : bool, default=False
Whether to use Gated Linear Units (GLU) in the model.
ensemble_size : int, default=32
Number of ensemble members for batch ensembling.
ensemble_scaling_in : bool, default=True
Whether to use input scaling for each ensemble member.
ensemble_scaling_out : bool, default=True
Whether to use output scaling for each ensemble member.
ensemble_bias : bool, default=True
Whether to use a unique bias term for each ensemble member.
scaling_init : {"ones", "random-signs", "normal"}, default="normal"
Initialization method for scaling weights.
average_ensembles : bool, default=False
Whether to average the outputs of the ensembles.
model_type : {"mini", "full"}, default="mini"
Model type to use ('mini' for reduced version, 'full' for complete model).
"""
# arch params
layer_sizes: list = field(default_factory=lambda: [256, 256, 128])
activation: Callable = nn.ReLU() # noqa: RUF009
dropout: float = 0.5
norm: str | None = None
use_glu: bool = False
# Batch ensembling specific configurations
ensemble_size: int = 32
ensemble_scaling_in: bool = True
ensemble_scaling_out: bool = True
ensemble_bias: bool = True
scaling_init: Literal["ones", "random-signs", "normal"] = "ones"
average_ensembles: bool = False
model_type: Literal["mini", "full"] = "mini"
average_embeddings: bool = True