Skip to content

Commit

Permalink
tests: add lora adapter test
Browse files Browse the repository at this point in the history
  • Loading branch information
dtrifiro committed Jul 18, 2024
1 parent 265812f commit 0b69d69
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 2 deletions.
46 changes: 44 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import pytest
import requests
import vllm
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.cli_args import make_arg_parser
Expand Down Expand Up @@ -38,18 +39,59 @@ def monkeysession():


@pytest.fixture(scope="session")
def args(
monkeysession, grpc_server_thread_port, http_server_thread_port
def lora_enabled():
# lora does not work on cpu
return not vllm.config.is_cpu()


@pytest.fixture(scope="session")
def requires_lora(lora_enabled): # noqa: PT004
if not lora_enabled:
pytest.skip(reason="Lora is not enabled. (disabled on cpu)")


@pytest.fixture(scope="session")
def lora_adapter_name(requires_lora):
return "lora-test"


@pytest.fixture(scope="session")
def lora_adapter_path(requires_lora):
from huggingface_hub import snapshot_download

path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
return path


@pytest.fixture(scope="session")
def args( # noqa: PLR0913
monkeysession,
grpc_server_thread_port,
http_server_thread_port,
lora_enabled,
lora_adapter_name,
lora_adapter_path,
) -> argparse.Namespace:
"""Return parsed CLI arguments for the adapter/vLLM."""
# avoid parsing pytest arguments as vllm/vllm_tgis_adapter arguments

extra_args: list[str] = []
if lora_enabled:
extra_args.extend(
(
"--enable-lora",
f"--lora-modules={lora_adapter_name}={lora_adapter_path}",
)
)

monkeysession.setattr(
sys,
"argv",
[
"__main__.py",
f"--grpc-port={grpc_server_thread_port}",
f"--port={http_server_thread_port}",
*extra_args,
],
)

Expand Down
6 changes: 6 additions & 0 deletions tests/test_grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,9 @@ def test_batched_generation_request(grpc_client, grpc_server_thread_port):

assert len(responses) == 2
assert all(response.text for response in responses)


def test_lora_request(grpc_client, lora_adapter_name):
response = grpc_client.make_request("hello", adapter_id=lora_adapter_name)

assert response.text
2 changes: 2 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def make_request(
text: str | list[str],
model_id: str | None = None,
max_new_tokens: int = 10,
adapter_id: str | None = None,
) -> GenerationResponse | Sequence[GenerationResponse]:
if single_request := isinstance(text, str):
text = [text]
Expand All @@ -127,6 +128,7 @@ def make_request(
params=Parameters(
stopping=StoppingCriteria(max_new_tokens=max_new_tokens),
),
adapter_id=adapter_id,
)

response = self.generation_service_stub.Generate(
Expand Down

0 comments on commit 0b69d69

Please sign in to comment.