Skip to content

miles Docker image does not support B300 (sm_103a) #533

@harvenstar

Description

@harvenstar

Problem

Tried to reproduce Qwen3-14B training (#530) on 8x B300 SXM6 (275GB). The Docker image radixark/miles:latest doesn't work on B300 because the bundled ptxas doesn't know about sm_103a.

What works

  • Docker pull, git pull, pip install — all fine
  • Weight conversion (convert_hf_to_torch_dist.py) — completes successfully, 14.77B params loaded and saved
  • Ray cluster startup + Megatron model loading — fine

What breaks

SGLang engine startup fails. I hit three issues in sequence:

1. TP memory imbalance check (workaround exists)

In colocate mode, Megatron loads first and takes GPU memory. SGLang then sees uneven memory across GPUs and refuses to start:

RuntimeError: The memory capacity is unbalanced.
min_per_gpu_memory=101.79, local_gpu_memory=265.55, local_gpu_memory * 0.9=238.99

miles sets SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK=true, but the current SGLang version checks SGLANG_ENABLE_TP_MEMORY_INBALANCE_CHECK instead (defaults to true). The old env var is deprecated and no longer controls the behavior.

Workaround: add "SGLANG_ENABLE_TP_MEMORY_INBALANCE_CHECK": "false" to the runtime env.

2. CUDA graph capture fails (workaround exists)

After bypassing (1), Triton fails to compile kernels during CUDA graph capture:

NoTritonConfigsError: No valid triton configs. PTXASError: PTXAS error: Internal Triton PTX codegen error

Workaround: --sglang-disable-cuda-graph

3. ptxas does not support sm_103a (no workaround)

After bypassing (1) and (2), SGLang still crashes because core Triton kernels (e.g. alloc_extend_kernel in the memory allocator) also fail to compile:

ptxas fatal   : Value 'sm_103a' is not defined for option 'gpu-name'

This is the bundled ptxas inside the Triton package, not the system one. There's no script-level workaround for this.

Environment

  • 8x NVIDIA B300 SXM6 AC, 275GB each
  • radixark/miles:latest (sha256:6e467519505da8407f8c33e3c5ec622c0cc4e6e0929102efe34f8415940caf06)
  • Triton ptxas path: /usr/local/lib/python3.12/dist-packages/triton/backends/nvidia/bin/ptxas

Notes

  • This blocks all models on B300, not just 14B.
  • Issue (1) about the deprecated env var seems like a miles-side bug that affects all GPUs in colocate mode — worth fixing separately regardless of B300 support.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions