Skip to content

Commit 598d5ea

Browse files
authored
add support for fx.Proxy (#374)
* add support for fx.Proxy * fix fake * fix * rewrite patch * fix * fix
1 parent 4b8160a commit 598d5ea

File tree

4 files changed

+20
-14
lines changed

4 files changed

+20
-14
lines changed

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,6 @@ def __init__(self, cache=None):
2828
]
2929
self.key_cache = [layer.keys for layer in layers]
3030
self.value_cache = [layer.values for layer in layers]
31-
if None in self.key_cache or None in self.value_cache:
32-
from .helper import string_type
33-
34-
raise AssertionError(
35-
f"issue with key_cache={string_type(self.key_cache)}, "
36-
f"or value_cache={string_type(self.value_cache)}, "
37-
f"cache.layers={string_type(cache.layers)}"
38-
)
3931
elif cache is not None and hasattr(cache, "key_cache"):
4032
self.key_cache = cache.key_cache
4133
self.value_cache = cache.value_cache

onnx_diagnostic/helpers/fake_tensor_helper.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ def fake_reshape(
105105
reduced_tensor = self.from_tensor(true_tensor, static_shapes=True).sum(
106106
axis=tuple(sorted(sh)), keepdim=True
107107
)
108+
if len(reduced_tensor.shape) == 0 == len(new_shape):
109+
return reduced_tensor
108110
return reduced_tensor.expand(*new_shape)
109111

110112
def make_fake(self, x: Any) -> Optional["FakeTensor"]: # noqa: F821
@@ -157,7 +159,9 @@ def make_fake_with_dynamic_dimensions(self, x: Any, dynamic_shapes: Any) -> Any:
157159
)
158160
if type(x) is dict:
159161
return {
160-
k: self.make_fake_with_dynamic_dimensions(v, dynamic_shapes=dynamic_shapes[k])
162+
k: self.make_fake_with_dynamic_dimensions(
163+
v, dynamic_shapes=dynamic_shapes[k] if dynamic_shapes else None
164+
)
161165
for k, v in x.items()
162166
}
163167
if x.__class__.__name__ in {"DynamicCache", "StaticCache", "HybridCache"}:
@@ -231,7 +235,7 @@ def make_fake_with_dynamic_dimensions(self, x: Any, dynamic_shapes: Any) -> Any:
231235

232236
x = torch.empty(tuple(new_shape), dtype=x.dtype, device=x.device)
233237

234-
t = self.fake_reshape(x, dynamic_shapes) # type: ignore[arg-type]
238+
t = self.fake_reshape(x, dynamic_shapes) if dynamic_shapes else x # type: ignore[arg-type]
235239
assert t.device == x.device, f"device mismatch {x.device} -> {t.device}"
236240
assert t.dtype == x.dtype, f"dtype mismatch {x.dtype} -> {t.dtype}"
237241
return t

onnx_diagnostic/helpers/helper.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -801,6 +801,9 @@ def _torch_sym_int_to_str(value: "torch.SymInt") -> Union[int, str]: # noqa: F
801801
print(f"[string_type] TT8:{type(obj)}")
802802
return repr(obj).replace(" ", "").replace("\n", " ")
803803

804+
if isinstance(obj, torch.fx.proxy.Proxy):
805+
return repr(obj)
806+
804807
if ignore:
805808
if verbose:
806809
print(f"[string_type] CACHE4:{type(obj)}")

onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,18 @@ class patched_DynamicLayer:
2424

2525
def lazy_initialization(self, key_states: torch.Tensor):
2626
self.dtype, self.device = key_states.dtype, key_states.device
27-
new_shape = list(key_states.shape)
28-
new_shape[-2] = 0
27+
assert (
28+
hasattr(key_states, "shape") and key_states is not None
29+
), f"Attribute 'shape' is wrong for type {type(key_states)}"
30+
like = torch.narrow(key_states, dim=-2, start=0, length=0)
2931
# PATCHED: used a tensor with an empty shape and not en empty list to initialize
30-
self.keys = torch.empty(new_shape, dtype=self.dtype, device=self.device)
31-
self.values = torch.empty(new_shape, dtype=self.dtype, device=self.device)
32+
if isinstance(key_states, torch._subclasses.fake_tensor.FakeTensor):
33+
with key_states.fake_mode:
34+
self.keys = torch.empty_like(like, dtype=self.dtype, device=self.device)
35+
self.values = torch.empty_like(like, dtype=self.dtype, device=self.device)
36+
else:
37+
self.keys = torch.empty_like(like, dtype=self.dtype, device=self.device)
38+
self.values = torch.empty_like(like, dtype=self.dtype, device=self.device)
3239
if patch_is_initialized:
3340
self.is_initialized = True
3441

0 commit comments

Comments
 (0)