Source code for deeptab.configs.resnet_config

from collections.abc import Callable
from dataclasses import dataclass, field

import torch.nn as nn

from .base_config import BaseConfig


[docs]@dataclass class DefaultResNetConfig(BaseConfig): """Configuration class for the default ResNet model with predefined hyperparameters. Parameters ---------- layer_sizes : list, default=(256, 128, 32) Sizes of the layers in the ResNet. activation : callable, default=nn.SELU() Activation function for the ResNet layers. skip_layers : bool, default=False Whether to skip layers in the ResNet. dropout : float, default=0.5 Dropout rate for regularization. norm : bool, default=False Whether to use normalization in the ResNet. use_glu : bool, default=False Whether to use Gated Linear Units (GLU) in the ResNet. skip_connections : bool, default=True Whether to use skip connections in the ResNet. num_blocks : int, default=3 Number of residual blocks in the ResNet. average_embeddings : bool, default=True Whether to average embeddings during the forward pass. """ # model params layer_sizes: list = field(default_factory=lambda: [256, 128, 32]) activation: Callable = nn.SELU() # noqa: RUF009 skip_layers: bool = False dropout: float = 0.5 norm: bool = False use_glu: bool = False skip_connections: bool = True num_blocks: int = 3 # embedding params average_embeddings: bool = True