@@ -101,8 +101,6 @@ def _gather(input_, dim):
101101 group = get_context_parallel_group ()
102102 cp_rank = get_context_parallel_rank ()
103103
104- # print('in _gather, cp_rank:', cp_rank, 'input_size:', input_.shape)
105-
106104 input_first_frame_ = input_ .transpose (0 , dim )[:1 ].transpose (0 , dim ).contiguous ()
107105 if cp_rank == 0 :
108106 input_ = input_ .transpose (0 , dim )[1 :].transpose (0 , dim ).contiguous ()
@@ -127,27 +125,21 @@ def _gather(input_, dim):
127125def _conv_split (input_ , dim , kernel_size ):
128126 cp_world_size = get_context_parallel_world_size ()
129127
130- # Bypass the function if context parallel is 1
131128 if cp_world_size == 1 :
132129 return input_
133130
134- # print('in _conv_split, cp_rank:', cp_rank, 'input_size:', input_.shape)
135-
136131 cp_rank = get_context_parallel_rank ()
137132
138133 dim_size = (input_ .size ()[dim ] - kernel_size ) // cp_world_size
139134
140135 if cp_rank == 0 :
141136 output = input_ .transpose (dim , 0 )[: dim_size + kernel_size ].transpose (dim , 0 )
142137 else :
143- # output = input_.transpose(dim, 0)[cp_rank * dim_size + 1:(cp_rank + 1) * dim_size + kernel_size].transpose(dim, 0)
144138 output = input_ .transpose (dim , 0 )[
145139 cp_rank * dim_size + kernel_size : (cp_rank + 1 ) * dim_size + kernel_size
146140 ].transpose (dim , 0 )
147141 output = output .contiguous ()
148142
149- # print('out _conv_split, cp_rank:', cp_rank, 'input_size:', output.shape)
150-
151143 return output
152144
153145
@@ -160,9 +152,6 @@ def _conv_gather(input_, dim, kernel_size):
160152
161153 group = get_context_parallel_group ()
162154 cp_rank = get_context_parallel_rank ()
163-
164- # print('in _conv_gather, cp_rank:', cp_rank, 'input_size:', input_.shape)
165-
166155 input_first_kernel_ = input_ .transpose (0 , dim )[:kernel_size ].transpose (0 , dim ).contiguous ()
167156 if cp_rank == 0 :
168157 input_ = input_ .transpose (0 , dim )[kernel_size :].transpose (0 , dim ).contiguous ()
@@ -255,17 +244,12 @@ def _fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding=Non
255244 if recv_rank % cp_world_size == cp_world_size - 1 :
256245 recv_rank += cp_world_size
257246
258- # req_send = torch.distributed.isend(input_[-kernel_size + 1:].contiguous(), send_rank, group=group)
259- # recv_buffer = torch.empty_like(input_[-kernel_size + 1:]).contiguous()
260- # req_recv = torch.distributed.recv(recv_buffer, recv_rank, group=group)
261- # req_recv.wait()
262247 recv_buffer = torch .empty_like (input_ [- kernel_size + 1 :]).contiguous ()
263248 if cp_rank < cp_world_size - 1 :
264249 req_send = torch .distributed .isend (input_ [- kernel_size + 1 :].contiguous (), send_rank , group = group )
265250 if cp_rank > 0 :
266251 req_recv = torch .distributed .irecv (recv_buffer , recv_rank , group = group )
267- # req_send = torch.distributed.isend(input_[-kernel_size + 1:].contiguous(), send_rank, group=group)
268- # req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group)
252+
269253
270254 if cp_rank == 0 :
271255 if cache_padding is not None :
@@ -421,7 +405,6 @@ def forward(self, input_):
421405
422406
423407def Normalize (in_channels , gather = False , ** kwargs ):
424- # same for 3D and 2D
425408 if gather :
426409 return ContextParallelGroupNorm (num_groups = 32 , num_channels = in_channels , eps = 1e-6 , affine = True )
427410 else :
@@ -468,8 +451,8 @@ def __init__(
468451 kernel_size = 1 ,
469452 )
470453
471- def forward (self , f , zq , clear_fake_cp_cache = True , fake_cp_rank0 = True ):
472- if f .shape [2 ] > 1 and get_context_parallel_rank () == 0 and fake_cp_rank0 :
454+ def forward (self , f , zq , clear_fake_cp_cache = True , fake_cp = True ):
455+ if f .shape [2 ] > 1 and get_context_parallel_rank () == 0 and fake_cp :
473456 f_first , f_rest = f [:, :, :1 ], f [:, :, 1 :]
474457 f_first_size , f_rest_size = f_first .shape [- 3 :], f_rest .shape [- 3 :]
475458 zq_first , zq_rest = zq [:, :, :1 ], zq [:, :, 1 :]
@@ -531,10 +514,11 @@ def __init__(
531514 self .conv = torch .nn .Conv2d (in_channels , in_channels , kernel_size = 3 , stride = 1 , padding = 1 )
532515 self .compress_time = compress_time
533516
534- def forward (self , x , fake_cp_rank0 = True ):
517+ def forward (self , x , fake_cp = True ):
535518 if self .compress_time and x .shape [2 ] > 1 :
536- if get_context_parallel_rank () == 0 and fake_cp_rank0 :
537- # print(x.shape)
519+ if get_context_parallel_rank () == 0 and fake_cp :
520+ print (x .shape )
521+ breakpoint ()
538522 # split first frame
539523 x_first , x_rest = x [:, :, 0 ], x [:, :, 1 :]
540524
@@ -545,8 +529,6 @@ def forward(self, x, fake_cp_rank0=True):
545529 torch .nn .functional .interpolate (split , scale_factor = 2.0 , mode = "nearest" ) for split in splits
546530 ]
547531 x_rest = torch .cat (interpolated_splits , dim = 1 )
548-
549- # x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0, mode="nearest")
550532 x = torch .cat ([x_first [:, :, None , :, :], x_rest ], dim = 2 )
551533 else :
552534 splits = torch .split (x , 32 , dim = 1 )
@@ -555,13 +537,10 @@ def forward(self, x, fake_cp_rank0=True):
555537 ]
556538 x = torch .cat (interpolated_splits , dim = 1 )
557539
558- # x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
559-
560540 else :
561541 # only interpolate 2D
562542 t = x .shape [2 ]
563543 x = rearrange (x , "b c t h w -> (b t) c h w" )
564- # x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
565544
566545 splits = torch .split (x , 32 , dim = 1 )
567546 interpolated_splits = [
@@ -590,12 +569,12 @@ def __init__(self, in_channels, with_conv, compress_time=False, out_channels=Non
590569 self .conv = torch .nn .Conv2d (in_channels , out_channels , kernel_size = 3 , stride = 2 , padding = 0 )
591570 self .compress_time = compress_time
592571
593- def forward (self , x , fake_cp_rank0 = True ):
572+ def forward (self , x , fake_cp = True ):
594573 if self .compress_time and x .shape [2 ] > 1 :
595574 h , w = x .shape [- 2 :]
596575 x = rearrange (x , "b c t h w -> (b h w) c t" )
597576
598- if get_context_parallel_rank () == 0 and fake_cp_rank0 :
577+ if get_context_parallel_rank () == 0 and fake_cp :
599578 # split first frame
600579 x_first , x_rest = x [..., 0 ], x [..., 1 :]
601580
@@ -693,32 +672,24 @@ def __init__(
693672 padding = 0 ,
694673 )
695674
696- def forward (self , x , temb , zq = None , clear_fake_cp_cache = True , fake_cp_rank0 = True ):
675+ def forward (self , x , temb , zq = None , clear_fake_cp_cache = True , fake_cp = True ):
697676 h = x
698677
699- # if isinstance(self.norm1, torch.nn.GroupNorm):
700- # h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
701678 if zq is not None :
702- h = self .norm1 (h , zq , clear_fake_cp_cache = clear_fake_cp_cache , fake_cp_rank0 = fake_cp_rank0 )
679+ h = self .norm1 (h , zq , clear_fake_cp_cache = clear_fake_cp_cache , fake_cp = fake_cp )
703680 else :
704681 h = self .norm1 (h )
705- # if isinstance(self.norm1, torch.nn.GroupNorm):
706- # h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
707682
708683 h = nonlinearity (h )
709684 h = self .conv1 (h , clear_cache = clear_fake_cp_cache )
710685
711686 if temb is not None :
712687 h = h + self .temb_proj (nonlinearity (temb ))[:, :, None , None , None ]
713688
714- # if isinstance(self.norm2, torch.nn.GroupNorm):
715- # h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
716689 if zq is not None :
717- h = self .norm2 (h , zq , clear_fake_cp_cache = clear_fake_cp_cache , fake_cp_rank0 = fake_cp_rank0 )
690+ h = self .norm2 (h , zq , clear_fake_cp_cache = clear_fake_cp_cache , fake_cp = fake_cp )
718691 else :
719692 h = self .norm2 (h )
720- # if isinstance(self.norm2, torch.nn.GroupNorm):
721- # h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
722693
723694 h = nonlinearity (h )
724695 h = self .dropout (h )
@@ -827,32 +798,33 @@ def __init__(
827798 kernel_size = 3 ,
828799 )
829800
830- def forward (self , x , clear_fake_cp_cache = True , fake_cp_rank0 = True ):
801+ def forward (self , x , use_cp = True ):
802+ global _USE_CP
803+ _USE_CP = use_cp
804+
831805 # timestep embedding
832806 temb = None
833807
834808 # downsampling
835- h = self .conv_in (x , clear_cache = clear_fake_cp_cache )
809+ hs = [ self .conv_in (x )]
836810 for i_level in range (self .num_resolutions ):
837811 for i_block in range (self .num_res_blocks ):
838- h = self .down [i_level ].block [i_block ](h , temb , clear_fake_cp_cache = clear_fake_cp_cache )
812+ h = self .down [i_level ].block [i_block ](hs [ - 1 ] , temb )
839813 if len (self .down [i_level ].attn ) > 0 :
840- print ("Attention not implemented" )
841814 h = self .down [i_level ].attn [i_block ](h )
815+ hs .append (h )
842816 if i_level != self .num_resolutions - 1 :
843- h = self .down [i_level ].downsample (h , fake_cp_rank0 = fake_cp_rank0 )
817+ hs . append ( self .down [i_level ].downsample (hs [ - 1 ]) )
844818
845819 # middle
846- h = self .mid .block_1 (h , temb , clear_fake_cp_cache = clear_fake_cp_cache )
847- h = self .mid .block_2 (h , temb , clear_fake_cp_cache = clear_fake_cp_cache )
820+ h = hs [- 1 ]
821+ h = self .mid .block_1 (h , temb )
822+ h = self .mid .block_2 (h , temb )
848823
849824 # end
850- # h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
851825 h = self .norm_out (h )
852- # h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
853-
854826 h = nonlinearity (h )
855- h = self .conv_out (h , clear_cache = clear_fake_cp_cache )
827+ h = self .conv_out (h )
856828
857829 return h
858830
@@ -895,11 +867,9 @@ def __init__(
895867 zq_ch = z_channels
896868
897869 # compute in_ch_mult, block_in and curr_res at lowest res
898- in_ch_mult = (1 ,) + tuple (ch_mult )
899870 block_in = ch * ch_mult [self .num_resolutions - 1 ]
900871 curr_res = resolution // 2 ** (self .num_resolutions - 1 )
901872 self .z_shape = (1 , z_channels , curr_res , curr_res )
902- print ("Working with z of shape {} = {} dimensions." .format (self .z_shape , np .prod (self .z_shape )))
903873
904874 self .conv_in = ContextParallelCausalConv3d (
905875 chan_in = z_channels ,
@@ -955,11 +925,6 @@ def __init__(
955925 up .block = block
956926 up .attn = attn
957927 if i_level != 0 :
958- # # Symmetrical enc-dec
959- if i_level <= self .temporal_compress_level :
960- up .upsample = Upsample3D (block_in , with_conv = resamp_with_conv , compress_time = True )
961- else :
962- up .upsample = Upsample3D (block_in , with_conv = resamp_with_conv , compress_time = False )
963928 if i_level < self .num_resolutions - self .temporal_compress_level :
964929 up .upsample = Upsample3D (block_in , with_conv = resamp_with_conv , compress_time = False )
965930 else :
@@ -974,7 +939,9 @@ def __init__(
974939 kernel_size = 3 ,
975940 )
976941
977- def forward (self , z , clear_fake_cp_cache = True , fake_cp_rank0 = True ):
942+ def forward (self , z , clear_fake_cp_cache = True , use_cp = True ):
943+ global _USE_CP
944+ _USE_CP = use_cp
978945 self .last_z_shape = z .shape
979946
980947 # timestep embedding
@@ -987,25 +954,25 @@ def forward(self, z, clear_fake_cp_cache=True, fake_cp_rank0=True):
987954 h = self .conv_in (z , clear_cache = clear_fake_cp_cache )
988955
989956 # middle
990- h = self .mid .block_1 (h , temb , zq , clear_fake_cp_cache = clear_fake_cp_cache , fake_cp_rank0 = fake_cp_rank0 )
991- h = self .mid .block_2 (h , temb , zq , clear_fake_cp_cache = clear_fake_cp_cache , fake_cp_rank0 = fake_cp_rank0 )
957+ h = self .mid .block_1 (h , temb , zq , clear_fake_cp_cache = clear_fake_cp_cache , fake_cp = use_cp )
958+ h = self .mid .block_2 (h , temb , zq , clear_fake_cp_cache = clear_fake_cp_cache , fake_cp = use_cp )
992959
993960 # upsampling
994961 for i_level in reversed (range (self .num_resolutions )):
995962 for i_block in range (self .num_res_blocks + 1 ):
996963 h = self .up [i_level ].block [i_block ](
997- h , temb , zq , clear_fake_cp_cache = clear_fake_cp_cache , fake_cp_rank0 = fake_cp_rank0
964+ h , temb , zq , clear_fake_cp_cache = clear_fake_cp_cache , fake_cp = use_cp
998965 )
999966 if len (self .up [i_level ].attn ) > 0 :
1000967 h = self .up [i_level ].attn [i_block ](h , zq )
1001968 if i_level != 0 :
1002- h = self .up [i_level ].upsample (h , fake_cp_rank0 = fake_cp_rank0 )
969+ h = self .up [i_level ].upsample (h , fake_cp = use_cp )
1003970
1004971 # end
1005972 if self .give_pre_end :
1006973 return h
1007974
1008- h = self .norm_out (h , zq , clear_fake_cp_cache = clear_fake_cp_cache , fake_cp_rank0 = fake_cp_rank0 )
975+ h = self .norm_out (h , zq , clear_fake_cp_cache = clear_fake_cp_cache , fake_cp = use_cp )
1009976 h = nonlinearity (h )
1010977 h = self .conv_out (h , clear_cache = clear_fake_cp_cache )
1011978
0 commit comments