Source code for mambular.configs.node_config

from collections.abc import Callable
from dataclasses import dataclass, field
import torch.nn as nn
from .base_config import BaseConfig


[docs]@dataclass class DefaultNODEConfig(BaseConfig): """Configuration class for the Neural Oblivious Decision Ensemble (NODE) model. Parameters ---------- num_layers : int, default=4 Number of dense layers in the model. layer_dim : int, default=128 Dimensionality of each dense layer. tree_dim : int, default=1 Dimensionality of the output from each tree leaf. depth : int, default=6 Depth of each decision tree in the ensemble. norm : str, default=None Type of normalization to use in the model. head_layer_sizes : list, default=() Sizes of the layers in the model's head. head_dropout : float, default=0.5 Dropout rate for the head layers. head_skip_layers : bool, default=False Whether to skip layers in the head. head_activation : callable, default=nn.SELU() Activation function for the head layers. head_use_batch_norm : bool, default=False Whether to use batch normalization in the head layers. """ # Architecture Parameters num_layers: int = 4 layer_dim: int = 128 tree_dim: int = 1 depth: int = 6 norm: str | None = None # Head Parameters head_layer_sizes: list = field(default_factory=list) head_dropout: float = 0.3 head_skip_layers: bool = False head_activation: Callable = nn.ReLU() # noqa: RUF009 head_use_batch_norm: bool = False