Skip to content

Commit dc74137

Browse files
authored
Fix patched_vmap (#275)
* Fix patched_vmap * documentation
1 parent 0d8811b commit dc74137

File tree

4 files changed

+93
-5
lines changed

4 files changed

+93
-5
lines changed

CHANGELOGS.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ Change Logs
44
0.7.17
55
++++++
66

7+
* :pr:`275`: fixes function ``patched_vmap``
8+
79
0.7.16
810
++++++
911

_unittests/ut_torch_export_patches/test_patch_torch.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,23 @@ def test_vmap(self):
2828
got = patched_vmap(f)(x, y)
2929
self.assertEqualArray(expected, got)
3030

31+
@requires_transformers("4.52")
32+
def test_export_patched_vmap_dynamic_shapes(self):
33+
from onnx_diagnostic.torch_export_patches.patches.patch_torch import patched_vmap
34+
35+
class Model(torch.nn.Module):
36+
def forward(self, x, y):
37+
f = lambda x, y: x * y + 1 # noqa: E731
38+
return patched_vmap(f)(x, y)
39+
40+
x = torch.tensor([1.0, 2.0, 3.0])
41+
y = torch.tensor([0.1, 0.2, 0.3])
42+
expected = Model()(x, y)
43+
DYN = torch.export.Dim.DYNAMIC
44+
ep = torch.export.export(Model(), (x, y), dynamic_shapes=({0: DYN}, {0: DYN}))
45+
got = ep.module()(x, y)
46+
self.assertEqualArray(expected, got)
47+
3148
@requires_torch("2.10")
3249
def test_export_vmap(self):
3350
class Model(torch.nn.Module):
@@ -56,6 +73,55 @@ def forward(self, x, y):
5673
ep = torch.export.export(Model(), (x, y))
5774
self.assertEqualArray(Model()(x, y), ep.module()(x, y))
5875

76+
@requires_torch("2.8")
77+
@requires_transformers("4.52")
78+
def test_export_patched_vmap_scan(self):
79+
from onnx_diagnostic.torch_export_patches.patches.patch_torch import patched_vmap
80+
81+
x = torch.tensor([1.0, 2.0, 3.0])
82+
y = torch.tensor([0.1, 0.2, 0.3])
83+
res = torch.ops.higher_order.scan(lambda x, y: x + y, [], [x, y], [])
84+
self.assertEqualArray(x + y, res[0])
85+
86+
class ModelVmap(torch.nn.Module):
87+
def forward(self, x, y):
88+
f = lambda x, y: x * y + 1 # noqa: E731
89+
return torch.vmap(f)(x, y)
90+
91+
expected = ModelVmap()(x, y)
92+
93+
class ModelNoScan(torch.nn.Module):
94+
def forward(self, x, y):
95+
f = lambda x, y: x * y + 1 # noqa: E731
96+
return patched_vmap(f, use_scan=False)(x, y)
97+
98+
expected2 = ModelNoScan()(x, y)
99+
self.assertEqualArray(expected, expected2)
100+
101+
class ModelScan(torch.nn.Module):
102+
def forward(self, x, y):
103+
f = lambda x, y: [x * y + 1] # noqa: E731
104+
return torch.ops.higher_order.scan(f, [], [x, y], [])[0]
105+
106+
expected2 = ModelNoScan()(x, y)
107+
self.assertEqualArray(expected, expected2)
108+
ep = torch.export.export(ModelScan(), (x, y))
109+
self.assertEqualArray(expected, ep.module()(x, y))
110+
111+
DYN = torch.export.Dim.DYNAMIC
112+
ep = torch.export.export(ModelScan(), (x, y), dynamic_shapes=({0: DYN}, {0: DYN}))
113+
self.assertEqualArray(expected, ep.module()(x, y))
114+
115+
class Model(torch.nn.Module):
116+
def forward(self, x, y):
117+
f = lambda x, y: x * y + 1 # noqa: E731
118+
return patched_vmap(f, use_scan=True)(x, y)
119+
120+
expected2 = Model()(x, y)
121+
self.assertEqualArray(expected, expected2)
122+
ep = torch.export.export(Model(), (x, y), dynamic_shapes=({0: DYN}, {0: DYN}))
123+
self.assertEqualArray(expected, ep.module()(x, y))
124+
59125
@requires_transformers("4.52")
60126
def test_vmap_outdim(self):
61127
from onnx_diagnostic.torch_export_patches.patches.patch_torch import patched_vmap

onnx_diagnostic/export/shape_helper.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,23 @@ def make_fake_with_dynamic_dimensions(
212212
constraints as the following dynamic shapes.
213213
This uses function :func:`onnx_diagnostic.helpers.fake_tensor_helper.make_fake`.
214214
215+
A simple tensor:
216+
217+
.. runpython::
218+
:showcode:
219+
220+
import torch
221+
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
222+
from onnx_diagnostic.export.shape_helper import make_fake_with_dynamic_dimensions
223+
224+
inputs, _ = make_fake_with_dynamic_dimensions(
225+
torch.rand((2, 3, 4, 5), dtype=dtype=torch.float32),
226+
{0: "batch", 2: "cache_length"},
227+
)
228+
print(inputs)
229+
230+
With a cache:
231+
215232
.. runpython::
216233
:showcode:
217234

onnx_diagnostic/torch_export_patches/patches/patch_torch.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -683,7 +683,7 @@ def compute_concrete_val() -> sympy.Basic:
683683
return concrete_val
684684

685685

686-
def patched_vmap(func, in_dims=0, out_dims=0):
686+
def patched_vmap(func, in_dims=0, out_dims=0, use_scan: bool = False):
687687
"""
688688
Python implementation of :func:`torch.vmap`.
689689
The implementation raises an issue when it is being exported with
@@ -724,8 +724,9 @@ def wrapped(*args):
724724
arg = arg.movedim(in_dim, 0)
725725
batched_args.append(arg)
726726

727-
if all(isinstance(a, torch.Tensor) for a in args) and isinstance(
728-
batch_size, torch.SymInt
727+
if use_scan or (
728+
all(isinstance(a, torch.Tensor) for a in args)
729+
and isinstance(batch_size, torch.SymInt)
729730
):
730731
batched_tensors = [
731732
(
@@ -735,7 +736,9 @@ def wrapped(*args):
735736
)
736737
for arg, in_dim in zip(batched_args, in_dims_)
737738
]
738-
results = torch.ops.higher_order.scan(func, [], batched_tensors, [])
739+
results = torch.ops.higher_order.scan(
740+
lambda *args, **kwargs: [func(*args, **kwargs)], [], batched_tensors, []
741+
)
739742
stacked = results[0]
740743
if out_dims != 0:
741744
return stacked.movedim(0, out_dims)
@@ -745,7 +748,7 @@ def wrapped(*args):
745748
torch._check(
746749
not isinstance(batch_size, torch.SymInt),
747750
lambda: (
748-
f"patched_vmap supports dynamic batch_size only if all argument "
751+
f"patched_vmap supports dynamic batch_size only if all arguments "
749752
f"are tensors but types are {[type(a) for a in args]}"
750753
),
751754
)

0 commit comments

Comments
 (0)