Skip to content

Commit 3a9af5b

Browse files
update with test code
1 parent b6abbea commit 3a9af5b

File tree

11 files changed

+589
-424
lines changed

11 files changed

+589
-424
lines changed

sat/arguments.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def add_sampling_config_args(parser):
3636
group.add_argument("--input-dir", type=str, default=None)
3737
group.add_argument("--input-type", type=str, default="cli")
3838
group.add_argument("--input-file", type=str, default="input.txt")
39+
group.add_argument("--sampling-image-size", type=list, default=[768, 1360])
3940
group.add_argument("--final-size", type=int, default=2048)
4041
group.add_argument("--sdedit", action="store_true")
4142
group.add_argument("--grid-num-rows", type=int, default=1)

sat/configs/images.jpg

35.3 KB
Loading

sat/diffusion_video.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,12 @@ def decode_first_stage(self, z):
185185
kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
186186
else:
187187
kwargs = {}
188-
out = self.first_stage_model.decode(z[n * n_samples : (n + 1) * n_samples], **kwargs)
188+
frame = z.shape[2] * 4 - 3
189+
if frame <= 9:
190+
use_cp = False
191+
else:
192+
use_cp = True
193+
out = self.first_stage_model.decode(z[n * n_samples : (n + 1) * n_samples], use_cp=use_cp, **kwargs)
189194
all_out.append(out)
190195
out = torch.cat(all_out, dim=0)
191196
return out
@@ -218,6 +223,7 @@ def sample(
218223
shape: Union[None, Tuple, List] = None,
219224
prefix=None,
220225
concat_images=None,
226+
ofs=None,
221227
**kwargs,
222228
):
223229
randn = torch.randn(batch_size, *shape).to(torch.float32).to(self.device)
@@ -241,7 +247,7 @@ def sample(
241247
self.model, input, sigma, c, concat_images=concat_images, **addtional_model_inputs
242248
)
243249

244-
samples = self.sampler(denoiser, randn, cond, uc=uc, scale=scale, scale_emb=scale_emb)
250+
samples = self.sampler(denoiser, randn, cond, uc=uc, scale=scale, scale_emb=scale_emb, ofs=ofs)
245251
samples = samples.to(self.dtype)
246252
return samples
247253

0 commit comments

Comments
 (0)