@@ -1337,9 +1337,18 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod
13371337 tile_latent_min_width = self .tile_sample_min_width // self .spatial_compression_ratio
13381338 tile_latent_stride_height = self .tile_sample_stride_height // self .spatial_compression_ratio
13391339 tile_latent_stride_width = self .tile_sample_stride_width // self .spatial_compression_ratio
1340-
1341- blend_height = self .tile_sample_min_height - self .tile_sample_stride_height
1342- blend_width = self .tile_sample_min_width - self .tile_sample_stride_width
1340+ tile_sample_stride_height = self .tile_sample_stride_height
1341+ tile_sample_stride_width = self .tile_sample_stride_width
1342+ if self .config .patch_size is not None :
1343+ sample_height = sample_height // self .config .patch_size
1344+ sample_width = sample_width // self .config .patch_size
1345+ tile_sample_stride_height = tile_sample_stride_height // self .config .patch_size
1346+ tile_sample_stride_width = tile_sample_stride_width // self .config .patch_size
1347+ blend_height = self .tile_sample_min_height // self .config .patch_size - tile_sample_stride_height
1348+ blend_width = self .tile_sample_min_width // self .config .patch_size - tile_sample_stride_width
1349+ else :
1350+ blend_height = self .tile_sample_min_height - tile_sample_stride_height
1351+ blend_width = self .tile_sample_min_width - tile_sample_stride_width
13431352
13441353 # Split z into overlapping tiles and decode them separately.
13451354 # The tiles have an overlap to avoid seams between tiles.
@@ -1353,7 +1362,9 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod
13531362 self ._conv_idx = [0 ]
13541363 tile = z [:, :, k : k + 1 , i : i + tile_latent_min_height , j : j + tile_latent_min_width ]
13551364 tile = self .post_quant_conv (tile )
1356- decoded = self .decoder (tile , feat_cache = self ._feat_map , feat_idx = self ._conv_idx )
1365+ decoded = self .decoder (
1366+ tile , feat_cache = self ._feat_map , feat_idx = self ._conv_idx , first_chunk = (k == 0 )
1367+ )
13571368 time .append (decoded )
13581369 row .append (torch .cat (time , dim = 2 ))
13591370 rows .append (row )
@@ -1369,11 +1380,15 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod
13691380 tile = self .blend_v (rows [i - 1 ][j ], tile , blend_height )
13701381 if j > 0 :
13711382 tile = self .blend_h (row [j - 1 ], tile , blend_width )
1372- result_row .append (tile [:, :, :, : self . tile_sample_stride_height , : self . tile_sample_stride_width ])
1383+ result_row .append (tile [:, :, :, :tile_sample_stride_height , :tile_sample_stride_width ])
13731384 result_rows .append (torch .cat (result_row , dim = - 1 ))
1374-
13751385 dec = torch .cat (result_rows , dim = 3 )[:, :, :, :sample_height , :sample_width ]
13761386
1387+ if self .config .patch_size is not None :
1388+ dec = unpatchify (dec , patch_size = self .config .patch_size )
1389+
1390+ dec = torch .clamp (dec , min = - 1.0 , max = 1.0 )
1391+
13771392 if not return_dict :
13781393 return (dec ,)
13791394 return DecoderOutput (sample = dec )
0 commit comments