federated_averaging
Federated Averaging protocol.
Classes
FederatedAveraging
class FederatedAveraging( *, algorithm: _FederatedAveragingCompatibleAlgoFactory, aggregator: Optional[_BaseAggregatorFactory] = None, steps_between_parameter_updates: Optional[int] = None, epochs_between_parameter_updates: Optional[int] = None, auto_eval: bool = True, secure_aggregation: bool = False, **kwargs: Any,):
Original Federated Averaging algorithm by McMahan et al. (2017).
This protocol performs a predetermined number of epochs or steps of training on each remote Pod before sending the updated model parameters to the modeller. These parameters are then averaged and sent back to the Pods for as many federated iterations as the Modeller specifies.
Arguments
algorithm
: The algorithm to use for training. This must be compatible with theFederatedAveraging
protocol.aggregator
: The aggregator to use for updating the model parameters across all Pods participating in the task. This argument takes priority over thesecure_aggregation
argument.steps_between_parameter_updates
: The number of steps between parameter updates, i.e. the number of rounds of local training before parameters are updated. Ifepochs_between_parameter_updates
is provided,steps_between_parameter_updates
cannot be provided. Defaults to None.epochs_between_parameter_updates
: The number of epochs between parameter updates, i.e. the number of rounds of local training before parameters are updated. Ifsteps_between_parameter_updates
is provided,epochs_between_parameter_updates
cannot be provided. Defaults to None.auto_eval
: Whether to automatically evaluate the model on the validation dataset. Defaults to True.secure_aggregation
: Whether to use secure aggregation. This argument is overridden by theaggregator
argument.
Attributes
name
: The name of the protocol.algorithm
: The algorithm to use for trainingaggregator
: The aggregator to use for updating the model parameters.steps_between_parameter_updates
: The number of steps between parameter updates.epochs_between_parameter_updates
: The number of epochs between parameter updates.auto_eval
: Whether to automatically evaluate the model on the validation dataset.
Raises
TypeError
: If thealgorithm
is not compatible with the protocol.
tip
For more information, take a look at the seminal paper: https://arxiv.org/abs/1602.05629
Ancestors
- bitfount.federated.protocols.base._BaseProtocolFactory
- abc.ABC
- bitfount.federated.roles._RolesMixIn
- bitfount.types._BaseSerializableObjectMixIn
Methods
def dump(self) ‑> SerializedProtocol:
Returns the JSON-serializable representation of the protocol.
def modeller( self, mailbox: _ModellerMailbox, early_stopping: Optional[FederatedEarlyStopping] = None, **kwargs: Any,) ‑> bitfount.federated.protocols.model_protocols.federated_averaging._ModellerSide:
Returns the modeller side of the FederatedAveraging protocol.
def worker( self, mailbox: _WorkerMailbox, hub: BitfountHub, **kwargs: Any,) ‑> bitfount.federated.protocols.model_protocols.federated_averaging._WorkerSide:
Returns the worker side of the FederatedAveraging protocol.
Raises
TypeError
: If the mailbox is not compatible with the aggregator.
Variables
- static
algorithm : bitfount.federated.protocols.model_protocols.federated_averaging._FederatedAveragingCompatibleAlgoFactory
- static
fields_dict : ClassVar[Dict[str, marshmallow.fields.Field]]
- static
nested_fields : ClassVar[Dict[str, Mapping[str, Any]]]