Skip to content

Commit 2216d83

Browse files
committed
add missing BartModel
1 parent fc664f8 commit 2216d83

File tree

4 files changed

+24
-2
lines changed

4 files changed

+24
-2
lines changed
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import unittest
2+
from onnx_diagnostic.ext_test_case import ExtTestCase
3+
from onnx_diagnostic.torch_export_patches.patch_module_helper import code_needing_rewriting
4+
5+
6+
class TestPatchRewrite(ExtTestCase):
7+
def test_code_needing_rewriting(self):
8+
res = code_needing_rewriting("BartModel")
9+
self.assertEqual(len(res), 2)
10+
11+
12+
if __name__ == "__main__":
13+
unittest.main(verbosity=2)

onnx_diagnostic/helpers/doc_helper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import os
2-
from typing import Dict, Optional, Tuple
2+
from typing import Dict, List, Optional, Tuple
33
import onnx
44
import onnx.helper as oh
55
import torch
66
from ..reference.torch_ops import OpRunKernel, OpRunTensor
77
from .torch_helper import onnx_dtype_to_torch_dtype, torch_dtype_to_onnx_dtype
88
from .ort_session import InferenceSessionForTorch
99

10-
_SAVED = []
10+
_SAVED: List[str] = []
1111
_SAVE_OPTIMIZED_MODEL_ = int(os.environ.get("DUMP_ONNX", "0"))
1212

1313

onnx_diagnostic/torch_export_patches/patch_module_helper.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def known_transformers_rewritings_clamp_float16() -> Dict[str, str]:
8080
"AutoformerModel": "AutoformerEncoderLayer",
8181
"BartEncoderLayer": "BartEncoderLayer",
8282
"BartForConditionalGeneration": "BartEncoderLayer",
83+
"BartModel": "BartEncoderLayer",
8384
"BigBirdPegasusForConditionalGeneration": "BigBirdPegasusEncoderLayer",
8485
"BigBirdPegasusForQuestionAnswering": "BigBirdPegasusEncoderLayer",
8586
"BigBirdPegasusForCausalLM": "BigBirdPegasusEncoderLayer",

onnx_diagnostic/torch_models/test_helper.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,12 @@ def validate_model(
387387
if model_options:
388388
print(f"[validate_model] model_options={model_options!r}")
389389
print(f"[validate_model] get dummy inputs with input_options={input_options}...")
390+
print(
391+
f"[validate_model] rewrite={rewrite}, patch={patch}, "
392+
f"stop_if_static={stop_if_static}"
393+
)
394+
print(f"[validate_model] exporter={exporter!r}, optimization={optimization!r}")
395+
print(f"[validate_model] dump_folder={dump_folder!r}")
390396
summary["model_id"] = model_id
391397
summary["model_subfolder"] = subfolder or ""
392398

@@ -446,6 +452,8 @@ def validate_model(
446452
print(f"[validate_model] model_rewrite={summary['model_rewrite']}")
447453
else:
448454
del data["rewrite"]
455+
if verbose:
456+
print("[validate_model] no rewrite")
449457
if os.environ.get("PRINT_CONFIG", "0") in (1, "1"):
450458
print("[validate_model] -- PRINT CONFIG")
451459
print("-- type(config)", type(data["configuration"]))

0 commit comments

Comments
 (0)