Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate vllm with example Lora and Mistral #3077

Merged
merged 17 commits into from
May 3, 2024
34 changes: 26 additions & 8 deletions examples/large_models/utils/test_llm_streaming_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,27 +26,39 @@ def _predict(self):
for chunk in response.iter_content(chunk_size=None):
if chunk:
data = orjson.loads(chunk)
combined_text += data["text"]
combined_text += data.get("text", "")
self.queue.put_nowait(f"payload={payload}\n, output={combined_text}\n")

def _get_url(self):
return f"http://localhost:8080/predictions/{self.args.model}"

def _format_payload(self):
prompt = _load_curl_like_data(self.args.prompt_text)
prompt_list = prompt.split(" ")
prompt_input = _load_curl_like_data(self.args.prompt_text)
if self.args.prompt_json:
prompt_input = orjson.loads(prompt_input)
prompt = prompt_input.get("prompt", None)
assert prompt is not None
prompt_list = prompt.split(" ")
rt = int(prompt_input.get("max_new_tokens", self.args.max_tokens))
else:
prompt_list = prompt_input.split(" ")
rt = self.args.max_tokens
rp = len(prompt_list)
rt = self.args.max_tokens
if self.args.prompt_randomize:
rp = random.randint(0, max_prompt_random_tokens)
rt = rp + self.args.max_tokens
for _ in range(rp):
prompt_list.insert(0, chr(ord("a") + random.randint(0, 25)))
cur_prompt = " ".join(prompt_list)
return {
"prompt": cur_prompt,
"max_new_tokens": rt,
}
if self.args.prompt_json:
prompt_input["prompt"] = cur_prompt
prompt_input["max_new_tokens"] = rt
return prompt_input
else:
return {
"prompt": cur_prompt,
"max_new_tokens": rt,
}


def _load_curl_like_data(text):
Expand Down Expand Up @@ -106,6 +118,12 @@ def parse_args():
default=1,
help="Execute the number of prediction in each thread",
)
parser.add_argument(
"--prompt-json",
action=argparse.BooleanOptionalAction,
default=False,
help="Flag the imput prompt is a json format with prompt parameters",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo

)

return parser.parse_args()

Expand Down
15 changes: 15 additions & 0 deletions examples/large_models/vllm/Readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Example showing inference with vLLM

