Skip to content

Commit 1eeb807

Browse files
committed
tiny changes
1 parent c28b9e5 commit 1eeb807

File tree

4 files changed

+7
-6
lines changed

4 files changed

+7
-6
lines changed

onnx_diagnostic/export/dynamic_shapes.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,9 @@ def _generic_walker(
232232
"""
233233
if not self.args:
234234
assert isinstance(self.kwargs, dict) and isinstance(self.dynamic_shapes, dict), (
235-
f"Type mismatch, args={string_type(self.args)} and "
236-
f"dynamic_shapes={self.dynamic_shapes} should have the same type."
235+
f"Type mismatch, args={string_type(self.args)}, "
236+
f"kwargs={string_type(self.kwargs)} and dynamic_shapes="
237+
f"{string_type(self.dynamic_shapes)} should have the same type."
237238
)
238239
res = self._generic_walker_step(
239240
processor,

onnx_diagnostic/helpers/helper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ def string_type(
397397
return "AUTO"
398398
if verbose:
399399
print(f"[string_type] Y7:{type(obj)}")
400-
return str(obj)
400+
return str(obj).replace("DimHint(DYNAMIC)", "DYNAMIC").replace("DimHint(AUTO)", "AUTO")
401401

402402
if isinstance(obj, bool):
403403
if with_min_max:
@@ -939,7 +939,7 @@ def flatten_object(x: Any, drop_keys: bool = False) -> Any:
939939
return flatten_object(list(x.values()), drop_keys=drop_keys)
940940
return flatten_object(list(x.items()), drop_keys=drop_keys)
941941

942-
if x.__class__.__name__ in {"DynamicCache", "StaticCache"}:
942+
if x.__class__.__name__ in {"DynamicCache", "StaticCache", "HybridCache"}:
943943
from .cache_helper import CacheKeyValue
944944

945945
kc = CacheKeyValue(x)

onnx_diagnostic/tasks/image_text_to_text.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,8 @@ def _check_():
206206
batch_size, 1, sequence_length, total_sequence_length
207207
),
208208
),
209-
cache_position=torch.arange(0, sequence_length).to(torch.int64),
210209
position_ids=torch.arange(0, sequence_length).to(torch.int64).expand((batch_size, -1)),
210+
cache_position=torch.arange(0, sequence_length).to(torch.int64),
211211
past_key_values=make_hybrid_cache(
212212
[
213213
(

onnx_diagnostic/torch_export_patches/eval/model_cases.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -861,7 +861,7 @@ def forward(self, x):
861861
y = torch.ones((x.shape[0], dy1))
862862
return y
863863

864-
_inputs = [(torch.rand((4, 4)),), (torch.rand((5, 5)),)]
864+
_inputs = [(torch.rand((4, 4)),)]
865865
_dynamic = {"x": {0: DIM("dx"), 1: DIM("dy")}}
866866

867867

0 commit comments

Comments
 (0)