-
Notifications
You must be signed in to change notification settings - Fork 842
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integrate vllm with example Lora and Mistral #3077
Changes from 16 commits
53cf7b1
7d52d63
27f25b3
e9d4943
bd1531a
f3fd17f
c048e4d
91b2efa
954b470
9ba9883
09cf328
7b646bb
8247385
4e5d96e
c02cc8c
dfd8ffa
66d3c03
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# Example showing inference with vLLM | ||
|
||
This folder contains multiple demonstrations showcasing the integration of [vLLM Engine](https://github.com/vllm-project/vllm) with TorchServe, running inference with continuous batching. | ||
vLLM achieves high throughput using PagedAttention. More details can be found [here](https://vllm.ai/) | ||
|
||
- demo1: [Mistral](mistral) | ||
- demo2: [lora](lora) | ||
|
||
### Supported vLLM Configuration | ||
* LLMEngine configuration: | ||
vLLM [EngineArgs](https://github.com/vllm-project/vllm/blob/258a2c58d08fc7a242556120877a89404861fbce/vllm/engine/arg_utils.py#L15) is defined in the section of `handler/vllm_engine_config` of model-config.yaml. | ||
|
||
|
||
* Sampling parameters for text generation: | ||
vLLM [SamplingParams](https://github.com/vllm-project/vllm/blob/258a2c58d08fc7a242556120877a89404861fbce/vllm/sampling_params.py#L27) is defined in the JSON format, for example, [prompt.json](lora/prompt.json). |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
import logging | ||
import pathlib | ||
|
||
from vllm import EngineArgs, LLMEngine, SamplingParams | ||
from vllm.lora.request import LoRARequest | ||
|
||
from ts.torch_handler.base_handler import BaseHandler | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class BaseVLLMHandler(BaseHandler): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
self.vllm_engine = None | ||
self.model = None | ||
self.model_dir = None | ||
self.lora_ids = {} | ||
self.adapters = None | ||
self.initialized = False | ||
|
||
def initialize(self, ctx): | ||
ctx.cache = {} | ||
|
||
self.model_dir = ctx.system_properties.get("model_dir") | ||
vllm_engine_config = self._get_vllm_engine_config( | ||
ctx.model_yaml_config.get("handler", {}) | ||
) | ||
self.adapters = ctx.model_yaml_config.get("handler", {}).get("adapters", {}) | ||
self.vllm_engine = LLMEngine.from_engine_args(vllm_engine_config) | ||
self.initialized = True | ||
|
||
def preprocess(self, requests): | ||
for req_id, req_data in zip(self.context.request_ids.values(), requests): | ||
if req_id not in self.context.cache: | ||
data = req_data.get("data") or req_data.get("body") | ||
if isinstance(data, (bytes, bytearray)): | ||
data = data.decode("utf-8") | ||
|
||
prompt = data.get("prompt") | ||
sampling_params = self._get_sampling_params(req_data) | ||
lora_request = self._get_lora_request(req_data) | ||
self.context.cache[req_id] = { | ||
"text_len": 0, | ||
"stopping_criteria": self._create_stopping_criteria(req_id), | ||
} | ||
self.vllm_engine.add_request( | ||
req_id, prompt, sampling_params, lora_request=lora_request | ||
) | ||
|
||
return requests | ||
|
||
def inference(self, input_batch): | ||
inference_outputs = self.vllm_engine.step() | ||
results = {} | ||
|
||
for output in inference_outputs: | ||
req_id = output.request_id | ||
results[req_id] = { | ||
"text": output.outputs[0].text[ | ||
self.context.cache[req_id]["text_len"] : | ||
], | ||
"tokens": output.outputs[0].token_ids[-1], | ||
"finished": output.finished, | ||
} | ||
self.context.cache[req_id]["text_len"] = len(output.outputs[0].text) | ||
|
||
return [results[i] for i in self.context.request_ids.values()] | ||
|
||
def postprocess(self, inference_outputs): | ||
self.context.stopping_criteria = [ | ||
self.context.cache[req_id]["stopping_criteria"] | ||
for req_id in self.context.request_ids.values() | ||
] | ||
|
||
return inference_outputs | ||
|
||
def _get_vllm_engine_config(self, handler_config: dict): | ||
vllm_engine_params = handler_config.get("vllm_engine_config", {}) | ||
model = vllm_engine_params.get("model", {}) | ||
if len(model) == 0: | ||
model_path = handler_config.get("model_path", {}) | ||
assert ( | ||
len(model_path) > 0 | ||
), "please define model in vllm_engine_config or model_path in handler" | ||
model = str(pathlib.Path(self.model_dir).joinpath(model_path)) | ||
logger.info(f"EngineArgs model={model}") | ||
vllm_engine_config = EngineArgs(model=model) | ||
self._set_attr_value(vllm_engine_config, vllm_engine_params) | ||
return vllm_engine_config | ||
|
||
def _get_sampling_params(self, req_data: dict): | ||
sampling_params = SamplingParams() | ||
self._set_attr_value(sampling_params, req_data) | ||
|
||
return sampling_params | ||
|
||
def _get_lora_request(self, req_data: dict): | ||
adapter_name = req_data.get("lora_adapter", "") | ||
|
||
if len(adapter_name) > 0: | ||
adapter_path = self.adapters.get(adapter_name, "") | ||
assert len(adapter_path) > 0, f"{adapter_name} misses adapter path" | ||
lora_id = self.lora_ids.setdefault(adapter_name, len(self.lora_ids) + 1) | ||
adapter_path = str(pathlib.Path(self.model_dir).joinpath(adapter_path)) | ||
logger.info(f"adapter_path=${adapter_path}") | ||
return LoRARequest(adapter_name, lora_id, adapter_path) | ||
|
||
return None | ||
|
||
def _clean_up(self, req_id): | ||
del self.context.cache[req_id] | ||
|
||
def _create_stopping_criteria(self, req_id): | ||
class StoppingCriteria(object): | ||
def __init__(self, outer, req_id): | ||
self.req_id = req_id | ||
self.outer = outer | ||
|
||
def __call__(self, res): | ||
if res["finished"]: | ||
self.outer._clean_up(self.req_id) | ||
return res["finished"] | ||
|
||
return StoppingCriteria(outer=self, req_id=req_id) | ||
|
||
def _set_attr_value(self, obj, config: dict): | ||
items = vars(obj) | ||
for k, v in config.items(): | ||
if k in items: | ||
setattr(obj, k, v) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# Example showing inference with vLLM on LoRA model | ||
|
||
This is an example showing how to integrate [vLLM](https://github.com/vllm-project/vllm) with TorchServe and run inference on model `Llama-2-7b-hf` + LoRA model `llama-2-7b-sql-lora-test` with continuous batching. | ||
|
||
### Step 1: Download Model from HuggingFace | ||
|
||
Login with a HuggingFace account | ||
``` | ||
huggingface-cli login | ||
# or using an environment variable | ||
huggingface-cli login --token $HUGGINGFACE_TOKEN | ||
``` | ||
|
||
```bash | ||
python ../../utils/Download_model.py --model_path model --model_name meta-llama/Llama-2-7b-chat-hf --use_auth_token True | ||
mkdir adapters && cd adapters | ||
python ../../../utils/Download_model.py --model_path model --model_name yard1/llama-2-7b-sql-lora-test --use_auth_token True | ||
cd .. | ||
``` | ||
|
||
### Step 2: Generate model artifacts | ||
|
||
Add the downloaded path to "model_path:" and "adapter_1:" in `model-config.yaml` and run the following. | ||
|
||
```bash | ||
torch-model-archiver --model-name llama-7b-lora --version 1.0 --handler ../base_vllm_handler.py --config-file model-config.yaml -r ../requirements.txt --archive-format no-archive | ||
mv model llama-7b-lora | ||
mv adapters llama-7b-lora | ||
``` | ||
|
||
### Step 3: Add the model artifacts to model store | ||
|
||
```bash | ||
mkdir model_store | ||
mv llama-7b-lora model_store | ||
``` | ||
|
||
### Step 4: Start torchserve | ||
|
||
```bash | ||
torchserve --start --ncs --ts-config ../config.properties --model-store model_store --models llama-7b-lora | ||
``` | ||
|
||
### Step 5: Run inference | ||
|
||
```bash | ||
python ../../utils/test_llm_streaming_response.py -m lora -o 50 -t 2 -n 4 --prompt-text "@prompt.json" --prompt-json | ||
``` |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# TorchServe frontend parameters | ||
minWorkers: 1 | ||
maxWorkers: 1 | ||
batchSize: 16 | ||
maxBatchDelay: 100 | ||
responseTimeout: 1200 | ||
deviceType: "gpu" | ||
continuousBatching: true | ||
|
||
handler: | ||
model_path: "model/models--meta-llama--Llama-2-7b-hf/snapshots/01c7f73d771dfac7d292323805ebc428287df4f9" | ||
vllm_engine_config: | ||
enable_lora: true | ||
max_loras: 4 | ||
max_cpu_loras: 4 | ||
max_num_seqs: 16 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. vllm uses paged attention which typically allows for a larger batch sizes. We need to figure out a way to saturate the engine as setting batchSize == max_num_seqs will lead to under utilization. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could use a similar strategy like for the micro-batching to always have enough requests for the engine available. Preferred would be an async mode which will just route all requests to the backend and gets replies asynchronously (as discussed earlier) |
||
max_model_len: 250 | ||
|
||
adapters: | ||
adapter_1: "adapters/model/models--yard1--llama-2-7b-sql-lora-test/snapshots/0dfa347e8877a4d4ed19ee56c140fa518470028c/" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
{ | ||
"prompt": "A robot may not injure a human being", | ||
"max_new_tokens": 50, | ||
"temperature": 0.8, | ||
"logprobs": 1, | ||
"prompt_logprobs": 1, | ||
"max_tokens": 128, | ||
"adapter": "adapter_1" | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo