Skip to content

feat(distributed_algorithms): Add FedProx algorithm#284

Open
adrianardv wants to merge 2 commits intoteam-decent:mainfrom
adrianardv:feat/fedprox
Open

feat(distributed_algorithms): Add FedProx algorithm#284
adrianardv wants to merge 2 commits intoteam-decent:mainfrom
adrianardv:feat/fedprox

Conversation

@adrianardv
Copy link
Copy Markdown
Contributor

This PR adds FedProx as a new federated algorithm.

Implemented as a proper FedAlgorithm subclass that preserves the existing federated round structure used by FedAvg. The only behavioral difference is the local client update: each selected client optimizes the FedProx subproblem with a proximal term anchored at the round’s server model.

Add a FedProx federated algorithm that mirrors the existing FedAvg
round flow while applying a proximal term during local client updates.

Add lightweight instantiation coverage for FedProx.
Document that mu=0 recovers FedAvg and that mu=0.01 is a sensible
benchmark default. Encourage tuning mu and recommend the grid
[0.001, 0.01, 0.1, 0.5, 1.0]. Add the same guidance to the user guide's
federated algorithms section.
Copy link
Copy Markdown
Member

@nicola-bastianello nicola-bastianello left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! just a few minor comments, and one on the algorithm implementation

:class:`FedAvg <decent_bench.distributed_algorithms.FedAvg>`, but each client solves a proximalized local
subproblem around the round's server model:

.. math::
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it might be better to write the local update in terms of h_k's gradient, since you have introduced the regularized cost. otherwise the notation is unused

throughout each selected client's local epochs, :math:`\mu \geq 0` is the proximal coefficient,
:math:`\eta` is the step size, and :math:`S_k` is the set of participating clients. Setting ``mu=0.0``
recovers FedAvg exactly. The default ``mu=0.01`` is a sensible benchmark starting point, but users are
strongly encouraged to tune ``mu`` for each problem; a practical grid is ``[0.001, 0.01, 0.1, 0.5, 1.0]``.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would remove the sentence The default ``mu=0.01`` is a sensible benchmark starting point, but users are strongly encouraged to tune ``mu`` for each problem; a practical grid is ``[0.001, 0.01, 0.1, 0.5, 1.0]``. this is because: that grid is potentially not the best starting point for some problems (as the best value of mu depends on the problem features). and we never mention tuning the hyperparameters, as it is implicit, so I don't think it's needed to mention it here

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also explicitly link FedAvg in the sentence Setting ``mu=0.0`` recovers FedAvg exactly.

:math:`\eta` is the step size, and :math:`S_k` is the set of participating clients. Setting ``mu=0.0``
recovers FedAvg exactly. The default ``mu=0.01`` is a sensible benchmark starting point, but users are
strongly encouraged to tune ``mu`` for each problem; a practical grid is ``[0.001, 0.01, 0.1, 0.5, 1.0]``.
As in FedAvg, aggregation uses client weights, defaulting to data-size weights when ``client_weights`` is not
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would not reference FedAvg in the docstring (except to say that it's the same as FedProx when mu=0). I think the docstring should self-contained

iterations: int = 100
step_size: float = 0.001
num_local_epochs: int = 1
mu: float = 0.01 # Sensible benchmark default; tune over {0.001, 0.01, 0.1, 0.5, 1} for new settings.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would remove the comment given the discussion above

:module: decent_bench.distributed_algorithms

FedProx extends FedAvg with a proximal coefficient ``mu``. Setting ``mu=0`` reduces
FedProx to FedAvg. In practice, ``mu`` should typically be tuned for each problem;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also here I would remove the discussion about tuning


def _compute_local_update(self, client: "Agent", server: "Agent") -> "Array":
reference_x = iop.copy(client.messages[server]) if server in client.messages else iop.copy(client.x)
local_x = iop.copy(reference_x)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if I'm not mistaken, this should not be here, because it effectively removes the regularization term local_x=reference_x. the result is that we always run FedAvg (I tested and that's indeed the case)

removing this line should be enough to fix it

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've looked at the paper, and it's not very clear what they do to initialize local_x. I think maybe it should be left to the previous value? I tried the current implementation also with num local epochs > 1 and I get the same results for both algorithms. I don't quite understand

Copy link
Copy Markdown
Contributor

@Simpag Simpag left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some comments, no changes needed. LGTM

network.send(sender=client, receiver=network.server(), msg=client.x)

def _compute_local_update(self, client: "Agent", server: "Agent") -> "Array":
reference_x = iop.copy(client.messages[server]) if server in client.messages else iop.copy(client.x)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You dont need to copy the reference_x since you're not modifying it

Comment on lines +472 to +475
for start in range(0, n_samples, per_client_batch):
batch_indices = indices[start : start + per_client_batch]
grad = cost.gradient(local_x, indices=batch_indices) + self.mu * (local_x - reference_x)
local_x -= self.step_size * grad
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want you could pass indices="all" here since the cost functions themselves should handle batching. But I'm not sure if that is what you want here so its just a comment

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants