Skip to content

Commit 7b8eb85

Browse files
committed
fix issues
1 parent 6137993 commit 7b8eb85

File tree

3 files changed

+10
-5
lines changed

3 files changed

+10
-5
lines changed

onnx_diagnostic/ext_test_case.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -910,7 +910,7 @@ def assertEqualAny(
910910
elif hasattr(expected, "shape"):
911911
self.assertEqual(type(expected), type(value), msg=msg)
912912
self.assertEqualArray(expected, value, msg=msg, atol=atol, rtol=rtol)
913-
elif expected.__class__.__name__ in ("Dim", "_Dim"):
913+
elif expected.__class__.__name__ in ("Dim", "_Dim", "_DimHintType"):
914914
self.assertEqual(type(expected), type(value), msg=msg)
915915
self.assertEqual(expected.__name__, value.__name__, msg=msg)
916916
else:

onnx_diagnostic/helpers/helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ def string_type(
409409
f"dtype={obj.dtype}, shape={obj.shape})"
410410
)
411411

412-
if obj.__class__.__name__ == "_DimHint":
412+
if obj.__class__.__name__ in ("_DimHint", "_DimHintType"):
413413
return str(obj)
414414

415415
if isinstance(obj, torch.nn.Module):

onnx_diagnostic/torch_models/test_helper.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ def validate_model(
226226
"""
227227
assert not trained, f"trained={trained} not supported yet"
228228
summary = version_summary()
229+
folder_name = None
229230
if dump_folder:
230231
folder_name = _make_folder_name(
231232
model_id, exporter, optimization, dtype=dtype, device=device
@@ -237,12 +238,12 @@ def validate_model(
237238
summary["dump_folder_name"] = folder_name
238239
if verbose:
239240
print(f"[validate_model] dump into {folder_name!r}")
240-
else:
241-
folder_name = None
241+
242242
if verbose:
243243
print(f"[validate_model] validate model id {model_id!r}")
244244
print("[validate_model] get dummy inputs...")
245245
summary["model_id"] = model_id
246+
246247
begin = time.perf_counter()
247248
if quiet:
248249
try:
@@ -344,7 +345,10 @@ def validate_model(
344345
)
345346
if patch:
346347
if verbose:
347-
print("[validate_model] applies patches before exporting")
348+
print(
349+
f"[validate_model] applies patches before exporting "
350+
f"stop_if_static={stop_if_static}"
351+
)
348352
with bypass_export_some_errors( # type: ignore
349353
patch_transformers=True,
350354
stop_if_static=stop_if_static,
@@ -527,6 +531,7 @@ def call_torch_export_export(
527531
"export-strict",
528532
"export-nostrict",
529533
}, f"Unexpected value for exporter={exporter!r}"
534+
assert not optimization, f"No optimization is implemented for exporter={exporter!r}"
530535
assert "model" in data, f"model is missing from data: {sorted(data)}"
531536
assert "inputs_export" in data, f"inputs_export is missing from data: {sorted(data)}"
532537
summary: Dict[str, Union[str, int, float]] = {}

0 commit comments

Comments
 (0)