Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions decent_bench/distributed_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,118 @@ def _epoch_minibatch_update(
return local_x


@tags("federated")
@dataclass(eq=False)
class FedProx(FedAlgorithm):
r"""
Federated Proximal (FedProx) with local SGD epochs.

FedProx follows the same communication and aggregation pattern as
:class:`FedAvg <decent_bench.distributed_algorithms.FedAvg>`, but each client solves a proximalized local
subproblem around the round's server model:

.. math::
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it might be better to write the local update in terms of h_k's gradient, since you have introduced the regularized cost. otherwise the notation is unused

h_k(\mathbf{w}; \mathbf{w}^t) = F_k(\mathbf{w}) + \frac{\mu}{2} \|\mathbf{w} - \mathbf{w}^t\|^2

.. math::
\mathbf{x}_{i, k}^{(t+1)} = \mathbf{x}_{i, k}^{(t)} - \eta \left(
\nabla f_i(\mathbf{x}_{i, k}^{(t)}) + \mu (\mathbf{x}_{i, k}^{(t)} - \mathbf{w}^t) \right)

.. math::
\mathbf{x}_{k+1} = \frac{1}{|S_k|} \sum_{i \in S_k} \mathbf{x}_{i, k}^{(E)}

where :math:`\mathbf{w}^t` is the server model broadcast at the start of round :math:`k`, held fixed
throughout each selected client's local epochs, :math:`\mu \geq 0` is the proximal coefficient,
:math:`\eta` is the step size, and :math:`S_k` is the set of participating clients. Setting ``mu=0.0``
recovers FedAvg exactly. The default ``mu=0.01`` is a sensible benchmark starting point, but users are
strongly encouraged to tune ``mu`` for each problem; a practical grid is ``[0.001, 0.01, 0.1, 0.5, 1.0]``.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would remove the sentence The default ``mu=0.01`` is a sensible benchmark starting point, but users are strongly encouraged to tune ``mu`` for each problem; a practical grid is ``[0.001, 0.01, 0.1, 0.5, 1.0]``. this is because: that grid is potentially not the best starting point for some problems (as the best value of mu depends on the problem features). and we never mention tuning the hyperparameters, as it is implicit, so I don't think it's needed to mention it here

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also explicitly link FedAvg in the sentence Setting ``mu=0.0`` recovers FedAvg exactly.

As in FedAvg, aggregation uses client weights, defaulting to data-size weights when ``client_weights`` is not
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would not reference FedAvg in the docstring (except to say that it's the same as FedProx when mu=0). I think the docstring should self-contained

provided, and client selection defaults to uniform sampling with fraction 1.0. For
:class:`~decent_bench.costs.EmpiricalRiskCost`, local updates use mini-batches of size
:attr:`EmpiricalRiskCost.batch_size <decent_bench.costs.EmpiricalRiskCost.batch_size>`; for generic costs,
local updates use full-batch gradients.
"""

iterations: int = 100
step_size: float = 0.001
num_local_epochs: int = 1
mu: float = 0.01 # Sensible benchmark default; tune over {0.001, 0.01, 0.1, 0.5, 1} for new settings.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would remove the comment given the discussion above

client_weights: ClientWeights | None = None
selection_scheme: ClientSelectionScheme | None = field(
default_factory=lambda: UniformClientSelection(client_fraction=1.0)
)
x0: InitialStates = None
name: str = "FedProx"

def __post_init__(self) -> None:
"""
Validate hyperparameters.

Raises:
ValueError: if hyperparameters are invalid.

"""
if self.step_size <= 0:
raise ValueError("`step_size` must be positive")
if self.num_local_epochs <= 0:
raise ValueError("`num_local_epochs` must be positive")
if self.mu < 0:
raise ValueError("`mu` must be non-negative")

def initialize(self, network: FedNetwork) -> None: # noqa: D102
self.x0 = alg_helpers.initial_states(self.x0, network)
network.server().initialize(x=self.x0[network.server()])
for client in network.clients():
client.initialize(x=self.x0[client])

def step(self, network: FedNetwork, iteration: int) -> None: # noqa: D102
selected_clients = self._selected_clients_for_round(network, iteration)
if not selected_clients:
return

self._sync_server_to_clients(network, selected_clients)
self._run_local_updates(network, selected_clients)
self.aggregate(network, selected_clients)

def _sync_server_to_clients(self, network: FedNetwork, selected_clients: Sequence["Agent"]) -> None:
network.send(sender=network.server(), receiver=selected_clients, msg=network.server().x)

def _run_local_updates(self, network: FedNetwork, selected_clients: Sequence["Agent"]) -> None:
for client in selected_clients:
client.x = self._compute_local_update(client, network.server())
network.send(sender=client, receiver=network.server(), msg=client.x)

def _compute_local_update(self, client: "Agent", server: "Agent") -> "Array":
reference_x = iop.copy(client.messages[server]) if server in client.messages else iop.copy(client.x)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You dont need to copy the reference_x since you're not modifying it

local_x = iop.copy(reference_x)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if I'm not mistaken, this should not be here, because it effectively removes the regularization term local_x=reference_x. the result is that we always run FedAvg (I tested and that's indeed the case)

