Source code for deeptab.models.utils.sklearn_base_lss

import warnings
from collections.abc import Callable

import lightning as pl
import numpy as np
import pandas as pd
import properscoring as ps
import torch
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint, ModelSummary
from pretab.preprocessor import Preprocessor
from sklearn.base import BaseEstimator
from sklearn.metrics import accuracy_score, mean_squared_error
from torch.utils.data import DataLoader
from tqdm import tqdm

from ...base_models.utils.lightning_wrapper import TaskModel
from ...data_utils.datamodule import MambularDataModule
from ...utils.distributional_metrics import (
    beta_brier_score,
    dirichlet_error,
    gamma_deviance,
    inverse_gamma_loss,
    negative_binomial_deviance,
    poisson_deviance,
    student_t_loss,
)
from ...utils.distributions import (
    BetaDistribution,
    CategoricalDistribution,
    DirichletDistribution,
    GammaDistribution,
    InverseGammaDistribution,
    JohnsonSuDistribution,
    NegativeBinomialDistribution,
    NormalDistribution,
    PoissonDistribution,
    Quantile,
    StudentTDistribution,
)


[docs]class SklearnBaseLSS(BaseEstimator): def __init__(self, model, config, **kwargs): self.preprocessor_arg_names = [ "n_bins", "feature_preprocessing", "numerical_preprocessing", "categorical_preprocessing", "use_decision_tree_bins", "binning_strategy", "task", "cat_cutoff", "treat_all_integers_as_numerical", "degree", "scaling_strategy", "n_knots", "use_decision_tree_knots", "knots_strategy", "spline_implementation", ] self.config_kwargs = { k: v for k, v in kwargs.items() if k not in self.preprocessor_arg_names and not k.startswith("optimizer") } self.config = config(**self.config_kwargs) preprocessor_kwargs = {k: v for k, v in kwargs.items() if k in self.preprocessor_arg_names} self.preprocessor = Preprocessor(**preprocessor_kwargs) self.task_model = None self.estimator = model self.built = False # Raise a warning if task is set to 'classification' if preprocessor_kwargs.get("task") == "classification": warnings.warn( "The task is set to 'classification'. Be aware of your preferred distribution,that \ this might lead to unsatisfactory results.", UserWarning, stacklevel=2, ) self.optimizer_type = kwargs.get("optimizer_type", "Adam") self.optimizer_kwargs = { k: v for k, v in kwargs.items() if k not in ["lr", "weight_decay", "patience", "lr_patience", "optimizer_type"] and k.startswith("optimizer_") }
[docs] def get_params(self, deep=True): """Get parameters for this estimator. Parameters ---------- deep : bool, default=True If True, will return the parameters for this estimator and contained subobjects that are estimators. Returns ------- params : dict Parameter names mapped to their values. """ params = {} params.update(self.config_kwargs) if deep: preprocessor_params = {"prepro__" + key: value for key, value in self.preprocessor.get_params().items()} # type: ignore[attr-defined] params.update(preprocessor_params) return params
[docs] def set_params(self, **parameters): """Set the parameters of this estimator. Parameters ---------- **parameters : dict Estimator parameters. Returns ------- self : object Estimator instance. """ config_params = {k: v for k, v in parameters.items() if not k.startswith("prepro__")} preprocessor_params = {k.split("__")[1]: v for k, v in parameters.items() if k.startswith("prepro__")} if config_params: self.config_kwargs.update(config_params) if self.config is not None: for key, value in config_params.items(): setattr(self.config, key, value) else: self.config = self.config_class(**self.config_kwargs) # type: ignore if preprocessor_params: self.preprocessor.set_params(**preprocessor_params) # type: ignore[attr-defined] return self
[docs] def build_model( self, X, y, val_size: float = 0.2, X_val=None, y_val=None, random_state: int = 101, batch_size: int = 128, shuffle: bool = True, lr: float | None = None, lr_patience: int | None = None, lr_factor: float | None = None, weight_decay: float | None = None, train_metrics: dict[str, Callable] | None = None, val_metrics: dict[str, Callable] | None = None, dataloader_kwargs={}, ): """Builds the model using the provided training data. Parameters ---------- X : DataFrame or array-like, shape (n_samples, n_features) The training input samples. y : array-like, shape (n_samples,) or (n_samples, n_targets) The target values (real numbers). val_size : float, default=0.2 The proportion of the dataset to include in the validation split if `X_val` is None. Ignored if `X_val` is provided. X_val : DataFrame or array-like, shape (n_samples, n_features), optional The validation input samples. If provided, `X` and `y` are not split and this data is used for validation. y_val : array-like, shape (n_samples,) or (n_samples, n_targets), optional The validation target values. Required if `X_val` is provided. random_state : int, default=101 Controls the shuffling applied to the data before applying the split. batch_size : int, default=64 Number of samples per gradient update. shuffle : bool, default=True Whether to shuffle the training data before each epoch. lr : float, default=1e-3 Learning rate for the optimizer. lr_patience : int, default=10 Number of epochs with no improvement on the validation loss to wait before reducing the learning rate. lr_factor : float, default=0.1 Factor by which the learning rate will be reduced. train_metrics : dict, default=None torch.metrics dict to be logged during training. val_metrics : dict, default=None torch.metrics dict to be logged during validation. weight_decay : float, default=0.025 Weight decay (L2 penalty) coefficient. dataloader_kwargs: dict, default={} The kwargs for the pytorch dataloader class. Returns ------- self : object The built distributional regressor. """ if not isinstance(X, pd.DataFrame): X = pd.DataFrame(X) if isinstance(y, pd.Series): y = y.values if X_val is not None: if not isinstance(X_val, pd.DataFrame): X_val = pd.DataFrame(X_val) if isinstance(y_val, pd.Series): y_val = y_val.values self.data_module = MambularDataModule( preprocessor=self.preprocessor, batch_size=batch_size, shuffle=shuffle, X_val=X_val, y_val=y_val, val_size=val_size, random_state=random_state, regression=False, **dataloader_kwargs, ) self.data_module.preprocess_data(X, y, X_val, y_val, val_size=val_size, random_state=random_state) self.task_model = TaskModel( model_class=self.estimator, # type: ignore num_classes=self.family.param_count, family=self.family, config=self.config, feature_information=( self.data_module.num_feature_info, self.data_module.cat_feature_info, self.data_module.embedding_feature_info, ), lr=lr if lr is not None else self.config.lr, lr_patience=(lr_patience if lr_patience is not None else self.config.lr_patience), lr_factor=lr_factor if lr_factor is not None else self.config.lr_factor, weight_decay=(weight_decay if weight_decay is not None else self.config.weight_decay), lss=True, train_metrics=train_metrics, val_metrics=val_metrics, optimizer_type=self.optimizer_type, optimizer_args=self.optimizer_kwargs, ) self.built = True self.estimator = self.task_model.estimator return self
[docs] def get_number_of_params(self, requires_grad=True): """Calculate the number of parameters in the model. Parameters ---------- requires_grad : bool, optional If True, only count the parameters that require gradients (trainable parameters). If False, count all parameters. Default is True. Returns ------- int The total number of parameters in the model. Raises ------ ValueError If the model has not been built prior to calling this method. """ if not self.built: raise ValueError("The model must be built before the number of parameters can be estimated") else: if requires_grad: return sum(p.numel() for p in self.task_model.parameters() if p.requires_grad) # type: ignore else: return sum(p.numel() for p in self.task_model.parameters()) # type: ignore
[docs] def fit( self, X, y, family, val_size: float = 0.2, X_val=None, y_val=None, max_epochs: int = 100, random_state: int = 101, batch_size: int = 128, shuffle: bool = True, patience: int = 15, monitor: str = "val_loss", mode: str = "min", lr: float | None = None, lr_patience: int | None = None, lr_factor: float | None = None, weight_decay: float | None = None, checkpoint_path="model_checkpoints", distributional_kwargs=None, train_metrics: dict[str, Callable] | None = None, val_metrics: dict[str, Callable] | None = None, dataloader_kwargs={}, rebuild=True, **trainer_kwargs, ): """Trains the regression model using the provided training data. Optionally, a separate validation set can be used. Parameters ---------- X : DataFrame or array-like, shape (n_samples, n_features) The training input samples. y : array-like, shape (n_samples,) or (n_samples, n_targets) The target values (real numbers). family : str The name of the distribution family to use for the loss function. Examples include 'normal' for regression tasks. val_size : float, default=0.2 The proportion of the dataset to include in the validation split if `X_val` is None. Ignored if `X_val` is provided. X_val : DataFrame or array-like, shape (n_samples, n_features), optional The validation input samples. If provided, `X` and `y` are not split and this data is used for validation. y_val : array-like, shape (n_samples,) or (n_samples, n_targets), optional The validation target values. Required if `X_val` is provided. max_epochs : int, default=100 Maximum number of epochs for training. random_state : int, default=101 Controls the shuffling applied to the data before applying the split. batch_size : int, default=64 Number of samples per gradient update. shuffle : bool, default=True Whether to shuffle the training data before each epoch. patience : int, default=10 Number of epochs with no improvement on the validation loss to wait before early stopping. monitor : str, default="val_loss" The metric to monitor for early stopping. mode : str, default="min" Whether the monitored metric should be minimized (`min`) or maximized (`max`). lr : float, default=1e-3 Learning rate for the optimizer. lr_patience : int, default=10 Number of epochs with no improvement on the validation loss to wait before reducing the learning rate. factor : float, default=0.1 Factor by which the learning rate will be reduced. weight_decay : float, default=0.025 Weight decay (L2 penalty) coefficient. distributional_kwargs : dict, default=None any arguments taht are specific for a certain distribution. train_metrics : dict, default=None torch.metrics dict to be logged during training. val_metrics : dict, default=None torch.metrics dict to be logged during validation. checkpoint_path : str, default="model_checkpoints" Path where the checkpoints are being saved. dataloader_kwargs: dict, default={} The kwargs for the pytorch dataloader class. **trainer_kwargs : Additional keyword arguments for PyTorch Lightning's Trainer class. Returns ------- self : object The fitted regressor. """ distribution_classes = { "normal": NormalDistribution, "poisson": PoissonDistribution, "gamma": GammaDistribution, "beta": BetaDistribution, "dirichlet": DirichletDistribution, "studentt": StudentTDistribution, "negativebinom": NegativeBinomialDistribution, "inversegamma": InverseGammaDistribution, "categorical": CategoricalDistribution, "quantile": Quantile, "johnsonsu": JohnsonSuDistribution, } if distributional_kwargs is None: distributional_kwargs = {} if family in distribution_classes: self.family = distribution_classes[family](**distributional_kwargs) else: raise ValueError(f"Unsupported family: {family}") if rebuild: self.build_model( X=X, y=y, val_size=val_size, X_val=X_val, y_val=y_val, random_state=random_state, batch_size=batch_size, shuffle=shuffle, lr=lr, lr_patience=lr_patience, lr_factor=lr_factor, train_metrics=train_metrics, val_metrics=val_metrics, weight_decay=weight_decay, dataloader_kwargs=dataloader_kwargs, ) else: if not self.built: raise ValueError( "The model must be built before calling the fit method. \ Either call .build_model() or set rebuild=True" ) early_stop_callback = EarlyStopping( monitor=monitor, min_delta=0.00, patience=patience, verbose=False, mode=mode ) checkpoint_callback = ModelCheckpoint( monitor="val_loss", # Adjust according to your validation metric mode="min", save_top_k=1, dirpath=checkpoint_path, # Specify the directory to save checkpoints filename="best_model", ) # Initialize the trainer and train the model self.trainer = pl.Trainer( max_epochs=max_epochs, callbacks=[ early_stop_callback, checkpoint_callback, ModelSummary(max_depth=2), ], **trainer_kwargs, ) self.trainer.fit(self.task_model, self.data_module) # type: ignore self.best_model_path = checkpoint_callback.best_model_path if self.best_model_path: torch.serialization.add_safe_globals([type(self.config)]) checkpoint = torch.load(self.best_model_path, weights_only=False) self.task_model.load_state_dict(checkpoint["state_dict"]) # type: ignore self.is_fitted_ = True return self
[docs] def predict(self, X, raw=False, device=None): """Predicts target values for the given input samples. Parameters ---------- X : DataFrame or array-like, shape (n_samples, n_features) The input samples for which to predict target values. Returns ------- predictions : ndarray, shape (n_samples,) or (n_samples, n_outputs) The predicted target values. """ # Ensure model and data module are initialized if self.task_model is None or self.data_module is None: raise ValueError("The model or data module has not been fitted yet.") # Preprocess the data using the data module self.data_module.assign_predict_dataset(X) # Set model to evaluation mode self.task_model.eval() # Perform inference using PyTorch Lightning's predict function predictions_list = self.trainer.predict(self.task_model, self.data_module) # Concatenate predictions from all batches predictions = torch.cat(predictions_list, dim=0) # type: ignore[arg-type] # Check if ensemble is used if getattr(self.estimator, "returns_ensemble", False): # If using ensemble predictions = predictions.mean(dim=1) # Average over ensemble dimension if not raw: result = self.task_model.family(predictions).cpu().numpy() # type: ignore return result else: return predictions.cpu().numpy()
[docs] def evaluate(self, X, y_true, metrics=None, distribution_family=None): """Evaluate the model on the given data using specified metrics. Parameters ---------- X : array-like or pd.DataFrame of shape (n_samples, n_features) The input samples to predict. y_true : array-like of shape (n_samples,) The true class labels against which to evaluate the predictions. metrics : dict A dictionary where keys are metric names and values are tuples containing the metric function and a boolean indicating whether the metric requires probability scores (True) or class labels (False). distribution_family : str, optional Specifies the distribution family the model is predicting for. If None, it will attempt to infer based on the model's settings. Returns ------- scores : dict A dictionary with metric names as keys and their corresponding scores as values. Notes ----- This method uses either the `predict` or `predict_proba` method depending on the metric requirements. """ # Infer distribution family from model settings if not provided if distribution_family is None: distribution_family = getattr(self.task_model, "distribution_family", "normal") # Setup default metrics if none are provided if metrics is None: metrics = self.get_default_metrics(distribution_family) # Make predictions predictions = self.predict(X, raw=False) # Initialize dictionary to store results scores = {} # Compute each metric for metric_name, metric_func in metrics.items(): scores[metric_name] = metric_func(y_true, predictions) return scores
[docs] def get_default_metrics(self, distribution_family): """Provides default metrics based on the distribution family. Parameters ---------- distribution_family : str The distribution family for which to provide default metrics. Returns ------- metrics : dict A dictionary of default metric functions. """ default_metrics = { "normal": { "MSE": lambda y, pred: mean_squared_error(y, pred[:, 0]), "CRPS": lambda y, pred: np.mean( [ps.crps_gaussian(y[i], mu=pred[i, 0], sig=np.sqrt(pred[i, 1])) for i in range(len(y))] ), }, "poisson": {"Poisson Deviance": poisson_deviance}, "gamma": {"Gamma Deviance": gamma_deviance}, "beta": {"Brier Score": beta_brier_score}, "dirichlet": {"Dirichlet Error": dirichlet_error}, "studentt": {"Student-T Loss": student_t_loss}, "negativebinom": {"Negative Binomial Deviance": negative_binomial_deviance}, "inversegamma": {"Inverse Gamma Loss": inverse_gamma_loss}, "categorical": {"Accuracy": accuracy_score}, } return default_metrics.get(distribution_family, {})
[docs] def score(self, X, y, metric="NLL"): """Calculate the score of the model using the specified metric. Parameters ---------- X : array-like or pd.DataFrame of shape (n_samples, n_features) The input samples to predict. y : array-like of shape (n_samples,) or (n_samples, n_outputs) The true target values against which to evaluate the predictions. metric : str, default="NLL" So far, only negative log-likelihood is supported Returns ------- score : float The score calculated using the specified metric. """ predictions = self.predict(X) score = self.task_model.family.evaluate_nll(y, predictions) # type: ignore return score
[docs] def encode(self, X, batch_size=64): """ Encodes input data using the trained model's embedding layer. Parameters ---------- X : array-like or DataFrame Input data to be encoded. batch_size : int, optional, default=64 Batch size for encoding. Returns ------- torch.Tensor Encoded representations of the input data. Raises ------ ValueError If the model or data module is not fitted. """ # Ensure model and data module are initialized if self.task_model is None or self.data_module is None: raise ValueError("The model or data module has not been fitted yet.") encoded_dataset = self.data_module.preprocess_new_data(X) data_loader = DataLoader(encoded_dataset, batch_size=batch_size, shuffle=False) # Process data in batches encoded_outputs = [] for num_features, cat_features in tqdm(data_loader): embeddings = self.task_model.estimator.encode(num_features, cat_features) # type: ignore[union-attr] # Call your encode function encoded_outputs.append(embeddings) # Concatenate all encoded outputs encoded_outputs = torch.cat(encoded_outputs, dim=0) return encoded_outputs
[docs] def optimize_hparams( self, X, y, X_val=None, y_val=None, time=100, max_epochs=200, prune_by_epoch=True, prune_epoch=5, fixed_params={ "pooling_method": "avg", "head_skip_layers": False, "head_layer_size_length": 0, "cat_encoding": "int", "head_skip_layer": False, "use_cls": False, }, custom_search_space=None, **optimize_kwargs, ): """Optimizes hyperparameters using Bayesian optimization with optional pruning. Parameters ---------- X : array-like Training data. y : array-like Training labels. X_val, y_val : array-like, optional Validation data and labels. time : int The number of optimization trials to run. max_epochs : int Maximum number of epochs for training. prune_by_epoch : bool Whether to prune based on a specific epoch (True) or the best validation loss (False). prune_epoch : int The specific epoch to prune by when prune_by_epoch is True. **optimize_kwargs : dict Additional keyword arguments passed to the fit method. Returns ------- best_hparams : list Best hyperparameters found during optimization. """ return super().optimize_hparams( # type: ignore[attr-defined] X, y, regression=False, X_val=X_val, y_val=y_val, time=time, max_epochs=max_epochs, prune_by_epoch=prune_by_epoch, prune_epoch=prune_epoch, fixed_params=fixed_params, custom_search_space=custom_search_space, **optimize_kwargs, )