Skip to content

Commit e7bcecf

Browse files
remove wrong fake_cp
1 parent d8ee013 commit e7bcecf

File tree

2 files changed

+34
-67
lines changed

2 files changed

+34
-67
lines changed

sat/diffusion_video.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,13 +192,13 @@ def decode_first_stage(self, z):
192192
for i in range(fake_cp_size):
193193
end_frame = start_frame + latent_time // fake_cp_size + (1 if i < latent_time % fake_cp_size else 0)
194194

195-
fake_cp_rank0 = True if i == 0 else False
195+
use_cp = True if i == 0 else False
196196
clear_fake_cp_cache = True if i == fake_cp_size - 1 else False
197197
with torch.no_grad():
198198
recon = self.first_stage_model.decode(
199199
z_now[:, :, start_frame:end_frame].contiguous(),
200200
clear_fake_cp_cache=clear_fake_cp_cache,
201-
fake_cp_rank0=fake_cp_rank0,
201+
use_cp=use_cp,
202202
)
203203
recons.append(recon)
204204
start_frame = end_frame

sat/vae_modules/cp_enc_dec.py

Lines changed: 32 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,6 @@ def _gather(input_, dim):
101101
group = get_context_parallel_group()
102102
cp_rank = get_context_parallel_rank()
103103

104-
# print('in _gather, cp_rank:', cp_rank, 'input_size:', input_.shape)
105-
106104
input_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous()
107105
if cp_rank == 0:
108106
input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous()
@@ -127,27 +125,21 @@ def _gather(input_, dim):
127125
def _conv_split(input_, dim, kernel_size):
128126
cp_world_size = get_context_parallel_world_size()
129127

130-
# Bypass the function if context parallel is 1
131128
if cp_world_size == 1:
132129
return input_
133130

134-
# print('in _conv_split, cp_rank:', cp_rank, 'input_size:', input_.shape)
135-
136131
cp_rank = get_context_parallel_rank()
137132

138133
dim_size = (input_.size()[dim] - kernel_size) // cp_world_size
139134

140135
if cp_rank == 0:
141136
output = input_.transpose(dim, 0)[: dim_size + kernel_size].transpose(dim, 0)
142137
else:
143-
# output = input_.transpose(dim, 0)[cp_rank * dim_size + 1:(cp_rank + 1) * dim_size + kernel_size].transpose(dim, 0)
144138
output = input_.transpose(dim, 0)[
145139
cp_rank * dim_size + kernel_size : (cp_rank + 1) * dim_size + kernel_size
146140
].transpose(dim, 0)
147141
output = output.contiguous()
148142

149-
# print('out _conv_split, cp_rank:', cp_rank, 'input_size:', output.shape)
150-
151143
return output
152144

153145

@@ -160,9 +152,6 @@ def _conv_gather(input_, dim, kernel_size):
160152

161153
group = get_context_parallel_group()
162154
cp_rank = get_context_parallel_rank()
163-
164-
# print('in _conv_gather, cp_rank:', cp_rank, 'input_size:', input_.shape)
165-
166155
input_first_kernel_ = input_.transpose(0, dim)[:kernel_size].transpose(0, dim).contiguous()
167156
if cp_rank == 0:
168157
input_ = input_.transpose(0, dim)[kernel_size:].transpose(0, dim).contiguous()
@@ -255,17 +244,12 @@ def _fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding=Non
255244
if recv_rank % cp_world_size == cp_world_size - 1:
256245
recv_rank += cp_world_size
257246

258-
# req_send = torch.distributed.isend(input_[-kernel_size + 1:].contiguous(), send_rank, group=group)
259-
# recv_buffer = torch.empty_like(input_[-kernel_size + 1:]).contiguous()
260-
# req_recv = torch.distributed.recv(recv_buffer, recv_rank, group=group)
261-
# req_recv.wait()
262247
recv_buffer = torch.empty_like(input_[-kernel_size + 1 :]).contiguous()
263248
if cp_rank < cp_world_size - 1:
264249
req_send = torch.distributed.isend(input_[-kernel_size + 1 :].contiguous(), send_rank, group=group)
265250
if cp_rank > 0:
266251
req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group)
267-
# req_send = torch.distributed.isend(input_[-kernel_size + 1:].contiguous(), send_rank, group=group)
268-
# req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group)
252+
269253

270254
if cp_rank == 0:
271255
if cache_padding is not None:
@@ -421,7 +405,6 @@ def forward(self, input_):
421405

