Skip to content

Commit b34d94f

Browse files
committed
fix plugs
1 parent 2ab6859 commit b34d94f

File tree

3 files changed

+9
-23
lines changed

3 files changed

+9
-23
lines changed

_unittests/ut_tasks/try_export.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -212,18 +212,18 @@ def _config_reduction(config, task):
212212
print(f"-- MODEL CONVERTED IN {time.perf_counter() - begin}")
213213
model = onnx.load(filename, load_external_data=False)
214214
if attention == "PACKED":
215-
self.assertIn("PackedMultiHeadAttention", str(model))
215+
self.assertIn('"PackedMultiHeadAttention"', str(model))
216216
elif attention == "BIGMASK":
217-
self.assertNotIn("PackedMultiHeadAttention", str(model))
217+
self.assertNotIn('"PackedMultiHeadAttention"', str(model))
218218
self.assertNotIn("MultiHeadAttention", str(model))
219219
self.assertNotIn("Loop", {n.op_type for n in model.graph.node})
220220
elif attention == "LOOPMHA":
221-
self.assertNotIn("PackedMultiHeadAttention", str(model))
222-
self.assertIn("MultiHeadAttention", str(model))
221+
self.assertNotIn('"PackedMultiHeadAttention"', str(model))
222+
self.assertIn('"MultiHeadAttention"', str(model))
223223
self.assertIn("Loop", {n.op_type for n in model.graph.node})
224224
elif attention == "LOOPA24":
225-
self.assertNotIn("PackedMultiHeadAttention", str(model))
226-
self.assertNotIn("MultiHeadAttention", str(model))
225+
self.assertNotIn('"PackedMultiHeadAttention"', str(model))
226+
self.assertNotIn('"MultiHeadAttention"', str(model))
227227
self.assertIn("Loop", {n.op_type for n in model.graph.node})
228228
else:
229229
raise AssertionError(f"attention={attention!r} not expected")
@@ -257,7 +257,7 @@ def _config_reduction(config, task):
257257
else ["CPUExecutionProvider"]
258258
),
259259
use_ort=True,
260-
atol=0.02,
260+
atol=0.05,
261261
rtol=10,
262262
# ep=pt2_file,
263263
expected=expected,

_unittests/ut_torch_onnx/test_sbs.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -693,21 +693,7 @@ def forward(self, query, key, value, seq_lens):
693693
ks = key * mask
694694
vs = value * mask
695695
attn_output = qwen_sdpa_attention_loopmha_versatile(
696-
qs,
697-
ks,
698-
vs,
699-
seq_lens,
700-
0.11,
701-
16,
702-
(
703-
onnx.TensorProto.FLOAT
704-
if query.dtype == torch.float32
705-
else (
706-
onnx.TensorProto.FLOAT16
707-
if query.dtype == torch.float16
708-
else onnx.TensorProto.BFLOAT16
709-
)
710-
),
696+
qs, ks, vs, seq_lens, 0.11, 16
711697
)
712698
red = attn_output.mean(dim=-1, keepdim=True)
713699
return attn_output - red

onnx_diagnostic/export/onnx_plug.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def get_function_proto(self, *args) -> onnx.FunctionProto:
232232
):
233233
return self._function_proto_versioned[args[0]]
234234
try:
235-
key = self.version_selector(*args)
235+
key = self.version_selector(*args) # type: ignore[misc]
236236
except (ValueError, AttributeError) as e:
237237
raise AssertionError(
238238
f"Unable to select a version, fails to get a key, available="

0 commit comments

Comments
 (0)