Skip to content

Commit 2f41fb5

Browse files
committed
issue
1 parent 0e8b0c7 commit 2f41fb5

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

_unittests/ut_torch_export_patches/test_patch_serialization.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def forward(self, cache):
4949
[[{0: DYN}, {0: DYN}, {0: DYN}], [{0: DYN}, {0: DYN}, {0: DYN}]],
5050
]
5151

52-
with bypass_export_some_errors():
52+
with bypass_export_some_errors(patch_transformers=True):
5353
torch.export.export(model, (cache,), dynamic_shapes=(ds,))
5454

5555
@ignore_warnings(UserWarning)
@@ -99,6 +99,21 @@ def test_base_model_output(self):
9999
self.string_type(bo2, with_shape=True, with_min_max=True),
100100
)
101101

102+
@ignore_warnings(UserWarning)
103+
def test_export_base_model_output(self):
104+
class Model(torch.nn.Module):
105+
def forward(self, cache):
106+
return cache.last_hidden_state[0]
107+
108+
bo = BaseModelOutput(last_hidden_state=torch.rand((4, 4, 4)))
109+
model = Model()
110+
model(bo)
111+
DYN = torch.export.Dim.DYNAMIC
112+
ds = [{0: DYN}]
113+
114+
with bypass_export_some_errors():
115+
torch.export.export(model, (bo,), dynamic_shapes=(ds,))
116+
102117

103118
if __name__ == "__main__":
104119
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)