Skip to content

Commit ab50275

Browse files
authored
[Speculative decoding] Support target-model logprobs (#4378)
1 parent 43c413e commit ab50275

File tree

15 files changed

+728
-87
lines changed

15 files changed

+728
-87
lines changed

tests/spec_decode/e2e/conftest.py

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
import asyncio
2+
import time
23
from itertools import cycle
3-
from typing import List, Optional, Tuple, Union
4+
from typing import Dict, List, Optional, Tuple, Union
45

56
import pytest
67
import ray
8+
import torch
9+
from pynvml import (nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo,
10+
nvmlInit)
711

812
from tests.conftest import cleanup
913
from vllm import LLM
@@ -13,7 +17,7 @@
1317
from vllm.model_executor.utils import set_random_seed
1418
from vllm.outputs import RequestOutput
1519
from vllm.sampling_params import SamplingParams
16-
from vllm.sequence import MultiModalData
20+
from vllm.sequence import Logprob, MultiModalData
1721
from vllm.usage.usage_lib import UsageContext
1822
from vllm.utils import Counter, random_uuid
1923

@@ -153,12 +157,19 @@ def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
153157
test_name = request.node.name
154158

155159
def generator_inner():
156-
print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}')
160+
161+
wait_for_gpu_memory_to_clear(
162+
devices=list(range(torch.cuda.device_count())),
163+
threshold_bytes=2 * 2**30,
164+
timeout_s=60,
165+
)
157166

158167
use_async = False
159168
if "use_async" in kwargs:
160169
use_async = kwargs.pop("use_async")
170+
print(f'{use_async=}')
161171

172+
print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}')
162173
llm = AsyncLLM(**kwargs) if use_async else LLM(**kwargs)
163174
set_random_seed(seed)
164175

@@ -188,6 +199,20 @@ def get_output_from_llm_generator(
188199
return tokens, token_ids
189200

190201

202+
def get_logprobs_from_llm_generator(
203+
llm_generator, prompts,
204+
sampling_params) -> List[List[Dict[int, Logprob]]]:
205+
"""Returns a dict of (token_id: Logprob) for each generated position, for
206+
each sequence in the batch.
207+
"""
208+
for llm in llm_generator():
209+
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
210+
logprobs = [output.outputs[0].logprobs[:] for output in outputs]
211+
del llm
212+
213+
return logprobs
214+
215+
191216
def run_greedy_equality_correctness_test(baseline_llm_generator,
192217
test_llm_generator,
193218
batch_size,
@@ -243,3 +268,38 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
243268
print(f'{i=} {baseline_token_ids=}')
244269
print(f'{i=} {spec_token_ids=}')
245270
assert baseline_token_ids == spec_token_ids
271+
272+
273+
def wait_for_gpu_memory_to_clear(devices: List[int],
274+
threshold_bytes: int,
275+
timeout_s: float = 120) -> None:
276+
# Use nvml instead of pytorch to reduce measurement error from torch cuda
277+
# context.
278+
nvmlInit()
279+
start_time = time.time()
280+
while True:
281+
output = {}
282+
output_raw = {}
283+
for device in devices:
284+
dev_handle = nvmlDeviceGetHandleByIndex(device)
285+
mem_info = nvmlDeviceGetMemoryInfo(dev_handle)
286+
gb_used = mem_info.used / 2**30
287+
output_raw[device] = gb_used
288+
output[device] = f'{gb_used:.02f}'
289+
290+
print('gpu memory used (GB): ', end='')
291+
for k, v in output.items():
292+
print(f'{k}={v}; ', end='')
293+
print('')
294+
295+
dur_s = time.time() - start_time
296+
if all(v <= (threshold_bytes / 2**30) for v in output_raw.values()):
297+
print(f'Done waiting for free GPU memory on devices {devices=} '
298+
f'({threshold_bytes/2**30=}) {dur_s=:.02f}')
299+
break
300+
301+
if dur_s >= timeout_s:
302+
raise ValueError(f'Memory of devices {devices=} not free after '
303+
f'{dur_s=:.02f} ({threshold_bytes/2**30=})')
304+
305+
time.sleep(5)

0 commit comments

Comments
 (0)