|
1 | 1 | import asyncio
|
| 2 | +import time |
2 | 3 | from itertools import cycle
|
3 |
| -from typing import List, Optional, Tuple, Union |
| 4 | +from typing import Dict, List, Optional, Tuple, Union |
4 | 5 |
|
5 | 6 | import pytest
|
6 | 7 | import ray
|
| 8 | +import torch |
| 9 | +from pynvml import (nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo, |
| 10 | + nvmlInit) |
7 | 11 |
|
8 | 12 | from tests.conftest import cleanup
|
9 | 13 | from vllm import LLM
|
|
13 | 17 | from vllm.model_executor.utils import set_random_seed
|
14 | 18 | from vllm.outputs import RequestOutput
|
15 | 19 | from vllm.sampling_params import SamplingParams
|
16 |
| -from vllm.sequence import MultiModalData |
| 20 | +from vllm.sequence import Logprob, MultiModalData |
17 | 21 | from vllm.usage.usage_lib import UsageContext
|
18 | 22 | from vllm.utils import Counter, random_uuid
|
19 | 23 |
|
@@ -153,12 +157,19 @@ def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
|
153 | 157 | test_name = request.node.name
|
154 | 158 |
|
155 | 159 | def generator_inner():
|
156 |
| - print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}') |
| 160 | + |
| 161 | + wait_for_gpu_memory_to_clear( |
| 162 | + devices=list(range(torch.cuda.device_count())), |
| 163 | + threshold_bytes=2 * 2**30, |
| 164 | + timeout_s=60, |
| 165 | + ) |
157 | 166 |
|
158 | 167 | use_async = False
|
159 | 168 | if "use_async" in kwargs:
|
160 | 169 | use_async = kwargs.pop("use_async")
|
| 170 | + print(f'{use_async=}') |
161 | 171 |
|
| 172 | + print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}') |
162 | 173 | llm = AsyncLLM(**kwargs) if use_async else LLM(**kwargs)
|
163 | 174 | set_random_seed(seed)
|
164 | 175 |
|
@@ -188,6 +199,20 @@ def get_output_from_llm_generator(
|
188 | 199 | return tokens, token_ids
|
189 | 200 |
|
190 | 201 |
|
| 202 | +def get_logprobs_from_llm_generator( |
| 203 | + llm_generator, prompts, |
| 204 | + sampling_params) -> List[List[Dict[int, Logprob]]]: |
| 205 | + """Returns a dict of (token_id: Logprob) for each generated position, for |
| 206 | + each sequence in the batch. |
| 207 | + """ |
| 208 | + for llm in llm_generator(): |
| 209 | + outputs = llm.generate(prompts, sampling_params, use_tqdm=True) |
| 210 | + logprobs = [output.outputs[0].logprobs[:] for output in outputs] |
| 211 | + del llm |
| 212 | + |
| 213 | + return logprobs |
| 214 | + |
| 215 | + |
191 | 216 | def run_greedy_equality_correctness_test(baseline_llm_generator,
|
192 | 217 | test_llm_generator,
|
193 | 218 | batch_size,
|
@@ -243,3 +268,38 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
|
243 | 268 | print(f'{i=} {baseline_token_ids=}')
|
244 | 269 | print(f'{i=} {spec_token_ids=}')
|
245 | 270 | assert baseline_token_ids == spec_token_ids
|
| 271 | + |
| 272 | + |
| 273 | +def wait_for_gpu_memory_to_clear(devices: List[int], |
| 274 | + threshold_bytes: int, |
| 275 | + timeout_s: float = 120) -> None: |
| 276 | + # Use nvml instead of pytorch to reduce measurement error from torch cuda |
| 277 | + # context. |
| 278 | + nvmlInit() |
| 279 | + start_time = time.time() |
| 280 | + while True: |
| 281 | + output = {} |
| 282 | + output_raw = {} |
| 283 | + for device in devices: |
| 284 | + dev_handle = nvmlDeviceGetHandleByIndex(device) |
| 285 | + mem_info = nvmlDeviceGetMemoryInfo(dev_handle) |
| 286 | + gb_used = mem_info.used / 2**30 |
| 287 | + output_raw[device] = gb_used |
| 288 | + output[device] = f'{gb_used:.02f}' |
| 289 | + |
| 290 | + print('gpu memory used (GB): ', end='') |
| 291 | + for k, v in output.items(): |
| 292 | + print(f'{k}={v}; ', end='') |
| 293 | + print('') |
| 294 | + |
| 295 | + dur_s = time.time() - start_time |
| 296 | + if all(v <= (threshold_bytes / 2**30) for v in output_raw.values()): |
| 297 | + print(f'Done waiting for free GPU memory on devices {devices=} ' |
| 298 | + f'({threshold_bytes/2**30=}) {dur_s=:.02f}') |
| 299 | + break |
| 300 | + |
| 301 | + if dur_s >= timeout_s: |
| 302 | + raise ValueError(f'Memory of devices {devices=} not free after ' |
| 303 | + f'{dur_s=:.02f} ({threshold_bytes/2**30=})') |
| 304 | + |
| 305 | + time.sleep(5) |
0 commit comments