422406

423407
def Normalize(in_channels, gather=False, **kwargs):
424-
# same for 3D and 2D
425408
if gather:
426409
return ContextParallelGroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
427410
else:
@@ -468,8 +451,8 @@ def __init__(
468451
kernel_size=1,
469452
)
470453

471-
def forward(self, f, zq, clear_fake_cp_cache=True, fake_cp_rank0=True):
472-
if f.shape[2] > 1 and get_context_parallel_rank() == 0 and fake_cp_rank0:
454+
def forward(self, f, zq, clear_fake_cp_cache=True, fake_cp=True):
455+
if f.shape[2] > 1 and get_context_parallel_rank() == 0 and fake_cp:
473456
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
474457
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
475458
zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:]
@@ -531,10 +514,11 @@ def __init__(
531514
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
532515
self.compress_time = compress_time
533516

534-
def forward(self, x, fake_cp_rank0=True):
517+
def forward(self, x, fake_cp=True):
535518
if self.compress_time and x.shape[2] > 1:
536-
if get_context_parallel_rank() == 0 and fake_cp_rank0:
537-
# print(x.shape)
519+
if get_context_parallel_rank() == 0 and fake_cp:
520+
print(x.shape)
521+
breakpoint()
538522
# split first frame
539523
x_first, x_rest = x[:, :, 0], x[:, :, 1:]
540524

@@ -545,8 +529,6 @@ def forward(self, x, fake_cp_rank0=True):
545529
torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits
546530
]
547531
x_rest = torch.cat(interpolated_splits, dim=1)
548-
549-
# x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0, mode="nearest")
550532
x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2)
551533
else:
552534
splits = torch.split(x, 32, dim=1)
@@ -555,13 +537,10 @@ def forward(self, x, fake_cp_rank0=True):
555537
]
556538
x = torch.cat(interpolated_splits, dim=1)
557539

558-
# x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
559-
560540
else:
561541
# only interpolate 2D
562542
t = x.shape[2]
563543
x = rearrange(x, "b c t h w -> (b t) c h w")
564-
# x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
565544

566545
splits = torch.split(x, 32, dim=1)
567546
interpolated_splits = [
@@ -590,12 +569,12 @@ def __init__(self, in_channels, with_conv, compress_time=False, out_channels=Non
590569
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
591570
self.compress_time = compress_time
592571

593-
def forward(self, x, fake_cp_rank0=True):
572+
def forward(self, x, fake_cp=True):
594573
if self.compress_time and x.shape[2] > 1:
595574
h, w = x.shape[-2:]
596575
x = rearrange(x, "b c t h w -> (b h w) c t")
597576

598-
if get_context_parallel_rank() == 0 and fake_cp_rank0:
577+
if get_context_parallel_rank() == 0 and fake_cp:
599578
# split first frame
600579
x_first, x_rest = x[..., 0], x[..., 1:]
601580

@@ -693,32 +672,24 @@ def __init__(
693672
padding=0,
694673
)
695674

696-
def forward(self, x, temb, zq=None, clear_fake_cp_cache=True, fake_cp_rank0=True):
675+
def forward(self, x, temb, zq=None, clear_fake_cp_cache=True, fake_cp=True):
697676
h = x
698677

699-
# if isinstance(self.norm1, torch.nn.GroupNorm):
700-
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
701678
if zq is not None:
702-
h = self.norm1(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
679+
h = self.norm1(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=fake_cp)
703680
else:
704681
h = self.norm1(h)
705-
# if isinstance(self.norm1, torch.nn.GroupNorm):
706-
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
707682

708683
h = nonlinearity(h)
709684
h = self.conv1(h, clear_cache=clear_fake_cp_cache)
710685

711686
if temb is not None:
712687
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None]
713688

714-
# if isinstance(self.norm2, torch.nn.GroupNorm):
715-
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
716689
if zq is not None:
717-
h = self.norm2(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
690+
h = self.norm2(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=fake_cp)
718691
else:
719692
h = self.norm2(h)
720-
# if isinstance(self.norm2, torch.nn.GroupNorm):
721-
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
722693

723694
h = nonlinearity(h)
724695
h = self.dropout(h)
@@ -827,32 +798,33 @@ def __init__(
827798
kernel_size=3,
828799
)
829800

830-
def forward(self, x, clear_fake_cp_cache=True, fake_cp_rank0=True):
801+
def forward(self, x, use_cp=True):
802+
global _USE_CP
803+
_USE_CP = use_cp
804+
831805
# timestep embedding
832806
temb = None
833807

834808
# downsampling
835-
h = self.conv_in(x, clear_cache=clear_fake_cp_cache)
809+
hs = [self.conv_in(x)]
836810
for i_level in range(self.num_resolutions):
837811
for i_block in range(self.num_res_blocks):
838-
h = self.down[i_level].block[i_block](h, temb, clear_fake_cp_cache=clear_fake_cp_cache)
812+
h = self.down[i_level].block[i_block](hs[-1], temb)
839813
if len(self.down[i_level].attn) > 0:
840-
print("Attention not implemented")
841814
h = self.down[i_level].attn[i_block](h)
815+
hs.append(h)
842816
if i_level != self.num_resolutions - 1:
843-
h = self.down[i_level].downsample(h, fake_cp_rank0=fake_cp_rank0)
817+
hs.append(self.down[i_level].downsample(hs[-1]))
844818

845819
# middle
846-
h = self.mid.block_1(h, temb, clear_fake_cp_cache=clear_fake_cp_cache)
847-
h = self.mid.block_2(h, temb, clear_fake_cp_cache=clear_fake_cp_cache)
820+
h = hs[-1]
821+
h = self.mid.block_1(h, temb)
822+
h = self.mid.block_2(h, temb)
848823

849824
# end
850-
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
851825
h = self.norm_out(h)
852-
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
853-
854826
h = nonlinearity(h)
855-
h = self.conv_out(h, clear_cache=clear_fake_cp_cache)
827+
h = self.conv_out(h)
856828

857829
return h
858830

@@ -895,11 +867,9 @@ def __init__(
895867
zq_ch = z_channels
896868

897869
# compute in_ch_mult, block_in and curr_res at lowest res
898-
in_ch_mult = (1,) + tuple(ch_mult)
899870
block_in = ch * ch_mult[self.num_resolutions - 1]
900871
curr_res = resolution // 2 ** (self.num_resolutions - 1)
901872
self.z_shape = (1, z_channels, curr_res, curr_res)
902-
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
903873

904874
self.conv_in = ContextParallelCausalConv3d(
905875
chan_in=z_channels,
@@ -955,11 +925,6 @@ def __init__(
955925
up.block = block
956926
up.attn = attn
957927
if i_level != 0:
958-
# # Symmetrical enc-dec
959-
if i_level <= self.temporal_compress_level:
960-
up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=True)
961-
else:
962-
up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False)
963928
if i_level < self.num_resolutions - self.temporal_compress_level:
964929
up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False)
965930
else:
@@ -974,7 +939,9 @@ def __init__(
974939
kernel_size=3,
975940
)
976941

977-
def forward(self, z, clear_fake_cp_cache=True, fake_cp_rank0=True):
942+
def forward(self, z, clear_fake_cp_cache=True, use_cp=True):
943+
global _USE_CP
944+
_USE_CP = use_cp
978945
self.last_z_shape = z.shape
979946

980947
# timestep embedding
@@ -987,25 +954,25 @@ def forward(self, z, clear_fake_cp_cache=True, fake_cp_rank0=True):
987954
h = self.conv_in(z, clear_cache=clear_fake_cp_cache)
988955

989956
# middle
990-
h = self.mid.block_1(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
991-
h = self.mid.block_2(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
957+
h = self.mid.block_1(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=use_cp)
958+
h = self.mid.block_2(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=use_cp)
992959

993960
# upsampling
994961
for i_level in reversed(range(self.num_resolutions)):
995962
for i_block in range(self.num_res_blocks + 1):
996963
h = self.up[i_level].block[i_block](
997-
h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0
964+
h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=use_cp
998965
)
999966
if len(self.up[i_level].attn) > 0:
1000967
h = self.up[i_level].attn[i_block](h, zq)
1001968
if i_level != 0:
1002-
h = self.up[i_level].upsample(h, fake_cp_rank0=fake_cp_rank0)
969+
h = self.up[i_level].upsample(h, fake_cp=use_cp)
1003970

1004971
# end
1005972
if self.give_pre_end:
1006973
return h
1007974

1008-
h = self.norm_out(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
975+
h = self.norm_out(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=use_cp)
1009976
h = nonlinearity(h)
1010977
h = self.conv_out(h, clear_cache=clear_fake_cp_cache)
1011978

0 commit comments

Comments
 (0)