Skip to content

Commit

Permalink
[Speculative decoding] Support target-model logprobs (vllm-project#4378)
Browse files Browse the repository at this point in the history
  • Loading branch information
cadedaniel authored and robertgshaw2-neuralmagic committed May 6, 2024
1 parent 4b0f703 commit 6dd96ce
Show file tree
Hide file tree
Showing 15 changed files with 728 additions and 87 deletions.
66 changes: 63 additions & 3 deletions tests/spec_decode/e2e/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import asyncio
import time
from itertools import cycle
from typing import List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union

import pytest
import ray
import torch
from pynvml import (nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo,
nvmlInit)

from tests.conftest import cleanup
from vllm import LLM
Expand All @@ -13,7 +17,7 @@
from vllm.model_executor.utils import set_random_seed
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import MultiModalData
from vllm.sequence import Logprob, MultiModalData
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, random_uuid

Expand Down Expand Up @@ -153,12 +157,19 @@ def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
test_name = request.node.name

def generator_inner():
print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}')

wait_for_gpu_memory_to_clear(
devices=list(range(torch.cuda.device_count())),
threshold_bytes=2 * 2**30,
timeout_s=60,
)

use_async = False
if "use_async" in kwargs:
use_async = kwargs.pop("use_async")
print(f'{use_async=}')

print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}')
llm = AsyncLLM(**kwargs) if use_async else LLM(**kwargs)
set_random_seed(seed)

Expand Down Expand Up @@ -188,6 +199,20 @@ def get_output_from_llm_generator(
return tokens, token_ids


def get_logprobs_from_llm_generator(
llm_generator, prompts,
sampling_params) -> List[List[Dict[int, Logprob]]]:
"""Returns a dict of (token_id: Logprob) for each generated position, for
each sequence in the batch.
"""
for llm in llm_generator():
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
logprobs = [output.outputs[0].logprobs[:] for output in outputs]
del llm

return logprobs


def run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
Expand Down Expand Up @@ -243,3 +268,38 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
print(f'{i=} {baseline_token_ids=}')
print(f'{i=} {spec_token_ids=}')
assert baseline_token_ids == spec_token_ids


def wait_for_gpu_memory_to_clear(devices: List[int],
threshold_bytes: int,
timeout_s: float = 120) -> None:
# Use nvml instead of pytorch to reduce measurement error from torch cuda
# context.
nvmlInit()
start_time = time.time()
while True:
output = {}
output_raw = {}
for device in devices:
dev_handle = nvmlDeviceGetHandleByIndex(device)
mem_info = nvmlDeviceGetMemoryInfo(dev_handle)
gb_used = mem_info.used / 2**30
output_raw[device] = gb_used
output[device] = f'{gb_used:.02f}'

print('gpu memory used (GB): ', end='')
for k, v in output.items():
print(f'{k}={v}; ', end='')
print('')

dur_s = time.time() - start_time
if all(v <= (threshold_bytes / 2**30) for v in output_raw.values()):
print(f'Done waiting for free GPU memory on devices {devices=} '
f'({threshold_bytes/2**30=}) {dur_s=:.02f}')
break

if dur_s >= timeout_s:
raise ValueError(f'Memory of devices {devices=} not free after '
f'{dur_s=:.02f} ({threshold_bytes/2**30=})')

time.sleep(5)
Loading

0 comments on commit 6dd96ce

Please sign in to comment.