diff --git a/decent_bench/benchmark/_utils.py b/decent_bench/benchmark/_utils.py index d55e802..7f3dd20 100644 --- a/decent_bench/benchmark/_utils.py +++ b/decent_bench/benchmark/_utils.py @@ -75,8 +75,8 @@ def model_gen() -> torch.nn.Module: ) # Mypy cannot infer that cost_cls is PyTorchCost here - costs = [ - cost_cls( # type: ignore[call-arg] + pytorch_costs: list[PyTorchCost] = [ + PyTorchCost( dataset=p, model=model_gen(), loss_fn=torch.nn.CrossEntropyLoss(), @@ -86,11 +86,15 @@ def model_gen() -> torch.nn.Module: ) for p in dataset.get_partitions() ] + costs: Sequence[Cost] = pytorch_costs x_optimal = None elif cost_cls is LogisticRegressionCost: - costs = [cost_cls(dataset=p, batch_size=batch_size) for p in dataset.get_partitions()] # type: ignore[call-arg] - sum_cost = reduce(add, costs) + classification_costs: list[LogisticRegressionCost] = [ + LogisticRegressionCost(dataset=p, batch_size=batch_size) for p in dataset.get_partitions() + ] + sum_cost = reduce(add, classification_costs) x_optimal = ca.accelerated_gradient_descent(sum_cost, x0=None, max_iter=50000, stop_tol=1e-100, max_tol=1e-16) + costs = classification_costs else: raise ValueError(f"Unsupported cost class: {cost_cls}") @@ -158,15 +162,19 @@ def model_gen() -> torch.nn.Module: output_size=1, ) - costs = [ - cost_cls(dataset=p, model=model_gen(), loss_fn=torch.nn.MSELoss(), batch_size=batch_size, device=device) # type: ignore[call-arg] + pytorch_costs: list[PyTorchCost] = [ + PyTorchCost(dataset=p, model=model_gen(), loss_fn=torch.nn.MSELoss(), batch_size=batch_size, device=device) for p in dataset.get_partitions() ] + costs: Sequence[Cost] = pytorch_costs x_optimal = None elif cost_cls is LinearRegressionCost: - costs = [cost_cls(dataset=p, batch_size=batch_size) for p in dataset.get_partitions()] # type: ignore[call-arg] - sum_cost = reduce(add, costs) + regression_costs: list[LinearRegressionCost] = [ + LinearRegressionCost(dataset=p, batch_size=batch_size) for p in dataset.get_partitions() + ] + sum_cost = reduce(add, regression_costs) x_optimal = ca.accelerated_gradient_descent(sum_cost, x0=None, max_iter=50000, stop_tol=1e-100, max_tol=1e-16) + costs = regression_costs else: raise ValueError(f"Unsupported cost class: {cost_cls}") diff --git a/decent_bench/costs/_base/_cost.py b/decent_bench/costs/_base/_cost.py index 2366cf2..d44a488 100644 --- a/decent_bench/costs/_base/_cost.py +++ b/decent_bench/costs/_base/_cost.py @@ -15,8 +15,8 @@ def _validate_cost_operation( self, other: object, *, - check_framework: bool = False, - check_device: bool = False, + check_framework: bool = True, + check_device: bool = True, ) -> None: """ Validate that another object can participate in a binary cost operation. diff --git a/decent_bench/costs/_base/_sum_cost.py b/decent_bench/costs/_base/_sum_cost.py index a4f2ae1..acb801c 100644 --- a/decent_bench/costs/_base/_sum_cost.py +++ b/decent_bench/costs/_base/_sum_cost.py @@ -25,8 +25,9 @@ class SumCost(Cost): """ def __init__(self, costs: list[Cost]): - if not all(costs[0].shape == cf.shape for cf in costs): - raise ValueError("All cost functions must have the same domain shape") + if len(costs) == 0: + raise ValueError("SumCost must contain at least one cost function.") + self.costs: list[Cost] = [] for cf in costs: if isinstance(cf, SumCost): @@ -34,6 +35,10 @@ def __init__(self, costs: list[Cost]): else: self.costs.append(cf) + first = self.costs[0] + for cf in self.costs[1:]: + first._validate_cost_operation(cf) # noqa: SLF001 + @property def shape(self) -> tuple[int, ...]: return self.costs[0].shape diff --git a/decent_bench/datasets/_pytorch_handler.py b/decent_bench/datasets/_pytorch_handler.py index fd1e1d4..e75c00e 100644 --- a/decent_bench/datasets/_pytorch_handler.py +++ b/decent_bench/datasets/_pytorch_handler.py @@ -2,6 +2,7 @@ import random from collections import defaultdict +from collections.abc import Iterable from typing import TYPE_CHECKING, Any, cast import decent_bench.utils.interoperability as iop @@ -161,9 +162,10 @@ def _heterogeneous_split(self) -> list[Dataset]: """ # Group indices by class in a single pass class_to_indices: dict[int, list[int]] = defaultdict(list) - for idx, (_, label) in enumerate(self.torch_dataset): # type: ignore[misc, arg-type] - if label in class_to_indices or len(class_to_indices) < (self.n_partitions * self.targets_per_partition): # type: ignore[has-type] - class_to_indices[label].append(idx) # type: ignore[has-type] + for idx, sample in enumerate(cast("Iterable[Any]", self.torch_dataset)): + _, label = cast("tuple[Any, int]", sample) + if label in class_to_indices or len(class_to_indices) < (self.n_partitions * self.targets_per_partition): + class_to_indices[label].append(idx) # Create partitions from class-grouped indices idx_partitions = [] diff --git a/decent_bench/distributed_algorithms.py b/decent_bench/distributed_algorithms.py index 6aa070c..3c75268 100644 --- a/decent_bench/distributed_algorithms.py +++ b/decent_bench/distributed_algorithms.py @@ -6,7 +6,7 @@ import decent_bench.utils.algorithm_helpers as alg_helpers import decent_bench.utils.interoperability as iop -from decent_bench.costs import EmpiricalRiskCost +from decent_bench.costs import Cost, EmpiricalRiskCost from decent_bench.networks import FedNetwork, Network, P2PNetwork from decent_bench.schemes import ClientSelectionScheme, UniformClientSelection from decent_bench.utils._tags import tags @@ -179,6 +179,10 @@ class FedAlgorithm(Algorithm[FedNetwork]): def _cleanup_agents(self, network: FedNetwork) -> Iterable["Agent"]: return [network.server(), *network.clients()] + def _sync_server_to_clients(self, network: FedNetwork, selected_clients: Sequence["Agent"]) -> None: + """Send the current server model to the selected clients.""" + network.send(sender=network.server(), receiver=selected_clients, msg=network.server().x) + def select_clients( self, clients: Sequence["Agent"], @@ -281,10 +285,10 @@ class FedAvg(FedAlgorithm): round :math:`k`. In FedAvg, each selected client performs ``num_local_epochs`` local SGD epochs, then the server aggregates the final local models to form :math:`\mathbf{x}_{k+1}`. The aggregation uses client weights, defaulting to data-size weights when ``client_weights`` is not provided. Client selection (subsampling) defaults to uniform - sampling with fraction 1.0 (all active clients) and can be customized via ``selection_scheme``. For - :class:`~decent_bench.costs.EmpiricalRiskCost`, local updates use mini-batches of size - :attr:`EmpiricalRiskCost.batch_size `; for generic costs, local - updates use full-batch gradients. + sampling with fraction 1.0 (all active clients) and can be customized via ``selection_scheme``. Costs that + preserve the :class:`~decent_bench.costs.EmpiricalRiskCost` abstraction use client-side mini-batches of size + :attr:`EmpiricalRiskCost.batch_size `; generic cost wrappers + fall back to full-gradient local updates. """ # C=0.1; batch size= inf/10/50 (dataset sizes are bigger; normally 1/10 of the total dataset). @@ -327,38 +331,48 @@ def step(self, network: FedNetwork, iteration: int) -> None: # noqa: D102 self._run_local_updates(network, selected_clients) self.aggregate(network, selected_clients) - def _sync_server_to_clients(self, network: FedNetwork, selected_clients: Sequence["Agent"]) -> None: - network.send(sender=network.server(), receiver=selected_clients, msg=network.server().x) - def _run_local_updates(self, network: FedNetwork, selected_clients: Sequence["Agent"]) -> None: for client in selected_clients: client.x = self._compute_local_update(client, network.server()) network.send(sender=client, receiver=network.server(), msg=client.x) + @staticmethod + def _empirical_cost_for_local_updates(cost: Cost) -> EmpiricalRiskCost | None: + """ + Return the empirical cost view to use for FedAvg local mini-batching. + + The repository's batching semantics are carried by the ``EmpiricalRiskCost`` abstraction itself. Wrappers that + preserve empirical-risk metadata, such as empirical regularization/scaling and ``PyTorchCost``, inherit from + ``EmpiricalRiskCost`` and are therefore batch-capable. Generic wrappers like ``SumCost`` and ``ScaledCost`` + intentionally erase that abstraction and should use the full-gradient path. + """ + if isinstance(cost, EmpiricalRiskCost): + return cost + return None + def _compute_local_update(self, client: "Agent", server: "Agent") -> "Array": local_x = iop.copy(client.messages[server]) if server in client.messages else iop.copy(client.x) - if isinstance(client.cost, EmpiricalRiskCost): - cost = client.cost - n_samples = cost.n_samples - return self._epoch_minibatch_update(cost, local_x, cost.batch_size, n_samples) + empirical_cost = self._empirical_cost_for_local_updates(client.cost) + if empirical_cost is not None: + return self._run_minibatch_local_epochs(empirical_cost, local_x) + return self._run_full_gradient_local_epochs(client.cost, local_x) + def _run_full_gradient_local_epochs(self, cost: Cost, local_x: "Array") -> "Array": for _ in range(self.num_local_epochs): - grad = client.cost.gradient(local_x) + grad = cost.gradient(local_x) local_x -= self.step_size * grad return local_x - def _epoch_minibatch_update( + def _run_minibatch_local_epochs( self, cost: EmpiricalRiskCost, local_x: "Array", - per_client_batch: int, - n_samples: int, ) -> "Array": for _ in range(self.num_local_epochs): - indices = list(range(n_samples)) + indices = list(range(cost.n_samples)) random.shuffle(indices) - for start in range(0, n_samples, per_client_batch): - batch_indices = indices[start : start + per_client_batch] + for start in range(0, cost.n_samples, cost.batch_size): + batch_indices = indices[start : start + cost.batch_size] grad = cost.gradient(local_x, indices=batch_indices) local_x -= self.step_size * grad return local_x diff --git a/decent_bench/networks.py b/decent_bench/networks.py index 3c9d0cc..8a18fb9 100644 --- a/decent_bench/networks.py +++ b/decent_bench/networks.py @@ -72,6 +72,7 @@ def __init__( agent_ids = [agent.id for agent in graph.nodes()] if len(agent_ids) != len(set(agent_ids)): raise ValueError("Agent IDs must be unique") + self._validate_agent_cost_compatibility(graph) self._graph = graph self._message_noise = self._initialize_message_schemes(message_noise, "noise", NoiseScheme, NoNoise) @@ -84,6 +85,37 @@ def __init__( self._buffer_messages = buffer_messages self._iteration = 0 # Current iteration, updated by the algorithm + @staticmethod + def _validate_agent_cost_compatibility(graph: AgentGraph) -> None: + """ + Validate that all agents' costs share the same shape, framework, and device. + + Raises: + ValueError: If agents in the graph have mismatching cost shape, framework, or device. + + """ + agents = list(graph.nodes()) + if len(agents) <= 1: + return + + first_cost = agents[0].cost + first_signature = (first_cost.shape, first_cost.framework, first_cost.device) + mismatches: list[str] = [] + for agent in agents[1:]: + signature = (agent.cost.shape, agent.cost.framework, agent.cost.device) + if signature != first_signature: + mismatches.append( + f"agent {agent.id}: shape={agent.cost.shape}, framework={agent.cost.framework}, " + f"device={agent.cost.device}" + ) + + if mismatches: + raise ValueError( + "All agents in a network must have costs with the same shape, framework, and device. " + f"Expected shape={first_cost.shape}, framework={first_cost.framework}, " + f"device={first_cost.device}; mismatches: {'; '.join(mismatches)}" + ) + def _initialize_message_schemes( self, scheme: object, diff --git a/decent_bench/schemes.py b/decent_bench/schemes.py index f587fe3..6e5d01d 100644 --- a/decent_bench/schemes.py +++ b/decent_bench/schemes.py @@ -178,7 +178,7 @@ def __init__(self, n_significant_digits: int): self.n_significant_digits = n_significant_digits def compress(self, msg: Array) -> Array: # noqa: D102 - res = np.vectorize(lambda x: float(f"%.{self.n_significant_digits - 1}e" % x))(iop.to_numpy(msg)) # noqa: RUF073 + res = np.vectorize(lambda x: float(format(x, f".{self.n_significant_digits - 1}e")))(iop.to_numpy(msg)) return iop.to_array_like(res, msg) diff --git a/docs/source/user.rst b/docs/source/user.rst index 85182de..bc44726 100644 --- a/docs/source/user.rst +++ b/docs/source/user.rst @@ -125,6 +125,30 @@ Classification :module: decent_bench.costs +PyTorchCost regularization +~~~~~~~~~~~~~~~~~~~~~~~~~~ +When combining :class:`~decent_bench.costs.PyTorchCost` with one of the +built-in regularizers, instantiate the regularizer with the same framework +and device as the empirical cost: + +.. code-block:: python + + from decent_bench.costs import L2RegularizerCost + from decent_bench.utils.types import SupportedFrameworks + + reg = L2RegularizerCost( + shape=cost.shape, + framework=SupportedFrameworks.PYTORCH, + device=cost.device, + ) + objective = cost + reg + +This preserves compatibility with the PyTorch empirical objective and keeps +the resulting objective in the empirical, batch-compatible abstraction. +It is convenient for composition, but it is not necessarily the most +efficient option compared with native framework-specific regularization. + + Execution settings ------------------ Configure settings for metrics, trials, statistical confidence level, logging, and multiprocessing. diff --git a/test/test_cost_operators.py b/test/test_cost_operators.py index 8b217d7..5fa494d 100644 --- a/test/test_cost_operators.py +++ b/test/test_cost_operators.py @@ -3,6 +3,7 @@ import decent_bench.utils.interoperability as iop from decent_bench.costs import Cost, L1RegularizerCost, L2RegularizerCost, QuadraticCost, SumCost +from decent_bench.utils.types import SupportedDevices, SupportedFrameworks def _simple_quadratic(A_scale: float, b_scale: float, c: float = 0.0) -> QuadraticCost: @@ -12,20 +13,28 @@ def _simple_quadratic(A_scale: float, b_scale: float, c: float = 0.0) -> Quadrat class _SimpleCost(Cost): - def __init__(self, scale: float): + def __init__( + self, + scale: float, + *, + framework: SupportedFrameworks = SupportedFrameworks.NUMPY, + device: SupportedDevices = SupportedDevices.CPU, + ): self.scale = scale + self._framework = framework + self._device = device @property def shape(self) -> tuple[int, ...]: return (2,) @property - def framework(self) -> str: - return "numpy" + def framework(self) -> SupportedFrameworks: + return self._framework @property - def device(self) -> str | None: - return "cpu" + def device(self) -> SupportedDevices: + return self._device @property def m_smooth(self) -> float: @@ -182,3 +191,35 @@ def test_cost_scalar_ops_reject_invalid_inputs() -> None: _ = cost / 0.0 with pytest.raises(TypeError): _ = 0.0 / cost + + +def test_cost_addition_rejects_mismatched_frameworks() -> None: + cost_a = _SimpleCost(scale=1.0, framework=SupportedFrameworks.NUMPY) + cost_b = _SimpleCost(scale=2.0, framework=SupportedFrameworks.PYTORCH) + + with pytest.raises(ValueError, match="Mismatching frameworks"): + _ = cost_a + cost_b + + +def test_cost_addition_rejects_mismatched_devices() -> None: + cost_a = _SimpleCost(scale=1.0, device=SupportedDevices.CPU) + cost_b = _SimpleCost(scale=2.0, device=SupportedDevices.GPU) + + with pytest.raises(ValueError, match="Mismatching devices"): + _ = cost_a + cost_b + + +def test_sum_cost_rejects_mismatched_frameworks() -> None: + cost_a = _SimpleCost(scale=1.0, framework=SupportedFrameworks.NUMPY) + cost_b = _SimpleCost(scale=2.0, framework=SupportedFrameworks.PYTORCH) + + with pytest.raises(ValueError, match="Mismatching frameworks"): + SumCost([cost_a, cost_b]) + + +def test_sum_cost_rejects_mismatched_devices() -> None: + cost_a = _SimpleCost(scale=1.0, device=SupportedDevices.CPU) + cost_b = _SimpleCost(scale=2.0, device=SupportedDevices.GPU) + + with pytest.raises(ValueError, match="Mismatching devices"): + SumCost([cost_a, cost_b]) diff --git a/test/test_cost_type_preservation.py b/test/test_cost_type_preservation.py index 1d10a19..b9c63e3 100644 --- a/test/test_cost_type_preservation.py +++ b/test/test_cost_type_preservation.py @@ -12,9 +12,18 @@ L2RegularizerCost, LinearRegressionCost, LogisticRegressionCost, + PyTorchCost, QuadraticCost, SumCost, ) +from decent_bench.utils.types import SupportedDevices, SupportedFrameworks + +try: + import torch + + TORCH_AVAILABLE = True +except ModuleNotFoundError: + TORCH_AVAILABLE = False def _simple_regularizers() -> tuple[L1RegularizerCost, L2RegularizerCost]: @@ -49,6 +58,26 @@ def _simple_logistic_regression_cost() -> LogisticRegressionCost: return LogisticRegressionCost(dataset=dataset, batch_size="all") +def _simple_pytorch_cost(batch_size: int = 2) -> PyTorchCost: + if not TORCH_AVAILABLE: + raise RuntimeError("PyTorch is not available.") + + dataset = [ + (torch.tensor([1.0, 0.0], dtype=torch.float32), torch.tensor([1.0], dtype=torch.float32)), + (torch.tensor([0.0, 1.0], dtype=torch.float32), torch.tensor([-1.0], dtype=torch.float32)), + (torch.tensor([1.0, 1.0], dtype=torch.float32), torch.tensor([0.5], dtype=torch.float32)), + ] + model = torch.nn.Linear(2, 1, bias=False) + loss_fn = torch.nn.MSELoss() + return PyTorchCost( + dataset=dataset, + model=model, + loss_fn=loss_fn, + batch_size=batch_size, + device=SupportedDevices.CPU, + ) + + def _assert_cost_matches_expression( actual_function: float, expected_function: float, @@ -685,3 +714,43 @@ def test_empirical_regularized_cost_proximal_is_explicitly_unsupported() -> None with pytest.raises(NotImplementedError, match="EmpiricalRegularizedCost does not implement a generic proximal"): (risk + reg_l2).proximal(x, rho=0.5) + + +@pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch not available") +def test_pytorch_cost_plus_builtin_l2_regularizer_preserves_empirical_behavior() -> None: + cost = _simple_pytorch_cost(batch_size=2) + reg = L2RegularizerCost( + shape=cost.shape, + framework=SupportedFrameworks.PYTORCH, + device=cost.device, + ) + objective = cost + reg + x = torch.tensor([0.25, -0.75], dtype=torch.float32) + + assert isinstance(objective, EmpiricalRegularizedCost) + assert isinstance(objective, EmpiricalRiskCost) + assert objective.framework == SupportedFrameworks.PYTORCH + assert objective.device == cost.device + assert objective.batch_size == cost.batch_size + assert objective.dataset is cost.dataset + + objective_gradient = objective.gradient(x, indices="all") + expected_gradient = cost.gradient(x, indices="all") + reg.gradient(x) + + assert isinstance(objective_gradient, torch.Tensor) + assert isinstance(objective.gradient(x, indices="batch"), torch.Tensor) + torch.testing.assert_close(objective_gradient, expected_gradient) + + batched_per_sample_gradient = objective.gradient(x, indices="batch", reduction=None) + assert isinstance(batched_per_sample_gradient, torch.Tensor) + assert batched_per_sample_gradient.shape[0] == objective.batch_size + assert len(objective.batch_used) == objective.batch_size + + +@pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch not available") +def test_pytorch_cost_plus_numpy_regularizer_raises_framework_mismatch() -> None: + cost = _simple_pytorch_cost(batch_size=2) + reg = L2RegularizerCost(shape=cost.shape) + + with pytest.raises(ValueError, match="Mismatching frameworks"): + _ = cost + reg diff --git a/test/test_costs.py b/test/test_costs.py new file mode 100644 index 0000000..dfa0b4a --- /dev/null +++ b/test/test_costs.py @@ -0,0 +1,222 @@ +import math + +import numpy as np +import pytest + +from decent_bench.costs import ( + LinearRegressionCost, + LogisticRegressionCost, + PyTorchCost, + QuadraticCost, + ZeroCost, +) +from decent_bench.utils import interoperability as iop +from decent_bench.utils.types import SupportedDevices, SupportedFrameworks + +try: + import torch + + TORCH_AVAILABLE = True +except ModuleNotFoundError: + TORCH_AVAILABLE = False + + +def test_linear_regression_cost_matches_closed_form_values() -> None: + dataset = [ + (np.array([1.0, 0.0]), np.array([1.0])), + (np.array([0.0, 1.0]), np.array([2.0])), + ] + cost = LinearRegressionCost(dataset=dataset, batch_size="all") + x = np.array([3.0, 4.0]) + + assert cost.n_samples == 2 + assert cost.batch_size == 2 + assert cost.function(x, indices="all") == pytest.approx(2.0) + np.testing.assert_allclose(cost.gradient(x, indices="all"), np.array([1.0, 1.0])) + np.testing.assert_allclose(cost.gradient(x, indices="all", reduction=None), np.array([[2.0, 0.0], [0.0, 2.0]])) + np.testing.assert_allclose(cost.hessian(x, indices="all"), 0.5 * np.eye(2)) + np.testing.assert_allclose(cost.proximal(x, rho=2.0), np.array([2.0, 3.0])) + assert cost.batch_used == [0, 1] + + +def test_linear_regression_cost_validates_constructor_and_indices() -> None: + with pytest.raises(ValueError, match="Dataset features must be vectors"): + LinearRegressionCost(dataset=[(np.array([[1.0, 2.0]]), np.array([1.0]))]) + + with pytest.raises(TypeError, match="Dataset targets must be single dimensional values"): + LinearRegressionCost(dataset=[(np.array([1.0, 2.0]), np.array([[1.0]]))]) + + with pytest.raises(ValueError, match="Batch size must be positive"): + LinearRegressionCost(dataset=[(np.array([1.0]), np.array([1.0]))], batch_size=0) + + cost = LinearRegressionCost(dataset=[(np.array([1.0]), np.array([1.0]))], batch_size="all") + with pytest.raises(ValueError, match="Invalid indices string"): + cost.function(np.array([0.0]), indices="invalid") + + +def test_logistic_regression_cost_matches_closed_form_values() -> None: + dataset = [ + (np.array([1.0, 0.0]), np.array([0.0])), + (np.array([0.0, 1.0]), np.array([1.0])), + ] + cost = LogisticRegressionCost(dataset=dataset, batch_size="all") + x = np.zeros(2) + + assert cost.function(x, indices="all") == pytest.approx(math.log(2.0)) + np.testing.assert_allclose(cost.gradient(x, indices="all"), np.array([0.25, -0.25])) + np.testing.assert_allclose(cost.hessian(x, indices="all"), 0.125 * np.eye(2)) + + +def test_logistic_regression_proximal_preserves_batch_size_and_returns_finite_result() -> None: + dataset = [ + (np.array([1.0, 0.0]), np.array([0.0])), + (np.array([0.0, 1.0]), np.array([1.0])), + (np.array([1.0, 1.0]), np.array([1.0])), + ] + cost = LogisticRegressionCost(dataset=dataset, batch_size=2) + x = np.array([0.5, -0.25]) + + prox = cost.proximal(x, rho=0.5) + + assert prox.shape == x.shape + assert np.all(np.isfinite(prox)) + assert cost.batch_size == 2 + + +def test_logistic_regression_validates_labels_and_indices() -> None: + with pytest.raises(ValueError, match="exactly two classes"): + LogisticRegressionCost( + dataset=[ + (np.array([1.0]), np.array([0.0])), + (np.array([2.0]), np.array([1.0])), + (np.array([3.0]), np.array([2.0])), + ] + ) + + cost = LogisticRegressionCost( + dataset=[ + (np.array([1.0, 0.0]), np.array([0.0])), + (np.array([0.0, 1.0]), np.array([1.0])), + ], + batch_size="all", + ) + with pytest.raises(ValueError, match="Invalid indices string"): + cost.gradient(np.zeros(2), indices="invalid") + + +def test_quadratic_cost_matches_direct_formula_and_symmetrized_derivatives() -> None: + A = np.array([[2.0, 1.0], [3.0, 4.0]]) + b = np.array([1.0, -2.0]) + x = np.array([1.0, -1.0]) + cost = QuadraticCost(A=A, b=b, c=3.0) + A_sym = 0.5 * (A + A.T) + + assert cost.function(x) == pytest.approx(7.0) + np.testing.assert_allclose(cost.gradient(x), A_sym @ x + b) + np.testing.assert_allclose(cost.hessian(x), A_sym) + np.testing.assert_allclose(cost.proximal(x, rho=0.5), np.array([0.3, -0.1])) + eigvals = np.linalg.eigvalsh(A_sym) + assert cost.m_smooth == pytest.approx(float(np.max(np.abs(eigvals)))) + assert cost.m_cvx == pytest.approx(float(np.min(eigvals))) + + +@pytest.mark.parametrize( + ("A", "b", "match"), + [ + (np.array([1.0, 2.0]), np.array([1.0, 2.0]), "Matrix A must be 2D"), + (np.array([[1.0, 2.0, 3.0]]), np.array([1.0]), "Matrix A must be square"), + (np.eye(2), np.array([[1.0], [2.0]]), "Vector b must be 1D"), + (np.eye(2), np.array([1.0, 2.0, 3.0]), "Dimension mismatch"), + ], +) +def test_quadratic_cost_validates_constructor_inputs(A: np.ndarray, b: np.ndarray, match: str) -> None: + with pytest.raises(ValueError, match=match): + QuadraticCost(A=A, b=b) + + +def test_zero_cost_returns_zero_values_and_preserves_framework_metadata() -> None: + cost = ZeroCost(shape=(2,), framework=SupportedFrameworks.NUMPY, device=SupportedDevices.CPU) + x = np.array([1.5, -2.5]) + + assert cost.framework == SupportedFrameworks.NUMPY + assert cost.device == SupportedDevices.CPU + assert cost.function(x) == 0.0 + np.testing.assert_allclose(iop.to_numpy(cost.gradient(x)), np.zeros(2)) + np.testing.assert_allclose(iop.to_numpy(cost.hessian(x)), np.zeros((2, 2))) + np.testing.assert_allclose(iop.to_numpy(cost.proximal(x, rho=1.0)), x) + assert cost.m_smooth == 0.0 + assert cost.m_cvx == 0.0 + + +def test_zero_cost_validates_shape_and_penalty() -> None: + with pytest.raises(ValueError, match="non-negative integers"): + ZeroCost(shape=(2, -1)) + + cost = ZeroCost(shape=(2,)) + with pytest.raises(ValueError, match="Mismatching domain shapes"): + cost.gradient(np.array([1.0, 2.0, 3.0])) + + with pytest.raises(ValueError, match="penalty parameter rho must be positive"): + cost.proximal(np.array([1.0, 2.0]), rho=0.0) + + +@pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch not available") +def _make_pytorch_cost() -> PyTorchCost: + dataset = [ + (torch.tensor([1.0, 0.0], dtype=torch.float32), torch.tensor([1.0], dtype=torch.float32)), + (torch.tensor([0.0, 1.0], dtype=torch.float32), torch.tensor([-1.0], dtype=torch.float32)), + ] + model = torch.nn.Linear(2, 1, bias=False) + return PyTorchCost( + dataset=dataset, + model=model, + loss_fn=torch.nn.MSELoss(), + batch_size="all", + device=SupportedDevices.CPU, + ) + + +@pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch not available") +def test_pytorch_cost_function_and_gradient_match_direct_torch_computation() -> None: + cost = _make_pytorch_cost() + x = torch.tensor([0.5, -1.0], dtype=torch.float32) + + expected_model = torch.nn.Linear(2, 1, bias=False) + with torch.no_grad(): + expected_model.weight.copy_(x.reshape(1, -1)) + inputs = torch.stack([sample_x for sample_x, _ in cost.dataset]) + targets = torch.stack([sample_y for _, sample_y in cost.dataset]) + loss = torch.nn.MSELoss()(expected_model(inputs), targets) + loss.backward() + expected_gradient = expected_model.weight.grad.flatten() + + assert cost.framework == SupportedFrameworks.PYTORCH + assert cost.device == SupportedDevices.CPU + assert cost.function(x, indices="all") == pytest.approx(float(loss.item())) + gradient = cost.gradient(x, indices="all") + assert isinstance(gradient, torch.Tensor) + assert gradient.device.type == "cpu" + torch.testing.assert_close(gradient, expected_gradient) + + +@pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch not available") +def test_pytorch_cost_per_sample_gradient_and_error_paths() -> None: + cost = _make_pytorch_cost() + x = torch.tensor([0.5, -1.0], dtype=torch.float32) + + gradient = cost.gradient(x, indices="all") + per_sample_gradient = cost.gradient(x, indices="all", reduction=None) + assert per_sample_gradient.shape == (2, 2) + torch.testing.assert_close(per_sample_gradient.mean(dim=0), gradient) + + with pytest.raises(ValueError, match="does not match total model parameters"): + cost.function(torch.tensor([1.0], dtype=torch.float32), indices="all") + + with pytest.raises(ValueError, match="Invalid indices string"): + cost.gradient(x, indices="invalid") + + with pytest.raises(NotImplementedError, match="Hessian computation is not implemented"): + cost.hessian(x) + + with pytest.raises(NotImplementedError, match="Proximal operator is not implemented"): + cost.proximal(x, rho=1.0) diff --git a/test/test_federated_aggregation.py b/test/test_federated_aggregation.py new file mode 100644 index 0000000..8da8974 --- /dev/null +++ b/test/test_federated_aggregation.py @@ -0,0 +1,80 @@ +from typing import Any + +import numpy as np + +from decent_bench.agents import Agent +from decent_bench.costs import Cost +from decent_bench.distributed_algorithms import FedAvg +from decent_bench.networks import FedNetwork +from decent_bench.utils.types import SupportedDevices, SupportedFrameworks + + +class TrackingCost(Cost): + def __init__(self, gradient_value: float = 1.0): + self._gradient = np.array([gradient_value], dtype=float) + + @property + def shape(self) -> tuple[int, ...]: + return (1,) + + @property + def framework(self) -> SupportedFrameworks: + return SupportedFrameworks.NUMPY + + @property + def device(self) -> SupportedDevices: + return SupportedDevices.CPU + + @property + def m_smooth(self) -> float: + return 0.0 + + @property + def m_cvx(self) -> float: + return 0.0 + + def function(self, x: np.ndarray, **kwargs: Any) -> float: + del x, kwargs + return 0.0 + + def gradient(self, x: np.ndarray, **kwargs: Any) -> np.ndarray: + del x, kwargs + return self._gradient.copy() + + def hessian(self, x: np.ndarray, **kwargs: Any) -> np.ndarray: + del x, kwargs + return np.zeros((1, 1), dtype=float) + + def proximal(self, x: np.ndarray, rho: float, **kwargs: Any) -> np.ndarray: + del rho, kwargs + return x + + +def _make_fed_network(*costs: Cost) -> tuple[FedNetwork, list[Agent]]: + clients = [Agent(i, cost) for i, cost in enumerate(costs)] + network = FedNetwork(clients=clients) + for client in clients: + client.initialize(x=np.zeros(client.cost.shape, dtype=float)) + network.server().initialize(x=np.zeros(clients[0].cost.shape, dtype=float)) + return network, clients + + +def test_aggregation_uses_only_received_client_updates() -> None: + algorithm = FedAvg(iterations=1, step_size=1.0) + network, clients = _make_fed_network(TrackingCost(1.0), TrackingCost(2.0)) + + network.send(sender=clients[0], receiver=network.server(), msg=np.array([3.0])) + + algorithm.aggregate(network, clients, client_weights={0: 1.0, 1: 10.0}) + + np.testing.assert_allclose(network.server().x, np.array([3.0])) + + +def test_aggregation_keeps_server_state_when_no_updates_are_received() -> None: + algorithm = FedAvg(iterations=1, step_size=1.0) + network, clients = _make_fed_network(TrackingCost(1.0), TrackingCost(2.0)) + network.server().x = np.array([7.0]) + + algorithm.aggregate(network, clients) + + np.testing.assert_allclose(network.server().x, np.array([7.0])) diff --git a/test/test_federated_cost_routing.py b/test/test_federated_cost_routing.py new file mode 100644 index 0000000..2107b9e --- /dev/null +++ b/test/test_federated_cost_routing.py @@ -0,0 +1,267 @@ +from typing import Any + +import numpy as np + +from decent_bench.agents import Agent +from decent_bench.costs import BaseRegularizerCost, Cost, EmpiricalRiskCost, ZeroCost +from decent_bench.distributed_algorithms import FedAvg +from decent_bench.utils.types import ( + Dataset, + EmpiricalRiskIndices, + EmpiricalRiskReduction, + SupportedDevices, + SupportedFrameworks, +) + + +class TrackingCost(Cost): + def __init__(self, gradient_value: float = 1.0): + self.gradient_kwargs: list[dict[str, Any]] = [] + self._gradient = np.array([gradient_value], dtype=float) + + @property + def shape(self) -> tuple[int, ...]: + return (1,) + + @property + def framework(self) -> SupportedFrameworks: + return SupportedFrameworks.NUMPY + + @property + def device(self) -> SupportedDevices: + return SupportedDevices.CPU + + @property + def m_smooth(self) -> float: + return 0.0 + + @property + def m_cvx(self) -> float: + return 0.0 + + def function(self, x: np.ndarray, **kwargs: Any) -> float: + del x, kwargs + return 0.0 + + def gradient(self, x: np.ndarray, **kwargs: Any) -> np.ndarray: + del x + self.gradient_kwargs.append(dict(kwargs)) + return self._gradient.copy() + + def hessian(self, x: np.ndarray, **kwargs: Any) -> np.ndarray: + del x, kwargs + return np.zeros((1, 1), dtype=float) + + def proximal(self, x: np.ndarray, rho: float, **kwargs: Any) -> np.ndarray: + del rho, kwargs + return x + + +class TrackingRegularizerCost(BaseRegularizerCost): + def __init__(self, gradient_value: float = 0.0): + super().__init__(shape=(1,)) + self.gradient_kwargs: list[dict[str, Any]] = [] + self._gradient = np.array([gradient_value], dtype=float) + + @property + def m_smooth(self) -> float: + return 0.0 + + @property + def m_cvx(self) -> float: + return 0.0 + + def function(self, x: np.ndarray, **kwargs: Any) -> float: + del x, kwargs + return 0.0 + + def gradient(self, x: np.ndarray, **kwargs: Any) -> np.ndarray: + del x + self.gradient_kwargs.append(dict(kwargs)) + return self._gradient.copy() + + def hessian(self, x: np.ndarray, **kwargs: Any) -> np.ndarray: + del x, kwargs + return np.zeros((1, 1), dtype=float) + + def proximal(self, x: np.ndarray, rho: float, **kwargs: Any) -> np.ndarray: + del rho, kwargs + return x + + +class TrackingZeroCost(ZeroCost): + def __init__(self): + super().__init__(shape=(1,)) + self.gradient_kwargs: list[dict[str, Any]] = [] + + def gradient(self, x: np.ndarray, **kwargs: Any) -> np.ndarray: + self.gradient_kwargs.append(dict(kwargs)) + return super().gradient(x, **kwargs) + + +class TrackingEmpiricalCost(EmpiricalRiskCost): + def __init__(self, n_samples: int = 5, batch_size: int = 2, gradient_value: float = 1.0): + self._dataset: Dataset = [(np.array([float(i)]), np.array([0.0])) for i in range(n_samples)] + self._batch_size = batch_size + self.gradient_indices: list[list[int]] = [] + self._gradient = np.array([gradient_value], dtype=float) + + @property + def shape(self) -> tuple[int, ...]: + return (1,) + + @property + def framework(self) -> SupportedFrameworks: + return SupportedFrameworks.NUMPY + + @property + def device(self) -> SupportedDevices: + return SupportedDevices.CPU + + @property + def m_smooth(self) -> float: + return 0.0 + + @property + def m_cvx(self) -> float: + return 0.0 + + @property + def n_samples(self) -> int: + return len(self._dataset) + + @property + def batch_size(self) -> int: + return self._batch_size + + @property + def dataset(self) -> Dataset: + return self._dataset + + def predict(self, x: np.ndarray, data: list[np.ndarray]) -> np.ndarray: + del x + return np.asarray(data) + + def function(self, x: np.ndarray, indices: EmpiricalRiskIndices = "batch", **kwargs: Any) -> float: + del x, kwargs + self._sample_batch_indices(indices) + return 0.0 + + def gradient( + self, + x: np.ndarray, + indices: EmpiricalRiskIndices = "batch", + reduction: EmpiricalRiskReduction = "mean", + **kwargs: Any, + ) -> np.ndarray: + del x, kwargs + sampled_indices = self._sample_batch_indices(indices) + self.gradient_indices.append(list(sampled_indices)) + if reduction is None: + return np.repeat(self._gradient[np.newaxis, :], len(sampled_indices), axis=0) + return self._gradient.copy() + + def hessian(self, x: np.ndarray, indices: EmpiricalRiskIndices = "batch", **kwargs: Any) -> np.ndarray: + del x, kwargs + self._sample_batch_indices(indices) + return np.zeros((1, 1), dtype=float) + + def _get_batch_data(self, indices: EmpiricalRiskIndices = "batch") -> list[tuple[np.ndarray, np.ndarray]]: + sampled_indices = self._sample_batch_indices(indices) + return [self._dataset[i] for i in sampled_indices] + + +def _run_fedavg_local_update(cost: Cost, *, step_size: float = 1.0, num_local_epochs: int = 1) -> np.ndarray: + algorithm = FedAvg(iterations=1, step_size=step_size, num_local_epochs=num_local_epochs) + client = Agent(0, cost) + server = Agent(1, ZeroCost(cost.shape)) + client.initialize(x=np.zeros(cost.shape, dtype=float)) + server.initialize(x=np.zeros(cost.shape, dtype=float)) + return algorithm._compute_local_update(client, server) + + +def test_empirical_costs_use_minibatch_local_updates() -> None: + cost = TrackingEmpiricalCost(n_samples=5, batch_size=2) + + updated = _run_fedavg_local_update(cost) + + np.testing.assert_allclose(updated, np.array([-3.0])) + assert sorted(len(indices) for indices in cost.gradient_indices) == [1, 2, 2] + + +def test_empirical_regularized_costs_keep_minibatch_local_updates() -> None: + empirical_cost = TrackingEmpiricalCost(n_samples=5, batch_size=2) + regularizer = TrackingRegularizerCost() + objective = empirical_cost + regularizer + + updated = _run_fedavg_local_update(objective) + + np.testing.assert_allclose(updated, np.array([-3.0])) + assert sorted(len(indices) for indices in empirical_cost.gradient_indices) == [1, 2, 2] + assert len(regularizer.gradient_kwargs) == 3 + assert all(kwargs == {} for kwargs in regularizer.gradient_kwargs) + + +def test_scaled_empirical_costs_keep_minibatch_local_updates() -> None: + empirical_cost = TrackingEmpiricalCost(n_samples=5, batch_size=2) + objective = 2.0 * empirical_cost + + updated = _run_fedavg_local_update(objective) + + np.testing.assert_allclose(updated, np.array([-6.0])) + assert sorted(len(indices) for indices in empirical_cost.gradient_indices) == [1, 2, 2] + + +def test_plain_costs_use_full_gradient_local_updates() -> None: + cost = TrackingCost(gradient_value=1.0) + + updated = _run_fedavg_local_update(cost, num_local_epochs=3) + + np.testing.assert_allclose(updated, np.array([-3.0])) + assert len(cost.gradient_kwargs) == 3 + assert all(kwargs == {} for kwargs in cost.gradient_kwargs) + + +def test_sum_costs_over_non_empirical_terms_use_full_gradient_local_updates() -> None: + cost_a = TrackingCost(gradient_value=1.0) + cost_b = TrackingCost(gradient_value=2.0) + objective = cost_a + cost_b + + updated = _run_fedavg_local_update(objective, num_local_epochs=2) + + np.testing.assert_allclose(updated, np.array([-6.0])) + assert len(cost_a.gradient_kwargs) == 2 + assert len(cost_b.gradient_kwargs) == 2 + assert all(kwargs == {} for kwargs in cost_a.gradient_kwargs) + assert all(kwargs == {} for kwargs in cost_b.gradient_kwargs) + + +def test_scaled_costs_over_non_empirical_terms_use_full_gradient_local_updates() -> None: + cost = TrackingCost(gradient_value=1.0) + objective = 2.0 * cost + + updated = _run_fedavg_local_update(objective, num_local_epochs=2) + + np.testing.assert_allclose(updated, np.array([-4.0])) + assert len(cost.gradient_kwargs) == 2 + assert all(kwargs == {} for kwargs in cost.gradient_kwargs) + + +def test_regularizers_follow_the_non_batched_local_update_path() -> None: + regularizer = TrackingRegularizerCost(gradient_value=1.0) + + updated = _run_fedavg_local_update(regularizer, num_local_epochs=2) + + np.testing.assert_allclose(updated, np.array([-2.0])) + assert len(regularizer.gradient_kwargs) == 2 + assert all(kwargs == {} for kwargs in regularizer.gradient_kwargs) + + +def test_zero_costs_do_not_need_special_local_update_handling() -> None: + cost = TrackingZeroCost() + + updated = _run_fedavg_local_update(cost, num_local_epochs=3) + + np.testing.assert_allclose(updated, np.array([0.0])) + assert len(cost.gradient_kwargs) == 3 + assert all(kwargs == {} for kwargs in cost.gradient_kwargs) diff --git a/test/test_networks.py b/test/test_networks.py index bee8278..6e6230f 100644 --- a/test/test_networks.py +++ b/test/test_networks.py @@ -1,14 +1,55 @@ import pytest import networkx as nx +import numpy as np from decent_bench.agents import Agent from decent_bench.networks import P2PNetwork, FedNetwork from decent_bench.costs import L2RegularizerCost from decent_bench.utils import interoperability as iop -from decent_bench.schemes import AlwaysActive, CompressionScheme, NoiseScheme, NoDrops, NoNoise, UniformActivationRate +from decent_bench.utils.types import SupportedDevices, SupportedFrameworks +from decent_bench.schemes import ( + AgentActivationScheme, + AlwaysActive, + CompressionScheme, + DropScheme, + NoiseScheme, + NoCompression, + NoDrops, + NoNoise, + UniformActivationRate, +) from unittest.mock import MagicMock +class NeverActive(AgentActivationScheme): + def is_active(self, iteration: int) -> bool: # noqa: D102, ARG002 + return False + + +class MultiplyCompression(CompressionScheme): + def __init__(self, factor: float): + self.factor = factor + + def compress(self, msg): # noqa: ANN001, D102 + return msg * self.factor + + +class AddNoise(NoiseScheme): + def __init__(self, offset: float): + self.offset = offset + + def make_noise(self, msg): # noqa: ANN001, D102 + return msg + self.offset + + +class FixedDrop(DropScheme): + def __init__(self, should_drop_message: bool): + self.should_drop_message = should_drop_message + + def should_drop(self) -> bool: # noqa: D102 + return self.should_drop_message + + def test_p2p_network(n_agents: int = 10) -> None: net = P2PNetwork( graph=nx.complete_graph(n_agents), @@ -83,6 +124,54 @@ def test_fed_network_default_server_is_always_active() -> None: assert isinstance(net.server()._activation, AlwaysActive) # noqa: SLF001 +def test_p2p_network_rejects_mixed_framework_costs() -> None: + agents = [ + Agent(0, L2RegularizerCost((2,), framework=SupportedFrameworks.NUMPY)), + Agent(1, L2RegularizerCost((2,), framework=SupportedFrameworks.PYTORCH)), + ] + + with pytest.raises(ValueError, match="same shape, framework, and device"): + P2PNetwork(graph=nx.complete_graph(2), agents=agents) + + +def test_p2p_network_rejects_mixed_device_costs() -> None: + agents = [ + Agent(0, L2RegularizerCost((2,), device=SupportedDevices.CPU)), + Agent(1, L2RegularizerCost((2,), device=SupportedDevices.GPU)), + ] + + with pytest.raises(ValueError, match="same shape, framework, and device"): + P2PNetwork(graph=nx.complete_graph(2), agents=agents) + + +def test_p2p_network_rejects_mismatched_cost_shapes() -> None: + agents = [ + Agent(0, L2RegularizerCost((2,))), + Agent(1, L2RegularizerCost((3,))), + ] + + with pytest.raises(ValueError, match="same shape, framework, and device"): + P2PNetwork(graph=nx.complete_graph(2), agents=agents) + + +def test_fed_network_rejects_mixed_framework_clients() -> None: + clients = [ + Agent(0, L2RegularizerCost((2,), framework=SupportedFrameworks.NUMPY)), + Agent(1, L2RegularizerCost((2,), framework=SupportedFrameworks.PYTORCH)), + ] + + with pytest.raises(ValueError, match="same shape, framework, and device"): + FedNetwork(clients=clients) + + +def test_fed_network_rejects_custom_server_with_mixed_framework() -> None: + clients = [Agent(i, L2RegularizerCost((2,), framework=SupportedFrameworks.NUMPY)) for i in range(2)] + server = Agent(99, L2RegularizerCost((2,), framework=SupportedFrameworks.PYTORCH), activation=AlwaysActive()) + + with pytest.raises(ValueError, match="same shape, framework, and device"): + FedNetwork(clients=clients, server=server) + + def test_initialize_message_schemes_with_dict_all_agents() -> None: """Test that per-agent scheme dicts work when all agents are provided.""" n_agents = 3 @@ -154,3 +243,111 @@ def test_initialize_message_schemes_dict_used_in_send() -> None: # verify agent 0's compression scheme was called mock_schemes[agents[0]].compress.assert_called_once() + + +def test_p2p_network_rejects_disconnected_graph() -> None: + agents = [Agent(i, L2RegularizerCost((2,))) for i in range(3)] + graph = nx.Graph() + graph.add_edges_from([(0, 1)]) + graph.add_node(2) + + with pytest.raises(ValueError, match="graph needs to be connected"): + P2PNetwork(graph=graph, agents=agents) + + +def test_p2p_network_rejects_directed_graph() -> None: + agents = [Agent(i, L2RegularizerCost((2,))) for i in range(2)] + graph = nx.DiGraph() + graph.add_edge(0, 1) + + with pytest.raises(ValueError, match="Directed graphs are not supported"): + P2PNetwork(graph=graph, agents=agents) + + +def test_p2p_network_rejects_multigraph() -> None: + agents = [Agent(i, L2RegularizerCost((2,))) for i in range(2)] + graph = nx.MultiGraph() + graph.add_edge(0, 1) + + with pytest.raises(NotImplementedError, match="multi-graphs"): + P2PNetwork(graph=graph, agents=agents) + + +def test_p2p_network_rejects_duplicate_agent_ids() -> None: + agent_a = Agent(0, L2RegularizerCost((2,))) + agent_b = Agent(0, L2RegularizerCost((2,))) + graph = nx.Graph() + graph.add_edge(agent_a, agent_b) + + with pytest.raises(ValueError, match="Agent IDs must be unique"): + P2PNetwork(graph=graph) + + +def test_send_rejects_inactive_receiver() -> None: + sender = Agent(0, L2RegularizerCost((2,))) + inactive_receiver = Agent(1, L2RegularizerCost((2,)), activation=NeverActive()) + net = P2PNetwork(graph=nx.Graph([(sender, inactive_receiver)])) + msg = iop.zeros(shape=(2,), framework=sender.cost.framework, device=sender.cost.device) + + with pytest.raises(ValueError, match="not active or not connected"): + net.send(sender=sender, receiver=inactive_receiver, msg=msg) + + +@pytest.mark.parametrize( + ("buffer_messages", "expect_message_after_step"), + [ + (False, False), + (True, True), + ], +) +def test_step_clears_or_preserves_messages_based_on_buffer_setting( + buffer_messages: bool, expect_message_after_step: bool +) -> None: + sender = Agent(0, L2RegularizerCost((2,))) + receiver = Agent(1, L2RegularizerCost((2,))) + net = P2PNetwork( + graph=nx.Graph([(sender, receiver)]), + buffer_messages=buffer_messages, + ) + msg = iop.to_array([1.0, -1.0], framework=sender.cost.framework, device=sender.cost.device) + + net.send(sender=sender, receiver=receiver, msg=msg) + assert sender in receiver.messages + + net._step(1) # noqa: SLF001 + + assert (sender in receiver.messages) is expect_message_after_step + + +def test_send_applies_drop_compression_and_noise_schemes() -> None: + sender = Agent(0, L2RegularizerCost((2,))) + dropped_sender = Agent(1, L2RegularizerCost((2,))) + receiver = Agent(2, L2RegularizerCost((2,))) + net = P2PNetwork( + graph=nx.complete_graph([sender, dropped_sender, receiver]), + message_compression={ + sender: MultiplyCompression(2.0), + dropped_sender: NoCompression(), + receiver: NoCompression(), + }, + message_noise={ + sender: AddNoise(3.0), + dropped_sender: NoNoise(), + receiver: NoNoise(), + }, + message_drop={ + sender: FixedDrop(False), + dropped_sender: FixedDrop(True), + receiver: NoDrops(), + }, + ) + msg = iop.to_array([1.0, 2.0], framework=sender.cost.framework, device=sender.cost.device) + + net.send(sender=sender, receiver=receiver, msg=msg) + np_received = iop.to_numpy(receiver.messages[sender]) + np_expected = np.array([5.0, 7.0]) + assert np_received.shape == np_expected.shape + assert np_received.tolist() == pytest.approx(np_expected.tolist()) + + net.send(sender=dropped_sender, receiver=receiver, msg=msg) + assert dropped_sender not in receiver.messages