-
Notifications
You must be signed in to change notification settings - Fork 5
fix: enforce backend consistency and expand tests #282
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
cdd302d
b281df2
8cbb181
d69f76d
68795eb
c443385
33f8c02
11330fc
5fc560f
e25a5c0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,7 +6,7 @@ | |
|
|
||
| import decent_bench.utils.algorithm_helpers as alg_helpers | ||
| import decent_bench.utils.interoperability as iop | ||
| from decent_bench.costs import EmpiricalRiskCost | ||
| from decent_bench.costs import Cost, EmpiricalRiskCost | ||
| from decent_bench.networks import FedNetwork, Network, P2PNetwork | ||
| from decent_bench.schemes import ClientSelectionScheme, UniformClientSelection | ||
| from decent_bench.utils._tags import tags | ||
|
|
@@ -179,6 +179,10 @@ class FedAlgorithm(Algorithm[FedNetwork]): | |
| def _cleanup_agents(self, network: FedNetwork) -> Iterable["Agent"]: | ||
| return [network.server(), *network.clients()] | ||
|
|
||
| def _sync_server_to_clients(self, network: FedNetwork, selected_clients: Sequence["Agent"]) -> None: | ||
| """Send the current server model to the selected clients.""" | ||
| network.send(sender=network.server(), receiver=selected_clients, msg=network.server().x) | ||
|
|
||
| def select_clients( | ||
| self, | ||
| clients: Sequence["Agent"], | ||
|
|
@@ -281,10 +285,10 @@ class FedAvg(FedAlgorithm): | |
| round :math:`k`. In FedAvg, each selected client performs ``num_local_epochs`` local SGD epochs, then the server | ||
| aggregates the final local models to form :math:`\mathbf{x}_{k+1}`. The aggregation uses client weights, defaulting | ||
| to data-size weights when ``client_weights`` is not provided. Client selection (subsampling) defaults to uniform | ||
| sampling with fraction 1.0 (all active clients) and can be customized via ``selection_scheme``. For | ||
| :class:`~decent_bench.costs.EmpiricalRiskCost`, local updates use mini-batches of size | ||
| :attr:`EmpiricalRiskCost.batch_size <decent_bench.costs.EmpiricalRiskCost.batch_size>`; for generic costs, local | ||
| updates use full-batch gradients. | ||
| sampling with fraction 1.0 (all active clients) and can be customized via ``selection_scheme``. Costs that | ||
| preserve the :class:`~decent_bench.costs.EmpiricalRiskCost` abstraction use client-side mini-batches of size | ||
| :attr:`EmpiricalRiskCost.batch_size <decent_bench.costs.EmpiricalRiskCost.batch_size>`; generic cost wrappers | ||
| fall back to full-gradient local updates. | ||
| """ | ||
|
|
||
| # C=0.1; batch size= inf/10/50 (dataset sizes are bigger; normally 1/10 of the total dataset). | ||
|
|
@@ -327,38 +331,48 @@ def step(self, network: FedNetwork, iteration: int) -> None: # noqa: D102 | |
| self._run_local_updates(network, selected_clients) | ||
| self.aggregate(network, selected_clients) | ||
|
|
||
| def _sync_server_to_clients(self, network: FedNetwork, selected_clients: Sequence["Agent"]) -> None: | ||
| network.send(sender=network.server(), receiver=selected_clients, msg=network.server().x) | ||
|
|
||
| def _run_local_updates(self, network: FedNetwork, selected_clients: Sequence["Agent"]) -> None: | ||
| for client in selected_clients: | ||
| client.x = self._compute_local_update(client, network.server()) | ||
| network.send(sender=client, receiver=network.server(), msg=client.x) | ||
|
|
||
| @staticmethod | ||
| def _empirical_cost_for_local_updates(cost: Cost) -> EmpiricalRiskCost | None: | ||
| """ | ||
| Return the empirical cost view to use for FedAvg local mini-batching. | ||
|
|
||
| The repository's batching semantics are carried by the ``EmpiricalRiskCost`` abstraction itself. Wrappers that | ||
| preserve empirical-risk metadata, such as empirical regularization/scaling and ``PyTorchCost``, inherit from | ||
| ``EmpiricalRiskCost`` and are therefore batch-capable. Generic wrappers like ``SumCost`` and ``ScaledCost`` | ||
| intentionally erase that abstraction and should use the full-gradient path. | ||
| """ | ||
| if isinstance(cost, EmpiricalRiskCost): | ||
| return cost | ||
| return None | ||
|
|
||
| def _compute_local_update(self, client: "Agent", server: "Agent") -> "Array": | ||
| local_x = iop.copy(client.messages[server]) if server in client.messages else iop.copy(client.x) | ||
| if isinstance(client.cost, EmpiricalRiskCost): | ||
| cost = client.cost | ||
| n_samples = cost.n_samples | ||
| return self._epoch_minibatch_update(cost, local_x, cost.batch_size, n_samples) | ||
| empirical_cost = self._empirical_cost_for_local_updates(client.cost) | ||
| if empirical_cost is not None: | ||
| return self._run_minibatch_local_epochs(empirical_cost, local_x) | ||
| return self._run_full_gradient_local_epochs(client.cost, local_x) | ||
|
|
||
| def _run_full_gradient_local_epochs(self, cost: Cost, local_x: "Array") -> "Array": | ||
| for _ in range(self.num_local_epochs): | ||
| grad = client.cost.gradient(local_x) | ||
| grad = cost.gradient(local_x) | ||
| local_x -= self.step_size * grad | ||
| return local_x | ||
|
|
||
| def _epoch_minibatch_update( | ||
| def _run_minibatch_local_epochs( | ||
| self, | ||
| cost: EmpiricalRiskCost, | ||
| local_x: "Array", | ||
| per_client_batch: int, | ||
| n_samples: int, | ||
| ) -> "Array": | ||
| for _ in range(self.num_local_epochs): | ||
| indices = list(range(n_samples)) | ||
| indices = list(range(cost.n_samples)) | ||
| random.shuffle(indices) | ||
| for start in range(0, n_samples, per_client_batch): | ||
| batch_indices = indices[start : start + per_client_batch] | ||
| for start in range(0, cost.n_samples, cost.batch_size): | ||
| batch_indices = indices[start : start + cost.batch_size] | ||
| grad = cost.gradient(local_x, indices=batch_indices) | ||
| local_x -= self.step_size * grad | ||
| return local_x | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. While this assures that a full epoch through the dataset is performed, I've updated empirical risk cost sampling so that it will iterate through the entire dataset (random order) before it re-uses datapoints. Therefore you could simplify this a bit by removing the indices parameter, but its not a big deal. One side effect; if you're using PyTorchCost with a dataloader, the indices parameter bypasses the dataloader and gathers data manually. From my experience dataloaders are slower when running on the cpu so its not commonly used but it might slow things down depending on the model and dataset size. |
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -178,7 +178,7 @@ def __init__(self, n_significant_digits: int): | |
| self.n_significant_digits = n_significant_digits | ||
|
|
||
| def compress(self, msg: Array) -> Array: # noqa: D102 | ||
| res = np.vectorize(lambda x: float(f"%.{self.n_significant_digits - 1}e" % x))(iop.to_numpy(msg)) # noqa: RUF073 | ||
| res = np.vectorize(lambda x: float(format(x, f".{self.n_significant_digits - 1}e")))(iop.to_numpy(msg)) | ||
| return iop.to_array_like(res, msg) | ||
|
|
||
|
Comment on lines
180
to
183
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel like there has to be a more efficient way of performing quantization than doing to_numpy -> float -> string -> float -> back to framework. If you feel like you have the time please check if there are any better ways of doing this, otherwise I'll just create an issue of this at some point no problem.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update: You dont have to worry about this. This is insanely inefficient, I have made an update to this and will include it in my bigger update within 1-2 weeks. Some simple math made this at least 10x more efficient |
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we want to make this a method in
FedAlgorithmmaybe it should be public, so that users implementing new algorithms can benefit from this util. if we want to keep private, then it should probably be in the subclassif made public, could be renamed like
server_broadcastor something like that (for shorter name)