Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Pipeline Parallel Support #4412

Merged
merged 6 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,16 @@ steps:
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py

- label: Pipeline Parallelism Test
working_dir: "/vllm-workspace/tests"
num_gpus: 4
commands:
- TP_SIZE=2 PP_SIZE=2 EAGER_MODE=1 CHUNKED_PREFILL=1 pytest -v -s distributed/test_pipeline_parallel.py
- TP_SIZE=2 PP_SIZE=2 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py
- PP_SIZE=4 EAGER_MODE=1 CHUNKED_PREFILL=1 pytest -v -s distributed/test_pipeline_parallel.py
- PP_SIZE=4 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py


- label: Engine Test
mirror_hardwares: [amd]
command: pytest -v -s engine tokenization test_sequence.py test_config.py test_logger.py
Expand Down
14 changes: 13 additions & 1 deletion tests/async_engine/test_async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch

from vllm import SamplingParams
from vllm.config import ParallelConfig
from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine

from ..utils import wait_for_gpu_memory_to_clear
Expand All @@ -23,15 +24,21 @@ def __init__(self):
self.add_request_calls = 0
self.abort_request_calls = 0
self.request_id = None
# Ugly, remove dependency when possible
self.parallel_config = ParallelConfig(1, 1, False)
youkaichao marked this conversation as resolved.
Show resolved Hide resolved

async def step_async(self):
async def step_async(self, virtual_engine):
# PP size is 1, ignore virtual engine
self.step_calls += 1
return [RequestOutput(
request_id=self.request_id)] if self.request_id else []

async def process_model_inputs_async(self, *args, **kwargs):
pass

async def stop_remote_worker_execution_loop_async(self):
pass

def generate(self, request_id):
self.request_id = request_id

Expand All @@ -41,6 +48,7 @@ def stop_generating(self):
def add_request(self, **kwargs):
del kwargs # Unused
self.add_request_calls += 1
print(f'Request calls: {self.add_request_calls}')

async def add_request_async(self, **kwargs):
self.add_request_calls += 1
Expand All @@ -53,6 +61,9 @@ def abort_request(self, request_id):
def has_unfinished_requests(self):
return self.request_id is not None

def has_unfinished_requests_for_virtual_engine(self, virtual_engine):
return self.request_id is not None


class MockAsyncLLMEngine(AsyncLLMEngine):

Expand All @@ -76,6 +87,7 @@ async def test_new_requests_event():
engine.engine.generate("2")
await asyncio.sleep(0)
await asyncio.sleep(0)
await asyncio.sleep(0)
andoorve marked this conversation as resolved.
Show resolved Hide resolved
assert engine.engine.add_request_calls == 2
assert engine.engine.step_calls >= 2
await asyncio.sleep(0.001)
Expand Down
4 changes: 2 additions & 2 deletions tests/async_engine/test_openapi_server_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
# and debugging.
import ray

from ..utils import RemoteOpenAIServer
from ..utils import VLLM_PATH, RemoteOpenAIServer

# any model with a chat template should work here
MODEL_NAME = "facebook/opt-125m"


@pytest.fixture(scope="module")
def ray_ctx():
ray.init()
ray.init(runtime_env={"working_dir": VLLM_PATH})
yield
ray.shutdown()

Expand Down
24 changes: 12 additions & 12 deletions tests/basic_correctness/test_preemption.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def test_chunked_prefill_recompute(
max_num_seqs=max_num_seqs,
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
ARTIFICIAL_PREEMPTION_MAX_CNT)
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
< ARTIFICIAL_PREEMPTION_MAX_CNT)

for i in range(len(example_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i]
Expand Down Expand Up @@ -91,10 +91,10 @@ def test_preemption(
disable_log_stats=False,
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
ARTIFICIAL_PREEMPTION_MAX_CNT)
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
< ARTIFICIAL_PREEMPTION_MAX_CNT)
total_preemption = (
vllm_model.model.llm_engine.scheduler.num_cumulative_preemption)
vllm_model.model.llm_engine.scheduler[0].num_cumulative_preemption)

check_outputs_equal(
outputs_0_lst=hf_outputs,
Expand Down Expand Up @@ -147,10 +147,10 @@ def test_swap(
) as vllm_model:
vllm_outputs = vllm_model.generate_beam_search(example_prompts,
beam_width, max_tokens)
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
ARTIFICIAL_PREEMPTION_MAX_CNT)
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
< ARTIFICIAL_PREEMPTION_MAX_CNT)
total_preemption = (
vllm_model.model.llm_engine.scheduler.num_cumulative_preemption)
vllm_model.model.llm_engine.scheduler[0].num_cumulative_preemption)

