Skip to content

Commit 4a3035d

Browse files
update 1105 sst test code with fake cp
1 parent 3a9af5b commit 4a3035d

File tree

5 files changed

+137
-97
lines changed

5 files changed

+137
-97
lines changed

sat/diffusion_video.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -179,19 +179,31 @@ def decode_first_stage(self, z):
179179
n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
180180
n_rounds = math.ceil(z.shape[0] / n_samples)
181181
all_out = []
182-
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
183-
for n in range(n_rounds):
184-
if isinstance(self.first_stage_model.decoder, VideoDecoder):
185-
kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
186-
else:
187-
kwargs = {}
188-
frame = z.shape[2] * 4 - 3
189-
if frame <= 9:
190-
use_cp = False
191-
else:
192-
use_cp = True
193-
out = self.first_stage_model.decode(z[n * n_samples : (n + 1) * n_samples], use_cp=use_cp, **kwargs)
194-
all_out.append(out)
182+
for n in range(n_rounds):
183+
z_now = z[n * n_samples : (n + 1) * n_samples, :, 1:]
184+
latent_time = z_now.shape[2] # check the time latent
185+
temporal_compress_times = 4
186+
187+
fake_cp_size = min(10, latent_time // 2)
188+
start_frame = 0
189+
190+
recons = []
191+
start_frame = 0
192+
for i in range(fake_cp_size):
193+
end_frame = start_frame + latent_time // fake_cp_size + (1 if i < latent_time % fake_cp_size else 0)
194+
195+
fake_cp_rank0 = True if i == 0 else False
196+
clear_fake_cp_cache = True if i == fake_cp_size - 1 else False
197+
with torch.no_grad():
198+
recon = self.first_stage_model.decode(
199+
z_now[:, :, start_frame:end_frame].contiguous(),
200+
clear_fake_cp_cache=clear_fake_cp_cache,
201+
fake_cp_rank0=fake_cp_rank0,
202+
)
203+
recons.append(recon)
204+
start_frame = end_frame
205+
recons = torch.cat(recons, dim=2)
206+
all_out.append(recons)
195207
out = torch.cat(all_out, dim=0)
196208
return out
197209

sat/dit_video_concat.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,6 @@ def __init__(
654654
time_interpolation=1.0,
655655
use_SwiGLU=False,
656656
use_RMSNorm=False,
657-
cfg_embed_dim=None,
658657
ofs_embed_dim=None,
659658
**kwargs,
660659
):
@@ -669,7 +668,6 @@ def __init__(
669668
self.hidden_size = hidden_size
670669
self.model_channels = hidden_size
671670
self.time_embed_dim = time_embed_dim if time_embed_dim is not None else hidden_size
672-
self.cfg_embed_dim = cfg_embed_dim
673671
self.ofs_embed_dim = ofs_embed_dim
674672
self.num_classes = num_classes
675673
self.adm_in_channels = adm_in_channels
@@ -728,13 +726,6 @@ def _build_modules(self, module_configs):
728726
linear(self.ofs_embed_dim, self.ofs_embed_dim),
729727
)
730728

731-
if self.cfg_embed_dim is not None:
732-
self.cfg_embed = nn.Sequential(
733-
linear(self.cfg_embed_dim, self.cfg_embed_dim),
734-
nn.SiLU(),
735-
linear(self.cfg_embed_dim, self.cfg_embed_dim),
736-
)
737-
738729
if self.num_classes is not None:
739730
if isinstance(self.num_classes, int):
740731
self.label_emb = nn.Embedding(self.num_classes, time_embed_dim)
@@ -848,14 +839,6 @@ def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
848839
ofs_emb = timestep_embedding(kwargs["ofs"], self.ofs_embed_dim, repeat_only=False, dtype=self.dtype)
849840
ofs_emb = self.ofs_embed(ofs_emb)
850841
emb = emb + ofs_emb
851-
if self.cfg_embed_dim is not None:
852-
cfg_emb = kwargs["scale_emb"]
853-
cfg_emb = self.cfg_embed(cfg_emb)
854-
emb = emb + cfg_emb
855-
856-
if "ofs" in kwargs.keys():
857-
ofs_emb = timestep_embedding(kwargs["ofs"], self.ofs_embed_dim, repeat_only=False, dtype=self.dtype)
858-
ofs_emb = self.ofs_embed(ofs_emb)
859842

860843
kwargs["seq_length"] = t * h * w // reduce(mul, self.patch_size)
861844
kwargs["images"] = x

sat/inference.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
44

55
environs="WORLD_SIZE=1 RANK=0 LOCAL_RANK=0 LOCAL_WORLD_SIZE=1"
66

7-
run_cmd="$environs python sample_video.py --base configs/cogvideox_5b.yaml configs/inference.yaml --seed $RANDOM"
7+
run_cmd="$environs python sample_video.py --base configs/test_cogvideox_5b.yaml configs/test_inference.yaml --seed $RANDOM"
88

99
echo ${run_cmd}
1010
eval ${run_cmd}

sat/sample_video.py

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -135,14 +135,14 @@ def sampling_main(args, model_cls):
135135
sample_func = model.sample
136136
num_samples = [1]
137137
force_uc_zero_embeddings = ["txt"]
138-
138+
T, C = args.sampling_num_frames, args.latent_channels
139139
with torch.no_grad():
140140
for text, cnt in tqdm(data_iter):
141141
if args.image2video:
142142
# use with input image shape
143-
text, image_path = text.split('@@')
143+
text, image_path = text.split("@@")
144144
assert os.path.exists(image_path), image_path
145-
image = Image.open(image_path).convert('RGB')
145+
image = Image.open(image_path).convert("RGB")
146146
(img_W, img_H) = image.size
147147

148148
def nearest_multiple_of_16(n):
@@ -163,7 +163,7 @@ def nearest_multiple_of_16(n):
163163
chained_trainsforms.append(TT.Resize(size=[int(H * 8), int(W * 8)], interpolation=1))
164164
chained_trainsforms.append(TT.ToTensor())
165165
transform = TT.Compose(chained_trainsforms)
166-
image = transform(image).unsqueeze(0).to('cuda')
166+
image = transform(image).unsqueeze(0).to("cuda")
167167
image = image * 2.0 - 1.0
168168
image = image.unsqueeze(2).to(torch.bfloat16)
169169
image = model.encode_first_stage(image, None)
@@ -173,7 +173,7 @@ def nearest_multiple_of_16(n):
173173
image = torch.concat([image, torch.zeros(pad_shape).to(image.device).to(image.dtype)], dim=1)
174174
else:
175175
image_size = args.sampling_image_size
176-
T, H, W, C = args.sampling_num_frames, image_size[0], image_size[1], args.latent_channels
176+
H, W = image_size[0], image_size[1]
177177
F = 8 # 8x downsampled
178178
image = None
179179

@@ -183,11 +183,7 @@ def nearest_multiple_of_16(n):
183183
src = global_rank * mp_size
184184
torch.distributed.broadcast_object_list(text_cast, src=src, group=mpu.get_model_parallel_group())
185185
text = text_cast[0]
186-
value_dict = {
187-
'prompt': text,
188-
'negative_prompt': '',
189-
'num_frames': torch.tensor(T).unsqueeze(0)
190-
}
186+
value_dict = {"prompt": text, "negative_prompt": "", "num_frames": torch.tensor(T).unsqueeze(0)}
191187

192188
batch, batch_uc = get_batch(
193189
get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples
@@ -216,19 +212,15 @@ def nearest_multiple_of_16(n):
216212
for index in range(args.batch_size):
217213
if args.image2video:
218214
samples_z = sample_func(
219-
c,
220-
uc=uc,
221-
batch_size=1,
222-
shape=(T, C, H, W),
223-
ofs=torch.tensor([2.0]).to('cuda')
215+
c, uc=uc, batch_size=1, shape=(T, C, H, W), ofs=torch.tensor([2.0]).to("cuda")
224216
)
225217
else:
226218
samples_z = sample_func(
227219
c,
228220
uc=uc,
229221
batch_size=1,
230222
shape=(T, C, H // F, W // F),
231-
).to('cuda')
223+
).to("cuda")
232224

233225
samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous()
234226
if args.only_save_latents:
@@ -250,11 +242,12 @@ def nearest_multiple_of_16(n):
250242
if mpu.get_model_parallel_rank() == 0:
251243
save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps)
252244

253-
if __name__ == '__main__':
254-
if 'OMPI_COMM_WORLD_LOCAL_RANK' in os.environ:
255-
os.environ['LOCAL_RANK'] = os.environ['OMPI_COMM_WORLD_LOCAL_RANK']
256-
os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE']
257-
os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK']
245+
246+
if __name__ == "__main__":
247+
if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ:
248+
os.environ["LOCAL_RANK"] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
249+
os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
250+
os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
258251
py_parser = argparse.ArgumentParser(add_help=False)
259252
known, args_list = py_parser.parse_known_args()
260253

0 commit comments

Comments
 (0)