Skip to content

Commit c4f123b

Browse files
committed
fix fake export
1 parent c47979b commit c4f123b

File tree

4 files changed

+17
-3
lines changed

4 files changed

+17
-3
lines changed

onnx_diagnostic/export/dynamic_shapes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def _valid_shapes_tensor(cls, inputs, ds):
226226
for i, d in enumerate(inputs.shape):
227227
if i in ds and not isinstance(ds[i], int):
228228
# dynamic then
229-
if d in {0, 1}:
229+
if isinstance(d, int) and d in {0, 1}:
230230
# export issues for sure
231231
issues[i] = f"d=[{d}]"
232232
return issues if issues else None

onnx_diagnostic/export/shape_helper.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,8 @@ def make_fake_with_dynamic_dimensions(
306306
return x, fake_mode
307307
if hasattr(x, "shape"):
308308
t = fake_reshape(x, dynamic_shapes, fake_mode=fake_mode)
309+
assert t.device == x.device, f"device mismatch {x.device} -> {t.device}"
310+
assert t.dtype == x.dtype, f"dtype mismatch {x.dtype} -> {t.dtype}"
309311
return t, fake_mode
310312
from ..helpers import string_type
311313

onnx_diagnostic/helpers/fake_tensor_helper.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@ def fake_reshape(
3434
if true_tensor.shape[i] <= 1:
3535
expanded_shape = list(true_tensor.shape)
3636
expanded_shape[i] = _unique()
37-
true_tensor = torch.empty(tuple(expanded_shape), dtype=true_tensor.dtype)
37+
true_tensor = torch.empty(
38+
tuple(expanded_shape), dtype=true_tensor.dtype, device=true_tensor.device
39+
)
3840

3941
# deal with equivalent dimension
4042
new_shape = list(true_tensor.shape)
@@ -47,7 +49,9 @@ def fake_reshape(
4749
d = _unique()
4850
mapping[d] = s
4951
new_shape[i] = d
50-
true_tensor = torch.empty(tuple(new_shape), dtype=true_tensor.dtype)
52+
true_tensor = torch.empty(
53+
tuple(new_shape), dtype=true_tensor.dtype, device=true_tensor.device
54+
)
5155

5256
# now switch to FakeTensor
5357
if fake_mode is None:

onnx_diagnostic/torch_models/validate.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1909,6 +1909,14 @@ def call_torch_export_custom(
19091909
strict = "-strict" in exporter
19101910
args, kwargs = split_args_kwargs(data["inputs_export"])
19111911
ds = data.get("dynamic_shapes", None)
1912+
if "-fake" in exporter:
1913+
from onnx_diagnostic.export.shape_helper import make_fake_with_dynamic_dimensions
1914+
1915+
if verbose:
1916+
print("[call_torch_export_custom] switching to FakeTensor")
1917+
assert not args, f"Exporter {exporter!r} not implemented with fake tensors."
1918+
kwargs = torch_deepcopy(kwargs)
1919+
kwargs, _ = make_fake_with_dynamic_dimensions(kwargs, dynamic_shapes=ds)
19121920
opset = data.get("model_opset", None)
19131921
if opset:
19141922
summary["export_opset"] = opset

0 commit comments

Comments
 (0)