Skip to content

Commit 4aebdb4

Browse files
Merge pull request #462 from DefTruth/main
[Parallel] Avoid OOM while batch size > 1
2 parents 3710a61 + bb69713 commit 4aebdb4

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

tools/parallel_inference/parallel_inference_xdit.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,14 @@ def main():
6161
)
6262
if args.enable_sequential_cpu_offload:
6363
pipe.enable_model_cpu_offload(gpu_id=local_rank)
64-
pipe.vae.enable_tiling()
6564
else:
6665
device = torch.device(f"cuda:{local_rank}")
6766
pipe = pipe.to(device)
6867

68+
# Always enable tiling and slicing to avoid VAE OOM while batch size > 1
69+
pipe.vae.enable_slicing()
70+
pipe.vae.enable_tiling()
71+
6972
torch.cuda.reset_peak_memory_stats()
7073
start_time = time.time()
7174

0 commit comments

Comments
 (0)