Skip to content

Commit c90d5fc

Browse files
authored
support custom-tracing (#371)
* support custom-tracing * enable fake tensor more * fix fake tensors * fix * fix * fix documentation * documentation
1 parent 39977bd commit c90d5fc

File tree

6 files changed

+51
-6
lines changed

6 files changed

+51
-6
lines changed

.github/workflows/documentation.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,12 @@ jobs:
104104

105105
- name: Check for errors and warnings
106106
run: |
107-
if [[ $(grep ERROR doc.txt | grep -v 'l-plot-tiny-llm-export') ]]; then
107+
if [[ $(grep ERROR doc.txt | grep -v 'l-plot-tiny-llm-export' | grep -v 'Unexpected section title or transition.') ]]; then
108108
echo "Documentation produces errors."
109-
grep ERROR doc.txt | grep -v 'l-plot-tiny-llm-export'
109+
grep ERROR doc.txt | grep -v 'l-plot-tiny-llm-export' | grep -v 'Unexpected section title or transition.'
110110
exit 1
111111
fi
112-
if [[ $(grep WARNING doc.txt | grep -v 'l-plot-tiny-llm-export' | grep -v 'Inline emphasis start-string' | grep -v 'Definition list ends without a blank line' | grep -v 'Unexpected section title or transition' | grep -v 'Inline strong start-string' | grep -v 'MambaCache') ]]; then
112+
if [[ $(grep WARNING doc.txt | grep -v 'l-plot-tiny-llm-export' | grep -v 'Inline emphasis start-string' | grep -v 'Definition list ends without a blank line' | grep -v 'Unexpected section title or transition' | grep -v 'Inline strong start-string' | grep -v 'MambaCache' ) ]]; then
113113
echo "Documentation produces warnings."
114114
grep WARNING doc.txt | grep -v 'l-plot-tiny-llm-export' | grep -v 'Inline emphasis start-string' | grep -v 'Definition list ends without a blank line' | grep -v 'Unexpected section title or transition' | grep -v 'Inline strong start-string' | grep -v 'MambaCache'
115115
exit 1

CHANGELOGS.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ Change Logs
44
0.8.8
55
+++++
66

7+
* :pr:`371`: fix make_fake_with_dynamic_dimensions
8+
79
0.8.7
810
+++++
911

_unittests/ut_torch_models/test_validate_whole_models1.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,22 @@ def test_n_validate_phi35_mini_instruct(self):
225225
self.assertIn("If", op_types)
226226
self.clean_dump()
227227

228+
@hide_stdout()
229+
@requires_transformers("4.57")
230+
def test_o_validate_model_export_fake(self):
231+
mid = "arnir0/Tiny-LLM"
232+
summary, data = validate_model(
233+
mid,
234+
do_run=True,
235+
verbose=10,
236+
exporter="custom-fake",
237+
dump_folder="dump_test/validate_model_export_fake",
238+
patch=True,
239+
)
240+
self.assertIsInstance(summary, dict)
241+
self.assertIsInstance(data, dict)
242+
self.clean_dump()
243+
228244

229245
if __name__ == "__main__":
230246
unittest.main(verbosity=2)

onnx_diagnostic/export/shape_helper.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ def make_fake_with_dynamic_dimensions(
210210
This uses function :func:`onnx_diagnostic.helpers.fake_tensor_helper.make_fake`.
211211
Parameter ``existing`` is used to reused the same object when the dynamic
212212
dimension is given the same name as another one.
213+
This function works with caches only if ``transformers>=4.57``.
213214
214215
A simple tensor:
215216

onnx_diagnostic/helpers/fake_tensor_helper.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,17 +144,18 @@ def make_fake_with_dynamic_dimensions(self, x: Any, dynamic_shapes: Any) -> Any:
144144
"""
145145
See
146146
:func:`onnx_diagnostic.export.shape_helper.make_fake_with_dynamic_dimensions`.
147+
If caches are used, it requires ``transformers>=4.57``.
147148
"""
148149
if x is None:
149150
return None, None
150-
if isinstance(x, (list, tuple)):
151+
if type(x) in (list, tuple):
151152
return x.__class__(
152153
[
153154
self.make_fake_with_dynamic_dimensions(i, dynamic_shapes=ds)
154155
for i, ds in zip(x, dynamic_shapes)
155156
]
156157
)
157-
if isinstance(x, dict):
158+
if type(x) is dict:
158159
return {
159160
k: self.make_fake_with_dynamic_dimensions(v, dynamic_shapes=dynamic_shapes[k])
160161
for k, v in x.items()
@@ -187,6 +188,17 @@ def make_fake_with_dynamic_dimensions(self, x: Any, dynamic_shapes: Any) -> Any:
187188
x.cross_attention_cache, dynamic_shapes=dynamic_shapes[1]
188189
)
189190
return x
191+
if x.__class__.__name__ == "BaseModelOutput":
192+
assert (
193+
list(x.keys()) == ["last_hidden_state"] and x.last_hidden_state is not None
194+
), (
195+
f"Field 'last_hidden_state' is empty for {type(x)} or other fields "
196+
f"{list(x.keys())} are used."
197+
)
198+
x.last_hidden_state = self.make_fake_with_dynamic_dimensions(
199+
x.last_hidden_state, dynamic_shapes=dynamic_shapes[0]
200+
)
201+
return x
190202
if hasattr(x, "shape"):
191203
assert dynamic_shapes is None or isinstance(dynamic_shapes, dict), (
192204
f"dynamic_shapes must be a dictionary at this stage but "
@@ -197,9 +209,11 @@ def make_fake_with_dynamic_dimensions(self, x: Any, dynamic_shapes: Any) -> Any:
197209
for idim, dim in enumerate(x.shape):
198210
if dynamic_shapes is not None and idim in dynamic_shapes:
199211
s = dynamic_shapes[idim]
212+
if s.__class__.__name__ == "Dim":
213+
s = s.__name__
200214
assert isinstance(s, str), (
201215
f"Unexpected type {type(s)} in dynamic_shapes={dynamic_shapes} "
202-
f"at index {idim}"
216+
f"at index {idim}, self._mapping_str={self._mapping_str}"
203217
)
204218
if s in self._mapping_str:
205219
dim = self._mapping_str[s]
@@ -221,6 +235,9 @@ def make_fake_with_dynamic_dimensions(self, x: Any, dynamic_shapes: Any) -> Any:
221235
assert t.device == x.device, f"device mismatch {x.device} -> {t.device}"
222236
assert t.dtype == x.dtype, f"dtype mismatch {x.dtype} -> {t.dtype}"
223237
return t
238+
if isinstance(x, (int, bool, float)):
239+
# It is a constant, we don't change that.
240+
return x
224241
from ..helpers import string_type
225242

226243
raise TypeError(

onnx_diagnostic/torch_models/validate.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2330,6 +2330,7 @@ def call_torch_export_custom(
23302330
"custom-dec",
23312331
"custom-decall",
23322332
"custom-fake",
2333+
"custom-tracing",
23332334
}
23342335
assert exporter in available, f"Unexpected value for exporter={exporter!r} in {available}"
23352336
assert "model" in data, f"model is missing from data: {sorted(data)}"
@@ -2342,11 +2343,16 @@ def call_torch_export_custom(
23422343
f"Options strict cannot be specified in the exporter name {exporter!r} "
23432344
f"and in the options {exporter_options}"
23442345
)
2346+
assert ("-tracing" not in exporter) or ("tracing" not in exporter_options), (
2347+
f"Options tracing cannot be specified in the exporter name {exporter!r} "
2348+
f"and in the options {exporter_options}"
2349+
)
23452350
summary: Dict[str, Union[str, int, float]] = {}
23462351
strict = "-strict" in exporter or exporter_options.pop("strict", False)
23472352
args, kwargs = split_args_kwargs(data["inputs_export"])
23482353
ds = data.get("dynamic_shapes", None)
23492354
fake = "-fake" in exporter or exporter_options.pop("fake", False)
2355+
tracing = "-tracing" in exporter or exporter_options.pop("tracing", False)
23502356
if fake:
23512357
from onnx_diagnostic.export.shape_helper import make_fake_with_dynamic_dimensions
23522358

@@ -2370,6 +2376,7 @@ def call_torch_export_custom(
23702376
summary["export_exporter"] = exporter
23712377
summary["export_optimization"] = optimization or ""
23722378
summary["export_strict"] = strict
2379+
summary["export_tracing"] = tracing
23732380
summary["export_fake"] = fake
23742381
summary["export_args"] = string_type(args, with_shape=True)
23752382
summary["export_kwargs"] = string_type(kwargs, with_shape=True)
@@ -2392,6 +2399,7 @@ def call_torch_export_custom(
23922399
)
23932400
)
23942401
large_model = bool(exporter_options.pop("large_model", True))
2402+
exporter_options.pop("tracing", False)
23952403
return_optimize_report = bool(exporter_options.pop("return_optimize_report", True))
23962404
export_modules_as_functions = bool(
23972405
exporter_options.pop("export_modules_as_functions", False)
@@ -2405,6 +2413,7 @@ def call_torch_export_custom(
24052413
summary["export_external_threshold"] = str(external_threshold)
24062414

24072415
export_options = ExportOptions(
2416+
tracing=tracing,
24082417
strict=strict,
24092418
decomposition_table=decomposition_table,
24102419
save_ep=(

0 commit comments

Comments
 (0)