Skip to content

Commit d1e45fb

Browse files
Merge pull request #468 from THUDM/main
merge
2 parents 0ae12e3 + 4aebdb4 commit d1e45fb

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

tools/parallel_inference/parallel_inference_xdit.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,14 @@ def main():
6161
)
6262
if args.enable_sequential_cpu_offload:
6363
pipe.enable_model_cpu_offload(gpu_id=local_rank)
64-
pipe.vae.enable_tiling()
6564
else:
6665
device = torch.device(f"cuda:{local_rank}")
6766
pipe = pipe.to(device)
6867

68+
# Always enable tiling and slicing to avoid VAE OOM while batch size > 1
69+
pipe.vae.enable_slicing()
70+
pipe.vae.enable_tiling()
71+
6972
torch.cuda.reset_peak_memory_stats()
7073
start_time = time.time()
7174

0 commit comments

Comments
 (0)