-
Notifications
You must be signed in to change notification settings - Fork 5
feat(distributed_algorithms): Add FedProx algorithm #284
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -364,6 +364,118 @@ def _epoch_minibatch_update( | |
| return local_x | ||
|
|
||
|
|
||
| @tags("federated") | ||
| @dataclass(eq=False) | ||
| class FedProx(FedAlgorithm): | ||
| r""" | ||
| Federated Proximal (FedProx) with local SGD epochs. | ||
|
|
||
| FedProx follows the same communication and aggregation pattern as | ||
| :class:`FedAvg <decent_bench.distributed_algorithms.FedAvg>`, but each client solves a proximalized local | ||
| subproblem around the round's server model: | ||
|
|
||
| .. math:: | ||
| h_k(\mathbf{w}; \mathbf{w}^t) = F_k(\mathbf{w}) + \frac{\mu}{2} \|\mathbf{w} - \mathbf{w}^t\|^2 | ||
|
|
||
| .. math:: | ||
| \mathbf{x}_{i, k}^{(t+1)} = \mathbf{x}_{i, k}^{(t)} - \eta \left( | ||
| \nabla f_i(\mathbf{x}_{i, k}^{(t)}) + \mu (\mathbf{x}_{i, k}^{(t)} - \mathbf{w}^t) \right) | ||
|
|
||
| .. math:: | ||
| \mathbf{x}_{k+1} = \frac{1}{|S_k|} \sum_{i \in S_k} \mathbf{x}_{i, k}^{(E)} | ||
|
|
||
| where :math:`\mathbf{w}^t` is the server model broadcast at the start of round :math:`k`, held fixed | ||
| throughout each selected client's local epochs, :math:`\mu \geq 0` is the proximal coefficient, | ||
| :math:`\eta` is the step size, and :math:`S_k` is the set of participating clients. Setting ``mu=0.0`` | ||
| recovers FedAvg exactly. The default ``mu=0.01`` is a sensible benchmark starting point, but users are | ||
| strongly encouraged to tune ``mu`` for each problem; a practical grid is ``[0.001, 0.01, 0.1, 0.5, 1.0]``. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would remove the sentence
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would also explicitly link FedAvg in the sentence |
||
| As in FedAvg, aggregation uses client weights, defaulting to data-size weights when ``client_weights`` is not | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would not reference FedAvg in the docstring (except to say that it's the same as FedProx when mu=0). I think the docstring should self-contained |
||
| provided, and client selection defaults to uniform sampling with fraction 1.0. For | ||
| :class:`~decent_bench.costs.EmpiricalRiskCost`, local updates use mini-batches of size | ||
| :attr:`EmpiricalRiskCost.batch_size <decent_bench.costs.EmpiricalRiskCost.batch_size>`; for generic costs, | ||
| local updates use full-batch gradients. | ||
| """ | ||
|
|
||
| iterations: int = 100 | ||
| step_size: float = 0.001 | ||
| num_local_epochs: int = 1 | ||
| mu: float = 0.01 # Sensible benchmark default; tune over {0.001, 0.01, 0.1, 0.5, 1} for new settings. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would remove the comment given the discussion above |
||
| client_weights: ClientWeights | None = None | ||
| selection_scheme: ClientSelectionScheme | None = field( | ||
| default_factory=lambda: UniformClientSelection(client_fraction=1.0) | ||
| ) | ||
| x0: InitialStates = None | ||
| name: str = "FedProx" | ||
|
|
||
| def __post_init__(self) -> None: | ||
| """ | ||
| Validate hyperparameters. | ||
|
|
||
| Raises: | ||
| ValueError: if hyperparameters are invalid. | ||
|
|
||
| """ | ||
| if self.step_size <= 0: | ||
| raise ValueError("`step_size` must be positive") | ||
| if self.num_local_epochs <= 0: | ||
| raise ValueError("`num_local_epochs` must be positive") | ||
| if self.mu < 0: | ||
| raise ValueError("`mu` must be non-negative") | ||
|
|
||
| def initialize(self, network: FedNetwork) -> None: # noqa: D102 | ||
| self.x0 = alg_helpers.initial_states(self.x0, network) | ||
| network.server().initialize(x=self.x0[network.server()]) | ||
| for client in network.clients(): | ||
| client.initialize(x=self.x0[client]) | ||
|
|
||
| def step(self, network: FedNetwork, iteration: int) -> None: # noqa: D102 | ||
| selected_clients = self._selected_clients_for_round(network, iteration) | ||
| if not selected_clients: | ||
| return | ||
|
|
||
| self._sync_server_to_clients(network, selected_clients) | ||
| self._run_local_updates(network, selected_clients) | ||
| self.aggregate(network, selected_clients) | ||
|
|
||
| def _sync_server_to_clients(self, network: FedNetwork, selected_clients: Sequence["Agent"]) -> None: | ||
| network.send(sender=network.server(), receiver=selected_clients, msg=network.server().x) | ||
|
|
||
| def _run_local_updates(self, network: FedNetwork, selected_clients: Sequence["Agent"]) -> None: | ||
| for client in selected_clients: | ||
| client.x = self._compute_local_update(client, network.server()) | ||
| network.send(sender=client, receiver=network.server(), msg=client.x) | ||
|
|
||
| def _compute_local_update(self, client: "Agent", server: "Agent") -> "Array": | ||
| reference_x = iop.copy(client.messages[server]) if server in client.messages else iop.copy(client.x) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You dont need to copy the reference_x since you're not modifying it |
||
| local_x = iop.copy(reference_x) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if I'm not mistaken, this should not be here, because it effectively removes the regularization term local_x=reference_x. the result is that we always run FedAvg (I tested and that's indeed the case) removing this line should be enough to fix it
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've looked at the paper, and it's not very clear what they do to initialize |
||
| if isinstance(client.cost, EmpiricalRiskCost): | ||
| cost = client.cost | ||
| n_samples = cost.n_samples | ||
| return self._epoch_minibatch_update(cost, local_x, reference_x, cost.batch_size, n_samples) | ||
|
|
||
| for _ in range(self.num_local_epochs): | ||
| grad = client.cost.gradient(local_x) + self.mu * (local_x - reference_x) | ||
| local_x -= self.step_size * grad | ||
| return local_x | ||
|
|
||
| def _epoch_minibatch_update( | ||
| self, | ||
| cost: EmpiricalRiskCost, | ||
| local_x: "Array", | ||
| reference_x: "Array", | ||
| per_client_batch: int, | ||
| n_samples: int, | ||
| ) -> "Array": | ||
| for _ in range(self.num_local_epochs): | ||
| indices = list(range(n_samples)) | ||
| random.shuffle(indices) | ||
| for start in range(0, n_samples, per_client_batch): | ||
| batch_indices = indices[start : start + per_client_batch] | ||
| grad = cost.gradient(local_x, indices=batch_indices) + self.mu * (local_x - reference_x) | ||
| local_x -= self.step_size * grad | ||
|
Comment on lines
+472
to
+475
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you want you could pass |
||
| return local_x | ||
|
|
||
|
|
||
| @tags("peer-to-peer", "gradient-based") | ||
| @dataclass(eq=False) | ||
| class DGD(P2PAlgorithm): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -108,6 +108,10 @@ Federated | |
| :tag: federated | ||
| :module: decent_bench.distributed_algorithms | ||
|
|
||
| FedProx extends FedAvg with a proximal coefficient ``mu``. Setting ``mu=0`` reduces | ||
| FedProx to FedAvg. In practice, ``mu`` should typically be tuned for each problem; | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also here I would remove the discussion about tuning |
||
| ``[0.001, 0.01, 0.1, 0.5, 1.0]`` is a reasonable starting grid. | ||
|
|
||
|
|
||
| Available costs | ||
| --------------- | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it might be better to write the local update in terms of h_k's gradient, since you have introduced the regularized cost. otherwise the notation is unused