Skip to content

Commit 7e45107

Browse files
authored
[Fix] Fix memory profiling when GPU is used by multiple processes (#2863)
1 parent 0c48b37 commit 7e45107

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

vllm/worker/worker.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ def init_model(self, cupy_port: Optional[int] = None) -> None:
8484
torch.cuda.set_device(self.device)
8585

8686
_check_if_gpu_supports_dtype(self.model_config.dtype)
87+
torch.cuda.empty_cache()
88+
self.init_gpu_memory = torch.cuda.mem_get_info()[0]
8789
else:
8890
raise RuntimeError(
8991
f"Not support device type: {self.device_config.device}")
@@ -126,7 +128,9 @@ def profile_num_available_blocks(
126128
# profiled peak memory.
127129
torch.cuda.synchronize()
128130
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
129-
peak_memory = total_gpu_memory - free_gpu_memory
131+
# NOTE(woosuk): Here we assume that the other processes using the same
132+
# GPU did not change their memory usage during the profiling.
133+
peak_memory = self.init_gpu_memory - free_gpu_memory
130134

131135
cache_block_size = CacheEngine.get_cache_block_size(
132136
block_size, cache_dtype, self.model_config, self.parallel_config)

0 commit comments

Comments
 (0)