File tree Expand file tree Collapse file tree 1 file changed +5
-1
lines changed Expand file tree Collapse file tree 1 file changed +5
-1
lines changed Original file line number Diff line number Diff line change @@ -84,6 +84,8 @@ def init_model(self, cupy_port: Optional[int] = None) -> None:
8484 torch .cuda .set_device (self .device )
8585
8686 _check_if_gpu_supports_dtype (self .model_config .dtype )
87+ torch .cuda .empty_cache ()
88+ self .init_gpu_memory = torch .cuda .mem_get_info ()[0 ]
8789 else :
8890 raise RuntimeError (
8991 f"Not support device type: { self .device_config .device } " )
@@ -126,7 +128,9 @@ def profile_num_available_blocks(
126128 # profiled peak memory.
127129 torch .cuda .synchronize ()
128130 free_gpu_memory , total_gpu_memory = torch .cuda .mem_get_info ()
129- peak_memory = total_gpu_memory - free_gpu_memory
131+ # NOTE(woosuk): Here we assume that the other processes using the same
132+ # GPU did not change their memory usage during the profiling.
133+ peak_memory = self .init_gpu_memory - free_gpu_memory
130134
131135 cache_block_size = CacheEngine .get_cache_block_size (
132136 block_size , cache_dtype , self .model_config , self .parallel_config )
You can’t perform that action at this time.
0 commit comments