Source code for deeptab.configs.modernnca_config
from collections.abc import Callable
from dataclasses import dataclass, field
import torch.nn as nn
from .base_config import BaseConfig
[docs]@dataclass
class DefaultModernNCAConfig(BaseConfig):
"""
Default configuration for the ModernNCA model.
"""
# Architecture Parameters
dim: int = 128 # Hidden dimension for encoding
d_block: int = 512 # Block size for MLP layers
n_blocks: int = 4 # Number of MLP blocks
dropout: float = 0.1 # Dropout rate
temperature: float = 0.75 # Temperature scaling for distance weighting
sample_rate: float = 0.5 # Fraction of candidate samples used
num_embeddings: dict | None = None # Dictionary for categorical embeddings
# Training Parameters
optimizer_type: str = "AdamW" # Optimizer type
weight_decay: float = 1e-5 # Weight decay for optimizer
learning_rate: float = 1e-02 # Learning rate
lr_patience: int = 10 # Patience for LR scheduler
lr_factor: float = 0.1 # Factor for LR scheduler
# Head Parameters
head_layer_sizes: list = field(default_factory=list)
head_dropout: float = 0.5
head_skip_layers: bool = False
head_activation: Callable = nn.SELU() # noqa: RUF009
head_use_batch_norm: bool = False
# Embedding Parameters
embedding_type: str = "plr"
plr_lite: bool = True
n_frequencies: int = 75
frequencies_init_scale: float = 0.045