|
1 | 1 | import unittest |
2 | 2 | import torch |
3 | 3 | from transformers.modeling_outputs import BaseModelOutput |
4 | | -from onnx_diagnostic.ext_test_case import ( |
5 | | - ExtTestCase, |
6 | | - ignore_warnings, |
7 | | - requires_torch, |
8 | | - requires_diffusers, |
9 | | -) |
| 4 | +from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings, requires_torch |
10 | 5 | from onnx_diagnostic.helpers.cache_helper import ( |
11 | 6 | make_encoder_decoder_cache, |
12 | 7 | make_dynamic_cache, |
@@ -159,7 +154,7 @@ def forward(self, cache): |
159 | 154 | DYN = torch.export.Dim.DYNAMIC |
160 | 155 | ds = [{0: DYN}] |
161 | 156 |
|
162 | | - with torch_export_patches(): |
| 157 | + with torch_export_patches(patch_transformers=True): |
163 | 158 | torch.export.export(model, (bo,), dynamic_shapes=(ds,)) |
164 | 159 |
|
165 | 160 | @ignore_warnings(UserWarning) |
@@ -218,63 +213,6 @@ def test_sliding_window_cache_flatten(self): |
218 | 213 | self.string_type(cache2, with_shape=True, with_min_max=True), |
219 | 214 | ) |
220 | 215 |
|
221 | | - @ignore_warnings(UserWarning) |
222 | | - @requires_diffusers("0.30") |
223 | | - def test_unet_2d_condition_output(self): |
224 | | - from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput |
225 | | - |
226 | | - bo = UNet2DConditionOutput(sample=torch.rand((4, 4, 4))) |
227 | | - self.assertEqual(bo.__class__.__name__, "UNet2DConditionOutput") |
228 | | - bo2 = torch_deepcopy([bo]) |
229 | | - self.assertIsInstance(bo2, list) |
230 | | - self.assertEqual( |
231 | | - "UNet2DConditionOutput(sample:T1s4x4x4)", |
232 | | - self.string_type(bo, with_shape=True), |
233 | | - ) |
234 | | - |
235 | | - with torch_export_patches(): |
236 | | - # internal function |
237 | | - bo2 = torch_deepcopy([bo]) |
238 | | - self.assertIsInstance(bo2, list) |
239 | | - self.assertEqual(bo2[0].__class__.__name__, "UNet2DConditionOutput") |
240 | | - self.assertEqualAny([bo], bo2) |
241 | | - self.assertEqual( |
242 | | - "UNet2DConditionOutput(sample:T1s4x4x4)", |
243 | | - self.string_type(bo, with_shape=True), |
244 | | - ) |
245 | | - |
246 | | - # serialization |
247 | | - flat, _spec = torch.utils._pytree.tree_flatten(bo) |
248 | | - self.assertEqual( |
249 | | - "#1[T1s4x4x4]", |
250 | | - self.string_type(flat, with_shape=True), |
251 | | - ) |
252 | | - bo2 = torch.utils._pytree.tree_unflatten(flat, _spec) |
253 | | - self.assertEqual( |
254 | | - self.string_type(bo, with_shape=True, with_min_max=True), |
255 | | - self.string_type(bo2, with_shape=True, with_min_max=True), |
256 | | - ) |
257 | | - |
258 | | - # flatten_unflatten |
259 | | - flat, _spec = torch.utils._pytree.tree_flatten(bo) |
260 | | - unflat = flatten_unflatten_for_dynamic_shapes(bo, use_dict=True) |
261 | | - self.assertIsInstance(unflat, dict) |
262 | | - self.assertEqual(list(unflat), ["sample"]) |
263 | | - |
264 | | - # export |
265 | | - class Model(torch.nn.Module): |
266 | | - def forward(self, cache): |
267 | | - return cache.sample[0] |
268 | | - |
269 | | - bo = UNet2DConditionOutput(sample=torch.rand((4, 4, 4))) |
270 | | - model = Model() |
271 | | - model(bo) |
272 | | - DYN = torch.export.Dim.DYNAMIC |
273 | | - ds = [{0: DYN}] |
274 | | - |
275 | | - with torch_export_patches(): |
276 | | - torch.export.export(model, (bo,), dynamic_shapes=(ds,)) |
277 | | - |
278 | 216 | @ignore_warnings(UserWarning) |
279 | 217 | @requires_torch("2.7.99") |
280 | 218 | def test_static_cache(self): |
|
0 commit comments