Skip to content

Commit ab387bf

Browse files
committed
changelogs
1 parent ea8da0e commit ab387bf

File tree

2 files changed

+18
-19
lines changed

2 files changed

+18
-19
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.5.0
55
+++++
66

7+
* :pr:`88`: extends ``steal_forward`` to dump input, outputs in onnx models
78
* :pr:`83`, :pr:`85`: improves the automated rewriting of control flow (test)
89

910
0.4.4

onnx_diagnostic/helpers/mini_onnx_builder.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ def to_onnx(self) -> ModelProto:
310310
return model
311311

312312

313-
def flatten_iterator(obj: Any, sep: str) -> Iterator:
313+
def _flatten_iterator(obj: Any, sep: str) -> Iterator:
314314
"""Iterates on all object."""
315315
if obj is not None:
316316
if isinstance(obj, np.ndarray):
@@ -329,21 +329,21 @@ def flatten_iterator(obj: Any, sep: str) -> Iterator:
329329
else:
330330
for i, o in enumerate(obj):
331331
if i == len(obj) - 1:
332-
for p, oo in flatten_iterator(o, sep):
332+
for p, oo in _flatten_iterator(o, sep):
333333
yield f"tuple.{sep}{p}", oo
334334
else:
335-
for p, oo in flatten_iterator(o, sep):
335+
for p, oo in _flatten_iterator(o, sep):
336336
yield f"tuple{sep}{p}", oo
337337
elif isinstance(obj, list):
338338
if not obj:
339339
yield f"list.{sep}empty", None
340340
else:
341341
for i, o in enumerate(obj):
342342
if i == len(obj) - 1:
343-
for p, oo in flatten_iterator(o, sep):
343+
for p, oo in _flatten_iterator(o, sep):
344344
yield f"list.{sep}{p}", oo
345345
else:
346-
for p, oo in flatten_iterator(o, sep):
346+
for p, oo in _flatten_iterator(o, sep):
347347
yield f"list{sep}{p}", oo
348348
elif isinstance(obj, dict):
349349
if not obj:
@@ -352,13 +352,13 @@ def flatten_iterator(obj: Any, sep: str) -> Iterator:
352352
for i, (k, v) in enumerate(obj.items()):
353353
assert sep not in k, (
354354
f"Key {k!r} cannot contain '{sep}'. "
355-
f"It would interfer with the serialization."
355+
f"It would interfere with the serialization."
356356
)
357357
if i == len(obj) - 1:
358-
for p, o in flatten_iterator(v, sep):
358+
for p, o in _flatten_iterator(v, sep):
359359
yield f"dict._{k}{sep}{p}", o
360360
else:
361-
for p, o in flatten_iterator(v, sep):
361+
for p, o in _flatten_iterator(v, sep):
362362
yield f"dict_{k}{sep}{p}", o
363363
elif obj.__class__.__name__ == "DynamicCache":
364364
# transformers
@@ -370,10 +370,10 @@ def flatten_iterator(obj: Any, sep: str) -> Iterator:
370370
atts = ["key_cache", "value_cache"]
371371
for i, att in enumerate(atts):
372372
if i == len(atts) - 1:
373-
for p, o in flatten_iterator(getattr(obj, att), sep):
373+
for p, o in _flatten_iterator(getattr(obj, att), sep):
374374
yield f"DynamicCache._{att}{sep}{p}", o
375375
else:
376-
for p, o in flatten_iterator(getattr(obj, att), sep):
376+
for p, o in _flatten_iterator(getattr(obj, att), sep):
377377
yield f"DynamicCache_{att}{sep}{p}", o
378378
else:
379379
raise NotImplementedError(f"Unexpected type {type(obj)}")
@@ -403,7 +403,7 @@ def create_onnx_model_from_input_tensors(
403403
switch_low_high = sys.byteorder != "big"
404404

405405
builder = MiniOnnxBuilder(sep=sep)
406-
for prefix, o in flatten_iterator(inputs, sep):
406+
for prefix, o in _flatten_iterator(inputs, sep):
407407
if o is None:
408408
builder.append_output_initializer(prefix, np.array([]))
409409
else:
@@ -413,17 +413,15 @@ def create_onnx_model_from_input_tensors(
413413
return model
414414

415415

416-
def unflatten(
416+
def _unflatten(
417417
sep: str,
418418
names: List[str],
419419
outputs: List[Any],
420420
pos: int = 0,
421421
level: int = 0,
422422
device: str = "cpu",
423423
) -> Tuple[int, Tuple[Any, ...]]:
424-
"""
425-
Unflattens a list of outputs flattened with :func:`flatten_iterator`.
426-
"""
424+
"""Unflattens a list of outputs flattened with :func:`flatten_iterator`."""
427425
name = names[pos]
428426
spl = name.split(sep)
429427
if len(spl) == level + 1:
@@ -448,7 +446,7 @@ def unflatten(
448446
name = names[pos]
449447
spl = name.split(sep)
450448
prefix = spl[level]
451-
next_pos, value = unflatten(
449+
next_pos, value = _unflatten(
452450
sep, names, outputs, pos=pos, level=level + 1, device=device
453451
)
454452

@@ -499,7 +497,7 @@ def create_input_tensors_from_onnx_model(
499497
device: str = "cpu",
500498
engine: str = "ExtendedReferenceEvaluator",
501499
sep: str = "___",
502-
) -> Union[Tuple[Any, ...], Dict[str, Any]]:
500+
) -> Any:
503501
"""
504502
Deserializes tensors stored with function
505503
:func:`create_onnx_model_from_input_tensors`.
@@ -511,7 +509,7 @@ def create_input_tensors_from_onnx_model(
511509
:param device: moves the tensor to this device
512510
:param engine: runtime to use, onnx, the default value, onnxruntime
513511
:param sep: separator
514-
:return: ModelProto
512+
:return: restored data
515513
"""
516514
if engine == "ExtendedReferenceEvaluator":
517515
from ..reference import ExtendedReferenceEvaluator
@@ -552,4 +550,4 @@ def create_input_tensors_from_onnx_model(
552550
return torch.from_numpy(output).to(device)
553551
raise AssertionError(f"Unexpected name {name!r} in {names}")
554552

555-
return unflatten(sep, names, got, device=device)[1]
553+
return _unflatten(sep, names, got, device=device)[1]

0 commit comments

Comments
 (0)