We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents 0ae12e3 + 4aebdb4 commit d1e45fbCopy full SHA for d1e45fb
tools/parallel_inference/parallel_inference_xdit.py
@@ -61,11 +61,14 @@ def main():
61
)
62
if args.enable_sequential_cpu_offload:
63
pipe.enable_model_cpu_offload(gpu_id=local_rank)
64
- pipe.vae.enable_tiling()
65
else:
66
device = torch.device(f"cuda:{local_rank}")
67
pipe = pipe.to(device)
68
+ # Always enable tiling and slicing to avoid VAE OOM while batch size > 1
69
+ pipe.vae.enable_slicing()
70
+ pipe.vae.enable_tiling()
71
+
72
torch.cuda.reset_peak_memory_stats()
73
start_time = time.time()
74
0 commit comments