@@ -76,6 +76,7 @@ def prepare_init_args_and_inputs_for_common(self):
7676 "sample_height" : 8 ,
7777 "sample_frames" : 8 ,
7878 "patch_size" : 2 ,
79+ "patch_size_t" : None ,
7980 "temporal_compression_ratio" : 4 ,
8081 "max_text_seq_length" : 8 ,
8182 }
@@ -85,3 +86,63 @@ def prepare_init_args_and_inputs_for_common(self):
8586 def test_gradient_checkpointing_is_applied (self ):
8687 expected_set = {"CogVideoXTransformer3DModel" }
8788 super ().test_gradient_checkpointing_is_applied (expected_set = expected_set )
89+
90+
91+ class CogVideoX1_5TransformerTests (ModelTesterMixin , unittest .TestCase ):
92+ model_class = CogVideoXTransformer3DModel
93+ main_input_name = "hidden_states"
94+ uses_custom_attn_processor = True
95+
96+ @property
97+ def dummy_input (self ):
98+ batch_size = 2
99+ num_channels = 4
100+ num_frames = 1
101+ height = 8
102+ width = 8
103+ embedding_dim = 8
104+ sequence_length = 8
105+
106+ hidden_states = torch .randn ((batch_size , num_frames , num_channels , height , width )).to (torch_device )
107+ encoder_hidden_states = torch .randn ((batch_size , sequence_length , embedding_dim )).to (torch_device )
108+ timestep = torch .randint (0 , 1000 , size = (batch_size ,)).to (torch_device )
109+
110+ return {
111+ "hidden_states" : hidden_states ,
112+ "encoder_hidden_states" : encoder_hidden_states ,
113+ "timestep" : timestep ,
114+ }
115+
116+ @property
117+ def input_shape (self ):
118+ return (1 , 4 , 8 , 8 )
119+
120+ @property
121+ def output_shape (self ):
122+ return (1 , 4 , 8 , 8 )
123+
124+ def prepare_init_args_and_inputs_for_common (self ):
125+ init_dict = {
126+ # Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings.
127+ "num_attention_heads" : 2 ,
128+ "attention_head_dim" : 8 ,
129+ "in_channels" : 4 ,
130+ "out_channels" : 4 ,
131+ "time_embed_dim" : 2 ,
132+ "text_embed_dim" : 8 ,
133+ "num_layers" : 1 ,
134+ "sample_width" : 8 ,
135+ "sample_height" : 8 ,
136+ "sample_frames" : 8 ,
137+ "patch_size" : 2 ,
138+ "patch_size_t" : 2 ,
139+ "temporal_compression_ratio" : 4 ,
140+ "max_text_seq_length" : 8 ,
141+ "use_rotary_positional_embeddings" : True ,
142+ }
143+ inputs_dict = self .dummy_input
144+ return init_dict , inputs_dict
145+
146+ def test_gradient_checkpointing_is_applied (self ):
147+ expected_set = {"CogVideoXTransformer3DModel" }
148+ super ().test_gradient_checkpointing_is_applied (expected_set = expected_set )
0 commit comments