bitfount_model
Contains PyTorch implementations of the BitfountModel paradigm.
Classes
PyTorchBitfountModel
class PyTorchBitfountModel( batch_size: int = 32, epochs: Optional[int] = None, steps: Optional[int] = None, **kwargs: Any,):
Blueprint for a pytorch custom model in the lightning format.
This class must be subclassed in its own module. A Path
to the module containing
the subclass can then be passed to BitfountModelReference
and on to your
Algorithm
of choice which will send the model to Bitfount Hub.
To get started, just implement the abstract methods in this class. For more advanced users feel free to override or overwrite any variables/methods in your subclass.
Take a look at the pytorch-lightning documentation on how to properly create a
LightningModule
:
https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html
Arguments
batch_size
: The batch size to use for training. Defaults to 32.epochs
: The number of epochs to train for.steps
: The number of steps to train for.- `kwargs`**: Any additional arguments to pass to parent constructors.
Attributes
batch_size
: The batch size to use for training.epochs
: The number of epochs to train for.steps
: The number of steps to train for.preds
: The predictions from the most recent test run.targs
: The targets from the most recent test run.val_stats
: Metrics from the validation set during training.
Raises
ValueError
: If bothepochs
andsteps
are specified.
Ancestors
- bitfount.backends.pytorch.federated.mixins._PyTorchDistributedModelMixIn
- bitfount.federated.mixins._DistributedModelMixIn
- BitfountModel
- bitfount.models.base_models._BaseModel
- bitfount.models.base_models._BaseModelRegistryMixIn
- bitfount.types._BaseSerializableObjectMixIn
- abc.ABC
- typing.Generic
- pytorch_lightning.core.lightning.LightningModule
- pytorch_lightning.core.mixins.device_dtype_mixin.DeviceDtypeModuleMixin
- pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin
- pytorch_lightning.core.saving.ModelIO
- pytorch_lightning.core.hooks.ModelHooks
- pytorch_lightning.core.hooks.DataHooks
- pytorch_lightning.core.hooks.CheckpointHooks
- torch.nn.modules.module.Module
Methods
def configure_optimizers( self,) ‑> Union[torch.optim.optimizer.Optimizer, Tuple[List[torch.optim.optimizer.Optimizer], List[torch.optim.lr_scheduler._LRScheduler]]]:
Configures the optimizer(s) and scheduler(s) for backpropagation.
Returns Either the optimizer of your choice or a tuple of optimizers and learning rate schedulers.
def create_model(self) ‑> torch.nn.modules.module.Module:
Creates and returns the underlying pytorch model.
Returns
Underlying pytorch model. This is set to self._model
.
def deserialize(self, filename: Union[str, os.PathLike]) ‑> None:
Deserialize model.
Arguments
filename
: Path to file containing serialized model.
danger
This should not be used on a model file that has been received across a
trust boundary due to underlying use of pickle
by torch
.
def forward(self, x: Any) ‑> Any:
Forward method of the model - just like a regular torch.nn.Module
class.
Arguments
x
: Input to the model.
Returns Output of the model.
tip
This will depend on your model but could be as simple as:
return self._model(x)
def initialise_model( self, data: Optional[BaseSource] = None, context: Optional[ModelContext] = None,) ‑> None:
Any initialisation of models/dataloaders to be done here.
Initialises the dataloaders and sets self._model
to be the output from
self.create_model
. Any initialisation ahead of federated or local training,
serialization or deserialization should be done here.
Arguments
data
: The datasource for model training. Defaults to None.context
: Indicates if the model is running as a modeller or worker. If None, there is no difference between modeller and worker.
def predict( self, data: BaseSource, **kwargs: Any,) ‑> numpy.ndarray:
This method runs inference on the test data, returns predictions.
This is done by calling test_step
under the hood. Customise this method as you
please but it must return a list of predictions and a list of targets. Note that
as this is the prediction function, only the predictions are returned.
Returns A numpy array containing the prediction values.
Raises
ValueError
: If no data is provided to test with.
def serialize(self, filename: Union[str, os.PathLike]) ‑> None:
Serialize model to file with provided filename
.
Arguments
filename
: Path to file to save serialized model.
def skip_training_batch(self, batch_idx: int) ‑> bool:
Checks if the current batch from the training set should be skipped.
This is a workaround for the fact that PyTorch Lightning starts the Dataloader
iteration from the beginning every time fit
is called. This means that if we
are training in steps, we are always training on the same batches. So this
method needs to be called at the beginning of every training_step
to skip
to the right batch index.
Arguments
batch_idx
: the index of the batch fromtraining_step
.
Returns True if the batch should be skipped, otherwise False.
def tensor_precision(self) ‑> +T_DTYPE:
Returns tensor dtype used by Pytorch Lightning Trainer.
Returns Pytorch tensor dtype.
note
Currently only 32-bit training is supported.
def test_dataloader( self,) ‑> bitfount.backends.pytorch.data.dataloaders._BasePyTorchBitfountDataLoader:
Returns test dataloader.
def test_step( self, batch: Any, batch_idx: int,) ‑> Union[torch.Tensor, Dict[str, Any], None]:
Performs test step and must set self.preds
and self.targs
.
They will be returned by the evaluate
method.
Arguments
batch
: The batch to be evaluated.batch_idx
: The index of the batch to be evaluated from the test dataloader.
Returns
Any object or value of interest. These will be passed to the
test_epoch_end
method. If returning None, testing will skip to the next
batch.
def train_dataloader( self,) ‑> bitfount.backends.pytorch.data.dataloaders._BasePyTorchBitfountDataLoader:
Returns training dataloader.
def trainer_init(self) ‑> pytorch_lightning.trainer.trainer.Trainer:
Initialises the Lightning Trainer for this model.
Documentation for pytorch-lightning trainer can be found here: https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html
Returns The pytorch lightning trainer.
tip
Override this method to choose your own Trainer
arguments.
def training_step( self, batch: Any, batch_idx: int,) ‑> Union[torch.Tensor, Dict[str, Any]]:
Training step.
Arguments
batch
: The batch to be trained on.batch_idx
: The index of the batch to be trained on from the train dataloader.
Returns
The loss from this batch as a torch.Tensor
. Or a dictionary which includes
the key loss
and the loss as a torch.Tensor
.
caution
If iterations have been specified in terms of steps, the default behaviour of
pytorch lightning is to train on the first n steps of the dataloader every
time fit
is called. This default behaviour is not desirable and has been dealt
with in the built-in Bitfount models but, until this bug gets fixed by the
pytorch lightning team, this needs to be implemented by the user for custom
models.
tip
Take a look at the skip_training_batch
method for one way on how to deal with
this. It can be used as follows:
if self.skip_training_batch(batch_idx):
return None
def val_dataloader( self,) ‑> bitfount.backends.pytorch.data.dataloaders._BasePyTorchBitfountDataLoader:
Returns validation dataloader.
def validation_epoch_end(self, outputs: List[Dict[str, Any]]) ‑> None:
Called at the end of the validation epoch with all validation step outputs.
Ensures that the average metrics from a validation epoch is stored. Logs results
and also appends to self.val_stats
.
Arguments
outputs
: List of outputs from each validation step.
def validation_step( self, batch: Any, batch_idx: int,) ‑> Union[torch.Tensor, Dict[str, Any]]:
Validation step.
Arguments
batch
: The batch to be evaluated.batch_idx
: The index of the batch to be evaluated from the validation dataloader.
Returns
A dictionary of strings and values that should be averaged at the end of
every epoch and logged e.g. {"validation_loss": loss}
. These will be
passed to the validation_epoch_end
method.
Variables
- static
train_dl : bitfount.backends.pytorch.data.dataloaders._BasePyTorchBitfountDataLoader