Skip to content

Commit f483e37

Browse files
committed
other fixes
1 parent c8dd61a commit f483e37

File tree

3 files changed

+71
-57
lines changed

3 files changed

+71
-57
lines changed

onnx_diagnostic/export/dynamic_shapes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def guess_dynamic_shape_object(self, *objs: Any, msg: Optional[Callable] = None)
319319
f"All instances of argument {i} are not of the same class but {kcl}, "
320320
f"types should be the same."
321321
)
322-
col_args = [torch.utils._pytree.tree_flatten(o) for o in objs]
322+
col_args = [torch.utils._pytree.tree_flatten(o)[0] for o in objs]
323323
kc = set(len(col_args) for o in objs)
324324
assert len(kc) == 1, (
325325
f"All instances of type {kcl.pop()} are not serialized into the same number "

onnx_diagnostic/helpers/helper.py

Lines changed: 57 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -343,65 +343,23 @@ def string_type(
343343

344344
# others classes
345345

346-
if type(obj).__name__ == "MambaCache":
347-
c = string_type(
348-
obj.conv_states,
349-
with_shape=with_shape,
350-
with_min_max=with_min_max,
351-
with_device=with_device,
352-
limit=limit,
353-
)
354-
d = string_type(
355-
obj.ssm_states,
346+
if obj.__class__ in torch.utils._pytree.SUPPORTED_NODES:
347+
args, _spec = torch.utils._pytree.tree_flatten(obj)
348+
att = string_type(
349+
args,
356350
with_shape=with_shape,
357351
with_min_max=with_min_max,
358352
with_device=with_device,
359353
limit=limit,
360354
)
361-
return f"MambaCache(conv_states={c}, ssm_states={d})"
355+
return f"{obj.__class__.__name__}[serialized]({att})"
356+
362357
if type(obj).__name__ == "Node" and hasattr(obj, "meta"):
363358
# torch.fx.node.Node
364359
return f"%{obj.target}"
365360
if type(obj).__name__ == "ValueInfoProto":
366361
return f"OT{obj.type.tensor_type.elem_type}"
367362

368-
if obj.__class__.__name__ == "DynamicCache":
369-
kc = string_type(
370-
obj.key_cache,
371-
with_shape=with_shape,
372-
with_min_max=with_min_max,
373-
with_device=with_device,
374-
limit=limit,
375-
)
376-
vc = string_type(
377-
obj.value_cache,
378-
with_shape=with_shape,
379-
with_min_max=with_min_max,
380-
with_device=with_device,
381-
limit=limit,
382-
)
383-
return f"{obj.__class__.__name__}(key_cache={kc}, value_cache={vc})"
384-
385-
if obj.__class__.__name__ == "EncoderDecoderCache":
386-
att = string_type(
387-
obj.self_attention_cache,
388-
with_shape=with_shape,
389-
with_min_max=with_min_max,
390-
with_device=with_device,
391-
limit=limit,
392-
)
393-
cross = string_type(
394-
obj.cross_attention_cache,
395-
with_shape=with_shape,
396-
with_min_max=with_min_max,
397-
with_device=with_device,
398-
limit=limit,
399-
)
400-
return (
401-
f"{obj.__class__.__name__}(self_attention_cache={att}, "
402-
f"cross_attention_cache={cross})"
403-
)
404-
405363
if obj.__class__.__name__ == "BatchFeature":
406364
s = string_type(
407365
obj.data,
@@ -440,19 +398,64 @@ def string_type(
440398
if isinstance(obj, torch.utils._pytree.TreeSpec):
441399
return repr(obj).replace(" ", "").replace("\n", " ")
442400

443-
if ignore:
444-
return f"{obj.__class__.__name__}(...)"
401+
# to avoid failures
445402

446-
if obj.__class__ in torch.utils._pytree.SUPPORTED_NODES:
447-
args, _spec = torch.utils._pytree.tree_flatten(obj)
403+
if type(obj).__name__ == "MambaCache":
404+
c = string_type(
405+
obj.conv_states,
406+
with_shape=with_shape,
407+
with_min_max=with_min_max,
408+
with_device=with_device,
409+
limit=limit,
410+
)
411+
d = string_type(
412+
obj.ssm_states,
413+
with_shape=with_shape,
414+
with_min_max=with_min_max,
415+
with_device=with_device,
416+
limit=limit,
417+
)
418+
return f"MambaCache(conv_states={c}, ssm_states={d})"
419+
420+
if obj.__class__.__name__ == "DynamicCache":
421+
kc = string_type(
422+
obj.key_cache,
423+
with_shape=with_shape,
424+
with_min_max=with_min_max,
425+
with_device=with_device,
426+
limit=limit,
427+
)
428+
vc = string_type(
429+
obj.value_cache,
430+
with_shape=with_shape,
431+
with_min_max=with_min_max,
432+
with_device=with_device,
433+
limit=limit,
434+
)
435+
return f"{obj.__class__.__name__}(key_cache={kc}, value_cache={vc})"
436+
437+
if obj.__class__.__name__ == "EncoderDecoderCache":
448438
att = string_type(
449-
args,
439+
obj.self_attention_cache,
450440
with_shape=with_shape,
451441
with_min_max=with_min_max,
452442
with_device=with_device,
453443
limit=limit,
454444
)
455-
return f"{obj.__class__.__name__}({att})"
445+
cross = string_type(
446+
obj.cross_attention_cache,
447+
with_shape=with_shape,
448+
with_min_max=with_min_max,
449+
with_device=with_device,
450+
limit=limit,
451+
)
452+
return (
453+
f"{obj.__class__.__name__}(self_attention_cache={att}, "
454+
f"cross_attention_cache={cross})"
455+
)
456+
457+
if ignore:
458+
return f"{obj.__class__.__name__}(...)"
456459

457460
raise AssertionError(f"Unsupported type {type(obj).__name__!r} - {type(obj)}")
458461

onnx_diagnostic/torch_export_patches/onnx_export_serialization.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ def flatten_dynamic_cache(
9797
dynamic_cache: transformers.cache_utils.DynamicCache,
9898
) -> Tuple[List[Any], torch.utils._pytree.Context]:
9999
"""Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
100+
import transformers.cache_utils
101+
102+
if hasattr(transformers.cache_utils, "_flatten_dynamic_cache"):
103+
return transformers.cache_utils._flatten_dynamic_cache(dynamic_cache)
100104
flat = [
101105
(k, getattr(dynamic_cache, k))
102106
for k in ["key_cache", "value_cache"]
@@ -111,7 +115,10 @@ def flatten_with_keys_dynamic_cache(d: Dict[Any, Any]) -> Tuple[
111115
]:
112116
"""Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
113117
import torch
118+
import transformers.cache_utils
114119

120+
if hasattr(transformers.cache_utils, "_flatten_with_keys_dynamic_cache"):
121+
return transformers.cache_utils._flatten_with_keys_dynamic_cache(d)
115122
values, context = flatten_dynamic_cache(d)
116123
return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
117124

@@ -122,9 +129,13 @@ def unflatten_dynamic_cache(
122129
output_type=None,
123130
) -> transformers.cache_utils.DynamicCache:
124131
"""Restores a :class:`transformers.cache_utils.DynamicCache` from python objects."""
125-
from transformers.cache_utils import DynamicCache
132+
import transformers.cache_utils
133+
134+
if hasattr(transformers.cache_utils, "_unflatten_dynamic_cache"):
135+
assert output_type is None, f"output_type={output_type} not supported"
136+
return transformers.cache_utils._unflatten_dynamic_cache(values, context)
126137

127-
cache = DynamicCache()
138+
cache = transformers.cache_utils.DynamicCache()
128139
values = dict(zip(context, values))
129140
for k, v in values.items():
130141
setattr(cache, k, v)

0 commit comments

Comments
 (0)