Skip to content

Commit b40fca4

Browse files
authored
[fix] replacing torch.cuda.set_device with CUDA_VISIBLE_DEVICES (#85)
1 parent 8960b96 commit b40fca4

File tree

2 files changed

+17
-2
lines changed

2 files changed

+17
-2
lines changed

ci/L0_multi_gpu_vllm/vllm_backend/vllm_multi_gpu_test.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,14 +77,29 @@ def _test_vllm_multi_gpu_utilization(self, model_name: str):
7777

7878
print("=============== After Loading vLLM Model ===============")
7979
vllm_model_used_gpus = 0
80+
gpu_memory_utilizations = []
81+
8082
for gpu_id in gpu_ids:
8183
memory_utilization = self.get_gpu_memory_utilization(gpu_id)
8284
print(f"GPU {gpu_id} Memory Utilization: {memory_utilization} bytes")
83-
if memory_utilization > mem_util_before_loading_model[gpu_id]:
85+
memory_delta = memory_utilization - mem_util_before_loading_model[gpu_id]
86+
if memory_delta > 0:
8487
vllm_model_used_gpus += 1
88+
gpu_memory_utilizations.append(memory_delta)
8589

8690
self.assertGreaterEqual(vllm_model_used_gpus, 2)
8791

92+
# Check if memory utilization is approximately equal across GPUs
93+
if len(gpu_memory_utilizations) >= 2:
94+
max_memory = max(gpu_memory_utilizations)
95+
min_memory = min(gpu_memory_utilizations)
96+
relative_diff = (max_memory - min_memory) / max_memory
97+
self.assertLessEqual(
98+
relative_diff,
99+
0.1,
100+
f"GPU memory utilization differs by {relative_diff:.2%} which exceeds the 10% threshold",
101+
)
102+
88103
def _test_vllm_model(self, model_name: str, send_parameters_as_tensor: bool = True):
89104
user_data = UserData()
90105
stream = False

src/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ def _validate_device_config(self):
320320
f"Detected KIND_GPU model instance, explicitly setting GPU device={triton_device_id} for {triton_instance}"
321321
)
322322
# vLLM doesn't currently (v0.4.2) expose device selection in the APIs
323-
torch.cuda.set_device(triton_device_id)
323+
os.environ["CUDA_VISIBLE_DEVICES"] = str(triton_device_id)
324324

325325
def _setup_lora(self):
326326
self.enable_lora = False

0 commit comments

Comments
 (0)