Skip to content

Commit 30bad5c

Browse files
authored
Fix peak memory profiling (#2031)
1 parent 3fefe27 commit 30bad5c

File tree

2 files changed

+3
-9
lines changed

2 files changed

+3
-9
lines changed

vllm/utils.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,6 @@ def get_max_shared_memory_bytes(gpu: int = 0) -> int:
4040
return int(max_shared_mem)
4141

4242

43-
def get_gpu_memory(gpu: int = 0) -> int:
44-
"""Returns the total memory of the GPU in bytes."""
45-
return torch.cuda.get_device_properties(gpu).total_memory
46-
47-
4843
def get_cpu_memory() -> int:
4944
"""Returns the total CPU memory of the node in bytes."""
5045
return psutil.virtual_memory().total

vllm/worker/worker.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
1414
from vllm.worker.cache_engine import CacheEngine
1515
from vllm.worker.model_runner import ModelRunner
16-
from vllm.utils import get_gpu_memory
1716

1817

1918
class Worker:
@@ -81,7 +80,6 @@ def profile_num_available_blocks(
8180
# Profile the memory usage of the model and get the maximum number of
8281
# cache blocks that can be allocated with the remaining free memory.
8382
torch.cuda.empty_cache()
84-
torch.cuda.reset_peak_memory_stats()
8583

8684
# Execute a forward pass with dummy inputs to profile the memory usage
8785
# of the model.
@@ -90,8 +88,9 @@ def profile_num_available_blocks(
9088
# Calculate the number of blocks that can be allocated with the
9189
# profiled peak memory.
9290
torch.cuda.synchronize()
93-
peak_memory = torch.cuda.max_memory_allocated()
94-
total_gpu_memory = get_gpu_memory()
91+
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
92+
peak_memory = total_gpu_memory - free_gpu_memory
93+
9594
cache_block_size = CacheEngine.get_cache_block_size(
9695
block_size, self.model_config, self.parallel_config)
9796
num_gpu_blocks = int(

0 commit comments

Comments
 (0)