Skip to content

Commit 067f898

Browse files
committed
better
1 parent 7558413 commit 067f898

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

onnx_diagnostic/torch_models/hghub/model_inputs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def get_untrained_model_with_inputs(
9696
assert (
9797
type(config) is not dict
9898
), f"Unable to set dynamic_rope if the configuration is a dictionary\n{config}"
99+
assert hasattr(config, "rope_scaling"), f"Missing 'rope_scaling' in\n{config}"
99100
config.rope_scaling = (
100101
{"rope_type": "dynamic", "factor": 10.0} if dynamic_rope else None
101102
)

onnx_diagnostic/torch_models/test_helper.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1069,15 +1069,16 @@ def call_torch_export_custom(
10691069
assert (
10701070
optimization in available
10711071
), f"unexpected value for optimization={optimization}, available={available}"
1072-
assert exporter in {
1072+
available = {
10731073
"custom",
10741074
"custom-strict",
1075-
"custom-strict-dec",
1075+
"custom-strict-default",
10761076
"custom-strict-all",
10771077
"custom-nostrict",
1078-
"custom-nostrict-dec",
1078+
"custom-nostrict-default",
10791079
"custom-nostrict-all",
1080-
}, f"Unexpected value for exporter={exporter!r}"
1080+
}
1081+
assert exporter in available, f"Unexpected value for exporter={exporter!r} in {available}"
10811082
assert "model" in data, f"model is missing from data: {sorted(data)}"
10821083
assert "inputs_export" in data, f"inputs_export is missing from data: {sorted(data)}"
10831084
summary: Dict[str, Union[str, int, float]] = {}
@@ -1109,7 +1110,7 @@ def call_torch_export_custom(
11091110
export_options = ExportOptions(
11101111
strict=strict,
11111112
decomposition_table=(
1112-
"dec" if "-dec" in exporter else ("all" if "-all" in exporter else None)
1113+
"default" if "-default" in exporter else ("all" if "-all" in exporter else None)
11131114
),
11141115
)
11151116
options = OptimizationOptions(patterns=optimization) if optimization else None

0 commit comments

Comments
 (0)