Skip to content

Commit 2d1e38a

Browse files
committed
last fix for modelbuilder
1 parent 5c56c9d commit 2d1e38a

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

_unittests/ut_torch_models/test_validate_whole_models.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,17 +195,18 @@ def test_k_filter_inputs(self):
195195
@ignore_warnings(FutureWarning)
196196
@requires_transformers("4.51")
197197
def test_l_validate_model_modelbuilder(self):
198-
mid = "meta-llama/Llama-2-7b-hf"
198+
mid = "microsoft/phi-2"
199199
summary, data = validate_model(
200200
mid,
201201
do_run=True,
202202
verbose=10,
203203
exporter="modelbuilder",
204204
dump_folder="dump_test/validate_model_modelbuilder",
205+
patch=True,
205206
)
206207
self.assertIsInstance(summary, dict)
207208
self.assertIsInstance(data, dict)
208-
self.assertLess(summary["disc_onnx_ort_run_abs"], 1e-2)
209+
self.assertLess(summary["disc_onnx_ort_run_abs"], 3e-2)
209210
onnx_filename = data["onnx_filename"]
210211
self.assertExists(onnx_filename)
211212

onnx_diagnostic/helpers/rt_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def make_feeds(
4141
"""
4242
# NOTE: position_ids is a special case because ModelBuilder does not usually use it,
4343
# because it's fued into rotary embedding in GQA.
44-
if isinstance(inputs, dict):
44+
if is_modelbuilder and isinstance(inputs, dict):
4545
inputs.pop("position_ids", None) # Ensure 'position_ids' absent before removing.
4646

4747
flat = flatten_object(inputs, drop_keys=True)

0 commit comments

Comments
 (0)