Skip to content

Commit c4c99c3

Browse files
authored
[tests] Fix broken cuda, nightly and lora tests on main for CogVideoX (huggingface#10270)
fix joint pos embedding device
1 parent 862a7d5 commit c4c99c3

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/diffusers/models/embeddings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -691,7 +691,7 @@ def _get_positional_embeddings(
691691
output_type="pt",
692692
)
693693
pos_embedding = pos_embedding.flatten(0, 1)
694-
joint_pos_embedding = torch.zeros(
694+
joint_pos_embedding = pos_embedding.new_zeros(
695695
1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False
696696
)
697697
joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding)

0 commit comments

Comments
 (0)