Skip to content

[Bug/Hardware Incompatibility] RTX 5090 (Blackwell / sm_120) crashes with "_deps/repo-flash-attention-src/hopper/flash_fwd_launch_template.h:166) during inference #160

@fakirsu

Description

@fakirsu

Thank you for your incredible work on this project!

Describe the bug

I am trying to run fishaudio/s2-pro using sglang_omni (in docker) on a brand new RTX 5090 (Blackwell architecture, compute capability 12.0). The server successfully starts, loads the model weights into VRAM, and passes the health check. However, as soon as it receives the first inference request (via Gradio or direct curl), it crashes with a CUDA error pointing to a Hopper-specific kernel:

CUDA error (_deps/repo-flash-attention-src/hopper/flash_fwd_launch_template.h:166): no kernel image is available for execution on the device
My Hypothesis It appears that sglang-omni (or its underlying SGLang custom CUDA operations) deeply embeds pre-compiled C++ dependencies or templates specifically targeting the Hopper architecture (sm_90). Because the RTX 5090 is sm_120, the GPU fails to execute these Hopper-specific instructions.

Environment

  • GPU: NVIDIA RTX 5090 (24GB VRAM)
  • OS: Windows 11 (running via Docker inside WSL2)
  • Base Image: frankleeeee/sglang-omni:dev
  • Command used: TORCH_COMPILE_DISABLE=1 python3 -m sglang_omni.cli.cli serve --host 0.0.0.0 --port 30001 --model-path fishaudio/s2-pro --config examples/configs/s2pro_tts.yaml

What we have tried so far (Workarounds that failed): We spent several hours trying to force SGLang to use native PyTorch SDPA or bypass the Hopper kernels, but nothing worked:

1. Recompiling Flash-Attention from source for sm_120: We ran TORCH_CUDA_ARCH_LIST="12.0" uv pip install flash-attn --no-build-isolation --no-binary flash-attn --reinstall. The compilation took 15 minutes and succeeded, but the runtime error remained the same, suggesting the Hopper dependency is baked into SGLang's own C++ extensions, not just the Python flash-attn package.
2. Uninstalling FlashAttention and FlashInfer entirely: We used uv pip uninstall flash-attn flashinfer and exported TORCH_ENABLE_FLASH_ATTENTION=0 hoping SGLang would gracefully fallback to a pure Torch Native math attention. The crash still occurred.
3. Patching server_args.py: We injected code to force self.attention_backend = "triton", bypassing flashinfer. The server booted with Triton logs, but still crashed on inference with the exact same Hopper kernel error.
4. Considering FlashAttention-4 (CuTeDSL): We noticed FA4 is optimized for Blackwell, but we understand we cannot simply pip install flash-attn-4 because SGLang's backend and API calls are hardcoded for FA2/FA3 structures (flash_attn_varlen_func, etc.).

Questions for the maintainers:

  1. Can you confirm that the current C++ backend of SGLang/SGLang Omni is indeed intrinsically incompatible with Blackwell (sm_120) out of the box due to these Hopper kernels?
  2. Is there any "pure Python/PyTorch" fallback flag we missed that could completely disable the loading of these custom .so C++ extensions and allow us to run the model on the RTX 5090, even if it means running at a slower speed?
  3. Is Blackwell (RTX 50-series / B200) support on the roadmap for the near future?

Thank you for your incredible work on this project!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions