Skip to content

Commit 877cdc0

Browse files
committed
update vae
1 parent 9f6f3f6 commit 877cdc0

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1089,6 +1089,7 @@ def _encode(self, x: torch.Tensor) -> torch.Tensor:
10891089
return self.tiled_encode(x)
10901090

10911091
frame_batch_size = self.num_sample_frames_batch_size
1092+
# Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
10921093
num_batches = num_frames // frame_batch_size if num_frames > 1 else 1
10931094
enc = []
10941095
for i in range(num_batches):
@@ -1141,8 +1142,9 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOut
11411142
return self.tiled_decode(z, return_dict=return_dict)
11421143

11431144
frame_batch_size = self.num_latent_frames_batch_size
1145+
num_batches = num_frames // frame_batch_size
11441146
dec = []
1145-
for i in range(num_frames // frame_batch_size):
1147+
for i in range(num_batches):
11461148
remaining_frames = num_frames % frame_batch_size
11471149
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
11481150
end_frame = frame_batch_size * (i + 1) + remaining_frames
@@ -1234,6 +1236,7 @@ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
12341236
for i in range(0, height, overlap_height):
12351237
row = []
12361238
for j in range(0, width, overlap_width):
1239+
# Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
12371240
num_batches = num_frames // frame_batch_size if num_frames > 1 else 1
12381241
time = []
12391242
for k in range(num_batches):
@@ -1311,8 +1314,9 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod
13111314
for i in range(0, height, overlap_height):
13121315
row = []
13131316
for j in range(0, width, overlap_width):
1317+
num_batches = num_frames // frame_batch_size
13141318
time = []
1315-
for k in range(num_frames // frame_batch_size):
1319+
for k in range(num_batches):
13161320
remaining_frames = num_frames % frame_batch_size
13171321
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
13181322
end_frame = frame_batch_size * (k + 1) + remaining_frames

0 commit comments

Comments
 (0)