feat(distributed_algorithms): Add FedProx algorithm#284
feat(distributed_algorithms): Add FedProx algorithm#284adrianardv wants to merge 2 commits intoteam-decent:mainfrom
Conversation
Add a FedProx federated algorithm that mirrors the existing FedAvg round flow while applying a proximal term during local client updates. Add lightweight instantiation coverage for FedProx.
Document that mu=0 recovers FedAvg and that mu=0.01 is a sensible benchmark default. Encourage tuning mu and recommend the grid [0.001, 0.01, 0.1, 0.5, 1.0]. Add the same guidance to the user guide's federated algorithms section.
nicola-bastianello
left a comment
There was a problem hiding this comment.
thanks! just a few minor comments, and one on the algorithm implementation
| :class:`FedAvg <decent_bench.distributed_algorithms.FedAvg>`, but each client solves a proximalized local | ||
| subproblem around the round's server model: | ||
|
|
||
| .. math:: |
There was a problem hiding this comment.
it might be better to write the local update in terms of h_k's gradient, since you have introduced the regularized cost. otherwise the notation is unused
| throughout each selected client's local epochs, :math:`\mu \geq 0` is the proximal coefficient, | ||
| :math:`\eta` is the step size, and :math:`S_k` is the set of participating clients. Setting ``mu=0.0`` | ||
| recovers FedAvg exactly. The default ``mu=0.01`` is a sensible benchmark starting point, but users are | ||
| strongly encouraged to tune ``mu`` for each problem; a practical grid is ``[0.001, 0.01, 0.1, 0.5, 1.0]``. |
There was a problem hiding this comment.
I would remove the sentence The default ``mu=0.01`` is a sensible benchmark starting point, but users are strongly encouraged to tune ``mu`` for each problem; a practical grid is ``[0.001, 0.01, 0.1, 0.5, 1.0]``. this is because: that grid is potentially not the best starting point for some problems (as the best value of mu depends on the problem features). and we never mention tuning the hyperparameters, as it is implicit, so I don't think it's needed to mention it here
There was a problem hiding this comment.
I would also explicitly link FedAvg in the sentence Setting ``mu=0.0`` recovers FedAvg exactly.
| :math:`\eta` is the step size, and :math:`S_k` is the set of participating clients. Setting ``mu=0.0`` | ||
| recovers FedAvg exactly. The default ``mu=0.01`` is a sensible benchmark starting point, but users are | ||
| strongly encouraged to tune ``mu`` for each problem; a practical grid is ``[0.001, 0.01, 0.1, 0.5, 1.0]``. | ||
| As in FedAvg, aggregation uses client weights, defaulting to data-size weights when ``client_weights`` is not |
There was a problem hiding this comment.
I would not reference FedAvg in the docstring (except to say that it's the same as FedProx when mu=0). I think the docstring should self-contained
| iterations: int = 100 | ||
| step_size: float = 0.001 | ||
| num_local_epochs: int = 1 | ||
| mu: float = 0.01 # Sensible benchmark default; tune over {0.001, 0.01, 0.1, 0.5, 1} for new settings. |
There was a problem hiding this comment.
I would remove the comment given the discussion above
| :module: decent_bench.distributed_algorithms | ||
|
|
||
| FedProx extends FedAvg with a proximal coefficient ``mu``. Setting ``mu=0`` reduces | ||
| FedProx to FedAvg. In practice, ``mu`` should typically be tuned for each problem; |
There was a problem hiding this comment.
also here I would remove the discussion about tuning
|
|
||
| def _compute_local_update(self, client: "Agent", server: "Agent") -> "Array": | ||
| reference_x = iop.copy(client.messages[server]) if server in client.messages else iop.copy(client.x) | ||
| local_x = iop.copy(reference_x) |
There was a problem hiding this comment.
if I'm not mistaken, this should not be here, because it effectively removes the regularization term local_x=reference_x. the result is that we always run FedAvg (I tested and that's indeed the case)
removing this line should be enough to fix it
There was a problem hiding this comment.
I've looked at the paper, and it's not very clear what they do to initialize local_x. I think maybe it should be left to the previous value? I tried the current implementation also with num local epochs > 1 and I get the same results for both algorithms. I don't quite understand
Simpag
left a comment
There was a problem hiding this comment.
Just some comments, no changes needed. LGTM
| network.send(sender=client, receiver=network.server(), msg=client.x) | ||
|
|
||
| def _compute_local_update(self, client: "Agent", server: "Agent") -> "Array": | ||
| reference_x = iop.copy(client.messages[server]) if server in client.messages else iop.copy(client.x) |
There was a problem hiding this comment.
You dont need to copy the reference_x since you're not modifying it
| for start in range(0, n_samples, per_client_batch): | ||
| batch_indices = indices[start : start + per_client_batch] | ||
| grad = cost.gradient(local_x, indices=batch_indices) + self.mu * (local_x - reference_x) | ||
| local_x -= self.step_size * grad |
There was a problem hiding this comment.
If you want you could pass indices="all" here since the cost functions themselves should handle batching. But I'm not sure if that is what you want here so its just a comment
This PR adds FedProx as a new federated algorithm.
Implemented as a proper FedAlgorithm subclass that preserves the existing federated round structure used by FedAvg. The only behavioral difference is the local client update: each selected client optimizes the FedProx subproblem with a proximal term anchored at the round’s server model.