Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate vllm with example Lora and Mistral #3077

Merged
merged 17 commits into from
May 3, 2024
7 changes: 7 additions & 0 deletions examples/large_models/vllm/Readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Example showing inference with vLLM

This folder contains multiple demonstrations showcasing the integration of [vLLM](https://github.com/vllm-project/vllm) with TorchServe, running inference on `mistralai/Mistral-7B-v0.1` model and multiple LoRA models.
lxning marked this conversation as resolved.
Show resolved Hide resolved
vLLM achieves high throughput using PagedAttention. More details can be found [here](https://vllm.ai/)

- demo1: [Mistral](mistral)
- demo2: [lora](lora)
122 changes: 122 additions & 0 deletions examples/large_models/vllm/base_vllm_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import logging
import pathlib

from vllm import EngineArgs, LLMEngine, SamplingParams
from vllm.lora.request import LoRARequest

from ts.torch_handler.base_handler import BaseHandler

logger = logging.getLogger(__name__)


class BaseVLLMHandler(BaseHandler):
def __init__(self):
super().__init__()

self.vllm_engine = None
self.model = None
self.lora_ids = {}
self.context = None
self.initialized = False

def initialize(self, ctx):
self.context = ctx
vllm_engine_config = self._get_vllm_engine_config(
ctx.model_yaml_config.get("handler", {})
)
self.vllm_engine = LLMEngine.from_engine_args(vllm_engine_config)
self.initialized = True

def preprocess(self, requests):
for req_id, req_data in zip(self.context.request_ids.values(), requests):
if req_id not in self.context.cache:
data = req_data.get("data") or request.get("body")
if isinstance(data, (bytes, bytearray)):
data = data.decode("utf-8")

prompt = data.get("prompt")
sampling_params = self._get_sampling_params(req_data)
lora_request = self._get_lora_request(req_id, req_data)
self.context.cache[req_id] = {
"stopping_criteria": self._create_stopping_criteria(req_id),
}
self.vllm_engine.add_request(
req_id, prompt, sampling_params, lora_request
)

return requests

def inference(self, input_batch):
return self.vllm_engine.step()
results = {}

for output in inference_outputs:
req_id = output.request_id
results[req_id]["output"] = {
"text": output.outputs[0].text,
"tokens": output.outputs[0].token_ids[-1],
}
results[req_id]["stopping_criteria"] = self.context.cache[req_id][
"stopping_criteria"
](output)

return [results[i] for i in self.context.request_ids.values()]

def postprocess(self, inference_outputs):
self.context.stopping_criteria = []
results = []
for output in inference_outputs:
self.context.stopping_criteria.append(output["stopping_criteria"])
results.append(output["output"])

return results

def _get_vllm_engine_config(self, handler_config: dict):
vllm_engine_params = handler_config.get("vllm_engine_config", {})
vllm_engine_config = EngineArgs()
self._set_attr_value(vllm_engine_config, vllm_engine_params)
return vllm_engine_config

def _get_sampling_params(self, req_data: dict):
sampling_params = SamplingParams()
self._set_attr_value(sampling_params, req_data)

return sampling_params

def _get_lora_request(self, req_id, req_data: dict):
lora_request_params = req_data.get("adapter", None)

if lora_request_params:
lora_name = lora_request_params.get("name", None)
lora_path = lora_request_params.get("path", None)
if lora_name and lora_path:
lora_id = self.lora_ids.get(lora_name, len(self.lora_ids) + 1)
return LoRARequest(lora_name, lora_id, lora_path)
else:
logger.error(f"request_id={req_id} missed adapter name or path")

return None

def _create_stopping_criteria(self, req_id):
class StoppingCriteria(object):
def __init__(self, cache, req_id):
self.req_id = req_id
self.cache = cache

def __call__(self, res):
return res.finished

def clean_up(self):
del self.cache[self.req_id]

return StoppingCriteria(self.context.cache, req_id)

def _set_attr_value(self, obj, config: dict):
items = vars(obj).items()
for k, v in config:
if k in items:
setattr(obj, k, v)
elif k == "model_path":
model_dir = self.context.system_properties.get("model_dir")
model_path = pathlib.Path(model_dir).joinpath(v)
setattr(obj, "model", model_path)
15 changes: 15 additions & 0 deletions examples/large_models/vllm/lora/model-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# TorchServe frontend parameters
minWorkers: 1
maxWorkers: 1
batchSize: 16
maxBatchDelay: 100
responseTimeout: 1200
deviceType: "gpu"
continuousBatching: true

handler:
model: huggyllama/llama-7b
tensor_parallel_size: 4
enable_lora: true
max_loras: 4
max_cpu_loras: 4
1 change: 1 addition & 0 deletions examples/large_models/vllm/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
vllm
Loading