Source code for mambular.base_models.resnet
import torch
import torch.nn as nn
import numpy as np
from ..arch_utils.layer_utils.embedding_layer import EmbeddingLayer
from ..arch_utils.resnet_utils import ResidualBlock
from ..configs.resnet_config import DefaultResNetConfig
from ..utils.get_feature_dimensions import get_feature_dimensions
from .utils.basemodel import BaseModel
[docs]class ResNet(BaseModel):
"""A ResNet model for tabular data, combining feature embeddings, residual blocks, and customizable architecture for
processing categorical and numerical features.
Parameters
----------
cat_feature_info : dict
Dictionary containing information about categorical features, including their names and dimensions.
num_feature_info : dict
Dictionary containing information about numerical features, including their names and dimensions.
num_classes : int, optional
The number of output classes or target dimensions for regression, by default 1.
config : DefaultResNetConfig, optional
Configuration object containing model hyperparameters such as layer sizes, number of residual blocks,
dropout rates, activation functions, and normalization settings, by default DefaultResNetConfig().
**kwargs : dict
Additional keyword arguments for the BaseModel class.
Attributes
----------
layer_sizes : list of int
List specifying the number of units in each layer of the ResNet.
cat_feature_info : dict
Stores categorical feature information.
num_feature_info : dict
Stores numerical feature information.
activation : callable
Activation function used in the residual blocks.
use_embeddings : bool
Flag indicating if embeddings should be used for categorical and numerical features.
embedding_layer : EmbeddingLayer, optional
Embedding layer for features, used if `use_embeddings` is enabled.
initial_layer : nn.Linear
Initial linear layer to project input features into the model's hidden dimension.
blocks : nn.ModuleList
List of residual blocks to process the hidden representations.
output_layer : nn.Linear
Output layer that produces the final prediction.
Methods
-------
forward(num_features, cat_features)
Perform a forward pass through the model, including embedding (if enabled), residual blocks,
and prediction steps.
"""
def __init__(
self,
feature_information: tuple, # Expecting (num_feature_info, cat_feature_info, embedding_feature_info)
num_classes: int = 1,
config: DefaultResNetConfig = DefaultResNetConfig(), # noqa: B008
**kwargs,
):
super().__init__(config=config, **kwargs)
self.save_hyperparameters(ignore=["feature_information"])
self.returns_ensemble = False
if self.hparams.use_embeddings:
self.embedding_layer = EmbeddingLayer(
*feature_information,
config=config,
)
input_dim = np.sum(
[len(info) * self.hparams.d_model for info in feature_information]
)
else:
input_dim = get_feature_dimensions(*feature_information)
self.initial_layer = nn.Linear(input_dim, self.hparams.layer_sizes[0])
self.blocks = nn.ModuleList()
for i in range(self.hparams.num_blocks):
input_dim = self.hparams.layer_sizes[i]
output_dim = (
self.hparams.layer_sizes[i + 1]
if i + 1 < len(self.hparams.layer_sizes)
else self.hparams.layer_sizes[-1]
)
block = ResidualBlock(
input_dim,
output_dim,
self.hparams.activation,
self.hparams.norm,
self.hparams.dropout,
)
self.blocks.append(block)
self.output_layer = nn.Linear(self.hparams.layer_sizes[-1], num_classes)
[docs] def forward(self, *data):
"""Forward pass of the ResNet model.
Parameters
----------
data : tuple
Input tuple of tensors of num_features, cat_features, embeddings.
Returns
-------
torch.Tensor
Output tensor.
"""
if self.hparams.use_embeddings:
x = self.embedding_layer(*data)
B, S, D = x.shape
x = x.reshape(B, S * D)
else:
x = torch.cat([t for tensors in data for t in tensors], dim=1)
x = self.initial_layer(x)
for block in self.blocks:
x = block(x)
x = self.output_layer(x)
return x