Skip to content

Commit 735b373

Browse files
committed
issues
1 parent a4a5c0e commit 735b373

File tree

2 files changed

+16
-16
lines changed

2 files changed

+16
-16
lines changed

_unittests/ut_torch_models/test_hghub_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def test_get_untrained_model_with_inputs_tiny_gpt_neo(self):
6262
self.assertEqual((316712, 79178), (data["size"], data["n_weights"]))
6363

6464
@hide_stdout()
65+
@ignore_errors(OSError)
6566
def test_get_untrained_model_with_inputs_phi_2(self):
6667
mid = "microsoft/phi-2"
6768
data = get_untrained_model_with_inputs(mid, verbose=1)
@@ -83,6 +84,7 @@ def test_get_untrained_model_with_inputs_beit(self):
8384
self.assertIn((data["size"], data["n_weights"]), [(111448, 27862), (56880, 14220)])
8485

8586
@hide_stdout()
87+
@ignore_errors(OSError)
8688
def test_get_untrained_model_with_inputs_codellama(self):
8789
mid = "codellama/CodeLlama-7b-Python-hf"
8890
data = get_untrained_model_with_inputs(mid, verbose=1)

onnx_diagnostic/export/dynamic_shapes.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import inspect
2-
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
2+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
33
import numpy as np
44
import torch
55
from ..helpers import string_type
@@ -488,7 +488,7 @@ def __str__(self) -> str:
488488
]
489489
)
490490

491-
def invalid_paths(self) -> Any:
491+
def invalid_paths(self):
492492
"""
493493
Tells the inputs are valid based on the dynamic shapes definition.
494494
The method assumes that all custom classes can be serialized.
@@ -501,7 +501,7 @@ def invalid_paths(self) -> Any:
501501
return self._generic_walker(self._valid_shapes_tensor)
502502

503503
@classmethod
504-
def _valid_shapes_tensor(cls, inputs: Any, ds: Any) -> Iterable:
504+
def _valid_shapes_tensor(cls, inputs, ds):
505505
assert isinstance(inputs, torch.Tensor), f"unexpected type for inputs {type(inputs)}"
506506
assert isinstance(ds, dict) and all(isinstance(s, int) for s in ds), (
507507
f"Unexpected types, inputs is a Tensor but ds is {ds}, "
@@ -516,7 +516,7 @@ def _valid_shapes_tensor(cls, inputs: Any, ds: Any) -> Iterable:
516516
issues[i] = f"d=[{d}]"
517517
return issues if issues else None
518518

519-
def _generic_walker(self, method_to_call: Callable) -> Any:
519+
def _generic_walker(self, processor: Callable):
520520
"""
521521
Generic deserializator walking through inputs and dynamic_shapes all along.
522522
The function returns a result with the same structure as the dynamic shapes.
@@ -526,14 +526,14 @@ def _generic_walker(self, method_to_call: Callable) -> Any:
526526
f"Type mismatch, args={string_type(self.args)} and "
527527
f"dynamic_shapes={self.dynamic_shapes} should have the same type."
528528
)
529-
return self._generic_walker_step(method_to_call, self.kwargs, self.dynamic_shapes)
529+
return self._generic_walker_step(processor, self.kwargs, self.dynamic_shapes)
530530

531531
if not self.kwargs:
532532
assert isinstance(self.args, tuple) and isinstance(self.dynamic_shapes, tuple), (
533533
f"Type mismatch, args={string_type(self.args)} and "
534534
f"dynamic_shapes={self.dynamic_shapes} should have the same type."
535535
)
536-
return self._generic_walker_step(method_to_call, self.args, self.dynamic_shapes)
536+
return self._generic_walker_step(processor, self.args, self.dynamic_shapes)
537537

538538
assert isinstance(self.dynamic_shapes, dict), (
539539
f"Both positional and named arguments (args and kwargs) are filled. "
@@ -543,14 +543,12 @@ def _generic_walker(self, method_to_call: Callable) -> Any:
543543
self.dynamic_shapes
544544
):
545545
# No dynamic shapes for the positional arguments.
546-
return self._generic_walker_step(method_to_call, self.kwargs, self.dynamic_shapes)
546+
return self._generic_walker_step(processor, self.kwargs, self.dynamic_shapes)
547547

548548
if isinstance(self.args_names, list):
549549
if not set(self.args_names) & set(self.dynamic_shapes):
550550
# No dynamic shapes for the positional arguments.
551-
return self._generic_walker_step(
552-
method_to_call, self.kwargs, self.dynamic_shapes
553-
)
551+
return self._generic_walker_step(processor, self.kwargs, self.dynamic_shapes)
554552

555553
assert self.args_names, (
556554
"args and kwargs are filled, then args_names must be specified in "
@@ -563,17 +561,17 @@ def _generic_walker(self, method_to_call: Callable) -> Any:
563561
)
564562
kwargs = dict(zip(self.args_names, self.args))
565563
kwargs.update(self.kwargs)
566-
return self._generic_walker_step(method_to_call, kwargs, self.dynamic_shapes)
564+
return self._generic_walker_step(processor, kwargs, self.dynamic_shapes)
567565

568566
raise NotImplementedError(
569567
f"Not yet implemented when args is filled, "
570568
f"kwargs as well but args_names is {type(self.args_names)}"
571569
)
572570

573571
@classmethod
574-
def _generic_walker_step(cls, method_to_call: Callable, inputs: Any, ds: Any) -> Iterable:
572+
def _generic_walker_step(cls, processor: Callable, inputs, ds):
575573
if isinstance(inputs, torch.Tensor):
576-
return method_to_call(inputs, ds)
574+
return processor(inputs, ds)
577575
if isinstance(inputs, (int, float, str)):
578576
return None
579577
if isinstance(inputs, (tuple, list, dict)):
@@ -588,7 +586,7 @@ def _generic_walker_step(cls, method_to_call: Callable, inputs: Any, ds: Any) ->
588586
if isinstance(inputs, (tuple, list)):
589587
value = []
590588
for i, d in zip(inputs, ds):
591-
value.append(cls._generic_walker_step(method_to_call, i, d))
589+
value.append(cls._generic_walker_step(processor, i, d))
592590
return (
593591
(value if isinstance(ds, list) else tuple(value))
594592
if any(v is not None for v in value)
@@ -599,7 +597,7 @@ def _generic_walker_step(cls, method_to_call: Callable, inputs: Any, ds: Any) ->
599597
), f"Keys mismatch between inputs {set(inputs)} and ds={set(ds)}"
600598
dvalue = {}
601599
for k, v in inputs.items():
602-
t = cls._generic_walker_step(method_to_call, v, ds[k])
600+
t = cls._generic_walker_step(processor, v, ds[k])
603601
if t is not None:
604602
dvalue[k] = t
605603
return dvalue if dvalue else None
@@ -611,4 +609,4 @@ def _generic_walker_step(cls, method_to_call: Callable, inputs: Any, ds: Any) ->
611609
f"map this class with the given dynamic shapes."
612610
)
613611
flat, _spec = torch.utils._pytree.tree_flatten(inputs)
614-
return cls._generic_walker_step(method_to_call, flat, ds)
612+
return cls._generic_walker_step(processor, flat, ds)

0 commit comments

Comments
 (0)