Skip to content

Commit 8f1c5bc

Browse files
authored
Remove all unnecessary torch.cuda.empty_cache (hao-ai-lab#606)
1 parent d3fd04a commit 8f1c5bc

File tree

7 files changed

+1
-30
lines changed

7 files changed

+1
-30
lines changed

fastvideo/v1/distributed/parallel_state.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
you can skip the model parallel initialization and destruction steps.
2424
"""
2525
import contextlib
26-
import gc
2726
import os
2827
import pickle
2928
import weakref
@@ -1016,15 +1015,6 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
10161015
if shutdown_ray:
10171016
import ray # Lazy import Ray
10181017
ray.shutdown()
1019-
gc.collect()
1020-
from fastvideo.v1.platforms import current_platform
1021-
if not current_platform.is_cpu():
1022-
torch.cuda.empty_cache()
1023-
try:
1024-
torch._C._host_emptyCache()
1025-
except AttributeError:
1026-
logger.warning(
1027-
"torch._C._host_emptyCache() only available in Pytorch >=2.5")
10281018

10291019

10301020
def in_the_same_node_as(pg: ProcessGroup | StatelessProcessGroup,

fastvideo/v1/entrypoints/video_generator.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
diffusion models.
77
"""
88

9-
import gc
109
import math
1110
import os
1211
import time
@@ -277,5 +276,3 @@ def shutdown(self):
277276
"""
278277
self.executor.shutdown()
279278
del self.executor
280-
gc.collect()
281-
torch.cuda.empty_cache()

fastvideo/v1/models/vaes/common.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,6 @@ def parallel_tiled_decode(self, z: torch.FloatTensor) -> torch.FloatTensor:
239239

240240
results = torch.cat(local_results, dim=0).contiguous()
241241
del local_results
242-
torch.cuda.empty_cache()
243242
# first gather size to pad the results
244243
local_size = torch.tensor([results.size(0)],
245244
device=results.device,
@@ -253,7 +252,7 @@ def parallel_tiled_decode(self, z: torch.FloatTensor) -> torch.FloatTensor:
253252
padded_results = torch.zeros(max_size, device=results.device)
254253
padded_results[:results.size(0)] = results
255254
del results
256-
torch.cuda.empty_cache()
255+
257256
# Gather all results
258257
gathered_dim_metadata = [None] * world_size
259258
gathered_results = torch.zeros_like(padded_results).repeat(

fastvideo/v1/pipelines/stages/encoding.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,6 @@ def forward(
136136
self.maybe_free_model_hooks()
137137

138138
self.vae.to("cpu")
139-
torch.cuda.empty_cache()
140139

141140
return batch
142141

fastvideo/v1/pipelines/stages/image_encoding.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
This module contains implementations of image encoding stages for diffusion pipelines.
66
"""
77

8-
import torch
9-
108
from fastvideo.v1.distributed import get_local_torch_device
119
from fastvideo.v1.fastvideo_args import FastVideoArgs
1210
from fastvideo.v1.forward_context import set_forward_context
@@ -68,7 +66,6 @@ def forward(
6866

6967
if fastvideo_args.use_cpu_offload:
7068
self.image_encoder.to('cpu')
71-
torch.cuda.empty_cache()
7269

7370
return batch
7471

fastvideo/v1/training/training_pipeline.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# SPDX-License-Identifier: Apache-2.0
2-
import gc
32
import math
43
import os
54
import time
@@ -706,5 +705,3 @@ def _log_validation(self, transformer, training_args, global_step) -> None:
706705
# Re-enable gradients for training
707706
training_args.inference_mode = False
708707
transformer.train()
709-
gc.collect()
710-
torch.cuda.empty_cache()

fastvideo/v1/worker/gpu_worker.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
import contextlib
33
import faulthandler
4-
import gc
54
import multiprocessing as mp
65
import os
76
import signal
@@ -69,8 +68,6 @@ def init_device(self) -> None:
6968
torch.cuda.set_device(self.device)
7069

7170
# _check_if_gpu_supports_dtype(self.model_config.dtype)
72-
gc.collect()
73-
torch.cuda.empty_cache()
7471
self.init_gpu_memory = torch.cuda.mem_get_info()[0]
7572

7673
os.environ["MASTER_ADDR"] = "localhost"
@@ -102,9 +99,6 @@ def shutdown(self) -> dict[str, Any]:
10299
if hasattr(self, 'pipeline') and self.pipeline is not None:
103100
# Clean up pipeline resources if needed
104101
pass
105-
# Release CUDA resources
106-
if torch.cuda.is_available():
107-
torch.cuda.empty_cache()
108102

109103
# Destroy the distributed environment
110104
cleanup_dist_env_and_memory(shutdown_ray=False)
@@ -133,8 +127,6 @@ def event_loop(self) -> None:
133127

134128
# Handle regular RPC calls
135129
if method_name == 'execute_forward':
136-
gc.collect()
137-
torch.cuda.empty_cache()
138130
forward_batch = recv_rpc['kwargs']['forward_batch']
139131
fastvideo_args = recv_rpc['kwargs']['fastvideo_args']
140132
output_batch = self.execute_forward(forward_batch,

0 commit comments

Comments
 (0)