|
1 | | -import random |
2 | 1 | from abc import ABC, abstractmethod |
3 | 2 | from collections.abc import Callable, Iterable, Sequence |
4 | 3 | from dataclasses import dataclass, field |
5 | 4 | from typing import TYPE_CHECKING, Any, Final, cast, final |
6 | 5 |
|
7 | 6 | import decent_bench.utils.algorithm_helpers as alg_helpers |
8 | 7 | import decent_bench.utils.interoperability as iop |
9 | | -from decent_bench.costs import EmpiricalRiskCost |
10 | 8 | from decent_bench.networks import FedNetwork, Network, P2PNetwork |
11 | 9 | from decent_bench.schemes import ClientSelectionScheme, UniformClientSelection |
12 | 10 | from decent_bench.utils._tags import tags |
@@ -179,6 +177,10 @@ class FedAlgorithm(Algorithm[FedNetwork]): |
179 | 177 | def _cleanup_agents(self, network: FedNetwork) -> Iterable["Agent"]: |
180 | 178 | return [network.server(), *network.clients()] |
181 | 179 |
|
| 180 | + def server_broadcast(self, network: FedNetwork, selected_clients: Sequence["Agent"]) -> None: |
| 181 | + """Send the current server model to the selected clients.""" |
| 182 | + network.send(sender=network.server(), receiver=selected_clients, msg=network.server().x) |
| 183 | + |
182 | 184 | def select_clients( |
183 | 185 | self, |
184 | 186 | clients: Sequence["Agent"], |
@@ -281,10 +283,10 @@ class FedAvg(FedAlgorithm): |
281 | 283 | round :math:`k`. In FedAvg, each selected client performs ``num_local_epochs`` local SGD epochs, then the server |
282 | 284 | aggregates the final local models to form :math:`\mathbf{x}_{k+1}`. The aggregation uses client weights, defaulting |
283 | 285 | to data-size weights when ``client_weights`` is not provided. Client selection (subsampling) defaults to uniform |
284 | | - sampling with fraction 1.0 (all active clients) and can be customized via ``selection_scheme``. For |
285 | | - :class:`~decent_bench.costs.EmpiricalRiskCost`, local updates use mini-batches of size |
286 | | - :attr:`EmpiricalRiskCost.batch_size <decent_bench.costs.EmpiricalRiskCost.batch_size>`; for generic costs, local |
287 | | - updates use full-batch gradients. |
| 286 | + sampling with fraction 1.0 (all active clients) and can be customized via ``selection_scheme``. Costs that |
| 287 | + preserve the :class:`~decent_bench.costs.EmpiricalRiskCost` abstraction use client-side mini-batches of size |
| 288 | + :attr:`EmpiricalRiskCost.batch_size <decent_bench.costs.EmpiricalRiskCost.batch_size>`; generic cost wrappers |
| 289 | + fall back to full-gradient local updates. |
288 | 290 | """ |
289 | 291 |
|
290 | 292 | # C=0.1; batch size= inf/10/50 (dataset sizes are bigger; normally 1/10 of the total dataset). |
@@ -323,46 +325,28 @@ def step(self, network: FedNetwork, iteration: int) -> None: # noqa: D102 |
323 | 325 | if not selected_clients: |
324 | 326 | return |
325 | 327 |
|
326 | | - self._sync_server_to_clients(network, selected_clients) |
| 328 | + self.server_broadcast(network, selected_clients) |
327 | 329 | self._run_local_updates(network, selected_clients) |
328 | 330 | self.aggregate(network, selected_clients) |
329 | 331 |
|
330 | | - def _sync_server_to_clients(self, network: FedNetwork, selected_clients: Sequence["Agent"]) -> None: |
331 | | - network.send(sender=network.server(), receiver=selected_clients, msg=network.server().x) |
332 | | - |
333 | 332 | def _run_local_updates(self, network: FedNetwork, selected_clients: Sequence["Agent"]) -> None: |
334 | 333 | for client in selected_clients: |
335 | 334 | client.x = self._compute_local_update(client, network.server()) |
336 | 335 | network.send(sender=client, receiver=network.server(), msg=client.x) |
337 | 336 |
|
338 | 337 | def _compute_local_update(self, client: "Agent", server: "Agent") -> "Array": |
339 | | - local_x = iop.copy(client.messages[server]) if server in client.messages else iop.copy(client.x) |
340 | | - if isinstance(client.cost, EmpiricalRiskCost): |
341 | | - cost = client.cost |
342 | | - n_samples = cost.n_samples |
343 | | - return self._epoch_minibatch_update(cost, local_x, cost.batch_size, n_samples) |
| 338 | + """ |
| 339 | + Run local gradient steps using the batching semantics of ``client.cost.gradient``. |
344 | 340 |
|
| 341 | + Costs that preserve the empirical-risk abstraction default ``gradient`` to ``indices="batch"``, so FedAvg |
| 342 | + performs mini-batch local updates automatically. Generic costs keep their usual full-gradient behavior. |
| 343 | + """ |
| 344 | + local_x = iop.copy(client.messages[server]) if server in client.messages else iop.copy(client.x) |
345 | 345 | for _ in range(self.num_local_epochs): |
346 | 346 | grad = client.cost.gradient(local_x) |
347 | 347 | local_x -= self.step_size * grad |
348 | 348 | return local_x |
349 | 349 |
|
350 | | - def _epoch_minibatch_update( |
351 | | - self, |
352 | | - cost: EmpiricalRiskCost, |
353 | | - local_x: "Array", |
354 | | - per_client_batch: int, |
355 | | - n_samples: int, |
356 | | - ) -> "Array": |
357 | | - for _ in range(self.num_local_epochs): |
358 | | - indices = list(range(n_samples)) |
359 | | - random.shuffle(indices) |
360 | | - for start in range(0, n_samples, per_client_batch): |
361 | | - batch_indices = indices[start : start + per_client_batch] |
362 | | - grad = cost.gradient(local_x, indices=batch_indices) |
363 | | - local_x -= self.step_size * grad |
364 | | - return local_x |
365 | | - |
366 | 350 |
|
367 | 351 | @tags("federated") |
368 | 352 | @dataclass(eq=False) |
@@ -1314,10 +1298,8 @@ def step(self, network: P2PNetwork, _: int) -> None: # noqa: D102 |
1314 | 1298 | network.send(i, j, s) |
1315 | 1299 | for i in network.active_agents(): |
1316 | 1300 | for j, msg in i.messages.items(): |
1317 | | - i.aux_vars["z_y"][j] = (1 - self.alpha) * i.aux_vars["z_y"][j] \ |
1318 | | - + self.alpha * msg[0] # fmt: skip |
1319 | | - i.aux_vars["z_s"][j] = (1 - self.alpha) * i.aux_vars["z_s"][j] \ |
1320 | | - + self.alpha * msg[1] # fmt: skip |
| 1301 | + i.aux_vars["z_y"][j] = (1 - self.alpha) * i.aux_vars["z_y"][j] + self.alpha * msg[0] |
| 1302 | + i.aux_vars["z_s"][j] = (1 - self.alpha) * i.aux_vars["z_s"][j] + self.alpha * msg[1] |
1321 | 1303 |
|
1322 | 1304 |
|
1323 | 1305 | ADMMTracking = ATG # alias |
|
0 commit comments