diff --git a/serve/mlc_serve/engine/base.py b/serve/mlc_serve/engine/base.py index c87e7d9429..de1fba57c0 100644 --- a/serve/mlc_serve/engine/base.py +++ b/serve/mlc_serve/engine/base.py @@ -34,6 +34,8 @@ class MLCServeEngineConfig: min_decode_steps: int = 32 max_decode_steps: int = 48 init_timeout: int = 120 + model_type: str = "tvm" # "tvm", "torch" + num_shards: Optional[int] = None # Need to be specified for if model_type is "torch" @classmethod def _from_json(config_cls, json_obj: Dict[Any, Any]): @@ -57,6 +59,11 @@ def get_engine_config(dict_config): assert (engine_config.min_decode_steps > 0) and (engine_config.max_decode_steps > 0) assert engine_config.max_decode_steps > engine_config.min_decode_steps + if engine_config.model_type == "torch": + assert ( + engine_config.num_shards is not None + ), "num_shards in MLCServeEngineConfig needs to be provided for PT models." + return engine_config diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index d3a4c00dec..a381d72f69 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -134,8 +134,6 @@ def detokenize_incrementally( prefix_begin_offset = generation_sequence.prefix_begin_offset prefix_end_offset = generation_sequence.prefix_end_offset - assert tokenizer.is_fast - prefix_text = tokenizer.convert_tokens_to_string( output_tokens[prefix_begin_offset:prefix_end_offset] ) diff --git a/serve/mlc_serve/engine/model_module.py b/serve/mlc_serve/engine/model_module.py index 829bfedeab..302cd623ff 100644 --- a/serve/mlc_serve/engine/model_module.py +++ b/serve/mlc_serve/engine/model_module.py @@ -104,7 +104,7 @@ def get_cache(self) -> KVCache: The returned value should be passed to Executor.generate_text. """ - def allocate(self, request_id: RequestId, num_tokens: int, num_sequnces: int): + def allocate(self, request_id: RequestId, num_tokens: int, num_sequences: int): """ Allocate cache space for request, raise error if there is no space. """ diff --git a/serve/mlc_serve/engine/staging_engine_worker.py b/serve/mlc_serve/engine/staging_engine_worker.py index ad5e3089aa..84e7ce8b28 100644 --- a/serve/mlc_serve/engine/staging_engine_worker.py +++ b/serve/mlc_serve/engine/staging_engine_worker.py @@ -358,7 +358,7 @@ def run_generation_loop_worker( try: model_module = model_module_loader(**model_module_loader_kwargs) - LOG.info("Model is initalized.") + LOG.info("Model is initialized.") worker = GenerationLoopWorker(model_module=model_module) except: LOG.exception("An error raised in model initialization.") @@ -370,6 +370,7 @@ def handle_command(): while True: cmd = command_queue.get() if isinstance(cmd, ShutdownCommand): + del worker.text_generator break elif isinstance(cmd, AddRequestsCommand): worker.add(cmd.request_states) diff --git a/serve/mlc_serve/model/base.py b/serve/mlc_serve/model/base.py index e47cf61267..63daf7895c 100644 --- a/serve/mlc_serve/model/base.py +++ b/serve/mlc_serve/model/base.py @@ -1,5 +1,8 @@ +from transformers import AutoConfig + from dataclasses import dataclass from typing import Optional +from pathlib import Path import os import json import inspect @@ -57,3 +60,17 @@ def get_model_artifact_config(model_artifact_path): json_object["paged_kv_cache_type"] = "vllm" return ModelArtifactConfig._from_json(json_object) + + +def get_hf_config(model_path: Path) -> AutoConfig: + hf_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + + if not hasattr(hf_config, "num_key_value_heads") and hasattr( + hf_config, "num_attention_heads" + ): + hf_config.num_key_value_heads = hf_config.num_attention_heads + + if not hasattr(hf_config, "sliding_window"): + hf_config.sliding_window = None + + return hf_config diff --git a/serve/mlc_serve/model/model_common.py b/serve/mlc_serve/model/model_common.py index 7cc499c5a4..d38edb88cd 100644 --- a/serve/mlc_serve/model/model_common.py +++ b/serve/mlc_serve/model/model_common.py @@ -30,15 +30,13 @@ def get_gpu_memory(gpu: int = 0) -> int: def get_num_cache_blocks( - model, + used_memory_bytes, block_size, - seq_lens, num_layers, num_kv_heads, head_size, gpu_memory_utilization=0.9, # the default used by vllm ): - used_memory_bytes = model.profile_memory_usage(seq_lens) cache_block_size = CacheManager.get_cache_block_size( block_size, num_layers, num_kv_heads, head_size ) @@ -85,7 +83,6 @@ def sample_from_logits( requests: Sequence[RequestType], sampling_state: SamplingState, vocab_size: int, - copy_stream: torch.cuda.Stream, torch_dtype: torch.dtype, torch_dev: str, past_decode_tokens: List[List[int]], @@ -93,14 +90,11 @@ def sample_from_logits( ) -> List[TextGenerationResult]: batch_size = logits.shape[0] assert batch_size == len(requests) + # Convert to torch tensors if logits are in tvm ndarray if isinstance(logits, tvm.nd.NDArray): logits = torch.from_dlpack(logits) - # synchronization point for sampling tensors - # wait until all the tensors are loaded on GPU - torch.cuda.current_stream().wait_stream(copy_stream) - # Logit processing for constraint sampling e.g., JSON Mode for i, (sequence_id, request) in enumerate(zip(sequence_ids, requests)): if request.sampling_params.logits_processor is not None: @@ -140,6 +134,7 @@ def sample_from_logits( " or element < 0" ) logits = torch.from_dlpack(logits) + for i in range(batch_size): sequence_id = sequence_ids[i] logits_per_token = logits[i] @@ -149,16 +144,14 @@ def sample_from_logits( # NOTE: Rerun the preparation for simplicity. # Assume this code path is taken rarely and the recomputation overhead is # marginal. - with torch.cuda.stream(copy_stream): - new_sampling_state = SamplingState.from_sampling_params( - [sampling_param], - [past_decode_tokens_per_request], - [prompt_mask], - torch_dtype, - torch_dev, - vocab_size, - ) - torch.cuda.current_stream().wait_stream(copy_stream) + new_sampling_state = SamplingState.from_sampling_params( + [sampling_param], + [past_decode_tokens_per_request], + [prompt_mask], + torch_dtype, + torch_dev, + vocab_size, + ) maybe_sampling_output: Optional[SamplingOutput] = sample( torch.unsqueeze(logits_per_token, 0), new_sampling_state, @@ -169,6 +162,7 @@ def sample_from_logits( logprob_info = maybe_sampling_output.logprob_infos[0] # Valid sample request = requests[i] + if maybe_sampling_output is not None: outputs.extend( prepare_textgen_result( @@ -200,24 +194,39 @@ def prepare_inputs( all_decode_block_tables, sliding_window, is_prefill, + block_size, num_decode_query_tokens=1, + for_vllm=False, ): + if for_vllm: + torch_int_dtype = torch.long + else: + torch_int_dtype = torch.int + block_tables = [] seq_lens = [] input_ids = [] slot_mapping = [] positions = [] - max_num_blocks_per_seq = 0 indices_within_window = [] start_idx = 0 + max_prompt_len = -1 + max_context_len = -1 for i, (sequence_id, token_ids) in enumerate(zip(sequence_ids, all_token_ids)): if is_prefill: - input_ids += token_ids prompt_len = len(token_ids) seq_lens.append(prompt_len) - positions += range(prompt_len) - slot_mapping += all_slot_mappings[sequence_id] + max_prompt_len = max(max_prompt_len, prompt_len) + + if for_vllm: + input_ids.append(token_ids) + positions.append(list(range(prompt_len))) + slot_mapping.append(all_slot_mappings[sequence_id]) + else: + input_ids += token_ids + positions += range(prompt_len) + slot_mapping += all_slot_mappings[sequence_id] if sliding_window: indices_within_window += range( @@ -228,15 +237,21 @@ def prepare_inputs( else: seq_len = prompt_lens[i] + len(token_ids) - input_ids += token_ids[-num_decode_query_tokens:] - for i in range(num_decode_query_tokens): - positions.append(seq_len - (num_decode_query_tokens - i)) + if for_vllm: + assert num_decode_query_tokens == 1 + input_ids.append([token_ids[-1]]) + positions.append([seq_len - 1]) + slot_mapping.append([all_slot_mappings[sequence_id][-1]]) + else: + input_ids += token_ids[-num_decode_query_tokens:] - slot_mapping += all_slot_mappings[sequence_id][-num_decode_query_tokens:] + for i in range(num_decode_query_tokens): + positions.append(seq_len - (num_decode_query_tokens - i)) + + slot_mapping += all_slot_mappings[sequence_id][-num_decode_query_tokens:] block_table = all_decode_block_tables[sequence_id] - max_num_blocks_per_seq = max(max_num_blocks_per_seq, len(block_table)) block_tables.append(block_table.get_blocks()) if sliding_window: @@ -244,13 +259,31 @@ def prepare_inputs( else: seq_lens.append(seq_len) + max_context_len = max(max_context_len, seq_lens[-1]) + + def _do_pad( + x: List[List[int]], + max_len: int, + pad_val: int, + ) -> List[List[int]]: + def _pad_to_max(x: List[int], max_len: int, pad_val: int) -> List[int]: + assert len(x) <= max_len + return x + [pad_val] * (max_len - len(x)) + + return [_pad_to_max(x_i, max_len, pad_val) for x_i in x] + + if for_vllm and is_prefill: + input_ids = _do_pad(input_ids, max_prompt_len, 0) + positions = _do_pad(positions, max_prompt_len, 0) + slot_mapping = _do_pad(slot_mapping, max_prompt_len, -1) + def to_torch(arr, torch_dtype): return torch.tensor(arr, dtype=torch_dtype, device="cuda") - input_ids = to_torch(input_ids, torch.int) - positions = to_torch(positions, torch.int) + input_ids = to_torch(input_ids, torch_int_dtype) + positions = to_torch(positions, torch_int_dtype) seq_lens = to_torch(seq_lens, torch.int) - slot_mapping = to_torch(slot_mapping, torch.int) + slot_mapping = to_torch(slot_mapping, torch_int_dtype) if is_prefill and sliding_window: indices_within_window = to_torch(indices_within_window, torch.int) @@ -258,14 +291,11 @@ def to_torch(arr, torch_dtype): indices_within_window = None if not is_prefill: + max_block_table_len = ( + max_context_len + block_size - 1 + ) // block_size - def _pad_to_max(x: List[int], max_len: int) -> List[int]: - return x + [0] * (max_len - len(x)) - - padded_block_tables = [ - _pad_to_max(block_table, max_num_blocks_per_seq) - for block_table in block_tables - ] + padded_block_tables = _do_pad(block_tables, max_block_table_len, 0) block_tables = to_torch(padded_block_tables, torch.int) else: block_tables = None diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index c8da1382c4..7994b00c5b 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -5,6 +5,7 @@ from .base import ModelArtifactConfig from .paged_cache_manager import CacheManager from .tokenizer import HfTokenizerModule, ConversationTemplate, Tokenizer +from .torch_model import init_torch_model from .tvm_model import init_tvm_model from ..engine import MLCServeEngineConfig @@ -81,15 +82,22 @@ def __init__( engine_config: MLCServeEngineConfig, model_artifact_config: ModelArtifactConfig ): - # TODO(masahi): Make the model type configurable. - model, cache_manager = init_tvm_model(model_artifact_config, engine_config) + if engine_config.model_type == "tvm": + model, cache_manager = init_tvm_model(model_artifact_config, engine_config) + tokenizer_module = HfTokenizerModule(model_artifact_path.joinpath("model")) + elif engine_config.model_type == "torch": + model, cache_manager = init_torch_model( + model_artifact_path, engine_config + ) + tokenizer_module = HfTokenizerModule(model_artifact_path) + else: + raise RuntimeError(f"Unknown model type {engine_config.model_type}") self.engine_config = engine_config self.model_artifact_config = model_artifact_config self.text_generator = PagedCacheModelTextGenerator(model) self.cache_manager = cache_manager - tokenizer_module = HfTokenizerModule(model_artifact_path) self.tokenizer = tokenizer_module.tokenizer self.conversation_template = tokenizer_module.conversation_template diff --git a/serve/mlc_serve/model/sampler.py b/serve/mlc_serve/model/sampler.py index 10267238c4..341fa5bd40 100644 --- a/serve/mlc_serve/model/sampler.py +++ b/serve/mlc_serve/model/sampler.py @@ -479,15 +479,15 @@ def adjust_logits(logits: torch.Tensor, sampling_state: SamplingState, vocab_siz @dataclass class SamplingOutput: - next_tokens: list[int] - logprob_infos: list[Optional[RawLogprobsInfo]] + next_tokens: np.ndarray + logprob_infos: List[Optional[RawLogprobsInfo]] def sample( logits: torch.Tensor, sampling_state: SamplingState, check_safety: bool = False, -) -> SamplingOutput: +) -> Optional[SamplingOutput]: def _is_safe_to_sample(prob_like): return ( torch.sum(torch.isnan(prob_like) | torch.isinf(prob_like) | (prob_like < 0)) @@ -504,21 +504,26 @@ def _is_safe_to_sample(prob_like): ) next_tokens = np.empty((batch_size,), dtype=np.int64) + if sampling_state.has_greedy: res_greedy = torch.argmax(logits[mask_greedy_t], -1) np_mask_greedy = mask_greedy_t.cpu().numpy() next_tokens[np_mask_greedy] = res_greedy.cpu().numpy() probs_random = None + if sampling_state.has_random: probs_random = torch.softmax(logits[mask_random_t], dim=-1) + if check_safety and not _is_safe_to_sample(probs_random): return None + res_random = torch.multinomial(probs_random, 1, True)[:, 0] np_mask_random = mask_random_t.cpu().numpy() next_tokens[np_mask_random] = res_random.cpu().numpy() logprob_infos: List[Optional[RawLogprobsInfo]] = [None] * batch_size + if sampling_state.has_logprob: # If everything is random sampling, save one extra softmax if not sampling_state.has_greedy: @@ -535,6 +540,7 @@ def _is_safe_to_sample(prob_like): mask = sampling_state.sampling_tensors.mask_top_logprob top_tokens = all_top_tokens[mask] top_logprobs = all_top_logprobs[mask] + for idx, batch_idx in enumerate(sampling_state.logprob_batch_indices): next_token = next_tokens[batch_idx] assert sampling_state.sampling_params[batch_idx].logprobs diff --git a/serve/mlc_serve/model/tokenizer.py b/serve/mlc_serve/model/tokenizer.py index 68bb978110..3e4a9a0551 100644 --- a/serve/mlc_serve/model/tokenizer.py +++ b/serve/mlc_serve/model/tokenizer.py @@ -1,8 +1,11 @@ +import structlog from typing import List from transformers import AutoTokenizer from ..engine import ChatMessage from pathlib import Path +LOG = structlog.stdlib.get_logger(__name__) + class Tokenizer: def __init__(self, hf_tokenizer, skip_special_tokens=True): @@ -46,11 +49,17 @@ def apply(self, messages: list[ChatMessage]) -> str: class HfTokenizerModule: - def __init__(self, model_artifact_path: Path): + def __init__(self, tokenizer_path: Path): hf_tokenizer = AutoTokenizer.from_pretrained( - model_artifact_path.joinpath("model"), - revision=None, tokenizer_revision=None, - trust_remote_code=False, + tokenizer_path, + trust_remote_code=True, + revision=None, + tokenizer_revision=None, ) self.tokenizer = Tokenizer(hf_tokenizer) self.conversation_template = ConversationTemplate(hf_tokenizer) + + if not self.tokenizer.is_fast: + LOG.warn("tokenizer.is_fast is false. Some models using an external tokenizer package, " + "such as QWen, might hit this condition but that does not imply that their " + "tokenizers are slow.") diff --git a/serve/mlc_serve/model/torch_model.py b/serve/mlc_serve/model/torch_model.py new file mode 100644 index 0000000000..04381cd3b4 --- /dev/null +++ b/serve/mlc_serve/model/torch_model.py @@ -0,0 +1,647 @@ +import time +import os +import tempfile +import socket +from typing import List, Tuple, Sequence +from pathlib import Path +from concurrent.futures import ThreadPoolExecutor + +import structlog + +import torch +import torch.multiprocessing as multiprocessing + +from transformers import AutoConfig + +try: + from vllm.model_executor.layers.sampler import get_logits + from vllm.model_executor.models.llama import LlamaForCausalLM + from vllm.model_executor.models.qwen import QWenLMHeadModel + from vllm.model_executor.models.phi import PhiForCausalLM + from vllm.model_executor.models.mistral import MistralForCausalLM + from vllm.model_executor.models.mixtral import MixtralForCausalLM + from vllm.model_executor import InputMetadata, SamplingMetadata + from vllm.model_executor.parallel_utils.parallel_state import ( + initialize_model_parallel, + ) + import rpyc + from rpyc.utils.classic import obtain + from rpyc.utils.server import ThreadedServer + from rpyc.utils.factory import unix_connect + + support_torch_model = True + +except ImportError: + support_torch_model = False + + +from .base import get_hf_config +from .paged_cache_manager import KVCacheInfo, CacheManager +from .model_common import ( + prepare_inputs, + get_num_cache_blocks, + sample_from_logits, +) + +from ..engine import ( + get_prompt_sequence_id, + MLCServeEngineConfig, +) +from ..engine.model_module import ( + DecodeRequest, + PrefillRequest, + TextGenerationResult, + TextGenerator, + RequestType, +) +from .sampler import SamplingState + + +LOG = structlog.stdlib.get_logger(__name__) + + +def init_cache_blocks(head_size, num_layers, num_heads, block_size, num_gpu_blocks): + element_size = 2 + x = 16 // element_size + + key_block_shape = (num_heads, head_size // x, block_size, x) + value_block_shape = (num_heads, head_size, block_size) + + gpu_cache = [] + for _ in range(num_layers): + key_blocks = torch.empty( + size=(num_gpu_blocks, *key_block_shape), + dtype=torch.float16, + device="cuda", + ) + value_blocks = torch.empty( + size=(num_gpu_blocks, *value_block_shape), + dtype=torch.float16, + device="cuda", + ) + gpu_cache.append((key_blocks, value_blocks)) + return gpu_cache + + +def profile_memory_usage(pt_model, seq_lens, num_hidden_layers, vocab_size): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + input_tokens: List[List[int]] = [] + input_positions: List[List[int]] = [] + slot_mapping: List[List[int]] = [] + + for seq_len in seq_lens: + prompt_tokens = [0] * seq_len + + input_tokens.append(prompt_tokens) + input_positions.append(list(range(seq_len))) + slot_mapping.append([0] * seq_len) + + selected_token_indices: List[int] = [] + + max_prompt_len = max(seq_lens) + seq_start = 0 + + for prompt_len in seq_lens: + selected_token_indices.append(seq_start + prompt_len - 1) + seq_start += max_prompt_len + + input_ids = torch.cuda.LongTensor(input_tokens) + positions = torch.cuda.LongTensor(input_positions) + slot_mapping_tensor = torch.cuda.LongTensor(slot_mapping) + prompt_lens_tensor = torch.cuda.LongTensor(seq_lens) + + peak_memory_before = torch.cuda.max_memory_allocated() + + input_metadata = InputMetadata( + is_prompt=True, + slot_mapping=slot_mapping_tensor, + prompt_lens=prompt_lens_tensor, + max_seq_len=None, + start_loc=None, + max_context_len=0, + context_lens=torch.cuda.IntTensor([]), + block_tables=torch.cuda.IntTensor([]), + use_cuda_graph=False, + ) + + sampling_metadata = SamplingMetadata( + seq_groups=None, + seq_data=None, + prompt_lens=seq_lens, + selected_token_indices=torch.tensor( + selected_token_indices, dtype=torch.long, device="cuda" + ), + categorized_sample_indices=None, + ) + + kv_caches = [(None, None)] * num_hidden_layers + + with torch.no_grad(): + hidden_states = pt_model.forward( + input_ids, + positions, + kv_caches, + input_metadata, + ) + + _ = get_logits( + pt_model.lm_head.weight, + hidden_states, + sampling_metadata, + vocab_size, + ) + + torch.cuda.synchronize() + + peak_memory = torch.cuda.max_memory_allocated() + LOG.info(f"peak memory during profling: {(peak_memory - peak_memory_before) / 1e9} GB") + + torch.cuda.empty_cache() + + return peak_memory + + +def profile_and_init_cache( + pt_model, + hf_config, + num_shards, + max_num_batched_tokens, +): + num_kv_heads = hf_config.num_key_value_heads // num_shards + num_hidden_layers = hf_config.num_hidden_layers + head_size = hf_config.hidden_size // hf_config.num_attention_heads + + block_size = 16 + + if max_num_batched_tokens > 0: + LOG.info("Running memory profiling.") + seq_lens = [1] * max_num_batched_tokens + used_memory_bytes = profile_memory_usage( + pt_model, seq_lens, num_hidden_layers, hf_config.vocab_size + ) + num_blocks = get_num_cache_blocks( + used_memory_bytes, + block_size, + hf_config.num_hidden_layers, + num_kv_heads, + head_size, + ) + else: + num_blocks = 500 + + LOG.info(f"Using {num_blocks} cache blocks.") + + cache_blocks = init_cache_blocks( + head_size, + hf_config.num_hidden_layers, + num_kv_heads, + block_size, + num_blocks, + ) + + LOG.info("Allocated KV cache blocks.") + + return cache_blocks, num_blocks + + +def load_model(hf_config, model_path): + model_map = { + "LlamaForCausalLM": LlamaForCausalLM, + "PhiForCausalLM": PhiForCausalLM, + "QWenLMHeadModel": QWenLMHeadModel, # requires tiktoken package + "MistralForCausalLM": MistralForCausalLM, + "MixtralForCausalLM": MixtralForCausalLM, + } + + arch = hf_config.architectures[0] + + if arch not in model_map: + raise RuntimeError(f"Unsupported model: {arch}") + + with torch.device("cuda"): + torch.set_default_dtype(torch.float16) + model = model_map[arch](hf_config) + model.load_weights(model_path, None, "auto", None) + return model + + +def generate( + requests: Sequence[RequestType], + cache_info: KVCacheInfo, + pt_model, + cache_blocks, + sliding_window, + vocab_size, +) -> List[TextGenerationResult]: + if len(requests) == 0: + return [] + + is_prefill = isinstance(requests[0], PrefillRequest) + + all_token_ids = [] + sequence_ids = [] + prompt_lens = [] + sampling_params = [] + past_decode_tokens = [] + prompt_masks = [] + + for request in requests: + if isinstance(request, PrefillRequest): + sequence_ids.append(get_prompt_sequence_id(request.request_id)) + prompt_lens.append(len(request.token_ids)) + past_decode_tokens.append([vocab_size]) + elif isinstance(request, DecodeRequest): + sequence_ids.append(request.sequence_id) + prompt_lens.append(request.prompt_token_counts) + past_decode_tokens.append([vocab_size, *request.token_ids]) + else: + raise RuntimeError(f"Unsupported request type {request}") + + all_token_ids.append(request.token_ids) + sampling_params.append(request.sampling_params) + prompt_masks.append(request.prompt_mask) + + selected_token_indices: List[int] = [] + + if is_prefill: + max_prompt_len = max(prompt_lens) + seq_start = 0 + + for prompt_len in prompt_lens: + selected_token_indices.append(seq_start + prompt_len - 1) + seq_start += max_prompt_len + + torch.cuda.nvtx.range_push(f"Prepare input") + + ( + input_ids, + positions, + seq_lens, + slot_mapping, + _, + block_tables, + ) = prepare_inputs( + sequence_ids, + all_token_ids, + prompt_lens, + cache_info.slot_mappings, + cache_info.decode_block_tables, + sliding_window, + is_prefill, + cache_info.block_size, + for_vllm=True, + ) + + torch.cuda.nvtx.range_pop() + + input_shape = input_ids.shape + + if block_tables is None: + torch.cuda.nvtx.range_push(f"forward prefill {input_shape}") + block_tables = torch.cuda.IntTensor([]) + context_lens = torch.cuda.IntTensor([]) + max_context_len = 0 + else: + torch.cuda.nvtx.range_push(f"forward decode {input_shape}") + context_lens = seq_lens + max_context_len = torch.max(seq_lens) + prompt_lens = [] + + prompt_lens = torch.cuda.LongTensor(prompt_lens) + + input_metadata = InputMetadata( + is_prompt=is_prefill, + slot_mapping=slot_mapping, + prompt_lens=prompt_lens, + max_seq_len=None, + start_loc=None, + max_context_len=max_context_len, + context_lens=context_lens, + block_tables=block_tables, + use_cuda_graph=False, + ) + + sampling_metadata = SamplingMetadata( + seq_groups=None, + seq_data=None, + prompt_lens=prompt_lens, + selected_token_indices=torch.tensor( + selected_token_indices, dtype=torch.long, device="cuda" + ), + categorized_sample_indices=None, + ) + + with torch.no_grad(): + hidden_states = pt_model.model( + input_ids, + positions, + cache_blocks, + input_metadata, + ) + + logits = get_logits( + pt_model.lm_head.weight, + hidden_states, + sampling_metadata, + vocab_size, + ) + + torch.cuda.synchronize() + torch.cuda.nvtx.range_pop() + + sampling_metadata = SamplingState.from_sampling_params( + sampling_params, + past_decode_tokens, + prompt_masks, + torch.float32, + "cuda", + vocab_size, + ) + + return sample_from_logits( + logits, + sequence_ids, + requests, + sampling_metadata, + vocab_size, + torch.float32, + "cuda", + past_decode_tokens, + prompt_masks, + ) + + +if support_torch_model: + + class ModelRpcServer(rpyc.Service): + def exposed_init_model( + self, + tp_rank: int, + num_shards: int, + model_path: Path, + hf_config: AutoConfig, + engine_config: MLCServeEngineConfig, + master_port: int, + ) -> int: + hf_config = obtain(hf_config) + engine_config = obtain(engine_config) + model_path = obtain(model_path) + + self.vocab_size = hf_config.vocab_size + self.sliding_window = hf_config.sliding_window + + # This was taken from vLLM + # torch.distributed.all_reduce does not free the input tensor until + # the synchronization point. This causes the memory usage to grow + # as the number of all_reduce calls increases. This env var disables + # this behavior. + # Related issue: + # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + + torch.cuda.set_device(tp_rank) + + os.environ["MASTER_ADDR"] = str("127.0.0.1") + os.environ["MASTER_PORT"] = str(master_port) + + torch.distributed.init_process_group( + backend="nccl", + world_size=num_shards, + rank=tp_rank, + ) + initialize_model_parallel(num_shards) + + # A small all_reduce for warmup. + torch.distributed.all_reduce(torch.zeros(1).cuda()) + + self.pt_model = load_model(hf_config, model_path) + + self.cache_blocks, num_blocks = profile_and_init_cache( + self.pt_model, + hf_config, + num_shards, + engine_config.max_num_batched_tokens, + ) + + return num_blocks + + def exposed_generate( + self, + requests: Sequence[RequestType], + cache: KVCacheInfo, + ) -> List[TextGenerationResult]: + # TODO(masahi): Currently, obtaining inputs is the bottleneck. + # We should switch to the architecture used by Disco and vLLM as of + # https://github.com/vllm-project/vllm/pull/2221 + torch.cuda.nvtx.range_push(f"Obtain input") + requests = obtain(requests) + cache = obtain(cache) + torch.cuda.nvtx.range_pop() + return generate( + requests, + cache, + self.pt_model, + self.cache_blocks, + self.sliding_window, + self.vocab_size, + ) + + +def _init_service(socket_path): + t = ThreadedServer( + ModelRpcServer(), + socket_path=socket_path, + protocol_config={"allow_pickle": True, "sync_request_timeout": 600}, + ) + t.start() + + +def start_model_process(socket_path): + proc = multiprocessing.Process(target=_init_service, args=(socket_path,)) + proc.start() + + time.sleep(1) + + repeat_count = 0 + conn = None + + while repeat_count < 20: + try: + conn = unix_connect( + socket_path, config={"allow_pickle": True, "sync_request_timeout": 600} + ) + break + except FileNotFoundError: + time.sleep(1) + repeat_count += 1 + + if repeat_count == 20: + raise RuntimeError("init rpc env error!") + + assert proc.is_alive() + return conn, proc + + +class ModelRpcClient: + def __init__( + self, + model_path: Path, + hf_config: AutoConfig, + engine_config: MLCServeEngineConfig, + ports: List[int], + ): + assert engine_config.num_shards is not None + + self.num_shards = engine_config.num_shards + + master_port = ports[-1] + self.executor = ThreadPoolExecutor(self.num_shards) + self.socket_paths = [tempfile.mktemp() for _ in range(self.num_shards)] + + self.model_servers = [] + self.connections = [] + self.procs = [] + + for conn, proc in self.executor.map(start_model_process, self.socket_paths): + self.model_servers.append(conn.root) + self.connections.append(conn) + self.procs.append(proc) + + def init_model(i): + return self.model_servers[i].init_model( + i, + self.num_shards, + model_path, + hf_config, + engine_config, + master_port, + ) + + rets = self.executor.map(init_model, range(self.num_shards)) + self.num_blocks = obtain(list(rets)[0]) + + def __del__(self): + self.executor.shutdown() + + for conn in self.connections: + conn.close() + + for proc in self.procs: + proc.terminate() + proc.join() + + def generate( + self, + requests: Sequence[RequestType], + cache: KVCacheInfo, + ) -> List[TextGenerationResult]: + def _generate(i): + # This calls ModelRpcServer.exposed_generate(...) via RPC. + return self.model_servers[i].generate(requests, cache) + + res = [obtain(x) for x in self.executor.map(_generate, range(self.num_shards))] + return res[0] + + +# Taken from sgl-project/sglang +def alloc_usable_network_port(num): + port_list = [] + for port in range(10000, 65536): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + s.bind(("", port)) + port_list.append(port) + except socket.error: + pass + + if len(port_list) == num: + break + + return port_list + + +class Model: + def __init__( + self, + model_path: Path, + hf_config: AutoConfig, + engine_config: MLCServeEngineConfig, + ): + if engine_config.num_shards and engine_config.num_shards > 1: + num_needed_ports = 1 # For torch distributed master port + ports = alloc_usable_network_port(num_needed_ports) + assert len(ports) == num_needed_ports, "Not enough ports available." + self.model_rpc = ModelRpcClient(model_path, hf_config, engine_config, ports) + self.num_blocks = self.model_rpc.num_blocks + self.cache_blocks = None # Owned by each remote shard + else: + ports = alloc_usable_network_port(1) + assert len(ports) == 1 + torch.distributed.init_process_group( + backend="nccl", + world_size=1, + rank=0, + init_method=f"tcp://localhost:{ports[0]}", + ) + initialize_model_parallel(1, 1) + + self.pt_model = load_model(hf_config, model_path) + self.cache_blocks, self.num_blocks = profile_and_init_cache( + self.pt_model, + hf_config, + 1, + engine_config.max_num_batched_tokens, + ) + self.model_rpc = None + + self.vocab_size = hf_config.vocab_size + self.sliding_window = hf_config.sliding_window + + def __del__(self): + if self.model_rpc: + del self.model_rpc + + def generate( + self, + requests: Sequence[RequestType], + cache: KVCacheInfo, + ) -> List[TextGenerationResult]: + if self.model_rpc is None: + return generate( + requests, + cache, + self.pt_model, + self.cache_blocks, + self.sliding_window, + self.vocab_size, + ) + + return self.model_rpc.generate(requests, cache) + + +def init_torch_model( + model_path: Path, engine_config: MLCServeEngineConfig +) -> Tuple[TextGenerator, CacheManager]: + if not support_torch_model: + raise RuntimeError( + "Running PyTorch models requires vLLM from " + "https://github.com/octoml/vllm/tree/for-mlc-serve installed. " + "Furthermore, rpyc is needed for multi-gpu support." + ) + + hf_config = get_hf_config(model_path) + + if engine_config.num_shards is None: + raise RuntimeError("num_shards needs to be specifed for PyTorch models.") + + model = Model(model_path, hf_config, engine_config) + + cache_manager = CacheManager( + model.num_blocks, + 16, + hf_config.sliding_window, + ) + + return model, cache_manager diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index ae4b204c82..67fa7eefbd 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -33,6 +33,7 @@ ) from .sampler import SamplingState + LOG = structlog.stdlib.get_logger(__name__) @@ -106,6 +107,7 @@ def _prepare_inputs( all_decode_block_tables, sliding_window, is_prefill, + block_size, num_decode_query_tokens=1, ): ( @@ -123,6 +125,7 @@ def _prepare_inputs( all_decode_block_tables, sliding_window, is_prefill, + block_size, num_decode_query_tokens, ) @@ -159,15 +162,13 @@ def __init__( self.num_shards = config.num_shards # TODO(@sunggg): Find a better way - if config.model_type == "llama": - self.torch_dtype = torch.float32 - elif config.model_type == "mistral" or config.model_type == "mixtral": + if config.model_type in ["llama", "mistral", "mixtral"]: self.torch_dtype = torch.float32 else: assert 0, f"{config.model_type} is NOT supported yet" self._copy_stream: torch.cuda.Stream = torch.cuda.Stream() - self.torch_dev: str = "cuda" + self.torch_dev = "cuda" if self.sliding_window: self.block_sliding_window = self.sliding_window // block_size @@ -252,6 +253,8 @@ def profile_memory_usage(self, seq_lens): vm_alloc_after = self.get_used_memory() + LOG.info(f"peak memory during profling: {(vm_alloc_after - vm_alloc_before) / 1e9} GB") + return self.get_param_nbytes() + (vm_alloc_after - vm_alloc_before) def generate_multi_query( @@ -339,13 +342,17 @@ def generate_multi_query( torch.cuda.nvtx.range_pop() last_query_logits = torch.from_dlpack(logits)[last_query_offsets] + + # synchronization point for sampling tensors + # wait until all the tensors are loaded on GPU + torch.cuda.current_stream().wait_stream(self._copy_stream) + return sample_from_logits( last_query_logits, sequence_ids, requests, sampling_state, self.vocab_size, - self._copy_stream, self.torch_dtype, self.torch_dev, past_decode_tokens, @@ -433,10 +440,12 @@ def generate( cache.decode_block_tables, self.sliding_window, is_prefill, + cache.block_size, num_decode_query_tokens, ) input_shape = input_ids.shape + if self.disco_session: input_ids = copy_to_worker_0(self.disco_session, input_ids) positions = copy_to_worker_0(self.disco_session, positions) @@ -527,13 +536,16 @@ def generate( # TODO(masahi, yelite): Proper logic for handling multi-query logits (speculative decoding). return [] + # synchronization point for sampling tensors + # wait until all the tensors are loaded on GPU + torch.cuda.current_stream().wait_stream(self._copy_stream) + return sample_from_logits( logits, sequence_ids, requests, sampling_state, self.vocab_size, - self._copy_stream, self.torch_dtype, self.torch_dev, past_decode_tokens, @@ -576,10 +588,11 @@ def init_tvm_model( if engine_config.max_num_batched_tokens > 0: LOG.info("Running memory profiling.") try: + seq_lens = [1] * engine_config.max_num_batched_tokens + used_memory_bytes = model.profile_memory_usage(seq_lens) num_blocks = get_num_cache_blocks( - model, + used_memory_bytes, block_size, - [1] * engine_config.max_num_batched_tokens, model_artifact_config.num_hidden_layers, num_kv_heads, head_size, diff --git a/serve/mlc_serve/utils.py b/serve/mlc_serve/utils.py index 77ce740b10..578f5893b3 100644 --- a/serve/mlc_serve/utils.py +++ b/serve/mlc_serve/utils.py @@ -10,7 +10,11 @@ from mlc_serve.logging_utils import configure_logging from mlc_serve.engine.staging_engine import StagingInferenceEngine from mlc_serve.engine.sync_engine import SynchronousInferenceEngine -from mlc_serve.model.base import get_model_artifact_config +from mlc_serve.model.base import ( + get_model_artifact_config, + ModelArtifactConfig, + get_hf_config, +) from mlc_serve.model.paged_cache_model import HfTokenizerModule, PagedCacheModelModule @@ -30,6 +34,7 @@ def get_default_mlc_serve_argparser(description="", allow_override=False): parser.add_argument("--max-decode-steps", type=int, default=56) parser.add_argument("--debug-logging", action="store_true") parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--num-shards", type=int, default=1) # Needed for PT models return parser @@ -48,26 +53,66 @@ def postproc_mlc_serve_args(args): def create_mlc_engine(args: argparse.Namespace, start_engine=True) -> InferenceEngine: + model_type = "tvm" + num_shards = None + + if not os.path.exists(args.model_artifact_path.joinpath("build_config.json")): + model_type = "torch" + num_shards = args.num_shards + + assert ( + num_shards is not None + ), "--num-shards needs to be provided for PT models." + + if num_shards > 1: + import torch + + torch.multiprocessing.set_start_method("spawn") + engine_config = get_engine_config( { "use_staging_engine": args.use_staging_engine, "max_num_batched_tokens": args.max_num_batched_tokens, "min_decode_steps": args.min_decode_steps, "max_decode_steps": args.max_decode_steps, + "model_type": model_type, + "num_shards": num_shards, } ) - model_artifact_config = get_model_artifact_config(args.model_artifact_path) + + if model_type == "tvm": + model_artifact_config = get_model_artifact_config(args.model_artifact_path) + else: + hf_config = get_hf_config(args.model_artifact_path) + + model_artifact_config = ModelArtifactConfig( + model_artifact_path=str(args.model_artifact_path), + num_shards=num_shards, + quantization=None, + max_context_length=hf_config.max_position_embeddings, + vocab_size=hf_config.vocab_size, + sliding_window=hf_config.sliding_window, + num_key_value_heads=hf_config.num_key_value_heads // num_shards, + num_attention_heads=hf_config.num_attention_heads, + num_hidden_layers=hf_config.num_hidden_layers, + hidden_size=hf_config.hidden_size, + ) engine: InferenceEngine if args.use_staging_engine: + if model_type == "tvm": + tokenizer_path = args.model_artifact_path.joinpath("model") + else: + tokenizer_path = args.model_artifact_path + engine = StagingInferenceEngine( - tokenizer_module=HfTokenizerModule(args.model_artifact_path), + tokenizer_module=HfTokenizerModule(tokenizer_path), model_module_loader=PagedCacheModelModule, model_module_loader_kwargs={ "model_artifact_path": args.model_artifact_path, "engine_config": engine_config, - "model_artifact_config": model_artifact_config + "model_artifact_config": model_artifact_config, }, ) @@ -78,7 +123,7 @@ def create_mlc_engine(args: argparse.Namespace, start_engine=True) -> InferenceE PagedCacheModelModule( model_artifact_path=args.model_artifact_path, engine_config=engine_config, - model_artifact_config=model_artifact_config + model_artifact_config=model_artifact_config, ) ) diff --git a/serve/pyproject.toml b/serve/pyproject.toml index f79af292c1..b75d16aa44 100644 --- a/serve/pyproject.toml +++ b/serve/pyproject.toml @@ -10,6 +10,11 @@ fastapi = ">=0.103.1" pydantic = ">=1.8.0" prometheus-client = ">=0.18.0" outlines = "0.0.23" +# PT model dep +rpyc = "5.3.1" +tiktoken = "0.5.2" # for Qwen +# TODO vllm dep +# https://github.com/octoml/vllm/tree/for-mlc-serve [tool.poetry.group.dev.dependencies] pytest = "^7.4.2"