Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable running PyTorch models #207

Merged
merged 85 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
85 commits
Select commit Hold shift + click to select a range
12ce0a3
refactor to separate TVM specific bits from paged_cache_model
masahi Jan 11, 2024
7a84f15
fix
masahi Jan 11, 2024
f454b7b
Remove engine config change for now
masahi Jan 11, 2024
afde741
make mypy happy with TextGenerator impl by Model
masahi Jan 12, 2024
c49ef45
stub
masahi Jan 11, 2024
d9ac72f
wip
masahi Jan 11, 2024
acbf825
wip
masahi Jan 11, 2024
fef750f
wip
masahi Jan 12, 2024
25a567e
PT model memory profiling works
masahi Jan 12, 2024
3d06f68
get rid of vllm prepare_inputs
masahi Jan 12, 2024
3cafc8b
wip
masahi Jan 12, 2024
34f77ef
model runs but nan output
masahi Jan 12, 2024
afb4d4f
mypy improvement
masahi Jan 12, 2024
e7212a5
runs e2e but the result is garbage
masahi Jan 12, 2024
f27e3b3
working
masahi Jan 12, 2024
2316e37
minor
masahi Jan 12, 2024
9b985e8
do sampling by mlc function
masahi Jan 12, 2024
f2dcc48
Merge branch 'batch-serving' into pt-model
masahi Jan 13, 2024
4d73e63
merge fix
masahi Jan 13, 2024
15a0d3b
wip parallel sampling
masahi Jan 13, 2024
959019d
fix test
masahi Jan 13, 2024
b6050d9
wip
masahi Jan 15, 2024
ff8eb27
fix
masahi Jan 15, 2024
8696df5
wip
masahi Jan 16, 2024
0af3a70
wip
masahi Jan 16, 2024
90ffccd
wip
masahi Jan 18, 2024
32686d8
attach cache_blocks to model
masahi Jan 18, 2024
de2631b
change get_num_cache_blocks signature
masahi Jan 18, 2024
618ca62
wip
masahi Jan 18, 2024
9ce2f47
wip
masahi Jan 19, 2024
c14c0e9
wip
masahi Jan 19, 2024
4328440
refactor
masahi Jan 19, 2024
9fb6358
update for qwen
masahi Jan 19, 2024
0c40fe8
Merge branch 'batch-serving' into pt-model
masahi Jan 19, 2024
7d89811
mergei fix
masahi Jan 19, 2024
0bbc41a
clean
masahi Jan 19, 2024
08a63ca
KV cache refactor to decouple cache blocks and metadata about them
masahi Jan 19, 2024
1dee091
Merge branch 'kv-refactor' into pt-model
masahi Jan 19, 2024
e098d0b
update for KV refactor
masahi Jan 19, 2024
02b7c1b
updated for the latest vllm
masahi Jan 20, 2024
5cefe97
Merge branch 'batch-serving' into pt-model
masahi Jan 20, 2024
c470c36
qwen and phi supported
masahi Jan 20, 2024
a4612da
Make num_shards configuable via engine config
masahi Jan 20, 2024
4564bd0
unify Model and ModelRpcClient classes
masahi Jan 20, 2024
196026c
support PT model in server
masahi Jan 20, 2024
de68a84
Merge branch 'batch-serving' into pt-model
masahi Jan 21, 2024
b502654
properly allocate port
masahi Jan 21, 2024
686780c
refactor engine creation
masahi Jan 22, 2024
bebd7b2
fix sync point
masahi Jan 22, 2024
ed46b5e
do not create executor at each step
masahi Jan 22, 2024
62918dd
remove dup obtain calls
masahi Jan 22, 2024
b98bdce
Merge branch 'batch-serving' into pt-model
masahi Feb 1, 2024
e144517
fix
masahi Feb 1, 2024
e6abcc7
Merge branch 'batch-serving' into pt-model
masahi Feb 2, 2024
dfbf359
use sample_from_logits
masahi Feb 2, 2024
04da3bb
enable TCP NoDelay option to fix slow socket recv issue
masahi Feb 6, 2024
5dfecb2
Replace TCP with Unix domain socket
masahi Feb 6, 2024
e4bbad9
clean and add note on RPC overhead
masahi Feb 6, 2024
ee9cdc9
clean
masahi Feb 6, 2024
2071749
RPC process join works
masahi Feb 6, 2024
15a90d0
fix mypy
masahi Feb 6, 2024
1f56ee9
Merge branch 'batch-serving' into pt-model
masahi Feb 13, 2024
90284fa
merge fix
masahi Feb 13, 2024
52ad1ad
wip test fix
masahi Feb 13, 2024
61b680e
fix
masahi Feb 13, 2024
f128fe6
Properly verify sampling params in api handler
masahi Feb 13, 2024
568583a
Create model artifact config before module initialization
masahi Feb 13, 2024
762012d
fix engine start
masahi Feb 13, 2024
72f3707
Merge branch 'sampling-params-init-fix' into pt-model
masahi Feb 13, 2024
3128329
Merge branch 'batch-serving' into pt-model
masahi Feb 14, 2024
dc5fb6e
fix
masahi Feb 14, 2024
ebe0b4e
black
masahi Feb 14, 2024
4b2de70
properly handle import failure
masahi Feb 14, 2024
f09d458
add titoken dep
masahi Feb 14, 2024
c9ac5ba
revert logprob change
masahi Feb 14, 2024
f1cf274
restored tokenizer.is_fast assert but commented out
masahi Feb 14, 2024
eaa53a7
Merge branch 'batch-serving' into pt-model
masahi Feb 15, 2024
1336fb8
fix vocab siz
masahi Feb 15, 2024
6186ef2
properly account for logits storage in memory profiling
masahi Feb 15, 2024
2229324
Merge branch 'batch-serving' into pt-model
masahi Feb 20, 2024
aa4d477
merge fix
masahi Feb 20, 2024
8bb96ed
validate num_shards in engine creation
masahi Feb 20, 2024
cf0813d
replace print with structlog
masahi Feb 20, 2024
f716851
add peak memory log for tvm as well
masahi Feb 20, 2024
992b1a0
add tokenizer.is_fast warning on creation
masahi Feb 21, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions serve/mlc_serve/engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class MLCServeEngineConfig:
min_decode_steps: int = 32
max_decode_steps: int = 48
init_timeout: int = 120
model_type: str = "tvm" # "tvm", "torch"
num_shards: Optional[int] = None # Need to be specified for if model_type is "torch"
masahi marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def _from_json(config_cls, json_obj: Dict[Any, Any]):
Expand All @@ -57,6 +59,11 @@ def get_engine_config(dict_config):
assert (engine_config.min_decode_steps > 0) and (engine_config.max_decode_steps > 0)
assert engine_config.max_decode_steps > engine_config.min_decode_steps

