Skip to content

Commit c8dd61a

Browse files
committed
Make the code more agnostic to custom classes
1 parent 00ad63e commit c8dd61a

File tree

5 files changed

+68
-2
lines changed

5 files changed

+68
-2
lines changed

onnx_diagnostic/export/dynamic_shapes.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,24 @@ def guess_dynamic_shape_object(self, *objs: Any, msg: Optional[Callable] = None)
313313
shapes[i] = self.guess_dynamic_shape_object(*[o[i] for o in objs])
314314
return shapes
315315

316+
if obj.__class__ in torch.utils._pytree.SUPPORTED_NODES:
317+
kcl = set(o.__class__ for o in objs)
318+
assert len(kcl) == 1, (
319+
f"All instances of argument {i} are not of the same class but {kcl}, "
320+
f"types should be the same."
321+
)
322+
col_args = [torch.utils._pytree.tree_flatten(o) for o in objs]
323+
kc = set(len(col_args) for o in objs)
324+
assert len(kc) == 1, (
325+
f"All instances of type {kcl.pop()} are not serialized into the same number "
326+
f"of arguments, it should be the same."
327+
)
328+
values = []
329+
for i in range(kc.pop()):
330+
values.append(self.guess_dynamic_dimensions(*[ca[i] for ca in col_args]))
331+
return values
332+
333+
# In case DynamicCache is not registered.
316334
if obj.__class__.__name__ == "DynamicCache":
317335
kc = set(len(o.key_cache) for o in objs)
318336
assert (

onnx_diagnostic/ext_test_case.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -898,7 +898,7 @@ def assertEqualAny(
898898
for e, g in zip(expected, value):
899899
self.assertEqualAny(e, g, msg=msg, atol=atol, rtol=rtol)
900900
elif expected.__class__.__name__ == "DynamicCache":
901-
atts = {"key_cache", "value_cache"}
901+
atts = ["key_cache", "value_cache"]
902902
self.assertEqualAny(
903903
{k: expected.__dict__.get(k, None) for k in atts},
904904
{k: value.__dict__.get(k, None) for k in atts},

onnx_diagnostic/helpers/helper.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,17 @@ def string_type(
443443
if ignore:
444444
return f"{obj.__class__.__name__}(...)"
445445

446+
if obj.__class__ in torch.utils._pytree.SUPPORTED_NODES:
447+
args, _spec = torch.utils._pytree.tree_flatten(obj)
448+
att = string_type(
449+
args,
450+
with_shape=with_shape,
451+
with_min_max=with_min_max,
452+
with_device=with_device,
453+
limit=limit,
454+
)
455+
return f"{obj.__class__.__name__}({att})"
456+
446457
raise AssertionError(f"Unsupported type {type(obj).__name__!r} - {type(obj)}")
447458

448459

@@ -1125,6 +1136,29 @@ def max_diff(
11251136
flatten=flatten,
11261137
)
11271138

1139+
if expected.__class__ in torch.utils._pytree.SUPPORTED_NODES:
1140+
if got.__class__ not in torch.utils._pytree.SUPPORTED_NODES:
1141+
return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
1142+
if verbose >= 6:
1143+
print(
1144+
f"[max_diff] {expected.__class__.__name__}: "
1145+
f"{string_type(expected)} ? {string_type(got)}"
1146+
)
1147+
expected_args, _spec = torch.utils._pytree.tree_flatten(expected)
1148+
got_args, _spec = torch.utils._pytree.tree_flatten(got)
1149+
return max_diff(
1150+
expected_args,
1151+
got_args,
1152+
level=level,
1153+
flatten=flatten,
1154+
debug_info=debug_info,
1155+
begin=begin,
1156+
end=end,
1157+
_index=_index,
1158+
verbose=verbose,
1159+
)
1160+
1161+
# backup function in case pytorch does not know how to serialize.
11281162
if expected.__class__.__name__ == "DynamicCache":
11291163
if got.__class__.__name__ == "DynamicCache":
11301164
if verbose >= 6:

onnx_diagnostic/helpers/torch_test_helper.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ def forward(self, input_ids):
305305

306306
def to_any(value: Any, to_value: Union[torch.dtype, torch.device]) -> Any:
307307
"""
308-
Applies torch.to is applicables.
308+
Applies torch.to is applicable.
309309
Goes recursively.
310310
"""
311311
if isinstance(value, (torch.nn.Module, torch.Tensor)):
@@ -329,6 +329,10 @@ def to_any(value: Any, to_value: Union[torch.dtype, torch.device]) -> Any:
329329
)
330330
)
331331
)
332+
if value.__class__ in torch.utils._pytree.SUPPORTED_NODES:
333+
args, spec = torch.utils._pytree.tree_flatten(value)
334+
new_args = to_any(args, to_value)
335+
return torch.utils._pytree.tree_unflatten(new_args, spec)
332336

333337
assert not isinstance(value, Iterable), f"Unsupported type {type(value)}"
334338
return value
@@ -361,6 +365,11 @@ def torch_deepcopy(value: Any) -> Any:
361365
torch_deepcopy(value.self_attention_cache),
362366
torch_deepcopy(value.cross_attention_cache),
363367
)
368+
if value.__class__ in torch.utils._pytree.SUPPORTED_NODES:
369+
args, spec = torch.utils._pytree.tree_flatten(value)
370+
new_args = torch_deepcopy(args)
371+
return torch.utils._pytree.tree_unflatten(new_args, spec)
372+
364373
# We should have a code using serialization, deserialization assuming a model
365374
# cannot be exported without them.
366375
raise NotImplementedError(f"torch_deepcopy not implemented for type {type(value)}")

onnx_diagnostic/torch_export_patches/patch_inputs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ def _make_shape(subset: Dict, cls: type, value: Any) -> Any:
3939
break
4040
new_shape = [[axes for i in range(cache_length)], [axes for i in range(cache_length)]]
4141
return new_shape
42+
if value.__class__ in torch.utils._pytree.SUPPORTED_NODES:
43+
raise NotImplementedError(
44+
f"_make_shape not implemented for registered class={cls}, "
45+
f"subset={subset}, value={string_type(value)}"
46+
)
4247
raise NotImplementedError(
4348
f"_make_shape not implemented for cls={cls}, "
4449
f"subset={subset}, value={string_type(value)}"

0 commit comments

Comments
 (0)