Source code for deeptab.base_models.trompt
import numpy as np
import torch
import torch.nn as nn
from ..arch_utils.get_norm_fn import get_normalization_layer
from ..arch_utils.layer_utils.embedding_layer import EmbeddingLayer
from ..arch_utils.trompt_utils import TromptCell, TromptDecoder
from ..configs.trompt_config import DefaultTromptConfig
from .utils.basemodel import BaseModel
[docs]class Trompt(BaseModel):
def __init__(
self,
feature_information: tuple, # Expecting (num_feature_info, cat_feature_info, embedding_feature_info)
num_classes=1,
config: DefaultTromptConfig = DefaultTromptConfig(), # noqa: B008
**kwargs,
):
super().__init__(config=config, **kwargs)
self.save_hyperparameters(ignore=["feature_information"])
self.returns_ensemble = True
# embedding layer
self.cells = nn.ModuleList(TromptCell(feature_information, config) for _ in range(config.n_cycles))
self.decoder = TromptDecoder(config.d_model, num_classes)
self.init_rec = nn.Parameter(torch.empty(config.P, config.d_model))
self.n_cycles = config.n_cycles
[docs] def forward(self, *data):
"""Defines the forward pass of the model.
Parameters
----------
data : tuple
Input tuple of tensors of num_features, cat_features, embeddings.
Returns
-------
Tensor
The output predictions of the model.
"""
O = self.init_rec.unsqueeze(0).repeat(data[0][0].shape[0], 1, 1) # noqa: E741
outputs = []
for i in range(self.n_cycles):
O = self.cells[i](*data, O=O) # noqa: E741
# print(O.shape)
# print(self.tdown(O).shape)
outputs.append(self.decoder(O))
out = torch.stack(outputs, dim=1).squeeze(-1)
# preds = out.mean(dim=1)
return out