Source code for mambular.base_models.tabularnn

from dataclasses import replace
import torch
import torch.nn as nn

from ..arch_utils.get_norm_fn import get_normalization_layer
from ..arch_utils.layer_utils.embedding_layer import EmbeddingLayer
from ..arch_utils.mlp_utils import MLPhead
from ..arch_utils.rnn_utils import ConvRNN
from ..configs.tabularnn_config import DefaultTabulaRNNConfig
from .utils.basemodel import BaseModel


[docs]class TabulaRNN(BaseModel): def __init__( self, feature_information: tuple, # Expecting (num_feature_info, cat_feature_info, embedding_feature_info) num_classes=1, config: DefaultTabulaRNNConfig = DefaultTabulaRNNConfig(), # noqa: B008 **kwargs, ): super().__init__(config=config, **kwargs) self.save_hyperparameters(ignore=["feature_information"]) self.returns_ensemble = False self.rnn = ConvRNN(config) self.embedding_layer = EmbeddingLayer( *feature_information, config=config, ) self.tabular_head = MLPhead( input_dim=self.hparams.dim_feedforward, config=config, output_dim=num_classes, ) self.linear = nn.Linear( self.hparams.d_model, self.hparams.dim_feedforward, ) temp_config = replace(config, d_model=config.dim_feedforward) self.norm_f = get_normalization_layer(temp_config) # pooling n_inputs = [len(info) for info in feature_information] self.initialize_pooling_layers(config=config, n_inputs=n_inputs)
[docs] def forward(self, *data): """Defines the forward pass of the model. Parameters ---------- num_features : Tensor Tensor containing the numerical features. cat_features : Tensor Tensor containing the categorical features. Returns ------- Tensor The output predictions of the model. """ x = self.embedding_layer(*data) # RNN forward pass out, _ = self.rnn(x) z = self.linear(torch.mean(x, dim=1)) x = self.pool_sequence(out) x = x + z if self.norm_f is not None: x = self.norm_f(x) preds = self.tabular_head(x) return preds