2828from .vae import DecoderOutput , DiagonalGaussianDistribution
2929
3030
31- class LTXCausalConv3d (nn .Module ):
31+ class LTXVideoCausalConv3d (nn .Module ):
3232 def __init__ (
3333 self ,
3434 in_channels : int ,
@@ -79,9 +79,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
7979 return hidden_states
8080
8181
82- class LTXResnetBlock3d (nn .Module ):
82+ class LTXVideoResnetBlock3d (nn .Module ):
8383 r"""
84- A 3D ResNet block used in the LTX model.
84+ A 3D ResNet block used in the LTXVideo model.
8585
8686 Args:
8787 in_channels (`int`):
@@ -117,21 +117,21 @@ def __init__(
117117 self .nonlinearity = get_activation (non_linearity )
118118
119119 self .norm1 = RMSNorm (in_channels , eps = 1e-8 , elementwise_affine = elementwise_affine )
120- self .conv1 = LTXCausalConv3d (
120+ self .conv1 = LTXVideoCausalConv3d (
121121 in_channels = in_channels , out_channels = out_channels , kernel_size = 3 , is_causal = is_causal
122122 )
123123
124124 self .norm2 = RMSNorm (out_channels , eps = 1e-8 , elementwise_affine = elementwise_affine )
125125 self .dropout = nn .Dropout (dropout )
126- self .conv2 = LTXCausalConv3d (
126+ self .conv2 = LTXVideoCausalConv3d (
127127 in_channels = out_channels , out_channels = out_channels , kernel_size = 3 , is_causal = is_causal
128128 )
129129
130130 self .norm3 = None
131131 self .conv_shortcut = None
132132 if in_channels != out_channels :
133133 self .norm3 = nn .LayerNorm (in_channels , eps = eps , elementwise_affine = True , bias = True )
134- self .conv_shortcut = LTXCausalConv3d (
134+ self .conv_shortcut = LTXVideoCausalConv3d (
135135 in_channels = in_channels , out_channels = out_channels , kernel_size = 1 , stride = 1 , is_causal = is_causal
136136 )
137137
@@ -157,7 +157,7 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
157157 return hidden_states
158158
159159
160- class LTXUpsampler3d (nn .Module ):
160+ class LTXVideoUpsampler3d (nn .Module ):
161161 def __init__ (
162162 self ,
163163 in_channels : int ,
@@ -170,7 +170,7 @@ def __init__(
170170
171171 out_channels = in_channels * stride [0 ] * stride [1 ] * stride [2 ]
172172
173- self .conv = LTXCausalConv3d (
173+ self .conv = LTXVideoCausalConv3d (
174174 in_channels = in_channels ,
175175 out_channels = out_channels ,
176176 kernel_size = 3 ,
@@ -191,9 +191,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
191191 return hidden_states
192192
193193
194- class LTXDownBlock3D (nn .Module ):
194+ class LTXVideoDownBlock3D (nn .Module ):
195195 r"""
196- Down block used in the LTX model.
196+ Down block used in the LTXVideo model.
197197
198198 Args:
199199 in_channels (`int`):
@@ -235,7 +235,7 @@ def __init__(
235235 resnets = []
236236 for _ in range (num_layers ):
237237 resnets .append (
238- LTXResnetBlock3d (
238+ LTXVideoResnetBlock3d (
239239 in_channels = in_channels ,
240240 out_channels = in_channels ,
241241 dropout = dropout ,
@@ -250,7 +250,7 @@ def __init__(
250250 if spatio_temporal_scale :
251251 self .downsamplers = nn .ModuleList (
252252 [
253- LTXCausalConv3d (
253+ LTXVideoCausalConv3d (
254254 in_channels = in_channels ,
255255 out_channels = in_channels ,
256256 kernel_size = 3 ,
@@ -262,7 +262,7 @@ def __init__(
262262
263263 self .conv_out = None
264264 if in_channels != out_channels :
265- self .conv_out = LTXResnetBlock3d (
265+ self .conv_out = LTXVideoResnetBlock3d (
266266 in_channels = in_channels ,
267267 out_channels = out_channels ,
268268 dropout = dropout ,
@@ -300,9 +300,9 @@ def create_forward(*inputs):
300300
301301
302302# Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoMidBlock3d
303- class LTXMidBlock3d (nn .Module ):
303+ class LTXVideoMidBlock3d (nn .Module ):
304304 r"""
305- A middle block used in the LTX model.
305+ A middle block used in the LTXVideo model.
306306
307307 Args:
308308 in_channels (`int`):
@@ -335,7 +335,7 @@ def __init__(
335335 resnets = []
336336 for _ in range (num_layers ):
337337 resnets .append (
338- LTXResnetBlock3d (
338+ LTXVideoResnetBlock3d (
339339 in_channels = in_channels ,
340340 out_channels = in_channels ,
341341 dropout = dropout ,
@@ -367,9 +367,9 @@ def create_forward(*inputs):
367367 return hidden_states
368368
369369
370- class LTXUpBlock3d (nn .Module ):
370+ class LTXVideoUpBlock3d (nn .Module ):
371371 r"""
372- Up block used in the LTX model.
372+ Up block used in the LTXVideo model.
373373
374374 Args:
375375 in_channels (`int`):
@@ -410,7 +410,7 @@ def __init__(
410410
411411 self .conv_in = None
412412 if in_channels != out_channels :
413- self .conv_in = LTXResnetBlock3d (
413+ self .conv_in = LTXVideoResnetBlock3d (
414414 in_channels = in_channels ,
415415 out_channels = out_channels ,
416416 dropout = dropout ,
@@ -421,12 +421,12 @@ def __init__(
421421
422422 self .upsamplers = None
423423 if spatio_temporal_scale :
424- self .upsamplers = nn .ModuleList ([LTXUpsampler3d (out_channels , stride = (2 , 2 , 2 ), is_causal = is_causal )])
424+ self .upsamplers = nn .ModuleList ([LTXVideoUpsampler3d (out_channels , stride = (2 , 2 , 2 ), is_causal = is_causal )])
425425
426426 resnets = []
427427 for _ in range (num_layers ):
428428 resnets .append (
429- LTXResnetBlock3d (
429+ LTXVideoResnetBlock3d (
430430 in_channels = out_channels ,
431431 out_channels = out_channels ,
432432 dropout = dropout ,
@@ -463,9 +463,9 @@ def create_forward(*inputs):
463463 return hidden_states
464464
465465
466- class LTXEncoder3d (nn .Module ):
466+ class LTXVideoEncoder3d (nn .Module ):
467467 r"""
468- The `LTXEncoder3D ` layer of a variational autoencoder that encodes input video samples to its latent
468+ The `LTXVideoEncoder3d ` layer of a variational autoencoder that encodes input video samples to its latent
469469 representation.
470470
471471 Args:
@@ -509,7 +509,7 @@ def __init__(
509509
510510 output_channel = block_out_channels [0 ]
511511
512- self .conv_in = LTXCausalConv3d (
512+ self .conv_in = LTXVideoCausalConv3d (
513513 in_channels = self .in_channels ,
514514 out_channels = output_channel ,
515515 kernel_size = 3 ,
@@ -524,7 +524,7 @@ def __init__(
524524 input_channel = output_channel
525525 output_channel = block_out_channels [i + 1 ] if i + 1 < num_block_out_channels else block_out_channels [i ]
526526
527- down_block = LTXDownBlock3D (
527+ down_block = LTXVideoDownBlock3D (
528528 in_channels = input_channel ,
529529 out_channels = output_channel ,
530530 num_layers = layers_per_block [i ],
@@ -536,7 +536,7 @@ def __init__(
536536 self .down_blocks .append (down_block )
537537
538538 # mid block
539- self .mid_block = LTXMidBlock3d (
539+ self .mid_block = LTXVideoMidBlock3d (
540540 in_channels = output_channel ,
541541 num_layers = layers_per_block [- 1 ],
542542 resnet_eps = resnet_norm_eps ,
@@ -546,14 +546,14 @@ def __init__(
546546 # out
547547 self .norm_out = RMSNorm (out_channels , eps = 1e-8 , elementwise_affine = False )
548548 self .conv_act = nn .SiLU ()
549- self .conv_out = LTXCausalConv3d (
549+ self .conv_out = LTXVideoCausalConv3d (
550550 in_channels = output_channel , out_channels = out_channels + 1 , kernel_size = 3 , stride = 1 , is_causal = is_causal
551551 )
552552
553553 self .gradient_checkpointing = False
554554
555555 def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
556- r"""The forward method of the `LTXEncoder3D ` class."""
556+ r"""The forward method of the `LTXVideoEncoder3d ` class."""
557557
558558 p = self .patch_size
559559 p_t = self .patch_size_t
@@ -599,9 +599,10 @@ def create_forward(*inputs):
599599 return hidden_states
600600
601601
602- class LTXDecoder3d (nn .Module ):
602+ class LTXVideoDecoder3d (nn .Module ):
603603 r"""
604- The `LTXDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output sample.
604+ The `LTXVideoDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output
605+ sample.
605606
606607 Args:
607608 in_channels (`int`, defaults to 128):
@@ -647,11 +648,11 @@ def __init__(
647648 layers_per_block = tuple (reversed (layers_per_block ))
648649 output_channel = block_out_channels [0 ]
649650
650- self .conv_in = LTXCausalConv3d (
651+ self .conv_in = LTXVideoCausalConv3d (
651652 in_channels = in_channels , out_channels = output_channel , kernel_size = 3 , stride = 1 , is_causal = is_causal
652653 )
653654
654- self .mid_block = LTXMidBlock3d (
655+ self .mid_block = LTXVideoMidBlock3d (
655656 in_channels = output_channel , num_layers = layers_per_block [0 ], resnet_eps = resnet_norm_eps , is_causal = is_causal
656657 )
657658
@@ -662,7 +663,7 @@ def __init__(
662663 input_channel = output_channel
663664 output_channel = block_out_channels [i ]
664665
665- up_block = LTXUpBlock3d (
666+ up_block = LTXVideoUpBlock3d (
666667 in_channels = input_channel ,
667668 out_channels = output_channel ,
668669 num_layers = layers_per_block [i + 1 ],
@@ -676,7 +677,7 @@ def __init__(
676677 # out
677678 self .norm_out = RMSNorm (out_channels , eps = 1e-8 , elementwise_affine = False )
678679 self .conv_act = nn .SiLU ()
679- self .conv_out = LTXCausalConv3d (
680+ self .conv_out = LTXVideoCausalConv3d (
680681 in_channels = output_channel , out_channels = self .out_channels , kernel_size = 3 , stride = 1 , is_causal = is_causal
681682 )
682683
@@ -777,7 +778,7 @@ def __init__(
777778 ) -> None :
778779 super ().__init__ ()
779780
780- self .encoder = LTXEncoder3d (
781+ self .encoder = LTXVideoEncoder3d (
781782 in_channels = in_channels ,
782783 out_channels = latent_channels ,
783784 block_out_channels = block_out_channels ,
@@ -788,7 +789,7 @@ def __init__(
788789 resnet_norm_eps = resnet_norm_eps ,
789790 is_causal = encoder_causal ,
790791 )
791- self .decoder = LTXDecoder3d (
792+ self .decoder = LTXVideoDecoder3d (
792793 in_channels = latent_channels ,
793794 out_channels = out_channels ,
794795 block_out_channels = block_out_channels ,
@@ -837,7 +838,7 @@ def __init__(
837838 self .tile_sample_stride_width = 448
838839
839840 def _set_gradient_checkpointing (self , module , value = False ):
840- if isinstance (module , (LTXEncoder3d , LTXDecoder3d )):
841+ if isinstance (module , (LTXVideoEncoder3d , LTXVideoDecoder3d )):
841842 module .gradient_checkpointing = value
842843
843844 def enable_tiling (
0 commit comments