Skip to content

Commit 3bbbc65

Browse files
Safeq qwen test (#1900)
* add safe flash attn check * cropt test --------- Co-authored-by: Grzegorz Klimaszewski <166530809+grzegorz-roboflow@users.noreply.github.com>
1 parent 7f10730 commit 3bbbc65

File tree

1 file changed

+17
-5
lines changed

1 file changed

+17
-5
lines changed

inference_models/inference_models/models/qwen3vl/qwen3vl_hf.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,22 @@
2121
)
2222

2323

24+
def _get_qwen3vl_attn_implementation(device: torch.device) -> str:
25+
"""Use flash_attention_2 if available, otherwise eager.
26+
27+
SDPA has dtype mismatch issues with some transformers versions.
28+
"""
29+
if is_flash_attn_2_available() and device and "cuda" in str(device):
30+
# Verify flash_attn can actually be imported (not just installed)
31+
try:
32+
import flash_attn # noqa: F401
33+
34+
return "flash_attention_2"
35+
except ImportError:
36+
pass
37+
return "eager"
38+
39+
2440
class Qwen3VLHF:
2541
default_dtype = torch.bfloat16
2642

@@ -53,11 +69,7 @@ def from_pretrained(
5369

5470
dtype = cls.default_dtype
5571

56-
attn_implementation = (
57-
"flash_attention_2"
58-
if (is_flash_attn_2_available() and device and "cuda" in str(device))
59-
else "eager"
60-
)
72+
attn_implementation = _get_qwen3vl_attn_implementation(device)
6173

6274
if os.path.exists(adapter_config_path):
6375
# Has adapter - load base model then apply LoRA

0 commit comments

Comments
 (0)