Skip to content

Commit dc622a9

Browse files
sywangyisayakpaul
andauthored
fix crash if tiling mode is enabled (huggingface#12521)
* fix crash in tiling mode is enabled Signed-off-by: Wang, Yi A <[email protected]> * fmt Signed-off-by: Wang, Yi A <[email protected]> --------- Signed-off-by: Wang, Yi A <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
1 parent ecfbc8f commit dc622a9

File tree

1 file changed

+21
-6
lines changed

1 file changed

+21
-6
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_wan.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1337,9 +1337,18 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod
13371337
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
13381338
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
13391339
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
1340-
1341-
blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
1342-
blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
1340+
tile_sample_stride_height = self.tile_sample_stride_height
1341+
tile_sample_stride_width = self.tile_sample_stride_width
1342+
if self.config.patch_size is not None:
1343+
sample_height = sample_height // self.config.patch_size
1344+
sample_width = sample_width // self.config.patch_size
1345+
tile_sample_stride_height = tile_sample_stride_height // self.config.patch_size
1346+
tile_sample_stride_width = tile_sample_stride_width // self.config.patch_size
1347+
blend_height = self.tile_sample_min_height // self.config.patch_size - tile_sample_stride_height
1348+
blend_width = self.tile_sample_min_width // self.config.patch_size - tile_sample_stride_width
1349+
else:
1350+
blend_height = self.tile_sample_min_height - tile_sample_stride_height
1351+
blend_width = self.tile_sample_min_width - tile_sample_stride_width
13431352

13441353
# Split z into overlapping tiles and decode them separately.
13451354
# The tiles have an overlap to avoid seams between tiles.
@@ -1353,7 +1362,9 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod
13531362
self._conv_idx = [0]
13541363
tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
13551364
tile = self.post_quant_conv(tile)
1356-
decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx)
1365+
decoded = self.decoder(
1366+
tile, feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=(k == 0)
1367+
)
13571368
time.append(decoded)
13581369
row.append(torch.cat(time, dim=2))
13591370
rows.append(row)
@@ -1369,11 +1380,15 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod
13691380
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
13701381
if j > 0:
13711382
tile = self.blend_h(row[j - 1], tile, blend_width)
1372-
result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
1383+
result_row.append(tile[:, :, :, :tile_sample_stride_height, :tile_sample_stride_width])
13731384
result_rows.append(torch.cat(result_row, dim=-1))
1374-
13751385
dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
13761386

1387+
if self.config.patch_size is not None:
1388+
dec = unpatchify(dec, patch_size=self.config.patch_size)
1389+
1390+
dec = torch.clamp(dec, min=-1.0, max=1.0)
1391+
13771392
if not return_dict:
13781393
return (dec,)
13791394
return DecoderOutput(sample=dec)

0 commit comments

Comments
 (0)