Skip to content

Global collective all-reduce fails on sequential calls due to hardcoded/reused transfer IDs #4549

@hs-cengel

Description

@hs-cengel

I used an LLM to help me write the issue below

Describe the bug

burn-collective's global all-reduce strategies (Tree, Centralized, Ring) use hardcoded or locally-scoped transfer IDs for the TensorDataService. When all_reduce is called multiple times in sequence (e.g., via GradientsParams::all_reduce, which iterates over each parameter tensor), the transfer IDs collide between calls, causing the data service to return the wrong tensor. This results in shape mismatches and panics.

Specifically:

Since the TensorDataService maintains a persistent WebSocket connection across calls, the second all_reduce call's expose(..., 0.into()) collides with stale state from the first call. This causes nodes to download the wrong tensor, leading to PeerSentIncoherentTensor errors or panics in broadcast_shape() / can_mut_broadcast() due to shape mismatches.

To Reproduce

  1. Set up a multi-node configuration with 2 nodes and a global orchestrator
  2. Call all_reduce more than once in sequence on different tensors (this is what GradientsParams::all_reduce does — it iterates over all parameter gradients and calls burn_collective::all_reduce for each one)
  3. The second all_reduce call panics with one of:
    • PeerSentIncoherentTensor — shape validation catches the wrong tensor
    • index out of bounds: the len is 1 but the index is 1 — in broadcast_shape() when the downloaded tensor has fewer dimensions than expected

Minimal reproduction: modify the multinode-tests node.rs to call all_reduce twice with tensors of different shapes (e.g., first [4, 8], then [16]). The second call will fail.

Note: the existing multinode-tests only call all_reduce once per session, which is why this wasn't caught.

Expected behavior

Sequential all_reduce calls should work correctly, each operating on its own tensor independently. This is required for GradientsParams::all_reduce to function, which iterates over all parameter gradients and calls burn_collective::all_reduce for each one — the primary integration point for DDP training with custom training loops.

Suggested fix

Use a monotonically increasing AtomicU64 counter (stored on the Node struct or passed through the all-reduce functions) for transfer IDs, so each expose/download_tensor pair gets a globally unique ID across calls. The counter should persist across all all_reduce invocations for the lifetime of the collective session.

For tree.rs, instead of:

data_service.expose(result.clone(), 1, 0.into()).await;
// ...
.download_tensor(child_addr.clone(), 0.into())

Use something like:

let id = node.next_transfer_id(); // AtomicU64::fetch_add(1, Ordering::SeqCst)
data_service.expose(result.clone(), 1, id.into()).await;
// ...
.download_tensor(child_addr.clone(), id.into())

Desktop:

  • OS: Linux (Ubuntu 24.04)
  • Burn version: 0.20.1
  • Tested with both NdArray and CUDA backends — same behavior on both

Additional context

The local collective (intra-process, thread-based) handles sequential all_reduce calls correctly. The bug is specific to the global collective (inter-process, WebSocket-based) used for multi-node training.

This blocks any use of GradientsParams::all_reduce with the global collective, which means DDP training via the documented pattern (custom training loop + GradientsParams::all_reduce) cannot work in multi-node setups.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions