base_models
Base models and helper classes using PyTorch as the backend.
Classes
BasePyTorchModel
class BasePyTorchModel( model_name: Optional[str] = None, model_version: Optional[Union[int, str]] = None, stochastic_weight_avg: bool = False, logger_config: Optional[LoggerConfig] = None, **kwargs: Any,):
Implements a Neural Network in PyTorch Lightning.
Arguments
model_name
: Used for tensorboard logging. Model name will be left blank and default to "default" if none provided. Optional.model_version
: Used for tensorboard logging.If version is not specified the logger inspects the save directory for existing versions, then automatically assigns the next available version. If it is a string then it is used as the run-specific subdirectory name, otherwise 'version_${version}' is used. Optional.stochastic_weight_avg
: Whether to use Stochastic Weight Averaging (SWA).logger_config
: Logger configuration. Optional. Will default to Tensorboard if not provided.
Attributes
epochs
: Number of epochs to train for when calling fit.model_name
: Name of the model.model_version
: Version of the model.steps
: Number of steps to train for when calling fit.stochastic_weight_avg
: Whether to use Stochastic Weight Averaging (SWA).
Ancestors
- bitfount.backends.pytorch.federated.mixins._PyTorchDistributedModelMixIn
- bitfount.federated.mixins._DistributedModelMixIn
- bitfount.backends.pytorch.models.base_models._PyTorchNeuralNetworkMixIn
- NeuralNetworkMixIn
- bitfount.federated.privacy.differential._DifferentiallyPrivate
- bitfount.models.base_models._BaseModel
- bitfount.models.base_models._BaseModelRegistryMixIn
- bitfount.types._BaseSerializableObjectMixIn
- abc.ABC
- typing.Generic
- pytorch_lightning.core.lightning.LightningModule
- pytorch_lightning.core.mixins.device_dtype_mixin.DeviceDtypeModuleMixin
- pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin
- pytorch_lightning.core.saving.ModelIO
- pytorch_lightning.core.hooks.ModelHooks
- pytorch_lightning.core.hooks.DataHooks
- pytorch_lightning.core.hooks.CheckpointHooks
- torch.nn.modules.module.Module
Methods
def configure_optimizers( self,) ‑> Union[torch.optim.optimizer.Optimizer, Tuple[List[torch.optim.optimizer.Optimizer], List[torch.optim.lr_scheduler._LRScheduler]]]:
Configures the optimizer(s) and scheduler(s) for backpropagation.
Returns Either the optimizer of your choice or a tuple of optimizers and learning rate schedulers.
def deserialize(self, filename: Union[str, os.PathLike]) ‑> None:
Deserialize model.
Arguments
filename
: Path to file containing serialized model.
danger
This should not be used on a model file that has been received across a
trust boundary due to underlying use of pickle
by torch
.
def forward(self, x: Union[ImgFwdTypes, TabFwdTypes]) ‑> Any:
Performs a forward pass of the underlying model.
Arguments
x
: Input to the model.
Returns Output of the model.
def initialise_model( self, data: Optional[BaseSource] = None, context: Optional[ModelContext] = None,) ‑> None:
Any initialisation of models/dataloaders to be done here.
Initialises the dataloaders and sets self._model
to be the output from
self.create_model
. Any initialisation ahead of federated or local training,
serialization or deserialization should be done here.
Arguments
data
: The datasource for model training. Defaults to None.context
: Indicates if the model is running as a modeller or worker. If None, there is no difference between modeller and worker.
def on_train_batch_start( self, batch: _SingleOrMulti[_SingleOrMulti[np.ndarray]], batch_idx: int,) ‑> Optional[Literal[-1]]:
Checks if any privacy guarantees have been exceeded and stops training if so.
Arguments
batch
: The batch to be trained on.batch_idx
: The index of the batch to be trained on from the train dataloader.dataloader_idx
: The index of the dataloader from which the batch was taken. This is useful for multi-dataloader training.
Returns -1 if the entire epoch should be skipped, otherwise None.
def predict(self, data: BaseSource, **kwargs: Any) ‑> numpy.ndarray:
This method runs inference on the test data, returns predictions.
This is done by calling test_step
under the hood. Customise this method as you
please but it must return a list of predictions and a list of targets. Note that
as this is the prediction function, only the predictions are returned.
Returns A numpy array containing the prediction values.
Raises
ValueError
: If no data is provided to test with.
def serialize(self, filename: Union[str, os.PathLike]) ‑> None:
Serialize model to file with provided filename
.
Arguments
filename
: Path to file to save serialized model.
def tensor_precision(self) ‑> +T_DTYPE:
Returns tensor dtype used by Pytorch Lightning Trainer.
Returns Pytorch tensor dtype.
note
Currently only 32-bit training is supported.
def test_dataloader(self) ‑> Optional[bitfount.data.dataloaders._BitfountDataLoader]:
Returns test dataloader.
def test_epoch_end(self, outputs: List[Dict[str, torch.Tensor]]) ‑> None:
Aggregates the predictions and targets from the test set.
Arguments
outputs
: List of outputs from each test step.
def test_step( self, batch: Tuple[Union[TabxorImgBatch, ImgAndTabBatch], torch.Tensor], batch_idx: int,) ‑> Dict[str, Any]:
Make sure to set self.preds and self.targs before returning in this method.
They will be returned by the evaluate
method.
Arguments
batch
: The batch to be evaluated.batch_idx
: The index of the batch to be evaluated from the test dataloader.
Returns
A dictionary of predictions and targets. These will be passed to the
test_epoch_end
method.
def train_dataloader(self) ‑> Optional[bitfount.data.dataloaders._BitfountDataLoader]:
Returns training dataloader.
def trainer_init(self) ‑> pytorch_lightning.trainer.trainer.Trainer:
Initialises the Lightning Trainer for this model.
Documentation for pytorch-lightning trainer can be found here: https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html
Returns The pytorch lightning trainer.
def training_epoch_end(self, kwargs: Any) ‑> None:
Extract gradient from weights after training.
def training_step( self, batch: Tuple[Union[TabxorImgBatch, ImgAndTabBatch], torch.Tensor], batch_idx: int,) ‑> Optional[torch.Tensor]:
Training step.
Arguments
batch
: The batch to be trained on.batch_idx
: The index of the batch to be trained on from the train dataloader.
Returns
The loss from this batch as a torch.Tensor
.
def val_dataloader(self) ‑> Optional[bitfount.data.dataloaders._BitfountDataLoader]:
Returns validation dataloader.
def validation_epoch_end(self, outputs: List[Dict[str, Tensor]]) ‑> None:
Called at the end of the validation epoch with all validation step outputs.
Ensures that the average metrics from a validation epoch is stored. Logs results
and also appends to self.val_stats
.
Arguments
outputs
: List of outputs from each validation step.
def validation_step( self, batch: Tuple[Union[TabxorImgBatch, ImgAndTabBatch], torch.Tensor], batch_idx: int,) ‑> Dict[str, torch.Tensor]:
Validation step.
Arguments
batch
: The batch to be evaluated.batch_idx
: The index of the batch to be evaluated from the validation dataloader.
Returns
A dictionary of strings and values that should be averaged at the end of
every epoch and logged e.g. {"validation_loss": loss}
. These will be
passed to the validation_epoch_end
method.
Variables
- static
fields_dict : ClassVar[T_FIELDS_DICT]
- static
train_dl : _BasePyTorchBitfountDataLoader
Inherited members
BaseTabNetModel
class BaseTabNetModel( virtual_batch_size: int = 64, patience: int = 2, embedding_sizes: Optional[Union[int, List[int]]] = None, num_workers: int = 0, inverse_class_weights: bool = True, mask_type: str = 'sparsemax', decision_prediction_layer_size: int = 8, attention_embedding_layer_size: int = 8, num_steps: int = 3, **kwargs: Any,):
TabNet Model as described in https://arxiv.org/abs/1908.07442.
This is a wrapper around the implementation from DreamQuark. Documentation can be found here: https://dreamquark-ai.github.io/tabnet.
Arguments
virtual_batch_size
: Virtual batch size used for ghost batch normalization.patience
: Number of epochs before early stopping.scheduler
: Learning rate scheduler. Defaults to None.scheduler_params
: Learning rate scheduler params. Defaults to None.embedding_sizes
: Embeddings sizes. Defaults to None.num_workers
: Number of workers for torch DataLoader. Defaults to 0.inverse_class_weights
: Inverse class weights (only for classification problems). Defaults to True.mask_type
: Mask type. Defaults to "sparsemax".decision_prediction_layer_size
: Final feedforward layer size. Defaults to 8.attention_embedding_layer_size
: Attention embedding layer size. Defaults to 8.num_steps
: Number of steps. Defaults to 3.
Raises
ValueError
: If virtual batch size > batch size.ValueError
: If themodel_structure
does not match the TabNet model.ValueError
: If training is specified insteps
rather thanepochs
. Training steps are not supported.
Ancestors
- bitfount.backends.pytorch.federated.mixins._PyTorchDistributedModelMixIn
- bitfount.federated.mixins._DistributedModelMixIn
- bitfount.backends.pytorch.models.base_models._PyTorchNeuralNetworkMixIn
- NeuralNetworkMixIn
- bitfount.models.base_models._BaseModel
- bitfount.models.base_models._BaseModelRegistryMixIn
- bitfount.types._BaseSerializableObjectMixIn
- abc.ABC
- typing.Generic
Subclasses
Methods
def deserialize(self, filename: Union[str, os.PathLike]) ‑> None:
Deserialize model.
Arguments
filename
: Path to file containing serialized model.
danger
This should not be used on a model file that has been received across a
trust boundary due to underlying use of pickle
by torch
.
def get_param_states(self) ‑> Dict[str, bitfount.types._TensorLike]:
See base class.
def initialise_model( self, data: Optional[BaseSource] = None, context: Optional[ModelContext] = None,) ‑> None:
Any initialisation of models/dataloaders to be done here.
Initialises the dataloaders and sets self._model
to be the output from
self.create_model
. Any initialisation ahead of federated or local training,
serialization or deserialization should be done here.
Arguments
data
: The datasource for model training. Defaults to None.context
: Indicates if the model is running as a modeller or worker. If None, there is no difference between modeller and worker.
def serialize(self, filename: Union[str, os.PathLike]) ‑> None:
Serialize model to file with provided filename
.
Arguments
filename
: Path to file to save serialized model.
Variables
- static
fields_dict : ClassVar[T_FIELDS_DICT]
- static
train_dl : _BasePyTorchBitfountDataLoader
Inherited members
PyTorchClassifierMixIn
class PyTorchClassifierMixIn( multilabel: bool = False, param_clipping: Optional[Dict[str, int]] = None, **kwargs: Any,):
MixIn for PyTorch classification problems.
PyTorch classification models must have this class in their inheritance hierarchy.
Ancestors
- ClassifierMixIn
- bitfount.models.base_models._BaseModelRegistryMixIn
- bitfount.types._BaseSerializableObjectMixIn
Subclasses
Variables
- static
fields_dict : ClassVar[Dict[str, marshmallow.fields.Field]]
- static
nested_fields : ClassVar[Dict[str, Mapping[str, Any]]]