import numpy as np
import torch
import torch.nn as nn
from ..arch_utils.layer_utils.embedding_layer import EmbeddingLayer
from ..configs.tangos_config import DefaultTangosConfig
from ..utils.get_feature_dimensions import get_feature_dimensions
from .utils.basemodel import BaseModel
[docs]class Tangos(BaseModel):
"""
A Multi-Layer Perceptron (MLP) model with optional GLU activation, batch normalization, layer normalization, and dropout. # noqa: W505
It includes a penalty term for specialization and orthogonality.
Parameters
----------
feature_information : tuple
A tuple containing feature information for numerical and categorical features.
num_classes : int, optional (default=1)
The number of output classes.
config : DefaultTangosConfig, optional (default=DefaultTangosConfig())
Configuration object defining model hyperparameters.
**kwargs : dict
Additional arguments for the base model.
Attributes
----------
returns_ensemble : bool
Whether the model returns an ensemble of predictions.
lamda1 : float
Regularization weight for the specialization loss.
lamda2 : float
Regularization weight for the orthogonality loss.
subsample : float
Proportion of neuron pairs to use for orthogonality loss calculation.
embedding_layer : EmbeddingLayer or None
Optional embedding layer for categorical features.
layers : nn.ModuleList
The main MLP layers including linear, normalization, and activation layers.
head : nn.Linear
The final output layer.
"""
def __init__(
self,
feature_information: tuple,
num_classes=1,
config: DefaultTangosConfig = DefaultTangosConfig(), # noqa: B008
**kwargs,
):
super().__init__(config=config, **kwargs)
self.save_hyperparameters(ignore=["feature_information"])
self.returns_ensemble = False
self.lamda1 = config.lamda1
self.lamda2 = config.lamda2
self.subsample = config.subsample
input_dim = get_feature_dimensions(*feature_information)
# Initialize layers
self.layers = nn.ModuleList()
# Input layer
self.layers.append(nn.Linear(input_dim, self.hparams.layer_sizes[0]))
if self.hparams.batch_norm:
self.layers.append(nn.BatchNorm1d(self.hparams.layer_sizes[0]))
if self.hparams.use_glu:
self.layers.append(nn.GLU())
else:
self.layers.append(self.hparams.activation)
if self.hparams.dropout > 0.0:
self.layers.append(nn.Dropout(self.hparams.dropout))
# Hidden layers
for i in range(1, len(self.hparams.layer_sizes)):
self.layers.append(nn.Linear(self.hparams.layer_sizes[i - 1], self.hparams.layer_sizes[i]))
if self.hparams.batch_norm:
self.layers.append(nn.BatchNorm1d(self.hparams.layer_sizes[i]))
if self.hparams.layer_norm:
self.layers.append(nn.LayerNorm(self.hparams.layer_sizes[i]))
if self.hparams.use_glu:
self.layers.append(nn.GLU())
else:
self.layers.append(self.hparams.activation)
if self.hparams.dropout > 0.0:
self.layers.append(nn.Dropout(self.hparams.dropout))
# Output layer
self.head = nn.Linear(self.hparams.layer_sizes[-1], num_classes)
[docs] def repr_forward(self, x) -> torch.Tensor:
"""
Computes the forward pass for feature representations.
This method processes the input through the MLP layers, optionally using
skip connections.
Parameters
----------
x : torch.Tensor
Input tensor of shape (batch_size, feature_dim).
Returns
-------
torch.Tensor
Output tensor after passing through the representation layers.
"""
x = x.unsqueeze(0)
for i in range(len(self.layers)):
if isinstance(self.layers[i], nn.Linear):
out = self.layers[i](x)
if self.hparams.skip_connections and x.shape == out.shape:
x = x + out
else:
x = out
else:
x = self.layers[i](x)
return x
[docs] def forward(self, *data) -> torch.Tensor:
"""
Performs a forward pass of the MLP model.
This method concatenates all input tensors before applying MLP layers.
Parameters
----------
data : tuple
A tuple containing lists of numerical, categorical, and embedded feature tensors.
Returns
-------
torch.Tensor
The output tensor of shape (batch_size, num_classes).
"""
x = torch.cat([t for tensors in data for t in tensors], dim=1)
for i in range(len(self.layers)):
if isinstance(self.layers[i], nn.Linear):
out = self.layers[i](x)
if self.hparams.skip_connections and x.shape == out.shape:
x = x + out
else:
x = out
else:
x = self.layers[i](x)
x = self.head(x)
return x
[docs] def penalty_forward(self, *data):
"""
Computes both the model predictions and a penalty term.
The penalty term includes:
- **Specialization loss**: Measures feature importance concentration.
- **Orthogonality loss**: Encourages diversity among learned features.
The method uses `jacrev` to compute the Jacobian of the representation function.
Parameters
----------
data : tuple
A tuple containing lists of numerical, categorical, and embedded feature tensors.
Returns
-------
tuple
- predictions : torch.Tensor
Model predictions of shape (batch_size, num_classes).
- penalty : torch.Tensor
The computed penalty term for regularization.
"""
x = torch.cat([t for tensors in data for t in tensors], dim=1)
batch_size = x.shape[0]
subsample = np.int32(self.subsample * batch_size)
# Flatten before passing to jacrev
flat_data = torch.cat([t for tensors in data for t in tensors], dim=1)
# Compute Jacobian
jacobian = torch.func.vmap(torch.func.jacrev(self.repr_forward), randomness="different")(flat_data)
jacobian = jacobian.squeeze()
neuron_attr = jacobian.swapaxes(0, 1)
h_dim = neuron_attr.shape[0]
if len(neuron_attr.shape) > 3:
# h_dim x batch_size x features
neuron_attr = neuron_attr.flatten(start_dim=2)
# calculate specialization loss component
spec_loss = torch.norm(neuron_attr, p=1) / (batch_size * h_dim * neuron_attr.shape[2])
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
orth_loss = torch.tensor(0.0, requires_grad=True).to(x.device)
# apply subsampling routine for orthogonalization loss
if self.subsample > 0 and self.subsample < h_dim * (h_dim - 1) / 2:
tensor_pairs = [list(np.random.choice(h_dim, size=(2), replace=False)) for i in range(subsample)]
for tensor_pair in tensor_pairs:
pairwise_corr = cos(neuron_attr[tensor_pair[0], :, :], neuron_attr[tensor_pair[1], :, :]).norm(p=1)
orth_loss = orth_loss + pairwise_corr
orth_loss = orth_loss / (batch_size * self.subsample)
else:
for neuron_i in range(1, h_dim):
for neuron_j in range(0, neuron_i):
pairwise_corr = cos(neuron_attr[neuron_i, :, :], neuron_attr[neuron_j, :, :]).norm(p=1)
orth_loss = orth_loss + pairwise_corr
num_pairs = h_dim * (h_dim - 1) / 2
orth_loss = orth_loss / (batch_size * num_pairs)
penalty = self.lamda1 * spec_loss + self.lamda2 * orth_loss
predictions = self.forward(*data)
return predictions, penalty