@@ -1089,6 +1089,7 @@ def _encode(self, x: torch.Tensor) -> torch.Tensor:
10891089 return self .tiled_encode (x )
10901090
10911091 frame_batch_size = self .num_sample_frames_batch_size
1092+ # Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
10921093 num_batches = num_frames // frame_batch_size if num_frames > 1 else 1
10931094 enc = []
10941095 for i in range (num_batches ):
@@ -1141,8 +1142,9 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOut
11411142 return self .tiled_decode (z , return_dict = return_dict )
11421143
11431144 frame_batch_size = self .num_latent_frames_batch_size
1145+ num_batches = num_frames // frame_batch_size
11441146 dec = []
1145- for i in range (num_frames // frame_batch_size ):
1147+ for i in range (num_batches ):
11461148 remaining_frames = num_frames % frame_batch_size
11471149 start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames )
11481150 end_frame = frame_batch_size * (i + 1 ) + remaining_frames
@@ -1234,6 +1236,7 @@ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
12341236 for i in range (0 , height , overlap_height ):
12351237 row = []
12361238 for j in range (0 , width , overlap_width ):
1239+ # Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
12371240 num_batches = num_frames // frame_batch_size if num_frames > 1 else 1
12381241 time = []
12391242 for k in range (num_batches ):
@@ -1311,8 +1314,9 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod
13111314 for i in range (0 , height , overlap_height ):
13121315 row = []
13131316 for j in range (0 , width , overlap_width ):
1317+ num_batches = num_frames // frame_batch_size
13141318 time = []
1315- for k in range (num_frames // frame_batch_size ):
1319+ for k in range (num_batches ):
13161320 remaining_frames = num_frames % frame_batch_size
13171321 start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames )
13181322 end_frame = frame_batch_size * (k + 1 ) + remaining_frames
0 commit comments