-
Notifications
You must be signed in to change notification settings - Fork 830
Description
I used an LLM to help me write the issue below
Describe the bug
burn-collective's global all-reduce strategies (Tree, Centralized, Ring) use hardcoded or locally-scoped transfer IDs for the TensorDataService. When all_reduce is called multiple times in sequence (e.g., via GradientsParams::all_reduce, which iterates over each parameter tensor), the transfer IDs collide between calls, causing the data service to return the wrong tensor. This results in shape mismatches and panics.
Specifically:
- Tree (
global/node/tree.rs): hardcodes transfer IDs0and1at L45, L66, L71, L87 - Centralized (
global/node/centralized.rs): hardcodes transfer IDs0and1at L42, L62, L68, L73 - Ring (
global/node/ring.rs): uses a localtransfer_counterstarting at0per call at L83, which resets on each invocation
Since the TensorDataService maintains a persistent WebSocket connection across calls, the second all_reduce call's expose(..., 0.into()) collides with stale state from the first call. This causes nodes to download the wrong tensor, leading to PeerSentIncoherentTensor errors or panics in broadcast_shape() / can_mut_broadcast() due to shape mismatches.
To Reproduce
- Set up a multi-node configuration with 2 nodes and a global orchestrator
- Call
all_reducemore than once in sequence on different tensors (this is whatGradientsParams::all_reducedoes — it iterates over all parameter gradients and callsburn_collective::all_reducefor each one) - The second
all_reducecall panics with one of:PeerSentIncoherentTensor— shape validation catches the wrong tensorindex out of bounds: the len is 1 but the index is 1— inbroadcast_shape()when the downloaded tensor has fewer dimensions than expected
Minimal reproduction: modify the multinode-tests node.rs to call all_reduce twice with tensors of different shapes (e.g., first [4, 8], then [16]). The second call will fail.
Note: the existing multinode-tests only call all_reduce once per session, which is why this wasn't caught.
Expected behavior
Sequential all_reduce calls should work correctly, each operating on its own tensor independently. This is required for GradientsParams::all_reduce to function, which iterates over all parameter gradients and calls burn_collective::all_reduce for each one — the primary integration point for DDP training with custom training loops.
Suggested fix
Use a monotonically increasing AtomicU64 counter (stored on the Node struct or passed through the all-reduce functions) for transfer IDs, so each expose/download_tensor pair gets a globally unique ID across calls. The counter should persist across all all_reduce invocations for the lifetime of the collective session.
For tree.rs, instead of:
data_service.expose(result.clone(), 1, 0.into()).await;
// ...
.download_tensor(child_addr.clone(), 0.into())Use something like:
let id = node.next_transfer_id(); // AtomicU64::fetch_add(1, Ordering::SeqCst)
data_service.expose(result.clone(), 1, id.into()).await;
// ...
.download_tensor(child_addr.clone(), id.into())Desktop:
- OS: Linux (Ubuntu 24.04)
- Burn version: 0.20.1
- Tested with both NdArray and CUDA backends — same behavior on both
Additional context
The local collective (intra-process, thread-based) handles sequential all_reduce calls correctly. The bug is specific to the global collective (inter-process, WebSocket-based) used for multi-node training.
This blocks any use of GradientsParams::all_reduce with the global collective, which means DDP training via the documented pattern (custom training loop + GradientsParams::all_reduce) cannot work in multi-node setups.