Skip to content

Commit 31b272f

Browse files
committed
fix fake
1 parent f07d2db commit 31b272f

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

onnx_diagnostic/helpers/fake_tensor_helper.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ def fake_reshape(
105105
reduced_tensor = self.from_tensor(true_tensor, static_shapes=True).sum(
106106
axis=tuple(sorted(sh)), keepdim=True
107107
)
108+
if len(reduced_tensor.shape) == 0 == len(new_shape):
109+
return reduced_tensor
108110
return reduced_tensor.expand(*new_shape)
109111

110112
def make_fake(self, x: Any) -> Optional["FakeTensor"]: # noqa: F821
@@ -157,7 +159,9 @@ def make_fake_with_dynamic_dimensions(self, x: Any, dynamic_shapes: Any) -> Any:
157159
)
158160
if type(x) is dict:
159161
return {
160-
k: self.make_fake_with_dynamic_dimensions(v, dynamic_shapes=dynamic_shapes[k])
162+
k: self.make_fake_with_dynamic_dimensions(
163+
v, dynamic_shapes=dynamic_shapes[k] if dynamic_shapes else None
164+
)
161165
for k, v in x.items()
162166
}
163167
if x.__class__.__name__ in {"DynamicCache", "StaticCache", "HybridCache"}:
@@ -231,7 +235,7 @@ def make_fake_with_dynamic_dimensions(self, x: Any, dynamic_shapes: Any) -> Any:
231235

232236
x = torch.empty(tuple(new_shape), dtype=x.dtype, device=x.device)
233237

234-
t = self.fake_reshape(x, dynamic_shapes) # type: ignore[arg-type]
238+
t = self.fake_reshape(x, dynamic_shapes) if dynamic_shapes else x # type: ignore[arg-type]
235239
assert t.device == x.device, f"device mismatch {x.device} -> {t.device}"
236240
assert t.dtype == x.dtype, f"dtype mismatch {x.dtype} -> {t.dtype}"
237241
return t

onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,13 @@ def lazy_initialization(self, key_states: torch.Tensor):
2727
new_shape = list(key_states.shape)
2828
new_shape[-2] = 0
2929
# PATCHED: used a tensor with an empty shape and not en empty list to initialize
30-
self.keys = torch.empty(new_shape, dtype=self.dtype, device=self.device)
31-
self.values = torch.empty(new_shape, dtype=self.dtype, device=self.device)
30+
if isinstance(key_states, torch._subclasses.fake_tensor.FakeTensor):
31+
with key_states.fake_mode:
32+
self.keys = torch.empty(new_shape, dtype=self.dtype, device=self.device)
33+
self.values = torch.empty(new_shape, dtype=self.dtype, device=self.device)
34+
else:
35+
self.keys = torch.empty(new_shape, dtype=self.dtype, device=self.device)
36+
self.values = torch.empty(new_shape, dtype=self.dtype, device=self.device)
3237
if patch_is_initialized:
3338
self.is_initialized = True
3439

0 commit comments

Comments
 (0)