@@ -28,6 +28,23 @@ def test_vmap(self):
2828 got = patched_vmap (f )(x , y )
2929 self .assertEqualArray (expected , got )
3030
31+ @requires_transformers ("4.52" )
32+ def test_export_patched_vmap_dynamic_shapes (self ):
33+ from onnx_diagnostic .torch_export_patches .patches .patch_torch import patched_vmap
34+
35+ class Model (torch .nn .Module ):
36+ def forward (self , x , y ):
37+ f = lambda x , y : x * y + 1 # noqa: E731
38+ return patched_vmap (f )(x , y )
39+
40+ x = torch .tensor ([1.0 , 2.0 , 3.0 ])
41+ y = torch .tensor ([0.1 , 0.2 , 0.3 ])
42+ expected = Model ()(x , y )
43+ DYN = torch .export .Dim .DYNAMIC
44+ ep = torch .export .export (Model (), (x , y ), dynamic_shapes = ({0 : DYN }, {0 : DYN }))
45+ got = ep .module ()(x , y )
46+ self .assertEqualArray (expected , got )
47+
3148 @requires_torch ("2.10" )
3249 def test_export_vmap (self ):
3350 class Model (torch .nn .Module ):
@@ -56,6 +73,55 @@ def forward(self, x, y):
5673 ep = torch .export .export (Model (), (x , y ))
5774 self .assertEqualArray (Model ()(x , y ), ep .module ()(x , y ))
5875
76+ @requires_torch ("2.8" )
77+ @requires_transformers ("4.52" )
78+ def test_export_patched_vmap_scan (self ):
79+ from onnx_diagnostic .torch_export_patches .patches .patch_torch import patched_vmap
80+
81+ x = torch .tensor ([1.0 , 2.0 , 3.0 ])
82+ y = torch .tensor ([0.1 , 0.2 , 0.3 ])
83+ res = torch .ops .higher_order .scan (lambda x , y : x + y , [], [x , y ], [])
84+ self .assertEqualArray (x + y , res [0 ])
85+
86+ class ModelVmap (torch .nn .Module ):
87+ def forward (self , x , y ):
88+ f = lambda x , y : x * y + 1 # noqa: E731
89+ return torch .vmap (f )(x , y )
90+
91+ expected = ModelVmap ()(x , y )
92+
93+ class ModelNoScan (torch .nn .Module ):
94+ def forward (self , x , y ):
95+ f = lambda x , y : x * y + 1 # noqa: E731
96+ return patched_vmap (f , use_scan = False )(x , y )
97+
98+ expected2 = ModelNoScan ()(x , y )
99+ self .assertEqualArray (expected , expected2 )
100+
101+ class ModelScan (torch .nn .Module ):
102+ def forward (self , x , y ):
103+ f = lambda x , y : [x * y + 1 ] # noqa: E731
104+ return torch .ops .higher_order .scan (f , [], [x , y ], [])[0 ]
105+
106+ expected2 = ModelNoScan ()(x , y )
107+ self .assertEqualArray (expected , expected2 )
108+ ep = torch .export .export (ModelScan (), (x , y ))
109+ self .assertEqualArray (expected , ep .module ()(x , y ))
110+
111+ DYN = torch .export .Dim .DYNAMIC
112+ ep = torch .export .export (ModelScan (), (x , y ), dynamic_shapes = ({0 : DYN }, {0 : DYN }))
113+ self .assertEqualArray (expected , ep .module ()(x , y ))
114+
115+ class Model (torch .nn .Module ):
116+ def forward (self , x , y ):
117+ f = lambda x , y : x * y + 1 # noqa: E731
118+ return patched_vmap (f , use_scan = True )(x , y )
119+
120+ expected2 = Model ()(x , y )
121+ self .assertEqualArray (expected , expected2 )
122+ ep = torch .export .export (Model (), (x , y ), dynamic_shapes = ({0 : DYN }, {0 : DYN }))
123+ self .assertEqualArray (expected , ep .module ()(x , y ))
124+
59125 @requires_transformers ("4.52" )
60126 def test_vmap_outdim (self ):
61127 from onnx_diagnostic .torch_export_patches .patches .patch_torch import patched_vmap
0 commit comments