if engine_config.model_type == "torch":
assert (
engine_config.num_shards is not None
), "num_shards in MLCServeEngineConfig needs to be provided for PT models."

return engine_config


Expand Down
2 changes: 0 additions & 2 deletions serve/mlc_serve/engine/engine_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,6 @@ def detokenize_incrementally(
prefix_begin_offset = generation_sequence.prefix_begin_offset
prefix_end_offset = generation_sequence.prefix_end_offset

assert tokenizer.is_fast
masahi marked this conversation as resolved.
Show resolved Hide resolved

prefix_text = tokenizer.convert_tokens_to_string(
output_tokens[prefix_begin_offset:prefix_end_offset]
)
Expand Down
2 changes: 1 addition & 1 deletion serve/mlc_serve/engine/model_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def get_cache(self) -> KVCache:
The returned value should be passed to Executor.generate_text.
"""

def allocate(self, request_id: RequestId, num_tokens: int, num_sequnces: int):
def allocate(self, request_id: RequestId, num_tokens: int, num_sequences: int):
"""
Allocate cache space for request, raise error if there is no space.
"""
Expand Down
3 changes: 2 additions & 1 deletion serve/mlc_serve/engine/staging_engine_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def run_generation_loop_worker(

try:
model_module = model_module_loader(**model_module_loader_kwargs)
LOG.info("Model is initalized.")
LOG.info("Model is initialized.")
worker = GenerationLoopWorker(model_module=model_module)
except:
LOG.exception("An error raised in model initialization.")
Expand All @@ -370,6 +370,7 @@ def handle_command():
while True:
cmd = command_queue.get()
if isinstance(cmd, ShutdownCommand):
del worker.text_generator
break
elif isinstance(cmd, AddRequestsCommand):
worker.add(cmd.request_states)
Expand Down
17 changes: 17 additions & 0 deletions serve/mlc_serve/model/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from transformers import AutoConfig

from dataclasses import dataclass
from typing import Optional
from pathlib import Path
import os
import json
import inspect
Expand Down Expand Up @@ -57,3 +60,17 @@ def get_model_artifact_config(model_artifact_path):
json_object["paged_kv_cache_type"] = "vllm"

return ModelArtifactConfig._from_json(json_object)


def get_hf_config(model_path: Path) -> AutoConfig:
hf_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)

if not hasattr(hf_config, "num_key_value_heads") and hasattr(
hf_config, "num_attention_heads"
):
hf_config.num_key_value_heads = hf_config.num_attention_heads

if not hasattr(hf_config, "sliding_window"):
hf_config.sliding_window = None

return hf_config
104 changes: 67 additions & 37 deletions serve/mlc_serve/model/model_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,13 @@ def get_gpu_memory(gpu: int = 0) -> int:


def get_num_cache_blocks(
model,
used_memory_bytes,
block_size,
seq_lens,
num_layers,
num_kv_heads,
head_size,
gpu_memory_utilization=0.9, # the default used by vllm
):
used_memory_bytes = model.profile_memory_usage(seq_lens)
cache_block_size = CacheManager.get_cache_block_size(
block_size, num_layers, num_kv_heads, head_size
)
Expand Down Expand Up @@ -85,22 +83,18 @@ def sample_from_logits(
requests: Sequence[RequestType],
sampling_state: SamplingState,
vocab_size: int,
copy_stream: torch.cuda.Stream,
torch_dtype: torch.dtype,
torch_dev: str,
past_decode_tokens: List[List[int]],
prompt_masks: List[torch.Tensor],
) -> List[TextGenerationResult]:
batch_size = logits.shape[0]
assert batch_size == len(requests)

# Convert to torch tensors if logits are in tvm ndarray
if isinstance(logits, tvm.nd.NDArray):
logits = torch.from_dlpack(logits)

# synchronization point for sampling tensors
# wait until all the tensors are loaded on GPU
torch.cuda.current_stream().wait_stream(copy_stream)
masahi marked this conversation as resolved.
Show resolved Hide resolved

# Logit processing for constraint sampling e.g., JSON Mode
for i, (sequence_id, request) in enumerate(zip(sequence_ids, requests)):
if request.sampling_params.logits_processor is not None:
Expand Down Expand Up @@ -140,6 +134,7 @@ def sample_from_logits(
" or element < 0"
)
logits = torch.from_dlpack(logits)

for i in range(batch_size):
sequence_id = sequence_ids[i]
logits_per_token = logits[i]
Expand All @@ -149,16 +144,14 @@ def sample_from_logits(
# NOTE: Rerun the preparation for simplicity.
# Assume this code path is taken rarely and the recomputation overhead is
# marginal.
with torch.cuda.stream(copy_stream):
new_sampling_state = SamplingState.from_sampling_params(
[sampling_param],
[past_decode_tokens_per_request],
[prompt_mask],
torch_dtype,
torch_dev,
vocab_size,
)
torch.cuda.current_stream().wait_stream(copy_stream)
new_sampling_state = SamplingState.from_sampling_params(
[sampling_param],
[past_decode_tokens_per_request],
[prompt_mask],
torch_dtype,
torch_dev,
vocab_size,
)
masahi marked this conversation as resolved.
Show resolved Hide resolved
maybe_sampling_output: Optional[SamplingOutput] = sample(
torch.unsqueeze(logits_per_token, 0),
new_sampling_state,
Expand All @@ -169,6 +162,7 @@ def sample_from_logits(
logprob_info = maybe_sampling_output.logprob_infos[0]
# Valid sample
request = requests[i]

if maybe_sampling_output is not None:
outputs.extend(
prepare_textgen_result(
Expand Down Expand Up @@ -200,24 +194,39 @@ def prepare_inputs(
all_decode_block_tables,
sliding_window,
is_prefill,
block_size,
num_decode_query_tokens=1,
for_vllm=False,
):
if for_vllm:
torch_int_dtype = torch.long
else:
torch_int_dtype = torch.int

block_tables = []
seq_lens = []
input_ids = []
slot_mapping = []
positions = []
max_num_blocks_per_seq = 0
indices_within_window = []
start_idx = 0
max_prompt_len = -1
max_context_len = -1

for i, (sequence_id, token_ids) in enumerate(zip(sequence_ids, all_token_ids)):
if is_prefill:
input_ids += token_ids
prompt_len = len(token_ids)
seq_lens.append(prompt_len)
positions += range(prompt_len)
slot_mapping += all_slot_mappings[sequence_id]
max_prompt_len = max(max_prompt_len, prompt_len)

if for_vllm:
input_ids.append(token_ids)
positions.append(list(range(prompt_len)))
slot_mapping.append(all_slot_mappings[sequence_id])
else:
input_ids += token_ids
positions += range(prompt_len)
slot_mapping += all_slot_mappings[sequence_id]

if sliding_window:
indices_within_window += range(
Expand All @@ -228,44 +237,65 @@ def prepare_inputs(

else:
seq_len = prompt_lens[i] + len(token_ids)
input_ids += token_ids[-num_decode_query_tokens:]

for i in range(num_decode_query_tokens):
positions.append(seq_len - (num_decode_query_tokens - i))
if for_vllm:
assert num_decode_query_tokens == 1
input_ids.append([token_ids[-1]])
positions.append([seq_len - 1])
slot_mapping.append([all_slot_mappings[sequence_id][-1]])
else:
input_ids += token_ids[-num_decode_query_tokens:]

slot_mapping += all_slot_mappings[sequence_id][-num_decode_query_tokens:]
for i in range(num_decode_query_tokens):
positions.append(seq_len - (num_decode_query_tokens - i))

slot_mapping += all_slot_mappings[sequence_id][-num_decode_query_tokens:]

block_table = all_decode_block_tables[sequence_id]
max_num_blocks_per_seq = max(max_num_blocks_per_seq, len(block_table))
block_tables.append(block_table.get_blocks())

if sliding_window:
seq_lens.append(min(seq_len, sliding_window))
else:
seq_lens.append(seq_len)

max_context_len = max(max_context_len, seq_lens[-1])

def _do_pad(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that we started considering vllm's tensor layout, what do you think about unifying it? It seems like upstream mlc-llm also uses 2D inputs.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And this also could help our cuda graph integration?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We haven't verified if 2D inputs is better for performance, and how much cuda graph actually helps.

The upstream input looks like 2D but it is always either (1, num_total_token) or (batch_size, 1). So their 2D input is essentially 1D.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think it is worth visiting imo. But not now, in the future. Although there might not be performance boost, it would be nice to unify the layout with upstream unless there is reason.

x: List[List[int]],
max_len: int,
pad_val: int,
) -> List[List[int]]:
def _pad_to_max(x: List[int], max_len: int, pad_val: int) -> List[int]:
assert len(x) <= max_len
return x + [pad_val] * (max_len - len(x))

return [_pad_to_max(x_i, max_len, pad_val) for x_i in x]

if for_vllm and is_prefill:
input_ids = _do_pad(input_ids, max_prompt_len, 0)
positions = _do_pad(positions, max_prompt_len, 0)
slot_mapping = _do_pad(slot_mapping, max_prompt_len, -1)

def to_torch(arr, torch_dtype):
return torch.tensor(arr, dtype=torch_dtype, device="cuda")

input_ids = to_torch(input_ids, torch.int)
positions = to_torch(positions, torch.int)
input_ids = to_torch(input_ids, torch_int_dtype)
positions = to_torch(positions, torch_int_dtype)
seq_lens = to_torch(seq_lens, torch.int)
slot_mapping = to_torch(slot_mapping, torch.int)
slot_mapping = to_torch(slot_mapping, torch_int_dtype)

if is_prefill and sliding_window:
indices_within_window = to_torch(indices_within_window, torch.int)
else:
indices_within_window = None

if not is_prefill:
max_block_table_len = (
max_context_len + block_size - 1
) // block_size

def _pad_to_max(x: List[int], max_len: int) -> List[int]:
return x + [0] * (max_len - len(x))

padded_block_tables = [
_pad_to_max(block_table, max_num_blocks_per_seq)
for block_table in block_tables
]
padded_block_tables = _do_pad(block_tables, max_block_table_len, 0)
block_tables = to_torch(padded_block_tables, torch.int)
else:
block_tables = None
Expand Down
14 changes: 11 additions & 3 deletions serve/mlc_serve/model/paged_cache_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .base import ModelArtifactConfig
from .paged_cache_manager import CacheManager
from .tokenizer import HfTokenizerModule, ConversationTemplate, Tokenizer
from .torch_model import init_torch_model
from .tvm_model import init_tvm_model

from ..engine import MLCServeEngineConfig
Expand Down Expand Up @@ -81,15 +82,22 @@ def __init__(
engine_config: MLCServeEngineConfig,
model_artifact_config: ModelArtifactConfig
):
# TODO(masahi): Make the model type configurable.
model, cache_manager = init_tvm_model(model_artifact_config, engine_config)
if engine_config.model_type == "tvm":
model, cache_manager = init_tvm_model(model_artifact_config, engine_config)
tokenizer_module = HfTokenizerModule(model_artifact_path.joinpath("model"))
elif engine_config.model_type == "torch":
model, cache_manager = init_torch_model(
model_artifact_path, engine_config
)
tokenizer_module = HfTokenizerModule(model_artifact_path)
else:
raise RuntimeError(f"Unknown model type {engine_config.model_type}")

self.engine_config = engine_config
self.model_artifact_config = model_artifact_config
self.text_generator = PagedCacheModelTextGenerator(model)
self.cache_manager = cache_manager

tokenizer_module = HfTokenizerModule(model_artifact_path)
self.tokenizer = tokenizer_module.tokenizer
self.conversation_template = tokenizer_module.conversation_template

Expand Down
12 changes: 9 additions & 3 deletions serve/mlc_serve/model/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,15 +479,15 @@ def adjust_logits(logits: torch.Tensor, sampling_state: SamplingState, vocab_siz

@dataclass
class SamplingOutput:
next_tokens: list[int]
logprob_infos: list[Optional[RawLogprobsInfo]]
next_tokens: np.ndarray
logprob_infos: List[Optional[RawLogprobsInfo]]


def sample(
logits: torch.Tensor,
sampling_state: SamplingState,
check_safety: bool = False,
) -> SamplingOutput:
) -> Optional[SamplingOutput]:
def _is_safe_to_sample(prob_like):
return (
torch.sum(torch.isnan(prob_like) | torch.isinf(prob_like) | (prob_like < 0))
Expand All @@ -504,21 +504,26 @@ def _is_safe_to_sample(prob_like):
)

next_tokens = np.empty((batch_size,), dtype=np.int64)

if sampling_state.has_greedy:
res_greedy = torch.argmax(logits[mask_greedy_t], -1)
np_mask_greedy = mask_greedy_t.cpu().numpy()
next_tokens[np_mask_greedy] = res_greedy.cpu().numpy()

probs_random = None

if sampling_state.has_random:
probs_random = torch.softmax(logits[mask_random_t], dim=-1)

if check_safety and not _is_safe_to_sample(probs_random):
return None

res_random = torch.multinomial(probs_random, 1, True)[:, 0]
np_mask_random = mask_random_t.cpu().numpy()
next_tokens[np_mask_random] = res_random.cpu().numpy()

logprob_infos: List[Optional[RawLogprobsInfo]] = [None] * batch_size

if sampling_state.has_logprob:
# If everything is random sampling, save one extra softmax
if not sampling_state.has_greedy:
Expand All @@ -535,6 +540,7 @@ def _is_safe_to_sample(prob_like):
mask = sampling_state.sampling_tensors.mask_top_logprob
top_tokens = all_top_tokens[mask]
top_logprobs = all_top_logprobs[mask]

for idx, batch_idx in enumerate(sampling_state.logprob_batch_indices):
next_token = next_tokens[batch_idx]
assert sampling_state.sampling_params[batch_idx].logprobs
Expand Down
Loading
Loading