for i in range(len(example_prompts)):
hf_output_ids, _ = hf_outputs[i]
Expand Down Expand Up @@ -214,8 +214,8 @@ def test_swap_infeasible(
example_prompts,
sampling_params=sampling_params,
)
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
ARTIFICIAL_PREEMPTION_MAX_CNT)
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
< ARTIFICIAL_PREEMPTION_MAX_CNT)

# Verify the request is ignored and not hang.
assert req_outputs[0].outputs[0].finish_reason == "length"
Expand Down Expand Up @@ -252,8 +252,8 @@ def test_preemption_infeasible(
sampling_params=sampling_params,
)

assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
ARTIFICIAL_PREEMPTION_MAX_CNT)
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
< ARTIFICIAL_PREEMPTION_MAX_CNT)

# Verify the request is ignored and not hang.
for req_output in req_outputs:
Expand Down
20 changes: 17 additions & 3 deletions tests/distributed/test_comm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def all_reduce_test_worker(tp_size: int, pp_size: int, rank: int,
(r + 1) for r in range(tp_size)
]
expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
t = all_tensors[rank]
t = all_tensors[rank % tp_size]
t = tensor_model_parallel_all_reduce(t)
assert torch.allclose(t, expected)

Expand Down Expand Up @@ -60,7 +60,7 @@ def all_gather_test_worker(tp_size: int, pp_size: int, rank: int,
for r in range(tp_size)
]
expected = torch.cat(all_tensors, dim=all_gather_dimension)
t = all_tensors[rank]
t = all_tensors[rank % tp_size]
t = tensor_model_parallel_all_gather(t, all_gather_dimension)
assert torch.allclose(t, expected)

Expand Down Expand Up @@ -91,7 +91,7 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
"f": torch.tensor([], dtype=torch.float32, device="cuda"),
}

if rank == 0:
if (rank % tp_size) == 0:
broadcast_tensor_dict(test_dict, src=0)
else:
recv_dict = broadcast_tensor_dict(src=0)
Expand Down Expand Up @@ -184,3 +184,17 @@ def test_multi_process_tensor_parallel(tp_size, test_target):
"test_target", [send_recv_test_worker, send_recv_tensor_dict_test_worker])
def test_multi_process_pipeline_parallel(pp_size, test_target):
multi_process_parallel(1, pp_size, test_target)


@pytest.mark.skipif(torch.cuda.device_count() < 4,
reason="Need at least 4 GPUs to run the test.")
@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("pp_size", [2])
@pytest.mark.parametrize("test_target", [
send_recv_test_worker, send_recv_tensor_dict_test_worker,
all_reduce_test_worker, all_gather_test_worker,
broadcast_tensor_dict_test_worker
])
def test_multi_process_tensor_parallel_pipeline_parallel(
tp_size, pp_size, test_target):
multi_process_parallel(tp_size, pp_size, test_target)
149 changes: 149 additions & 0 deletions tests/distributed/test_pipeline_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import os

import openai # use the official client for correctness check
import pytest
# using Ray for overall ease of process management, parallel requests,
# and debugging.
import ray

from ..utils import VLLM_PATH, RemoteOpenAIServer

# downloading lora to test lora requests

# any model with a chat template should work here
MODEL_NAME = "meta-llama/Meta-Llama-3-8B"
EAGER_MODE = bool(int(os.getenv("EAGER_MODE", 0)))
CHUNKED_PREFILL = bool(int(os.getenv("CHUNKED_PREFILL", 0)))
TP_SIZE = int(os.getenv("TP_SIZE", 1))
PP_SIZE = int(os.getenv("PP_SIZE", 1))

pytestmark = pytest.mark.asyncio


@pytest.fixture(scope="module")
def ray_ctx():
ray.init(runtime_env={"working_dir": VLLM_PATH})
yield
ray.shutdown()


@pytest.fixture(scope="module")
def server(ray_ctx):
args = [
"--model",
MODEL_NAME,
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--pipeline-parallel-size",
andoorve marked this conversation as resolved.
Show resolved Hide resolved
str(PP_SIZE),
"--tensor-parallel-size",
str(TP_SIZE),
"--distributed-executor-backend",
"ray",
]
if CHUNKED_PREFILL:
args += [
"--enable-chunked-prefill",
]
if EAGER_MODE:
args += [
"--enforce-eager",
]
return RemoteOpenAIServer(args, num_gpus=PP_SIZE * TP_SIZE)


