Skip to content

Commit 9edddc1

Browse files
committed
add modeling tests for cogvideox 1.5
1 parent e481843 commit 9edddc1

File tree

1 file changed

+61
-0
lines changed

1 file changed

+61
-0
lines changed

tests/models/transformers/test_models_transformer_cogvideox.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def prepare_init_args_and_inputs_for_common(self):
7676
"sample_height": 8,
7777
"sample_frames": 8,
7878
"patch_size": 2,
79+
"patch_size_t": None,
7980
"temporal_compression_ratio": 4,
8081
"max_text_seq_length": 8,
8182
}
@@ -85,3 +86,63 @@ def prepare_init_args_and_inputs_for_common(self):
8586
def test_gradient_checkpointing_is_applied(self):
8687
expected_set = {"CogVideoXTransformer3DModel"}
8788
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
89+
90+
91+
class CogVideoX1_5TransformerTests(ModelTesterMixin, unittest.TestCase):
92+
model_class = CogVideoXTransformer3DModel
93+
main_input_name = "hidden_states"
94+
uses_custom_attn_processor = True
95+
96+
@property
97+
def dummy_input(self):
98+
batch_size = 2
99+
num_channels = 4
100+
num_frames = 1
101+
height = 8
102+
width = 8
103+
embedding_dim = 8
104+
sequence_length = 8
105+
106+
hidden_states = torch.randn((batch_size, num_frames, num_channels, height, width)).to(torch_device)
107+
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
108+
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
109+
110+
return {
111+
"hidden_states": hidden_states,
112+
"encoder_hidden_states": encoder_hidden_states,
113+
"timestep": timestep,
114+
}
115+
116+
@property
117+
def input_shape(self):
118+
return (1, 4, 8, 8)
119+
120+
@property
121+
def output_shape(self):
122+
return (1, 4, 8, 8)
123+
124+
def prepare_init_args_and_inputs_for_common(self):
125+
init_dict = {
126+
# Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings.
127+
"num_attention_heads": 2,
128+
"attention_head_dim": 8,
129+
"in_channels": 4,
130+
"out_channels": 4,
131+
"time_embed_dim": 2,
132+
"text_embed_dim": 8,
133+
"num_layers": 1,
134+
"sample_width": 8,
135+
"sample_height": 8,
136+
"sample_frames": 8,
137+
"patch_size": 2,
138+
"patch_size_t": 2,
139+
"temporal_compression_ratio": 4,
140+
"max_text_seq_length": 8,
141+
"use_rotary_positional_embeddings": True,
142+
}
143+
inputs_dict = self.dummy_input
144+
return init_dict, inputs_dict
145+
146+
def test_gradient_checkpointing_is_applied(self):
147+
expected_set = {"CogVideoXTransformer3DModel"}
148+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

0 commit comments

Comments
 (0)