Skip to content

Commit 78d3520

Browse files
authored
fix missing parameter in config for qwen (#365)
* fix config for qwen * fixes
1 parent a52db71 commit 78d3520

File tree

3 files changed

+17
-4
lines changed

3 files changed

+17
-4
lines changed

onnx_diagnostic/ci_models/export_qwen25_vl.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,10 @@ def main(
280280
).eval()
281281
data = dict(model=model)
282282
config = model.config
283+
if not hasattr(config, "bos_token_id") or not config.bos_token_id:
284+
config.bos_token_id = 151643
285+
if not hasattr(config, "eos_token_id") or not config.eos_token_id:
286+
config.eos_token_id = 151645
283287
else:
284288
print("-- random model")
285289
data = get_untrained_model(model_id, second_input=second_input, verbose=1)

onnx_diagnostic/helpers/log_helper.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1921,9 +1921,7 @@ def first_err(df: pandas.DataFrame) -> pandas.Series:
19211921
return lambdas[formula]
19221922

19231923
if formula == "onnx_n_nodes_no_cst":
1924-
return lambda df: gdf(df, "onnx_n_nodes", 0) - gdf(
1925-
df, "op_onnx__Constant", 0
1926-
).fillna(0)
1924+
return lambda df: gdf(df, "onnx_n_nodes", 0) - gdf(df, "op_onnx__Constant", 0)
19271925
if formula == "peak_gpu_torch":
19281926
return lambda df: gdf(df, "mema_gpu_5_after_export") - gdf(df, "mema_gpu_4_reset")
19291927
if formula == "peak_gpu_nvidia":

onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,10 +256,21 @@ def qwen_sdpa_attention(
256256
return attn_output
257257

258258
def qwen_version_selector(opset: int, *args: torch.Tensor) -> Tuple[str, torch.dtype]:
259+
import onnx_ir
260+
259261
first_float_tensor = next(
260262
a
261263
for a in args
262-
if a is not None and a.dtype in {torch.float16, torch.float32, torch.bfloat16}
264+
if a is not None
265+
and a.dtype
266+
in {
267+
torch.float16,
268+
torch.float32,
269+
torch.bfloat16,
270+
onnx_ir.DataType.BFLOAT16,
271+
onnx_ir.DataType.FLOAT16,
272+
onnx_ir.DataType.FLOAT,
273+
}
263274
)
264275
dtype = first_float_tensor.dtype
265276
strategy = patched_Qwen2_5_VLVisionAttention.STRATEGY_FOR_ATTENTION()

0 commit comments

Comments
 (0)