Skip to content

Commit 0e796f7

Browse files
committed
python_vmap
1 parent ffd492a commit 0e796f7

File tree

5 files changed

+145
-0
lines changed

5 files changed

+145
-0
lines changed

_doc/api/torch_export_patches/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ onnx_diagnostic.torch_export_patches
1010
onnx_export_serialization
1111
patches/index
1212
patch_expressions
13+
patch_helper
1314
patch_inputs
1415
patch_module
1516
patch_module_helper
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
onnx_diagnostic.torch_export_patches.patch_helper
3+
=================================================
4+
5+
.. automodule:: onnx_diagnostic.torch_export_patches.patch_helper
6+
:members:
7+
:no-undoc-members:
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import unittest
2+
import torch
3+
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch
4+
from onnx_diagnostic.torch_export_patches.patch_helper import py_vmap
5+
6+
7+
class TestPatchHelper(ExtTestCase):
8+
def test_vmap(self):
9+
f = lambda x, y: x * y + 1 # noqa: E731
10+
x = torch.tensor([1.0, 2.0, 3.0])
11+
y = torch.tensor([0.1, 0.2, 0.3])
12+
expected = torch.vmap(f)(x, y)
13+
got = py_vmap(f)(x, y)
14+
self.assertEqualArray(expected, got)
15+
16+
@requires_torch("2.9")
17+
def test_export_vmap(self):
18+
class Model(torch.nn.Module):
19+
def forward(self, x, y):
20+
f = lambda x, y: x * y + 1 # noqa: E731
21+
return torch.vmap(f)(x, y)
22+
23+
x = torch.tensor([1.0, 2.0, 3.0])
24+
y = torch.tensor([0.1, 0.2, 0.3])
25+
DYN = torch.export.Dim.DYNAMIC
26+
torch.export.export(Model(), (x, y), ({0: DYN}, {1: DYN}))
27+
28+
def test_export_py_vmap(self):
29+
class Model(torch.nn.Module):
30+
def forward(self, x, y):
31+
f = lambda x, y: x * y + 1 # noqa: E731
32+
return py_vmap(f)(x, y)
33+
34+
x = torch.tensor([1.0, 2.0, 3.0])
35+
y = torch.tensor([0.1, 0.2, 0.3])
36+
torch.export.export(Model(), (x, y))
37+
38+
def test_vmap_outdim(self):
39+
f = lambda x: x**2 # noqa: E731
40+
x = torch.randn(2, 5)
41+
expected = torch.vmap(f, out_dims=1)(x)
42+
got = py_vmap(f, out_dims=1)(x)
43+
self.assertEqualArray(expected, got)
44+
45+
def test_vmap_dict(self):
46+
f = lambda d: torch.dot(d["x"], d["y"]) # noqa: E731
47+
x, y = torch.randn(2, 5), torch.randn(5)
48+
input = {"x": x, "y": y}
49+
_expected = torch.vmap(f, in_dims=({"x": 0, "y": None},))(input)
50+
self.assertRaise(
51+
lambda: py_vmap(f, in_dims=({"x": 0, "y": None},))(input), AssertionError
52+
)
53+
# self.assertEqualArray(_expected, got)
54+
55+
def test_vmap_tuple(self):
56+
x, y = torch.randn(2, 5), torch.randn(5)
57+
expected = torch.vmap(torch.dot, in_dims=(0, None))(x, y)
58+
got = py_vmap(torch.dot, in_dims=(0, None))(x, y)
59+
self.assertEqualArray(expected, got)
60+
61+
62+
if __name__ == "__main__":
63+
unittest.main(verbosity=2)

onnx_diagnostic/torch_export_patches/eval/model_cases.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
import torch
3+
from ..patch_helper import py_vmap
34

45
DIM = torch.export.Dim
56
DYN = torch.export.Dim.DYNAMIC
@@ -875,3 +876,21 @@ def forward(self, x):
875876

876877
_inputs = [(torch.rand((4, 4)),), (torch.rand((5, 5)),)]
877878
_dynamic = {"x": {0: DIM("dx"), 1: DIM("dy")}}
879+
880+
881+
class Vmap(torch.nn.Module):
882+
def forward(self, x, y):
883+
f = lambda x, y: x * y + 1 # noqa: E731
884+
return torch.vmap(f)(x, y)
885+
886+
_inputs = [(torch.tensor([1.0, 2.0, 3.0]), torch.tensor([0.1, 0.2, 0.3]))]
887+
_dynamic = {"x": {0: DYN}, "y": {0: DYN}}
888+
889+
890+
class VmapPython(torch.nn.Module):
891+
def forward(self, x, y):
892+
f = lambda x, y: x * y + 1 # noqa: E731
893+
return py_vmap(f)(x, y)
894+
895+
_inputs = [(torch.tensor([1.0, 2.0, 3.0]), torch.tensor([0.1, 0.2, 0.3]))]
896+
_dynamic = {"x": {0: DYN}, "y": {0: DYN}}
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import torch
2+
from ..helpers import string_type
3+
4+
5+
def py_vmap(func, in_dims=0, out_dims=0):
6+
"""
7+
Python implementation of :func:`torch.vmap`.
8+
"""
9+
10+
def wrapped(*args):
11+
assert all(not isinstance(a, dict) for a in args), (
12+
f"dictionaries are not implemented in "
13+
f"args={string_type(args, with_shape=True)}"
14+
)
15+
16+
in_dims_ = (
17+
([in_dims] * len(args))
18+
if not isinstance(in_dims, (list, tuple))
19+
else list(in_dims)
20+
)
21+
assert len(in_dims_) == len(args)
22+
23+
batch_size = None
24+
batched_args = []
25+
for arg, in_dim in zip(args, in_dims_):
26+
if in_dim is None:
27+
batched_args.append(arg)
28+
continue
29+
30+
assert batch_size is None or batch_size == arg.size(in_dim), (
31+
f"Unable to continue, batch_size={batch_size}, in_dim={in_dim}, "
32+
f"arg.size(in_dim)={arg.size(in_dim)}"
33+
)
34+
if batch_size is None:
35+
batch_size = arg.size(in_dim)
36+
arg = arg.movedim(in_dim, 0)
37+
batched_args.append(arg)
38+
39+
results = []
40+
for i in range(batch_size):
41+
input_slice = [
42+
(arg[i] if isinstance(arg, torch.Tensor) and in_dim is not None else arg)
43+
for arg, in_dim in zip(batched_args, in_dims_)
44+
]
45+
result = func(*input_slice)
46+
results.append(result)
47+
48+
if isinstance(results[0], torch.Tensor):
49+
stacked = torch.stack(results)
50+
if out_dims != 0:
51+
return stacked.movedim(0, out_dims)
52+
return stacked
53+
return results
54+
55+
return wrapped

0 commit comments

Comments
 (0)