Skip to content
Open
24 changes: 16 additions & 8 deletions decent_bench/benchmark/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ def model_gen() -> torch.nn.Module:
)

# Mypy cannot infer that cost_cls is PyTorchCost here
costs = [
cost_cls( # type: ignore[call-arg]
pytorch_costs: list[PyTorchCost] = [
PyTorchCost(
dataset=p,
model=model_gen(),
loss_fn=torch.nn.CrossEntropyLoss(),
Expand All @@ -86,11 +86,15 @@ def model_gen() -> torch.nn.Module:
)
for p in dataset.get_partitions()
]
costs: Sequence[Cost] = pytorch_costs
x_optimal = None
elif cost_cls is LogisticRegressionCost:
costs = [cost_cls(dataset=p, batch_size=batch_size) for p in dataset.get_partitions()] # type: ignore[call-arg]
sum_cost = reduce(add, costs)
classification_costs: list[LogisticRegressionCost] = [
LogisticRegressionCost(dataset=p, batch_size=batch_size) for p in dataset.get_partitions()
]
sum_cost = reduce(add, classification_costs)
x_optimal = ca.accelerated_gradient_descent(sum_cost, x0=None, max_iter=50000, stop_tol=1e-100, max_tol=1e-16)
costs = classification_costs
else:
raise ValueError(f"Unsupported cost class: {cost_cls}")

Expand Down Expand Up @@ -158,15 +162,19 @@ def model_gen() -> torch.nn.Module:
output_size=1,
)

costs = [
cost_cls(dataset=p, model=model_gen(), loss_fn=torch.nn.MSELoss(), batch_size=batch_size, device=device) # type: ignore[call-arg]
pytorch_costs: list[PyTorchCost] = [
PyTorchCost(dataset=p, model=model_gen(), loss_fn=torch.nn.MSELoss(), batch_size=batch_size, device=device)
for p in dataset.get_partitions()
]
costs: Sequence[Cost] = pytorch_costs
x_optimal = None
elif cost_cls is LinearRegressionCost:
costs = [cost_cls(dataset=p, batch_size=batch_size) for p in dataset.get_partitions()] # type: ignore[call-arg]
sum_cost = reduce(add, costs)
regression_costs: list[LinearRegressionCost] = [
LinearRegressionCost(dataset=p, batch_size=batch_size) for p in dataset.get_partitions()
]
sum_cost = reduce(add, regression_costs)
x_optimal = ca.accelerated_gradient_descent(sum_cost, x0=None, max_iter=50000, stop_tol=1e-100, max_tol=1e-16)
costs = regression_costs
else:
raise ValueError(f"Unsupported cost class: {cost_cls}")

Expand Down
4 changes: 2 additions & 2 deletions decent_bench/costs/_base/_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def _validate_cost_operation(
self,
other: object,
*,
check_framework: bool = False,
check_device: bool = False,
check_framework: bool = True,
check_device: bool = True,
) -> None:
"""
Validate that another object can participate in a binary cost operation.
Expand Down
9 changes: 7 additions & 2 deletions decent_bench/costs/_base/_sum_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,20 @@ class SumCost(Cost):
"""

def __init__(self, costs: list[Cost]):
if not all(costs[0].shape == cf.shape for cf in costs):
raise ValueError("All cost functions must have the same domain shape")
if len(costs) == 0:
raise ValueError("SumCost must contain at least one cost function.")

self.costs: list[Cost] = []
for cf in costs:
if isinstance(cf, SumCost):
self.costs.extend(cf.costs)
else:
self.costs.append(cf)

first = self.costs[0]
for cf in self.costs[1:]:
first._validate_cost_operation(cf) # noqa: SLF001

@property
def shape(self) -> tuple[int, ...]:
return self.costs[0].shape
Expand Down
8 changes: 5 additions & 3 deletions decent_bench/datasets/_pytorch_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import random
from collections import defaultdict
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, cast

import decent_bench.utils.interoperability as iop
Expand Down Expand Up @@ -161,9 +162,10 @@ def _heterogeneous_split(self) -> list[Dataset]:
"""
# Group indices by class in a single pass
class_to_indices: dict[int, list[int]] = defaultdict(list)
for idx, (_, label) in enumerate(self.torch_dataset): # type: ignore[misc, arg-type]
if label in class_to_indices or len(class_to_indices) < (self.n_partitions * self.targets_per_partition): # type: ignore[has-type]
class_to_indices[label].append(idx) # type: ignore[has-type]
for idx, sample in enumerate(cast("Iterable[Any]", self.torch_dataset)):
_, label = cast("tuple[Any, int]", sample)
if label in class_to_indices or len(class_to_indices) < (self.n_partitions * self.targets_per_partition):
class_to_indices[label].append(idx)

# Create partitions from class-grouped indices
idx_partitions = []
Expand Down
52 changes: 33 additions & 19 deletions decent_bench/distributed_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import decent_bench.utils.algorithm_helpers as alg_helpers
import decent_bench.utils.interoperability as iop
from decent_bench.costs import EmpiricalRiskCost
from decent_bench.costs import Cost, EmpiricalRiskCost
from decent_bench.networks import FedNetwork, Network, P2PNetwork
from decent_bench.schemes import ClientSelectionScheme, UniformClientSelection
from decent_bench.utils._tags import tags
Expand Down Expand Up @@ -179,6 +179,10 @@ class FedAlgorithm(Algorithm[FedNetwork]):
def _cleanup_agents(self, network: FedNetwork) -> Iterable["Agent"]:
return [network.server(), *network.clients()]

def _sync_server_to_clients(self, network: FedNetwork, selected_clients: Sequence["Agent"]) -> None:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we want to make this a method in FedAlgorithm maybe it should be public, so that users implementing new algorithms can benefit from this util. if we want to keep private, then it should probably be in the subclass

if made public, could be renamed like server_broadcast or something like that (for shorter name)

"""Send the current server model to the selected clients."""
network.send(sender=network.server(), receiver=selected_clients, msg=network.server().x)

def select_clients(
self,
clients: Sequence["Agent"],
Expand Down Expand Up @@ -281,10 +285,10 @@ class FedAvg(FedAlgorithm):
round :math:`k`. In FedAvg, each selected client performs ``num_local_epochs`` local SGD epochs, then the server
aggregates the final local models to form :math:`\mathbf{x}_{k+1}`. The aggregation uses client weights, defaulting
to data-size weights when ``client_weights`` is not provided. Client selection (subsampling) defaults to uniform
sampling with fraction 1.0 (all active clients) and can be customized via ``selection_scheme``. For
:class:`~decent_bench.costs.EmpiricalRiskCost`, local updates use mini-batches of size
:attr:`EmpiricalRiskCost.batch_size <decent_bench.costs.EmpiricalRiskCost.batch_size>`; for generic costs, local
updates use full-batch gradients.
sampling with fraction 1.0 (all active clients) and can be customized via ``selection_scheme``. Costs that
preserve the :class:`~decent_bench.costs.EmpiricalRiskCost` abstraction use client-side mini-batches of size
:attr:`EmpiricalRiskCost.batch_size <decent_bench.costs.EmpiricalRiskCost.batch_size>`; generic cost wrappers
fall back to full-gradient local updates.
"""

# C=0.1; batch size= inf/10/50 (dataset sizes are bigger; normally 1/10 of the total dataset).
Expand Down Expand Up @@ -327,38 +331,48 @@ def step(self, network: FedNetwork, iteration: int) -> None: # noqa: D102
self._run_local_updates(network, selected_clients)
self.aggregate(network, selected_clients)

def _sync_server_to_clients(self, network: FedNetwork, selected_clients: Sequence["Agent"]) -> None:
network.send(sender=network.server(), receiver=selected_clients, msg=network.server().x)

def _run_local_updates(self, network: FedNetwork, selected_clients: Sequence["Agent"]) -> None:
for client in selected_clients:
client.x = self._compute_local_update(client, network.server())
network.send(sender=client, receiver=network.server(), msg=client.x)

@staticmethod
def _empirical_cost_for_local_updates(cost: Cost) -> EmpiricalRiskCost | None:
"""
Return the empirical cost view to use for FedAvg local mini-batching.

The repository's batching semantics are carried by the ``EmpiricalRiskCost`` abstraction itself. Wrappers that
preserve empirical-risk metadata, such as empirical regularization/scaling and ``PyTorchCost``, inherit from
``EmpiricalRiskCost`` and are therefore batch-capable. Generic wrappers like ``SumCost`` and ``ScaledCost``
intentionally erase that abstraction and should use the full-gradient path.
"""
if isinstance(cost, EmpiricalRiskCost):
return cost
return None

def _compute_local_update(self, client: "Agent", server: "Agent") -> "Array":
local_x = iop.copy(client.messages[server]) if server in client.messages else iop.copy(client.x)
if isinstance(client.cost, EmpiricalRiskCost):
cost = client.cost
n_samples = cost.n_samples
return self._epoch_minibatch_update(cost, local_x, cost.batch_size, n_samples)
empirical_cost = self._empirical_cost_for_local_updates(client.cost)
if empirical_cost is not None:
return self._run_minibatch_local_epochs(empirical_cost, local_x)
return self._run_full_gradient_local_epochs(client.cost, local_x)

def _run_full_gradient_local_epochs(self, cost: Cost, local_x: "Array") -> "Array":
for _ in range(self.num_local_epochs):
grad = client.cost.gradient(local_x)
grad = cost.gradient(local_x)
local_x -= self.step_size * grad
return local_x

def _epoch_minibatch_update(
def _run_minibatch_local_epochs(
self,
cost: EmpiricalRiskCost,
local_x: "Array",
per_client_batch: int,
n_samples: int,
) -> "Array":
for _ in range(self.num_local_epochs):
indices = list(range(n_samples))
indices = list(range(cost.n_samples))
random.shuffle(indices)
for start in range(0, n_samples, per_client_batch):
batch_indices = indices[start : start + per_client_batch]
for start in range(0, cost.n_samples, cost.batch_size):
batch_indices = indices[start : start + cost.batch_size]
grad = cost.gradient(local_x, indices=batch_indices)
local_x -= self.step_size * grad
return local_x
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While this assures that a full epoch through the dataset is performed, I've updated empirical risk cost sampling so that it will iterate through the entire dataset (random order) before it re-uses datapoints. Therefore you could simplify this a bit by removing the indices parameter, but its not a big deal.

One side effect; if you're using PyTorchCost with a dataloader, the indices parameter bypasses the dataloader and gathers data manually. From my experience dataloaders are slower when running on the cpu so its not commonly used but it might slow things down depending on the model and dataset size.

Expand Down
32 changes: 32 additions & 0 deletions decent_bench/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __init__(
agent_ids = [agent.id for agent in graph.nodes()]
if len(agent_ids) != len(set(agent_ids)):
raise ValueError("Agent IDs must be unique")
self._validate_agent_cost_compatibility(graph)

self._graph = graph
self._message_noise = self._initialize_message_schemes(message_noise, "noise", NoiseScheme, NoNoise)
Expand All @@ -84,6 +85,37 @@ def __init__(
self._buffer_messages = buffer_messages
self._iteration = 0 # Current iteration, updated by the algorithm

@staticmethod
def _validate_agent_cost_compatibility(graph: AgentGraph) -> None:
"""
Validate that all agents' costs share the same shape, framework, and device.

Raises:
ValueError: If agents in the graph have mismatching cost shape, framework, or device.

"""
agents = list(graph.nodes())
if len(agents) <= 1:
return

first_cost = agents[0].cost
first_signature = (first_cost.shape, first_cost.framework, first_cost.device)
mismatches: list[str] = []
for agent in agents[1:]:
signature = (agent.cost.shape, agent.cost.framework, agent.cost.device)
if signature != first_signature:
mismatches.append(
f"agent {agent.id}: shape={agent.cost.shape}, framework={agent.cost.framework}, "
f"device={agent.cost.device}"
)

if mismatches:
raise ValueError(
"All agents in a network must have costs with the same shape, framework, and device. "
f"Expected shape={first_cost.shape}, framework={first_cost.framework}, "
f"device={first_cost.device}; mismatches: {'; '.join(mismatches)}"
)

def _initialize_message_schemes(
self,
scheme: object,
Expand Down
2 changes: 1 addition & 1 deletion decent_bench/schemes.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def __init__(self, n_significant_digits: int):
self.n_significant_digits = n_significant_digits

def compress(self, msg: Array) -> Array: # noqa: D102
res = np.vectorize(lambda x: float(f"%.{self.n_significant_digits - 1}e" % x))(iop.to_numpy(msg)) # noqa: RUF073
res = np.vectorize(lambda x: float(format(x, f".{self.n_significant_digits - 1}e")))(iop.to_numpy(msg))
return iop.to_array_like(res, msg)

Comment on lines 180 to 183
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like there has to be a more efficient way of performing quantization than doing to_numpy -> float -> string -> float -> back to framework. If you feel like you have the time please check if there are any better ways of doing this, otherwise I'll just create an issue of this at some point no problem.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update: You dont have to worry about this. This is insanely inefficient, I have made an update to this and will include it in my bigger update within 1-2 weeks. Some simple math made this at least 10x more efficient


Expand Down
24 changes: 24 additions & 0 deletions docs/source/user.rst
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,30 @@ Classification
:module: decent_bench.costs


PyTorchCost regularization
~~~~~~~~~~~~~~~~~~~~~~~~~~
When combining :class:`~decent_bench.costs.PyTorchCost` with one of the
built-in regularizers, instantiate the regularizer with the same framework
and device as the empirical cost:

.. code-block:: python

from decent_bench.costs import L2RegularizerCost
from decent_bench.utils.types import SupportedFrameworks

reg = L2RegularizerCost(
shape=cost.shape,
framework=SupportedFrameworks.PYTORCH,
device=cost.device,
)
objective = cost + reg

This preserves compatibility with the PyTorch empirical objective and keeps
the resulting objective in the empirical, batch-compatible abstraction.
It is convenient for composition, but it is not necessarily the most
efficient option compared with native framework-specific regularization.


Execution settings
------------------
Configure settings for metrics, trials, statistical confidence level, logging, and multiprocessing.
Expand Down
51 changes: 46 additions & 5 deletions test/test_cost_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import decent_bench.utils.interoperability as iop
from decent_bench.costs import Cost, L1RegularizerCost, L2RegularizerCost, QuadraticCost, SumCost
from decent_bench.utils.types import SupportedDevices, SupportedFrameworks


def _simple_quadratic(A_scale: float, b_scale: float, c: float = 0.0) -> QuadraticCost:
Expand All @@ -12,20 +13,28 @@ def _simple_quadratic(A_scale: float, b_scale: float, c: float = 0.0) -> Quadrat


class _SimpleCost(Cost):
def __init__(self, scale: float):
def __init__(
self,
scale: float,
*,
framework: SupportedFrameworks = SupportedFrameworks.NUMPY,
device: SupportedDevices = SupportedDevices.CPU,
):
self.scale = scale
self._framework = framework
self._device = device

@property
def shape(self) -> tuple[int, ...]:
return (2,)

@property
def framework(self) -> str:
return "numpy"
def framework(self) -> SupportedFrameworks:
return self._framework

@property
def device(self) -> str | None:
return "cpu"
def device(self) -> SupportedDevices:
return self._device

@property
def m_smooth(self) -> float:
Expand Down Expand Up @@ -182,3 +191,35 @@ def test_cost_scalar_ops_reject_invalid_inputs() -> None:
_ = cost / 0.0
with pytest.raises(TypeError):
_ = 0.0 / cost


def test_cost_addition_rejects_mismatched_frameworks() -> None:
cost_a = _SimpleCost(scale=1.0, framework=SupportedFrameworks.NUMPY)
cost_b = _SimpleCost(scale=2.0, framework=SupportedFrameworks.PYTORCH)

with pytest.raises(ValueError, match="Mismatching frameworks"):
_ = cost_a + cost_b


def test_cost_addition_rejects_mismatched_devices() -> None:
cost_a = _SimpleCost(scale=1.0, device=SupportedDevices.CPU)
cost_b = _SimpleCost(scale=2.0, device=SupportedDevices.GPU)

with pytest.raises(ValueError, match="Mismatching devices"):
_ = cost_a + cost_b


def test_sum_cost_rejects_mismatched_frameworks() -> None:
cost_a = _SimpleCost(scale=1.0, framework=SupportedFrameworks.NUMPY)
cost_b = _SimpleCost(scale=2.0, framework=SupportedFrameworks.PYTORCH)

with pytest.raises(ValueError, match="Mismatching frameworks"):
SumCost([cost_a, cost_b])


def test_sum_cost_rejects_mismatched_devices() -> None:
cost_a = _SimpleCost(scale=1.0, device=SupportedDevices.CPU)
cost_b = _SimpleCost(scale=2.0, device=SupportedDevices.GPU)

with pytest.raises(ValueError, match="Mismatching devices"):
SumCost([cost_a, cost_b])
Loading
Loading