Base models and helper classes using PyTorch as the backend.



class BasePyTorchModel(    model_name: Optional[str] = None,    model_version: Optional[Union[int, str]] = None,    stochastic_weight_avg: bool = False,    logger_config: Optional[LoggerConfig] = None,    **kwargs: Any,):

Implements a Neural Network in PyTorch Lightning.


  • model_name: Used for tensorboard logging. Model name will be left blank and default to "default" if none provided. Optional.
  • model_version: Used for tensorboard logging.If version is not specified the logger inspects the save directory for existing versions, then automatically assigns the next available version. If it is a string then it is used as the run-specific subdirectory name, otherwise 'version_${version}' is used. Optional.
  • stochastic_weight_avg: Whether to use Stochastic Weight Averaging (SWA).
  • logger_config: Logger configuration. Optional. Will default to Tensorboard if not provided.


def configure_optimizers(    self,)> Union[torch.optim.optimizer.Optimizer, Tuple[List[torch.optim.optimizer.Optimizer], List[torch.optim.lr_scheduler._LRScheduler]]]:

Configures the optimizer(s) and scheduler(s) for backpropagation.

Returns Either the optimizer of your choice or a tuple of optimizers and learning rate schedulers.

def deserialize(self, filename: Union[str, os.PathLike])> None:

Deserialize model.


  • filename: Path to file containing serialized model.

This should not be used on a model file that has been received across a trust boundary due to underlying use of pickle by torch.

def forward(self, x: Union[ImgFwdTypes, TabFwdTypes])> Any:

Performs a forward pass of the underlying model.


  • x: Input to the model.

Returns Output of the model.

def initialise_model(    self, data: Optional[BaseSource] = None, context: Optional[ModelContext] = None,)> None:

Any initialisation of models/dataloaders to be done here.

Initialises the dataloaders and sets self._model to be the output from self.create_model. Any initialisation ahead of federated or local training, serialization or deserialization should be done here.


  • data: The datasource for model training. Defaults to None.
  • context: Indicates if the model is running as a modeller or worker. If None, there is no difference between modeller and worker.
def on_train_batch_start(    self, batch: _SingleOrMulti[_SingleOrMulti[np.ndarray]], batch_idx: int,)> Optional[Literal[-1]]:

Checks if any privacy guarantees have been exceeded and stops training if so.


  • batch: The batch to be trained on.
  • batch_idx: The index of the batch to be trained on from the train dataloader.
  • dataloader_idx: The index of the dataloader from which the batch was taken. This is useful for multi-dataloader training.

Returns -1 if the entire epoch should be skipped, otherwise None.

def predict(self, data: BaseSource, **kwargs: Any)> numpy.ndarray:

This method runs inference on the test data, returns predictions.

This is done by calling test_step under the hood. Customise this method as you please but it must return a list of predictions and a list of targets. Note that as this is the prediction function, only the predictions are returned.

Returns A numpy array containing the prediction values.


  • ValueError: If no data is provided to test with.
def serialize(self, filename: Union[str, os.PathLike])> None:

Serialize model to file with provided filename.


  • filename: Path to file to save serialized model.
def tensor_precision(self)> +T_DTYPE:

Returns tensor dtype used by Pytorch Lightning Trainer.

Returns Pytorch tensor dtype.


Currently only 32-bit training is supported.

def test_dataloader(self)> Optional[]:

Returns test dataloader.

def test_epoch_end(self, outputs: List[Dict[str, torch.Tensor]])> None:

Aggregates the predictions and targets from the test set.


  • outputs: List of outputs from each test step.
def test_step(    self,    batch: Tuple[Union[TabxorImgBatch, ImgAndTabBatch], torch.Tensor],    batch_idx: int,)> Dict[str, Any]:

Make sure to set self.preds and self.targs before returning in this method.

They will be returned by the evaluate method.


  • batch: The batch to be evaluated.
  • batch_idx: The index of the batch to be evaluated from the test dataloader.

Returns A dictionary of predictions and targets. These will be passed to the test_epoch_end method.

def train_dataloader(self)> Optional[]:

Returns training dataloader.

