Skip to content

Commit 46b8402

Browse files
committed
fix unit test
1 parent 40cadf5 commit 46b8402

File tree

2 files changed

+89
-5
lines changed

2 files changed

+89
-5
lines changed

_unittests/ut_tasks/test_tasks_image_to_video.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ class TestTasksImageToVideo(ExtTestCase):
1717
@hide_stdout()
1818
@requires_diffusers("0.35")
1919
@requires_transformers("4.55")
20-
@requires_torch("2.8.99")
21-
def test_image_to_video(self):
20+
@requires_torch("2.10.99")
21+
def test_image_to_video_oblivious(self):
2222
kwargs = {
2323
"_diffusers_version": "0.34.0.dev0",
2424
"_class_name": "CosmosTransformer3DModel",
@@ -63,6 +63,53 @@ def test_image_to_video(self):
6363
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
6464
)
6565

66+
@hide_stdout()
67+
@requires_diffusers("0.35")
68+
@requires_transformers("4.55")
69+
@requires_torch("2.8.99")
70+
def test_image_to_video_not_oblivious(self):
71+
kwargs = {
72+
"_diffusers_version": "0.34.0.dev0",
73+
"_class_name": "CosmosTransformer3DModel",
74+
"max_size": [128, 240, 240],
75+
"text_embed_dim": 128,
76+
"use_cache": True,
77+
"in_channels": 3,
78+
"out_channels": 16,
79+
"num_layers": 2,
80+
"model_type": "dia",
81+
"patch_size": [1, 2, 2],
82+
"rope_scale": [1.0, 3.0, 3.0],
83+
"attention_head_dim": 16,
84+
"mlp_ratio": 0.4,
85+
"initializer_range": 0.02,
86+
"num_attention_heads": 16,
87+
"is_encoder_decoder": True,
88+
"adaln_lora_dim": 16,
89+
"concat_padding_mask": True,
90+
"extra_pos_embed_type": None,
91+
}
92+
config = transformers.DiaConfig(**kwargs)
93+
mid = "nvidia/Cosmos-Predict2-2B-Video2World"
94+
data = get_untrained_model_with_inputs(
95+
mid,
96+
verbose=1,
97+
add_second_input=True,
98+
subfolder="transformer",
99+
config=config,
100+
inputs_kwargs=dict(image_height=8 * 50, image_width=8 * 80),
101+
)
102+
self.assertEqual(data["task"], "image-to-video")
103+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
104+
model(**inputs)
105+
model(**data["inputs2"])
106+
with torch_export_patches(
107+
patch_transformers=True, patch_diffusers=True, verbose=10, stop_if_static=1
108+
):
109+
torch.export.export(
110+
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
111+
)
112+
66113

67114
if __name__ == "__main__":
68115
unittest.main(verbosity=2)

_unittests/ut_torch_models/test_llm_phi2.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88
)
99
from onnx_diagnostic.torch_models.llms import get_phi2
1010
from onnx_diagnostic.helpers import string_type
11-
from onnx_diagnostic.torch_export_patches import torch_export_patches
11+
from onnx_diagnostic.torch_export_patches import (
12+
torch_export_patches,
13+
register_additional_serialization_functions,
14+
)
1215
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
1316

1417

@@ -21,8 +24,8 @@ def test_get_phi2(self):
2124

2225
@ignore_warnings(UserWarning)
2326
@requires_transformers("4.54")
24-
@requires_torch("2.9.99")
25-
def test_export_phi2_1_batch_size_1(self):
27+
@requires_torch("2.10.99")
28+
def test_export_phi2_1_batch_size_1_oblivious(self):
2629
# exporting vmap does not work
2730
data = get_phi2(num_hidden_layers=2, batch_size=1)
2831
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
@@ -38,6 +41,40 @@ def test_export_phi2_1_batch_size_1(self):
3841
)
3942
assert ep
4043

44+
@ignore_warnings(UserWarning)
45+
@requires_transformers("4.54")
46+
@requires_torch("2.9.99")
47+
def test_export_phi2_1_batch_size_1_not_oblivious(self):
48+
# exporting vmap does not work
49+
data = get_phi2(num_hidden_layers=2, batch_size=1)
50+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
51+
self.assertEqual(inputs["input_ids"].shape[0], 1)
52+
self.assertEqual(
53+
{"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs)
54+
)
55+
with torch_export_patches(patch_transformers=True):
56+
ep = torch.export.export(
57+
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)
58+
)
59+
assert ep
60+
61+
@ignore_warnings(UserWarning)
62+
@requires_transformers("4.54")
63+
@requires_torch("2.12")
64+
def test_export_phi2_1_batch_size_1_no_patch(self):
65+
# exporting vmap does not work
66+
data = get_phi2(num_hidden_layers=2, batch_size=1)
67+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
68+
self.assertEqual(inputs["input_ids"].shape[0], 1)
69+
self.assertEqual(
70+
{"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs)
71+
)
72+
with register_additional_serialization_functions(patch_transformers=True):
73+
ep = torch.export.export(
74+
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)
75+
)
76+
assert ep
77+
4178
@ignore_warnings(UserWarning)
4279
@requires_transformers("4.54")
4380
@requires_torch("2.9.99")

0 commit comments

Comments
 (0)