removing this line should be enough to fix it

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've looked at the paper, and it's not very clear what they do to initialize local_x. I think maybe it should be left to the previous value? I tried the current implementation also with num local epochs > 1 and I get the same results for both algorithms. I don't quite understand

if isinstance(client.cost, EmpiricalRiskCost):
cost = client.cost
n_samples = cost.n_samples
return self._epoch_minibatch_update(cost, local_x, reference_x, cost.batch_size, n_samples)

for _ in range(self.num_local_epochs):
grad = client.cost.gradient(local_x) + self.mu * (local_x - reference_x)
local_x -= self.step_size * grad
return local_x

def _epoch_minibatch_update(
self,
cost: EmpiricalRiskCost,
local_x: "Array",
reference_x: "Array",
per_client_batch: int,
n_samples: int,
) -> "Array":
for _ in range(self.num_local_epochs):
indices = list(range(n_samples))
random.shuffle(indices)
for start in range(0, n_samples, per_client_batch):
batch_indices = indices[start : start + per_client_batch]
grad = cost.gradient(local_x, indices=batch_indices) + self.mu * (local_x - reference_x)
local_x -= self.step_size * grad
Comment on lines +472 to +475
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want you could pass indices="all" here since the cost functions themselves should handle batching. But I'm not sure if that is what you want here so its just a comment

return local_x


@tags("peer-to-peer", "gradient-based")
@dataclass(eq=False)
class DGD(P2PAlgorithm):
Expand Down
4 changes: 4 additions & 0 deletions docs/source/user.rst
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ Federated
:tag: federated
:module: decent_bench.distributed_algorithms

FedProx extends FedAvg with a proximal coefficient ``mu``. Setting ``mu=0`` reduces
FedProx to FedAvg. In practice, ``mu`` should typically be tuned for each problem;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also here I would remove the discussion about tuning

``[0.001, 0.01, 0.1, 0.5, 1.0]`` is a reasonable starting grid.


Available costs
---------------
Expand Down
2 changes: 2 additions & 0 deletions test/test_distributed_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
ED,
EXTRA,
FedAvg,
FedProx,
NIDS,
AugDGM,
SimpleGT,
Expand All @@ -21,6 +22,7 @@
("algorithm_cls", "kwargs"),
[
(FedAvg, {"iterations": 10, "step_size": 0.1}),
(FedProx, {"iterations": 10, "step_size": 0.1}),
(DGD, {"iterations": 10, "step_size": 0.1}),
(ATC, {"iterations": 10, "step_size": 0.1}),
(SimpleGT, {"iterations": 10, "step_size": 0.1}),
Expand Down
Loading