This folder contains multiple demonstrations showcasing the integration of [vLLM Engine](https://github.com/vllm-project/vllm) with TorchServe, running inference with continuous batching.
vLLM achieves high throughput using PagedAttention. More details can be found [here](https://vllm.ai/)

- demo1: [Mistral](mistral)
- demo2: [lora](lora)

### Supported vLLM Configuration
* LLMEngine configuration:
vLLM [EngineArgs](https://github.com/vllm-project/vllm/blob/258a2c58d08fc7a242556120877a89404861fbce/vllm/engine/arg_utils.py#L15) is defined in the section of `handler/vllm_engine_config` of model-config.yaml.


* Sampling parameters for text generation:
vLLM [SamplingParams](https://github.com/vllm-project/vllm/blob/258a2c58d08fc7a242556120877a89404861fbce/vllm/sampling_params.py#L27) is defined in the JSON format, for example, [prompt.json](lora/prompt.json).
132 changes: 132 additions & 0 deletions examples/large_models/vllm/base_vllm_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import logging
import pathlib

from vllm import EngineArgs, LLMEngine, SamplingParams
from vllm.lora.request import LoRARequest

from ts.torch_handler.base_handler import BaseHandler

logger = logging.getLogger(__name__)


class BaseVLLMHandler(BaseHandler):
def __init__(self):
super().__init__()

self.vllm_engine = None
self.model = None
self.model_dir = None
self.lora_ids = {}
self.adapters = None
self.initialized = False

def initialize(self, ctx):
ctx.cache = {}

self.model_dir = ctx.system_properties.get("model_dir")
vllm_engine_config = self._get_vllm_engine_config(
ctx.model_yaml_config.get("handler", {})
)
self.adapters = ctx.model_yaml_config.get("handler", {}).get("adapters", {})
self.vllm_engine = LLMEngine.from_engine_args(vllm_engine_config)
self.initialized = True

def preprocess(self, requests):
for req_id, req_data in zip(self.context.request_ids.values(), requests):
if req_id not in self.context.cache:
data = req_data.get("data") or req_data.get("body")
if isinstance(data, (bytes, bytearray)):
data = data.decode("utf-8")

prompt = data.get("prompt")
sampling_params = self._get_sampling_params(req_data)
lora_request = self._get_lora_request(req_data)
self.context.cache[req_id] = {
"text_len": 0,
"stopping_criteria": self._create_stopping_criteria(req_id),
}
self.vllm_engine.add_request(
req_id, prompt, sampling_params, lora_request=lora_request
)

return requests

def inference(self, input_batch):
inference_outputs = self.vllm_engine.step()
results = {}

for output in inference_outputs:
req_id = output.request_id
results[req_id] = {
"text": output.outputs[0].text[
self.context.cache[req_id]["text_len"] :
],
"tokens": output.outputs[0].token_ids[-1],
"finished": output.finished,
}
self.context.cache[req_id]["text_len"] = len(output.outputs[0].text)

return [results[i] for i in self.context.request_ids.values()]

def postprocess(self, inference_outputs):
self.context.stopping_criteria = [
self.context.cache[req_id]["stopping_criteria"]
for req_id in self.context.request_ids.values()
]

return inference_outputs

def _get_vllm_engine_config(self, handler_config: dict):
vllm_engine_params = handler_config.get("vllm_engine_config", {})
model = vllm_engine_params.get("model", {})
if len(model) == 0:
model_path = handler_config.get("model_path", {})
assert (
len(model_path) > 0
), "please define model in vllm_engine_config or model_path in handler"
model = str(pathlib.Path(self.model_dir).joinpath(model_path))
logger.info(f"EngineArgs model={model}")
vllm_engine_config = EngineArgs(model=model)
self._set_attr_value(vllm_engine_config, vllm_engine_params)
return vllm_engine_config

def _get_sampling_params(self, req_data: dict):
sampling_params = SamplingParams()
self._set_attr_value(sampling_params, req_data)

return sampling_params

def _get_lora_request(self, req_data: dict):
adapter_name = req_data.get("lora_adapter", "")

if len(adapter_name) > 0:
adapter_path = self.adapters.get(adapter_name, "")
assert len(adapter_path) > 0, f"{adapter_name} misses adapter path"
lora_id = self.lora_ids.setdefault(adapter_name, len(self.lora_ids) + 1)
adapter_path = str(pathlib.Path(self.model_dir).joinpath(adapter_path))
logger.info(f"adapter_path=${adapter_path}")
return LoRARequest(adapter_name, lora_id, adapter_path)

return None

def _clean_up(self, req_id):
del self.context.cache[req_id]

def _create_stopping_criteria(self, req_id):
class StoppingCriteria(object):
def __init__(self, outer, req_id):
self.req_id = req_id
self.outer = outer

def __call__(self, res):
if res["finished"]:
self.outer._clean_up(self.req_id)
return res["finished"]

return StoppingCriteria(outer=self, req_id=req_id)

def _set_attr_value(self, obj, config: dict):
items = vars(obj)
for k, v in config.items():
if k in items:
setattr(obj, k, v)
48 changes: 48 additions & 0 deletions examples/large_models/vllm/lora/Readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Example showing inference with vLLM on LoRA model

This is an example showing how to integrate [vLLM](https://github.com/vllm-project/vllm) with TorchServe and run inference on model `Llama-2-7b-hf` + LoRA model `llama-2-7b-sql-lora-test` with continuous batching.

### Step 1: Download Model from HuggingFace

Login with a HuggingFace account
```
huggingface-cli login
# or using an environment variable
huggingface-cli login --token $HUGGINGFACE_TOKEN
```

```bash
python ../../utils/Download_model.py --model_path model --model_name meta-llama/Llama-2-7b-chat-hf --use_auth_token True
mkdir adapters && cd adapters
python ../../../utils/Download_model.py --model_path model --model_name yard1/llama-2-7b-sql-lora-test --use_auth_token True
cd ..
```

### Step 2: Generate model artifacts

Add the downloaded path to "model_path:" and "adapter_1:" in `model-config.yaml` and run the following.

```bash
torch-model-archiver --model-name llama-7b-lora --version 1.0 --handler ../base_vllm_handler.py --config-file model-config.yaml -r ../requirements.txt --archive-format no-archive
mv model llama-7b-lora
mv adapters llama-7b-lora
```

### Step 3: Add the model artifacts to model store

```bash
mkdir model_store
mv llama-7b-lora model_store
```

### Step 4: Start torchserve

```bash
torchserve --start --ncs --ts-config ../config.properties --model-store model_store --models llama-7b-lora
```

### Step 5: Run inference

```bash
python ../../utils/test_llm_streaming_response.py -m lora -o 50 -t 2 -n 4 --prompt-text "@prompt.json" --prompt-json
```
20 changes: 20 additions & 0 deletions examples/large_models/vllm/lora/model-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# TorchServe frontend parameters
minWorkers: 1
maxWorkers: 1
batchSize: 16
maxBatchDelay: 100
responseTimeout: 1200
deviceType: "gpu"
continuousBatching: true

handler:
model_path: "model/models--meta-llama--Llama-2-7b-hf/snapshots/01c7f73d771dfac7d292323805ebc428287df4f9"
vllm_engine_config:
enable_lora: true
max_loras: 4
max_cpu_loras: 4
max_num_seqs: 16
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vllm uses paged attention which typically allows for a larger batch sizes. We need to figure out a way to saturate the engine as setting batchSize == max_num_seqs will lead to under utilization.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could use a similar strategy like for the micro-batching to always have enough requests for the engine available. Preferred would be an async mode which will just route all requests to the backend and gets replies asynchronously (as discussed earlier)

max_model_len: 250

adapters:
adapter_1: "adapters/model/models--yard1--llama-2-7b-sql-lora-test/snapshots/0dfa347e8877a4d4ed19ee56c140fa518470028c/"
9 changes: 9 additions & 0 deletions examples/large_models/vllm/lora/prompt.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"prompt": "A robot may not injure a human being",
"max_new_tokens": 50,
"temperature": 0.8,
"logprobs": 1,
"prompt_logprobs": 1,
"max_tokens": 128,
"adapter": "adapter_1"
}
39 changes: 14 additions & 25 deletions examples/large_models/vllm/mistral/Readme.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# Example showing inference with vLLM with mistralai/Mistral-7B-v0.1 model
# Example showing inference with vLLM on Mistral model

This is an example showing how to integrate [vLLM](https://github.com/vllm-project/vllm) with TorchServe and run inference on `mistralai/Mistral-7B-v0.1` model.
vLLM achieves high throughput using PagedAttention. More details can be found [here](https://vllm.ai/)
This is an example showing how to integrate [vLLM](https://github.com/vllm-project/vllm) with TorchServe and run inference on model `mistralai/Mistral-7B-v0.1` with continuous batching.

### Step 1: Login to HuggingFace
### Step 1: Download Model from HuggingFace

Login with a HuggingFace account
```
Expand All @@ -13,43 +12,33 @@ huggingface-cli login --token $HUGGINGFACE_TOKEN
```

```bash
python ../../Huggingface_accelerate/Download_model.py --model_path model --model_name mistralai/Mistral-7B-v0.1
python ../../utils/Download_model.py --model_path model --model_name mistralai/Mistral-7B-v0.1 --use_auth_token True
```
Model will be saved in the following path, `mistralai/Mistral-7B-v0.1`.

### Step 2: Generate MAR file
### Step 2: Generate model artifacts

Add the downloaded path to " model_path:" in `model-config.yaml` and run the following.
Add the downloaded path to "model_path:" in `model-config.yaml` and run the following.

```bash
torch-model-archiver --model-name mistral7b --version 1.0 --handler custom_handler.py --config-file model-config.yaml -r requirements.txt --archive-format tgz
torch-model-archiver --model-name mistral --version 1.0 --handler ../base_vllm_handler.py --config-file model-config.yaml -r ../requirements.txt --archive-format no-archive
mv model mistral
```

### Step 3: Add the mar file to model store
### Step 3: Add the model artifacts to model store

```bash
mkdir model_store
mv mistral7b.tar.gz model_store
mv mistral model_store
```

### Step 3: Start torchserve

### Step 4: Start torchserve

```bash
torchserve --start --ncs --ts-config config.properties --model-store model_store --models mistral7b.tar.gz
torchserve --start --ncs --ts-config ../config.properties --model-store model_store --models mistral
```

### Step 4: Run inference
### Step 5: Run inference

```bash
curl -v "http://localhost:8080/predictions/mistral7b" -T sample_text.txt
```

results in the following output
```
Mayonnaise is made of eggs, oil, vinegar, salt and pepper. Using an electric blender, combine all the ingredients and beat at high speed for 4 to 5 minutes.

Try it with some mustard and paprika mixed in, and a bit of sweetener if you like. But use real mayonnaise or it isn’t the same. Marlou

What in the world is mayonnaise?
python ../../utils/test_llm_streaming_response.py -m mistral -o 50 -t 2 -n 4 --prompt-text "@prompt.json" --prompt-json
```
Loading
Loading