@pytest.fixture(scope="module")
def client(server):
return server.get_async_client()


async def test_check_models(server, client: openai.AsyncOpenAI):
models = await client.models.list()
models = models.data
served_model = models[0]
assert served_model.id == MODEL_NAME
assert all(model.root == MODEL_NAME for model in models)


@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
andoorve marked this conversation as resolved.
Show resolved Hide resolved
async def test_single_completion(server, client: openai.AsyncOpenAI,
model_name: str):
completion = await client.completions.create(model=model_name,
prompt="Hello, my name is",
max_tokens=5,
temperature=0.0)

assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 1
assert completion.choices[0].text is not None and len(
completion.choices[0].text) >= 5
assert completion.choices[0].finish_reason == "length"
assert completion.usage == openai.types.CompletionUsage(
completion_tokens=5, prompt_tokens=6, total_tokens=11)
youkaichao marked this conversation as resolved.
Show resolved Hide resolved

# test using token IDs
completion = await client.completions.create(
model=MODEL_NAME,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
)
assert completion.choices[0].text is not None and len(
completion.choices[0].text) >= 5


@pytest.mark.parametrize(
# just test 1 lora hereafter
"model_name",
[MODEL_NAME],
)
async def test_batch_completions(server, client: openai.AsyncOpenAI,
model_name: str):
# test simple list
batch = await client.completions.create(
model=model_name,
prompt=["Hello, my name is", "Hello, my name is"],
max_tokens=5,
temperature=0.0,
)
assert len(batch.choices) == 2
assert batch.choices[0].text == batch.choices[1].text

# test n = 2
batch = await client.completions.create(
model=model_name,
prompt=["Hello, my name is", "Hello, my name is"],
n=2,
max_tokens=5,
temperature=0.0,
extra_body=dict(
# NOTE: this has to be true for n > 1 in vLLM, but not necessary
# for official client.
use_beam_search=True),
)
assert len(batch.choices) == 4
assert batch.choices[0].text != batch.choices[
1].text, "beam search should be different"
assert batch.choices[0].text == batch.choices[
2].text, "two copies of the same prompt should be the same"
assert batch.choices[1].text == batch.choices[
3].text, "two copies of the same prompt should be the same"

# test streaming
batch = await client.completions.create(
model=model_name,
prompt=["Hello, my name is", "Hello, my name is"],
max_tokens=5,
temperature=0.0,
stream=True,
)
texts = [""] * 2
async for chunk in batch:
assert len(chunk.choices) == 1
choice = chunk.choices[0]
texts[choice.index] += choice.text
assert texts[0] == texts[1]
8 changes: 4 additions & 4 deletions tests/engine/output_processor/test_multi_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_appends_token_ids(num_new_tokens: int, seq_output_len: int):

output_processor = MultiStepOutputProcessor(
detokenizer=detokenizer,
scheduler=scheduler,
scheduler=[scheduler],
seq_counter=seq_counter,
get_tokenizer_for_seq=lambda _: mock_tokenizer(),
stop_checker=stop_checker,
Expand Down Expand Up @@ -86,7 +86,7 @@ def test_respects_max_tokens(num_new_tokens: int, seq_prompt_len: int,

output_processor = MultiStepOutputProcessor(
detokenizer=detokenizer,
scheduler=scheduler,
scheduler=[scheduler],
seq_counter=seq_counter,
get_tokenizer_for_seq=lambda _: mock_tokenizer(),
stop_checker=stop_checker,
Expand Down Expand Up @@ -148,7 +148,7 @@ def test_respects_eos_token_id(num_new_tokens: int, seq_prompt_len: int,

output_processor = MultiStepOutputProcessor(
detokenizer=detokenizer,
scheduler=scheduler,
scheduler=[scheduler],
seq_counter=seq_counter,
get_tokenizer_for_seq=lambda _: mock_tokenizer(eos_token_id),
stop_checker=stop_checker,
Expand Down Expand Up @@ -215,7 +215,7 @@ def test_ignores_eos_token_id(num_new_tokens: int, seq_prompt_len: int,

output_processor = MultiStepOutputProcessor(
detokenizer=detokenizer,
scheduler=scheduler,
scheduler=[scheduler],
seq_counter=seq_counter,
get_tokenizer_for_seq=lambda _: mock_tokenizer(eos_token_id),
stop_checker=stop_checker,
Expand Down
Loading
Loading