def trainer_init(self)> pytorch_lightning.trainer.trainer.Trainer:

Initialises the Lightning Trainer for this model.

Documentation for pytorch-lightning trainer can be found here:

Returns The pytorch lightning trainer.

def training_epoch_end(self, kwargs: Any)> None:

Extract gradient from weights after training.

def training_step(    self,    batch: Tuple[Union[TabxorImgBatch, ImgAndTabBatch], torch.Tensor],    batch_idx: int,)> Optional[torch.Tensor]:

Training step.


  • batch: The batch to be trained on.
  • batch_idx: The index of the batch to be trained on from the train dataloader.

Returns The loss from this batch as a torch.Tensor.

def val_dataloader(self)> Optional[]:

Returns validation dataloader.

def validation_epoch_end(self, outputs: List[Dict[str, Tensor]])> None:

Called at the end of the validation epoch with all validation step outputs.

Ensures that the average metrics from a validation epoch is stored. Logs results and also appends to self.val_stats.


  • outputs: List of outputs from each validation step.
def validation_step(    self,    batch: Tuple[Union[TabxorImgBatch, ImgAndTabBatch], torch.Tensor],    batch_idx: int,)> Dict[str, torch.Tensor]:

Validation step.


  • batch: The batch to be evaluated.
  • batch_idx: The index of the batch to be evaluated from the validation dataloader.

Returns A dictionary of strings and values that should be averaged at the end of every epoch and logged e.g. {"validation_loss": loss}. These will be passed to the validation_epoch_end method.


class BaseTabNetModel(    virtual_batch_size: int = 64,    patience: int = 2,    embedding_sizes: Optional[Union[int, List[int]]] = None,    num_workers: int = 0,    inverse_class_weights: bool = True,    mask_type: str = 'sparsemax',    decision_prediction_layer_size: int = 8,    attention_embedding_layer_size: int = 8,    num_steps: int = 3,    **kwargs: Any,):

TabNet Model as described in

This is a wrapper around the implementation from DreamQuark. Documentation can be found here:


  • virtual_batch_size: Virtual batch size used for ghost batch normalization.
  • patience: Number of epochs before early stopping.
  • scheduler: Learning rate scheduler. Defaults to None.
  • scheduler_params: Learning rate scheduler params. Defaults to None.
  • embedding_sizes: Embeddings sizes. Defaults to None.
  • num_workers: Number of workers for torch DataLoader. Defaults to 0.
  • inverse_class_weights: Inverse class weights (only for classification problems). Defaults to True.
  • mask_type: Mask type. Defaults to "sparsemax".
  • decision_prediction_layer_size: Final feedforward layer size. Defaults to 8.
  • attention_embedding_layer_size: Attention embedding layer size. Defaults to 8.
  • num_steps: Number of steps. Defaults to 3.


  • ValueError: If virtual batch size > batch size.
  • ValueError: If the model_structure does not match the TabNet model.
  • ValueError: If training is specified in steps rather than epochs. Training steps are not supported.


def deserialize(self, filename: Union[str, os.PathLike])> None:

Deserialize model.


  • filename: Path to file containing serialized model.

This should not be used on a model file that has been received across a trust boundary due to underlying use of pickle by torch.

def get_param_states(self)> Dict[str, bitfount.types._TensorLike]:

See base class.

def initialise_model(    self, data: Optional[BaseSource] = None, context: Optional[ModelContext] = None,)> None:

Any initialisation of models/dataloaders to be done here.

Initialises the dataloaders and sets self._model to be the output from self.create_model. Any initialisation ahead of federated or local training, serialization or deserialization should be done here.


  • data: The datasource for model training. Defaults to None.
  • context: Indicates if the model is running as a modeller or worker. If None, there is no difference between modeller and worker.
def serialize(self, filename: Union[str, os.PathLike])> None:

Serialize model to file with provided filename.


  • filename: Path to file to save serialized model.


class PyTorchClassifierMixIn(    multilabel: bool = False,    param_clipping: Optional[Dict[str, int]] = None,    **kwargs: Any,):

MixIn for PyTorch classification problems.

PyTorch classification models must have this class in their inheritance hierarchy.


