Source code for mambular.base_models.ndtf

import numpy as np
import torch
import torch.nn as nn

from ..arch_utils.neural_decision_tree import NeuralDecisionTree
from ..configs.ndtf_config import DefaultNDTFConfig
from ..utils.get_feature_dimensions import get_feature_dimensions
from .utils.basemodel import BaseModel


[docs]class NDTF(BaseModel): """A Neural Decision Tree Forest (NDTF) model for tabular data, composed of an ensemble of neural decision trees with convolutional feature interactions, capable of producing predictions and penalty-based regularization. Parameters ---------- cat_feature_info : dict Dictionary containing information about categorical features, including their names and dimensions. num_feature_info : dict Dictionary containing information about numerical features, including their names and dimensions. num_classes : int, optional The number of output classes or target dimensions for regression, by default 1. config : DefaultNDTFConfig, optional Configuration object containing model hyperparameters such as the number of ensembles, tree depth, penalty factor, sampling settings, and temperature, by default DefaultNDTFConfig(). **kwargs : dict Additional keyword arguments for the BaseModel class. Attributes ---------- cat_feature_info : dict Stores categorical feature information. num_feature_info : dict Stores numerical feature information. penalty_factor : float Scaling factor for the penalty applied during training, specified in the self.hparams. input_dimensions : list of int List of input dimensions for each tree in the ensemble, with random sampling. trees : nn.ModuleList List of neural decision trees used in the ensemble. conv_layer : nn.Conv1d Convolutional layer for feature interactions before passing inputs to trees. tree_weights : nn.Parameter Learnable parameter to weight each tree's output in the ensemble. Methods ------- forward(num_features, cat_features) -> torch.Tensor Perform a forward pass through the model, producing predictions based on an ensemble of neural decision trees. penalty_forward(num_features, cat_features) -> tuple of torch.Tensor Perform a forward pass with penalty regularization, returning predictions and the calculated penalty term. """ def __init__( self, feature_information: tuple, # Expecting (num_feature_info, cat_feature_info, embedding_feature_info) num_classes: int = 1, config: DefaultNDTFConfig = DefaultNDTFConfig(), # noqa: B008 **kwargs, ): super().__init__(config=config, **kwargs) self.save_hyperparameters(ignore=["feature_information"]) self.returns_ensemble = False input_dim = get_feature_dimensions(*feature_information) self.input_dimensions = [input_dim] for _ in range(self.hparams.n_ensembles - 1): self.input_dimensions.append(np.random.randint(1, input_dim)) self.trees = nn.ModuleList( [ NeuralDecisionTree( input_dim=self.input_dimensions[idx], depth=np.random.randint( self.hparams.min_depth, self.hparams.max_depth ), output_dim=num_classes, lamda=self.hparams.lamda, temperature=self.hparams.temperature + np.abs(np.random.normal(0, 0.1)), node_sampling=self.hparams.node_sampling, ) for idx in range(self.hparams.n_ensembles) ] ) self.conv_layer = nn.Conv1d( in_channels=self.input_dimensions[0], out_channels=1, # Single channel output if one feature interaction is desired # Choose appropriate kernel size kernel_size=self.input_dimensions[0], # To keep output size the same as input_dim if desired padding=self.input_dimensions[0] - 1, bias=True, ) self.tree_weights = nn.Parameter( torch.full((self.hparams.n_ensembles, 1), 1.0 / self.hparams.n_ensembles), requires_grad=True, )
[docs] def forward(self, *data) -> torch.Tensor: """Forward pass of the NDTF model. Parameters ---------- data : tuple Input tuple of tensors of num_features, cat_features, embeddings. Returns ------- torch.Tensor Output tensor. """ x = torch.cat([t for tensors in data for t in tensors], dim=1) x = self.conv_layer(x.unsqueeze(2)) x = x.transpose(1, 2).squeeze(-1) preds = [] for idx, tree in enumerate(self.trees): tree_input = x[:, : self.input_dimensions[idx]] preds.append(tree(tree_input, return_penalty=False)) preds = torch.stack(preds, dim=1).squeeze(-1) return preds @ self.tree_weights
[docs] def penalty_forward(self, *data) -> torch.Tensor: """Forward pass of the NDTF model. Parameters ---------- data : tuple Input tuple of tensors of num_features, cat_features, embeddings. Returns ------- torch.Tensor Output tensor. """ x = torch.cat([t for tensors in data for t in tensors], dim=1) x = self.conv_layer(x.unsqueeze(2)) x = x.transpose(1, 2).squeeze(-1) penalty = 0.0 preds = [] # Iterate over trees and collect predictions and penalties for idx, tree in enumerate(self.trees): # Select subset of features for the current tree tree_input = x[:, : self.input_dimensions[idx]] # Get prediction and penalty from the current tree pred, pen = tree(tree_input, return_penalty=True) preds.append(pred) penalty += pen # Stack predictions and calculate mean across trees preds = torch.stack(preds, dim=1).squeeze(-1) return preds @ self.tree_weights, self.hparams.penalty_factor * penalty # type: ignore