fix: enforce backend consistency and expand tests#282
fix: enforce backend consistency and expand tests#282adrianardv wants to merge 10 commits intoteam-decent:mainfrom
Conversation
Route FedAvg local updates through the preserved cost abstraction so empirical costs and empirical wrappers keep mini-batch behavior while generic costs, regularizers, and zero costs continue to use full gradients. Add FedAvg-specific batching tests into a dedicated module.
Document how to combine PyTorchCost with built-in regularizers using matching framework and device settings. Add tests that verify empirical behavior is preserved for compatible PyTorch objectives and that mismatched NumPy regularizers raise an error.
Move server-to-client synchronization into FedAlgorithm and rename FedAvg local update helpers to reflect their batching and full-gradient roles more clearly. Add focused federated routing and aggregation tests to cover the preserved FedAvg behavior with composed costs.
Add focused tests for concrete cost implementations and network communication behavior. Cover invalid graphs, inactive receivers, message buffer lifecycle, and scheme integration, and add direct checks for linear, logistic, quadratic, zero, and PyTorch costs.
…sitions Validate agent's cost shape, framework, and device when constructing networks so mixed-backend configurations fail early. Validate cost composition on framework and device by default and enforce the same invariant in SumCost construction so mixed-backedn composite costs are rejected consistently. Add regression tests covering mixed framework/device rejection for networks and generic cost composition.
|
looks good, thank you! I'll review the details soon one idea I had: currently the checks are made using Cost's method I would say you can set this PR to close #275 , and we address the |
nicola-bastianello
left a comment
There was a problem hiding this comment.
looks good, thanks! just one comment
| def _cleanup_agents(self, network: FedNetwork) -> Iterable["Agent"]: | ||
| return [network.server(), *network.clients()] | ||
|
|
||
| def _sync_server_to_clients(self, network: FedNetwork, selected_clients: Sequence["Agent"]) -> None: |
There was a problem hiding this comment.
if we want to make this a method in FedAlgorithm maybe it should be public, so that users implementing new algorithms can benefit from this util. if we want to keep private, then it should probably be in the subclass
if made public, could be renamed like server_broadcast or something like that (for shorter name)
Simpag
left a comment
There was a problem hiding this comment.
Two minor comments, dont have to do anything if you dont want to/have the time, otherwise LGTM.
| batch_indices = indices[start : start + cost.batch_size] | ||
| grad = cost.gradient(local_x, indices=batch_indices) | ||
| local_x -= self.step_size * grad | ||
| return local_x |
There was a problem hiding this comment.
While this assures that a full epoch through the dataset is performed, I've updated empirical risk cost sampling so that it will iterate through the entire dataset (random order) before it re-uses datapoints. Therefore you could simplify this a bit by removing the indices parameter, but its not a big deal.
One side effect; if you're using PyTorchCost with a dataloader, the indices parameter bypasses the dataloader and gathers data manually. From my experience dataloaders are slower when running on the cpu so its not commonly used but it might slow things down depending on the model and dataset size.
| def compress(self, msg: Array) -> Array: # noqa: D102 | ||
| res = np.vectorize(lambda x: float(f"%.{self.n_significant_digits - 1}e" % x))(iop.to_numpy(msg)) # noqa: RUF073 | ||
| res = np.vectorize(lambda x: float(format(x, f".{self.n_significant_digits - 1}e")))(iop.to_numpy(msg)) | ||
| return iop.to_array_like(res, msg) | ||
|
|
There was a problem hiding this comment.
I feel like there has to be a more efficient way of performing quantization than doing to_numpy -> float -> string -> float -> back to framework. If you feel like you have the time please check if there are any better ways of doing this, otherwise I'll just create an issue of this at some point no problem.
There was a problem hiding this comment.
Update: You dont have to worry about this. This is insanely inefficient, I have made an update to this and will include it in my bigger update within 1-2 weeks. Some simple math made this at least 10x more efficient
Summary
This PR adds a set of small enhancements and test coverage improvements across costs, networks, and federated behavior.
The main functional change is to enforce backend consistency more strictly:
This addresses the main issue in #275. The only remaining point discussed there is globally setting framework/device through IOP.
What changed
SumCostPyTorchCost+ built-in regularizerCloses #275