1111# limitations under the License.
1212# ==============================================================================
1313
14- from dataclasses import dataclass
15- from typing import Tuple , Optional
1614import math
17- import random
15+ from dataclasses import dataclass
16+ from typing import Optional , Tuple # noqa: UP035
17+
1818import numpy as np
19- from einops import rearrange
2019import torch
21- from torch import Tensor , nn
2220import torch .nn .functional as F
23-
2421from diffusers .configuration_utils import ConfigMixin , register_to_config
2522from diffusers .models .modeling_outputs import AutoencoderKLOutput
2623from diffusers .models .modeling_utils import ModelMixin
27- from diffusers .utils .torch_utils import randn_tensor
2824from diffusers .utils import BaseOutput
25+ from diffusers .utils .torch_utils import randn_tensor
26+ from einops import rearrange
27+ from torch import Tensor , nn
28+
2929
3030class DiagonalGaussianDistribution (object ):
3131 def __init__ (self , parameters : torch .Tensor , deterministic : bool = False ):
@@ -57,6 +57,7 @@ def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTens
5757 x = self .mean + self .std * sample
5858 return x
5959
60+
6061@dataclass
6162class DecoderOutput (BaseOutput ):
6263 sample : torch .FloatTensor
@@ -71,6 +72,7 @@ def forward_with_checkpointing(module, *inputs, use_checkpointing=False):
7172 def create_custom_forward (module ):
7273 def custom_forward (* inputs ):
7374 return module (* inputs )
75+
7476 return custom_forward
7577
7678 if use_checkpointing :
@@ -81,7 +83,7 @@ def custom_forward(*inputs):
8183
8284class Conv3d (nn .Conv3d ):
8385 """
84- Perform Conv3d on patches with numerical differences from nn.Conv3d within 1e-5.
86+ Perform Conv3d on patches with numerical differences from nn.Conv3d within 1e-5.
8587 Only symmetric padding is supported.
8688 """
8789
@@ -102,9 +104,9 @@ def forward(self, input):
102104 value = 0 ,
103105 )
104106 if i > 0 :
105- padded_chunk [:, :, :self .padding [0 ]] = chunks [i - 1 ][:, :, - self .padding [0 ]:]
107+ padded_chunk [:, :, : self .padding [0 ]] = chunks [i - 1 ][:, :, - self .padding [0 ] :]
106108 if i < len (chunks ) - 1 :
107- padded_chunk [:, :, - self .padding [0 ]:] = chunks [i + 1 ][:, :, :self .padding [0 ]]
109+ padded_chunk [:, :, - self .padding [0 ] :] = chunks [i + 1 ][:, :, : self .padding [0 ]]
108110 else :
109111 padded_chunk = chunks [i ]
110112 padded_chunks .append (padded_chunk )
@@ -120,7 +122,8 @@ def forward(self, input):
120122
121123
122124class AttnBlock (nn .Module ):
123- """ Attention with torch sdpa implementation. """
125+ """Attention with torch sdpa implementation."""
126+
124127 def __init__ (self , in_channels : int ):
125128 super ().__init__ ()
126129 self .in_channels = in_channels
@@ -178,6 +181,7 @@ def forward(self, x):
178181 x = self .nin_shortcut (x )
179182 return x + h
180183
184+
181185class DownsampleDCAE (nn .Module ):
182186 def __init__ (self , in_channels : int , out_channels : int , add_temporal_downsample : bool = True ):
183187 super ().__init__ ()
@@ -198,6 +202,7 @@ def forward(self, x: Tensor):
198202 shortcut = shortcut .view (B , h .shape [1 ], self .group_size , T , H , W ).mean (dim = 2 )
199203 return h + shortcut
200204
205+
201206class UpsampleDCAE (nn .Module ):
202207 def __init__ (self , in_channels : int , out_channels : int , add_temporal_upsample : bool = True ):
203208 super ().__init__ ()
@@ -215,10 +220,12 @@ def forward(self, x: Tensor):
215220 shortcut = rearrange (shortcut , "b (r1 r2 r3 c) f h w -> b c (f r1) (h r2) (w r3)" , r1 = r1 , r2 = 2 , r3 = 2 )
216221 return h + shortcut
217222
223+
218224class Encoder (nn .Module ):
219225 """
220226 The encoder network of AutoencoderKLConv3D.
221227 """
228+
222229 def __init__ (
223230 self ,
224231 in_channels : int ,
@@ -251,8 +258,9 @@ def __init__(
251258 down .block = block
252259
253260 add_spatial_downsample = bool (i_level < np .log2 (ffactor_spatial ))
254- add_temporal_downsample = (add_spatial_downsample and
255- bool (i_level >= np .log2 (ffactor_spatial // ffactor_temporal )))
261+ add_temporal_downsample = add_spatial_downsample and bool (
262+ i_level >= np .log2 (ffactor_spatial // ffactor_temporal )
263+ )
256264 if add_spatial_downsample or add_temporal_downsample :
257265 assert i_level < len (block_out_channels ) - 1
258266 block_out = block_out_channels [i_level + 1 ] if downsample_match_channel else block_in
@@ -280,7 +288,8 @@ def forward(self, x: Tensor) -> Tensor:
280288 for i_level in range (len (self .block_out_channels )):
281289 for i_block in range (self .num_res_blocks ):
282290 h = forward_with_checkpointing (
283- self .down [i_level ].block [i_block ], h , use_checkpointing = use_checkpointing )
291+ self .down [i_level ].block [i_block ], h , use_checkpointing = use_checkpointing
292+ )
284293 if hasattr (self .down [i_level ], "downsample" ):
285294 h = forward_with_checkpointing (self .down [i_level ].downsample , h , use_checkpointing = use_checkpointing )
286295
@@ -298,10 +307,12 @@ def forward(self, x: Tensor) -> Tensor:
298307 h += shortcut
299308 return h
300309
310+
301311class Decoder (nn .Module ):
302312 """
303313 The decoder network of AutoencoderKLConv3D.
304314 """
315+
305316 def __init__ (
306317 self ,
307318 z_channels : int ,
@@ -380,10 +391,12 @@ def forward(self, z: Tensor) -> Tensor:
380391 h = self .conv_out (h )
381392 return h
382393
394+
383395class AutoencoderKLConv3D (ModelMixin , ConfigMixin ):
384396 """
385397 Autoencoder model with KL-regularized latent space based on 3D convolutions.
386398 """
399+
387400 _supports_gradient_checkpointing = True
388401
389402 @register_to_config
@@ -402,8 +415,8 @@ def __init__(
402415 shift_factor : Optional [float ] = None ,
403416 downsample_match_channel : bool = True ,
404417 upsample_match_channel : bool = True ,
405- only_encoder : bool = False , # only build encoder for saving memory
406- only_decoder : bool = False , # only build decoder for saving memory
418+ only_encoder : bool = False , # only build encoder for saving memory
419+ only_decoder : bool = False , # only build decoder for saving memory
407420 ):
408421 super ().__init__ ()
409422 self .ffactor_spatial = ffactor_spatial
@@ -449,27 +462,29 @@ def __init__(
449462
450463 # use torch.compile for faster encode speed
451464 self .use_compile = False
452-
465+
453466 def _set_gradient_checkpointing (self , module , value = False ):
454467 if isinstance (module , (Encoder , Decoder )):
455468 module .gradient_checkpointing = value
456-
469+
457470 def blend_h (self , a : torch .Tensor , b : torch .Tensor , blend_extent : int ):
458471 blend_extent = min (a .shape [- 1 ], b .shape [- 1 ], blend_extent )
459472 for x in range (blend_extent ):
460- b [:, :, :, :, x ] = \
461- a [:, :, :, :, - blend_extent + x ] * (1 - x / blend_extent ) + b [:, :, :, :, x ] * (x / blend_extent )
473+ b [:, :, :, :, x ] = a [:, :, :, :, - blend_extent + x ] * (1 - x / blend_extent ) + b [:, :, :, :, x ] * (
474+ x / blend_extent
475+ )
462476 return b
463477
464478 def blend_v (self , a : torch .Tensor , b : torch .Tensor , blend_extent : int ):
465479 blend_extent = min (a .shape [- 2 ], b .shape [- 2 ], blend_extent )
466480 for y in range (blend_extent ):
467- b [:, :, :, y , :] = \
468- a [:, :, :, - blend_extent + y , :] * (1 - y / blend_extent ) + b [:, :, :, y , :] * (y / blend_extent )
481+ b [:, :, :, y , :] = a [:, :, :, - blend_extent + y , :] * (1 - y / blend_extent ) + b [:, :, :, y , :] * (
482+ y / blend_extent
483+ )
469484 return b
470485
471486 def spatial_tiled_decode (self , z : torch .Tensor ):
472- """ spatial tailing for frames """
487+ """spatial tailing for frames"""
473488 B , C , T , H , W = z .shape
474489 overlap_size = int (self .tile_latent_min_size * (1 - self .tile_overlap_factor )) # 8 * (1 - 0.25) = 6
475490 blend_extent = int (self .tile_sample_min_size * self .tile_overlap_factor ) # 256 * 0.25 = 64
@@ -479,7 +494,7 @@ def spatial_tiled_decode(self, z: torch.Tensor):
479494 for i in range (0 , H , overlap_size ):
480495 row = []
481496 for j in range (0 , W , overlap_size ):
482- tile = z [:, :, :, i : i + self .tile_latent_min_size , j : j + self .tile_latent_min_size ]
497+ tile = z [:, :, :, i : i + self .tile_latent_min_size , j : j + self .tile_latent_min_size ]
483498 decoded = self .decoder (tile )
484499 row .append (decoded )
485500 rows .append (row )
@@ -498,7 +513,7 @@ def spatial_tiled_decode(self, z: torch.Tensor):
498513 return dec
499514
500515 def temporal_tiled_decode (self , z : torch .Tensor ):
501- """ temporal tailing for frames """
516+ """temporal tailing for frames"""
502517 B , C , T , H , W = z .shape
503518 overlap_size = int (self .tile_latent_min_tsize * (1 - self .tile_overlap_factor )) # 8 * (1 - 0.25) = 6
504519 blend_extent = int (self .tile_sample_min_tsize * self .tile_overlap_factor ) # 64 * 0.25 = 16
@@ -507,9 +522,10 @@ def temporal_tiled_decode(self, z: torch.Tensor):
507522
508523 row = []
509524 for i in range (0 , T , overlap_size ):
510- tile = z [:, :, i : i + self .tile_latent_min_tsize , :, :]
525+ tile = z [:, :, i : i + self .tile_latent_min_tsize , :, :]
511526 if self .use_spatial_tiling and (
512- tile .shape [- 1 ] > self .tile_latent_min_size or tile .shape [- 2 ] > self .tile_latent_min_size ):
527+ tile .shape [- 1 ] > self .tile_latent_min_size or tile .shape [- 2 ] > self .tile_latent_min_size
528+ ):
513529 decoded = self .spatial_tiled_decode (tile )
514530 else :
515531 decoded = self .decoder (tile )
@@ -522,23 +538,27 @@ def temporal_tiled_decode(self, z: torch.Tensor):
522538 result_row .append (tile [:, :, :t_limit , :, :])
523539 dec = torch .cat (result_row , dim = - 3 )
524540 return dec
525-
541+
526542 def encode (self , x : Tensor , return_dict : bool = True ):
527543 """
528544 Encodes the input by passing through the encoder network.
529545 Support slicing and tiling for memory efficiency.
530546 """
547+
531548 def _encode (x ):
532549 if self .use_temporal_tiling and x .shape [- 3 ] > self .tile_sample_min_tsize :
533550 return self .temporal_tiled_encode (x )
534551 if self .use_spatial_tiling and (
535- x .shape [- 1 ] > self .tile_sample_min_size or x .shape [- 2 ] > self .tile_sample_min_size ):
552+ x .shape [- 1 ] > self .tile_sample_min_size or x .shape [- 2 ] > self .tile_sample_min_size
553+ ):
536554 return self .spatial_tiled_encode (x )
537555
538556 if self .use_compile :
557+
539558 @torch .compile
540559 def encoder (x ):
541560 return self .encoder (x )
561+
542562 return encoder (x )
543563 return self .encoder (x )
544564
@@ -567,17 +587,19 @@ def encoder(x):
567587 return (posterior ,)
568588
569589 return AutoencoderKLOutput (latent_dist = posterior )
570-
590+
571591 def decode (self , z : Tensor , return_dict : bool = True , generator = None ):
572592 """
573593 Decodes the input by passing through the decoder network.
574594 Support slicing and tiling for memory efficiency.
575595 """
596+
576597 def _decode (z ):
577598 if self .use_temporal_tiling and z .shape [- 3 ] > self .tile_latent_min_tsize :
578599 return self .temporal_tiled_decode (z )
579600 if self .use_spatial_tiling and (
580- z .shape [- 1 ] > self .tile_latent_min_size or z .shape [- 2 ] > self .tile_latent_min_size ):
601+ z .shape [- 1 ] > self .tile_latent_min_size or z .shape [- 2 ] > self .tile_latent_min_size
602+ ):
581603 return self .spatial_tiled_decode (z )
582604 return self .decoder (z )
583605
0 commit comments