@@ -49,7 +49,7 @@ def forward(self, cache):
4949 [[{0 : DYN }, {0 : DYN }, {0 : DYN }], [{0 : DYN }, {0 : DYN }, {0 : DYN }]],
5050 ]
5151
52- with bypass_export_some_errors ():
52+ with bypass_export_some_errors (patch_transformers = True ):
5353 torch .export .export (model , (cache ,), dynamic_shapes = (ds ,))
5454
5555 @ignore_warnings (UserWarning )
@@ -99,6 +99,21 @@ def test_base_model_output(self):
9999 self .string_type (bo2 , with_shape = True , with_min_max = True ),
100100 )
101101
102+ @ignore_warnings (UserWarning )
103+ def test_export_base_model_output (self ):
104+ class Model (torch .nn .Module ):
105+ def forward (self , cache ):
106+ return cache .last_hidden_state [0 ]
107+
108+ bo = BaseModelOutput (last_hidden_state = torch .rand ((4 , 4 , 4 )))
109+ model = Model ()
110+ model (bo )
111+ DYN = torch .export .Dim .DYNAMIC
112+ ds = [{0 : DYN }]
113+
114+ with bypass_export_some_errors ():
115+ torch .export .export (model , (bo ,), dynamic_shapes = (ds ,))
116+
102117
103118if __name__ == "__main__" :
104119 unittest .main (verbosity = 2 )
0 commit comments