Skip to content

Commit 34442cf

Browse files
committed
Merge branch 'main' into feat/fedprox
2 parents 02554b0 + f327cba commit 34442cf

14 files changed

Lines changed: 996 additions & 57 deletions

decent_bench/benchmark/_utils.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ def model_gen() -> torch.nn.Module:
7575
)
7676

7777
# Mypy cannot infer that cost_cls is PyTorchCost here
78-
costs = [
79-
cost_cls( # type: ignore[call-arg]
78+
pytorch_costs: list[PyTorchCost] = [
79+
PyTorchCost(
8080
dataset=p,
8181
model=model_gen(),
8282
loss_fn=torch.nn.CrossEntropyLoss(),
@@ -86,11 +86,15 @@ def model_gen() -> torch.nn.Module:
8686
)
8787
for p in dataset.get_partitions()
8888
]
89+
costs: Sequence[Cost] = pytorch_costs
8990
x_optimal = None
9091
elif cost_cls is LogisticRegressionCost:
91-
costs = [cost_cls(dataset=p, batch_size=batch_size) for p in dataset.get_partitions()] # type: ignore[call-arg]
92-
sum_cost = reduce(add, costs)
92+
classification_costs: list[LogisticRegressionCost] = [
93+
LogisticRegressionCost(dataset=p, batch_size=batch_size) for p in dataset.get_partitions()
94+
]
95+
sum_cost = reduce(add, classification_costs)
9396
x_optimal = ca.accelerated_gradient_descent(sum_cost, x0=None, max_iter=50000, stop_tol=1e-100, max_tol=1e-16)
97+
costs = classification_costs
9498
else:
9599
raise ValueError(f"Unsupported cost class: {cost_cls}")
96100

@@ -158,15 +162,19 @@ def model_gen() -> torch.nn.Module:
158162
output_size=1,
159163
)
160164

161-
costs = [
162-
cost_cls(dataset=p, model=model_gen(), loss_fn=torch.nn.MSELoss(), batch_size=batch_size, device=device) # type: ignore[call-arg]
165+
pytorch_costs: list[PyTorchCost] = [
166+
PyTorchCost(dataset=p, model=model_gen(), loss_fn=torch.nn.MSELoss(), batch_size=batch_size, device=device)
163167
for p in dataset.get_partitions()
164168
]
169+
costs: Sequence[Cost] = pytorch_costs
165170
x_optimal = None
166171
elif cost_cls is LinearRegressionCost:
167-
costs = [cost_cls(dataset=p, batch_size=batch_size) for p in dataset.get_partitions()] # type: ignore[call-arg]
168-
sum_cost = reduce(add, costs)
172+
regression_costs: list[LinearRegressionCost] = [
173+
LinearRegressionCost(dataset=p, batch_size=batch_size) for p in dataset.get_partitions()
174+
]
175+
sum_cost = reduce(add, regression_costs)
169176
x_optimal = ca.accelerated_gradient_descent(sum_cost, x0=None, max_iter=50000, stop_tol=1e-100, max_tol=1e-16)
177+
costs = regression_costs
170178
else:
171179
raise ValueError(f"Unsupported cost class: {cost_cls}")
172180

decent_bench/costs/_base/_cost.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ def _validate_cost_operation(
1515
self,
1616
other: object,
1717
*,
18-
check_framework: bool = False,
19-
check_device: bool = False,
18+
check_framework: bool = True,
19+
check_device: bool = True,
2020
) -> None:
2121
"""
2222
Validate that another object can participate in a binary cost operation.

decent_bench/costs/_base/_sum_cost.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,20 @@ class SumCost(Cost):
2525
"""
2626

2727
def __init__(self, costs: list[Cost]):
28-
if not all(costs[0].shape == cf.shape for cf in costs):
29-
raise ValueError("All cost functions must have the same domain shape")
28+
if len(costs) == 0:
29+
raise ValueError("SumCost must contain at least one cost function.")
30+
3031
self.costs: list[Cost] = []
3132
for cf in costs:
3233
if isinstance(cf, SumCost):
3334
self.costs.extend(cf.costs)
3435
else:
3536
self.costs.append(cf)
3637

38+
first = self.costs[0]
39+
for cf in self.costs[1:]:
40+
first._validate_cost_operation(cf) # noqa: SLF001
41+
3742
@property
3843
def shape(self) -> tuple[int, ...]:
3944
return self.costs[0].shape

decent_bench/datasets/_pytorch_handler.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import random
44
from collections import defaultdict
5+
from collections.abc import Iterable
56
from typing import TYPE_CHECKING, Any, cast
67

78
import decent_bench.utils.interoperability as iop
@@ -161,9 +162,10 @@ def _heterogeneous_split(self) -> list[Dataset]:
161162
"""
162163
# Group indices by class in a single pass
163164
class_to_indices: dict[int, list[int]] = defaultdict(list)
164-
for idx, (_, label) in enumerate(self.torch_dataset): # type: ignore[misc, arg-type]
165-
if label in class_to_indices or len(class_to_indices) < (self.n_partitions * self.targets_per_partition): # type: ignore[has-type]
166-
class_to_indices[label].append(idx) # type: ignore[has-type]
165+
for idx, sample in enumerate(cast("Iterable[Any]", self.torch_dataset)):
166+
_, label = cast("tuple[Any, int]", sample)
167+
if label in class_to_indices or len(class_to_indices) < (self.n_partitions * self.targets_per_partition):
168+
class_to_indices[label].append(idx)
167169

168170
# Create partitions from class-grouped indices
169171
idx_partitions = []

decent_bench/distributed_algorithms.py

Lines changed: 17 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
1-
import random
21
from abc import ABC, abstractmethod
32
from collections.abc import Callable, Iterable, Sequence
43
from dataclasses import dataclass, field
54
from typing import TYPE_CHECKING, Any, Final, cast, final
65

76
import decent_bench.utils.algorithm_helpers as alg_helpers
87
import decent_bench.utils.interoperability as iop
9-
from decent_bench.costs import EmpiricalRiskCost
108
from decent_bench.networks import FedNetwork, Network, P2PNetwork
119
from decent_bench.schemes import ClientSelectionScheme, UniformClientSelection
1210
from decent_bench.utils._tags import tags
@@ -179,6 +177,10 @@ class FedAlgorithm(Algorithm[FedNetwork]):
179177
def _cleanup_agents(self, network: FedNetwork) -> Iterable["Agent"]:
180178
return [network.server(), *network.clients()]
181179

180+
def server_broadcast(self, network: FedNetwork, selected_clients: Sequence["Agent"]) -> None:
181+
"""Send the current server model to the selected clients."""
182+
network.send(sender=network.server(), receiver=selected_clients, msg=network.server().x)
183+
182184
def select_clients(
183185
self,
184186
clients: Sequence["Agent"],
@@ -281,10 +283,10 @@ class FedAvg(FedAlgorithm):
281283
round :math:`k`. In FedAvg, each selected client performs ``num_local_epochs`` local SGD epochs, then the server
282284
aggregates the final local models to form :math:`\mathbf{x}_{k+1}`. The aggregation uses client weights, defaulting
283285
to data-size weights when ``client_weights`` is not provided. Client selection (subsampling) defaults to uniform
284-
sampling with fraction 1.0 (all active clients) and can be customized via ``selection_scheme``. For
285-
:class:`~decent_bench.costs.EmpiricalRiskCost`, local updates use mini-batches of size
286-
:attr:`EmpiricalRiskCost.batch_size <decent_bench.costs.EmpiricalRiskCost.batch_size>`; for generic costs, local
287-
updates use full-batch gradients.
286+
sampling with fraction 1.0 (all active clients) and can be customized via ``selection_scheme``. Costs that
287+
preserve the :class:`~decent_bench.costs.EmpiricalRiskCost` abstraction use client-side mini-batches of size
288+
:attr:`EmpiricalRiskCost.batch_size <decent_bench.costs.EmpiricalRiskCost.batch_size>`; generic cost wrappers
289+
fall back to full-gradient local updates.
288290
"""
289291

290292
# C=0.1; batch size= inf/10/50 (dataset sizes are bigger; normally 1/10 of the total dataset).
@@ -323,46 +325,28 @@ def step(self, network: FedNetwork, iteration: int) -> None: # noqa: D102
323325
if not selected_clients:
324326
return
325327

326-
self._sync_server_to_clients(network, selected_clients)
328+
self.server_broadcast(network, selected_clients)
327329
self._run_local_updates(network, selected_clients)
328330
self.aggregate(network, selected_clients)
329331

330-
def _sync_server_to_clients(self, network: FedNetwork, selected_clients: Sequence["Agent"]) -> None:
331-
network.send(sender=network.server(), receiver=selected_clients, msg=network.server().x)
332-
333332
def _run_local_updates(self, network: FedNetwork, selected_clients: Sequence["Agent"]) -> None:
334333
for client in selected_clients:
335334
client.x = self._compute_local_update(client, network.server())
336335
network.send(sender=client, receiver=network.server(), msg=client.x)
337336

338337
def _compute_local_update(self, client: "Agent", server: "Agent") -> "Array":
339-
local_x = iop.copy(client.messages[server]) if server in client.messages else iop.copy(client.x)
340-
if isinstance(client.cost, EmpiricalRiskCost):
341-
cost = client.cost
342-
n_samples = cost.n_samples
343-
return self._epoch_minibatch_update(cost, local_x, cost.batch_size, n_samples)
338+
"""
339+
Run local gradient steps using the batching semantics of ``client.cost.gradient``.
344340
341+
Costs that preserve the empirical-risk abstraction default ``gradient`` to ``indices="batch"``, so FedAvg
342+
performs mini-batch local updates automatically. Generic costs keep their usual full-gradient behavior.
343+
"""
344+
local_x = iop.copy(client.messages[server]) if server in client.messages else iop.copy(client.x)
345345
for _ in range(self.num_local_epochs):
346346
grad = client.cost.gradient(local_x)
347347
local_x -= self.step_size * grad
348348
return local_x
349349

350-
def _epoch_minibatch_update(
351-
self,
352-
cost: EmpiricalRiskCost,
353-
local_x: "Array",
354-
per_client_batch: int,
355-
n_samples: int,
356-
) -> "Array":
357-
for _ in range(self.num_local_epochs):
358-
indices = list(range(n_samples))
359-
random.shuffle(indices)
360-
for start in range(0, n_samples, per_client_batch):
361-
batch_indices = indices[start : start + per_client_batch]
362-
grad = cost.gradient(local_x, indices=batch_indices)
363-
local_x -= self.step_size * grad
364-
return local_x
365-
366350

367351
@tags("federated")
368352
@dataclass(eq=False)
@@ -1314,10 +1298,8 @@ def step(self, network: P2PNetwork, _: int) -> None: # noqa: D102
13141298
network.send(i, j, s)
13151299
for i in network.active_agents():
13161300
for j, msg in i.messages.items():
1317-
i.aux_vars["z_y"][j] = (1 - self.alpha) * i.aux_vars["z_y"][j] \
1318-
+ self.alpha * msg[0] # fmt: skip
1319-
i.aux_vars["z_s"][j] = (1 - self.alpha) * i.aux_vars["z_s"][j] \
1320-
+ self.alpha * msg[1] # fmt: skip
1301+
i.aux_vars["z_y"][j] = (1 - self.alpha) * i.aux_vars["z_y"][j] + self.alpha * msg[0]
1302+
i.aux_vars["z_s"][j] = (1 - self.alpha) * i.aux_vars["z_s"][j] + self.alpha * msg[1]
13211303

13221304

13231305
ADMMTracking = ATG # alias

decent_bench/networks.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def __init__(
7272
agent_ids = [agent.id for agent in graph.nodes()]
7373
if len(agent_ids) != len(set(agent_ids)):
7474
raise ValueError("Agent IDs must be unique")
75+
self._validate_agent_cost_compatibility(graph)
7576

7677
self._graph = graph
7778
self._message_noise = self._initialize_message_schemes(message_noise, "noise", NoiseScheme, NoNoise)
@@ -84,6 +85,37 @@ def __init__(
8485
self._buffer_messages = buffer_messages
8586
self._iteration = 0 # Current iteration, updated by the algorithm
8687

88+
@staticmethod
89+
def _validate_agent_cost_compatibility(graph: AgentGraph) -> None:
90+
"""
91+
Validate that all agents' costs share the same shape, framework, and device.
92+
93+
Raises:
94+
ValueError: If agents in the graph have mismatching cost shape, framework, or device.
95+
96+
"""
97+
agents = list(graph.nodes())
98+
if len(agents) <= 1:
99+
return
100+
101+
first_cost = agents[0].cost
102+
first_signature = (first_cost.shape, first_cost.framework, first_cost.device)
103+
mismatches: list[str] = []
104+
for agent in agents[1:]:
105+
signature = (agent.cost.shape, agent.cost.framework, agent.cost.device)
106+
if signature != first_signature:
107+
mismatches.append(
108+
f"agent {agent.id}: shape={agent.cost.shape}, framework={agent.cost.framework}, "
109+
f"device={agent.cost.device}"
110+
)
111+
112+
if mismatches:
113+
raise ValueError(
114+
"All agents in a network must have costs with the same shape, framework, and device. "
115+
f"Expected shape={first_cost.shape}, framework={first_cost.framework}, "
116+
f"device={first_cost.device}; mismatches: {'; '.join(mismatches)}"
117+
)
118+
87119
def _initialize_message_schemes(
88120
self,
89121
scheme: object,

decent_bench/schemes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def __init__(self, n_significant_digits: int):
178178
self.n_significant_digits = n_significant_digits
179179

180180
def compress(self, msg: Array) -> Array: # noqa: D102
181-
res = np.vectorize(lambda x: float(f"%.{self.n_significant_digits - 1}e" % x))(iop.to_numpy(msg)) # noqa: RUF073
181+
res = np.vectorize(lambda x: float(format(x, f".{self.n_significant_digits - 1}e")))(iop.to_numpy(msg))
182182
return iop.to_array_like(res, msg)
183183

184184

docs/source/user.rst

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,30 @@ Classification
128128
:module: decent_bench.costs
129129

130130

131+
PyTorchCost regularization
132+
~~~~~~~~~~~~~~~~~~~~~~~~~~
133+
When combining :class:`~decent_bench.costs.PyTorchCost` with one of the
134+
built-in regularizers, instantiate the regularizer with the same framework
135+
and device as the empirical cost:
136+
137+
.. code-block:: python
138+
139+
from decent_bench.costs import L2RegularizerCost
140+
from decent_bench.utils.types import SupportedFrameworks
141+
142+
reg = L2RegularizerCost(
143+
shape=cost.shape,
144+
framework=SupportedFrameworks.PYTORCH,
145+
device=cost.device,
146+
)
147+
objective = cost + reg
148+
149+
This preserves compatibility with the PyTorch empirical objective and keeps
150+
the resulting objective in the empirical, batch-compatible abstraction.
151+
It is convenient for composition, but it is not necessarily the most
152+
efficient option compared with native framework-specific regularization.
153+
154+
131155
Execution settings
132156
------------------
133157
Configure settings for metrics, trials, statistical confidence level, logging, and multiprocessing.

test/test_cost_operators.py

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import decent_bench.utils.interoperability as iop
55
from decent_bench.costs import Cost, L1RegularizerCost, L2RegularizerCost, QuadraticCost, SumCost
6+
from decent_bench.utils.types import SupportedDevices, SupportedFrameworks
67

78

89
def _simple_quadratic(A_scale: float, b_scale: float, c: float = 0.0) -> QuadraticCost:
@@ -12,20 +13,28 @@ def _simple_quadratic(A_scale: float, b_scale: float, c: float = 0.0) -> Quadrat
1213

1314

1415
class _SimpleCost(Cost):
15-
def __init__(self, scale: float):
16+
def __init__(
17+
self,
18+
scale: float,
19+
*,
20+
framework: SupportedFrameworks = SupportedFrameworks.NUMPY,
21+
device: SupportedDevices = SupportedDevices.CPU,
22+
):
1623
self.scale = scale
24+
self._framework = framework
25+
self._device = device
1726

1827
@property
1928
def shape(self) -> tuple[int, ...]:
2029
return (2,)
2130

2231
@property
23-
def framework(self) -> str:
24-
return "numpy"
32+
def framework(self) -> SupportedFrameworks:
33+
return self._framework
2534

2635
@property
27-
def device(self) -> str | None:
28-
return "cpu"
36+
def device(self) -> SupportedDevices:
37+
return self._device
2938

3039
@property
3140
def m_smooth(self) -> float:
@@ -182,3 +191,35 @@ def test_cost_scalar_ops_reject_invalid_inputs() -> None:
182191
_ = cost / 0.0
183192
with pytest.raises(TypeError):
184193
_ = 0.0 / cost
194+
195+
196+
def test_cost_addition_rejects_mismatched_frameworks() -> None:
197+
cost_a = _SimpleCost(scale=1.0, framework=SupportedFrameworks.NUMPY)
198+
cost_b = _SimpleCost(scale=2.0, framework=SupportedFrameworks.PYTORCH)
199+
200+
with pytest.raises(ValueError, match="Mismatching frameworks"):
201+
_ = cost_a + cost_b
202+
203+
204+
def test_cost_addition_rejects_mismatched_devices() -> None:
205+
cost_a = _SimpleCost(scale=1.0, device=SupportedDevices.CPU)
206+
cost_b = _SimpleCost(scale=2.0, device=SupportedDevices.GPU)
207+
208+
with pytest.raises(ValueError, match="Mismatching devices"):
209+
_ = cost_a + cost_b
210+
211+
212+
def test_sum_cost_rejects_mismatched_frameworks() -> None:
213+
cost_a = _SimpleCost(scale=1.0, framework=SupportedFrameworks.NUMPY)
214+
cost_b = _SimpleCost(scale=2.0, framework=SupportedFrameworks.PYTORCH)
215+
216+
with pytest.raises(ValueError, match="Mismatching frameworks"):
217+
SumCost([cost_a, cost_b])
218+
219+
220+
def test_sum_cost_rejects_mismatched_devices() -> None:
221+
cost_a = _SimpleCost(scale=1.0, device=SupportedDevices.CPU)
222+
cost_b = _SimpleCost(scale=2.0, device=SupportedDevices.GPU)
223+
224+
with pytest.raises(ValueError, match="Mismatching devices"):
225+
SumCost([cost_a, cost_b])

0 commit comments

Comments
 (0)