From 12ce0a3b2ffb344889bcf66fc8e59dcf54b17fdc Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 11 Jan 2024 19:31:06 +0000 Subject: [PATCH 01/73] refactor to separate TVM specific bits from paged_cache_model --- serve/mlc_serve/engine/base.py | 1 + serve/mlc_serve/engine/model_module.py | 4 +- serve/mlc_serve/model/model_common.py | 249 ++++++++ serve/mlc_serve/model/paged_cache_model.py | 655 +-------------------- serve/mlc_serve/model/tvm_model.py | 500 ++++++++++++++++ 5 files changed, 762 insertions(+), 647 deletions(-) create mode 100644 serve/mlc_serve/model/model_common.py create mode 100644 serve/mlc_serve/model/tvm_model.py diff --git a/serve/mlc_serve/engine/base.py b/serve/mlc_serve/engine/base.py index bf678268fe..9fed40d046 100644 --- a/serve/mlc_serve/engine/base.py +++ b/serve/mlc_serve/engine/base.py @@ -30,6 +30,7 @@ class MLCServeEngineConfig: min_decode_steps: int = 32 max_decode_steps: int = 48 init_timeout: int = 120 + model_type: str = "tvm" # "tvm", "torch" @classmethod def _from_json(config_cls, json_obj: Dict[Any, Any]): diff --git a/serve/mlc_serve/engine/model_module.py b/serve/mlc_serve/engine/model_module.py index 9b018c6cc4..71d10c8e07 100644 --- a/serve/mlc_serve/engine/model_module.py +++ b/serve/mlc_serve/engine/model_module.py @@ -2,7 +2,7 @@ Required interfaces for the actual inference capability in InferenceEngine. """ from dataclasses import dataclass -from typing import Optional, Protocol, Union, List +from typing import Optional, Protocol, Union, List, Sequence from .base import ChatMessage, RequestId, MLCServeEngineConfig, RequestState, SequenceId from ..model.base import ModelArtifactConfig @@ -117,7 +117,7 @@ class TextGenerator(Protocol): def generate( self, - requests: List[Union[PrefillRequest, DecodeRequest]], + requests: Sequence[Union[PrefillRequest, DecodeRequest]], kv_cache: KVCache, ) -> List[TextGenerationResult]: """ diff --git a/serve/mlc_serve/model/model_common.py b/serve/mlc_serve/model/model_common.py new file mode 100644 index 0000000000..e72dd0dbfe --- /dev/null +++ b/serve/mlc_serve/model/model_common.py @@ -0,0 +1,249 @@ +from typing import List, Union, Optional + +import structlog +import numpy as np +import torch +import tvm + +from .paged_cache_manager import CacheManager +from ..engine import ( + SamplingType, + SamplingParams, +) + +LOG = structlog.stdlib.get_logger(__name__) + + +def get_gpu_memory(gpu: int = 0) -> int: + return torch.cuda.get_device_properties(gpu).total_memory + + +def get_num_cache_blocks( + model, + seq_lens, + num_layers, + num_kv_heads, + head_size, + gpu_memory_utilization=0.9, # the default used by vllm +): + used_memory_bytes = model.profile_memory_usage(seq_lens) + cache_block_size = CacheManager.get_cache_block_size( + num_layers, num_kv_heads, head_size + ) + total_vram = get_gpu_memory() + return int( + (total_vram * gpu_memory_utilization - used_memory_bytes) // cache_block_size + ) + + +def _apply_top_p_top_k(logits, top_ps, top_ks): + p = torch.tensor(top_ps, dtype=logits.dtype, device=logits.device) + k = torch.tensor(top_ks, dtype=torch.int, device=logits.device) + logits_sort, logits_idx = logits.sort(dim=-1, descending=True) + + # Apply top-p. + probs_sort = logits_sort.softmax(dim=-1) + probs_sum = probs_sort.cumsum(dim=-1) + top_p_mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1) + logits_sort[top_p_mask] = -float("inf") + + # Apply top-k. + # Create a mask for the top-k elements. + top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device) + top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1) + top_k_mask = top_k_mask >= k.unsqueeze(dim=1) + logits_sort[top_k_mask] = -float("inf") + + # Re-sort the probabilities. + logits = torch.gather(logits_sort, dim=-1, index=torch.argsort(logits_idx, dim=-1)) + return logits + + +def sample( + logits: Union[tvm.nd.NDArray, torch.Tensor], + sampling_params: List[SamplingParams], + vocab_size: int, + check_safety=False, +) -> Optional[np.ndarray]: + def _is_safe_to_sample(prob_like): + return ( + torch.sum(torch.isnan(prob_like) | torch.isinf(prob_like) | (prob_like < 0)) + == 0 + ) + + logits = torch.from_dlpack(logits) + num_seq = len(sampling_params) + + mask_random = torch.tensor( + [p.sampling_type == SamplingType.RANDOM for p in sampling_params], + dtype=torch.bool, + ) + mask_greedy = torch.logical_not(mask_random) + + logits_greedy = logits[mask_greedy] + + if logits_greedy.shape[0] > 0: + res_greedy = torch.argmax(logits_greedy, -1).cpu().numpy() + + if logits_greedy.shape[0] == num_seq: + return res_greedy + + temperatures = [] + top_ps = [] + top_ks = [] + divide_by_temperature = False + do_top_p = False + do_top_k = False + + for i in range(num_seq): + param = sampling_params[i] + freq = param.appeared_tokens_freq + + if param.sampling_type == SamplingType.RANDOM: + temperatures.append(param.temperature) + top_ps.append(param.top_p) + top_ks.append(param.top_k if param.top_k != -1 else vocab_size) + + divide_by_temperature |= temperatures[-1] != 1.0 + do_top_p |= top_ps[-1] < 1.0 + do_top_k |= top_ks[-1] != vocab_size + + # TODO(vvchernov): need to strictly define order of using penalties and logit bias or + # prohibit simultaneous using of them. At the latter case it can be LogitProcessor + if ( + not param.presence_penalty == 0.0 or not param.frequency_penalty == 0 + ) and bool(freq): + index = torch.from_numpy(np.array(list(freq.keys()))).to( + device=logits.device + ) + src = ( + torch.from_numpy(np.array(list(freq.values()))) + .type_as(logits) + .to(device=logits.device) + ) + logits[i][index] -= ( + src * param.frequency_penalty + param.presence_penalty + ) + + if not param.repetition_penalty == 1.0 and bool(freq): + index = torch.from_numpy(np.array(list(freq.keys()))).to( + device=logits.device + ) + logits[i][index] /= param.repetition_penalty + + if param.logit_bias: + logits[i][param.logit_bias_index] += ( + torch.Tensor(param.logit_bias_value) + .type_as(logits) + .to(device=logits.device) + ) + + logits_random = logits[mask_random] + + if divide_by_temperature: + t = torch.tensor(temperatures, dtype=logits.dtype, device=logits.device) + logits_random.div_(t.unsqueeze(dim=1)) + + if do_top_p or do_top_k: + logits_random = _apply_top_p_top_k(logits_random, top_ps, top_ks) + + probs = torch.softmax(logits_random, dim=-1) + + if check_safety and not _is_safe_to_sample(probs): + return None + + res_random = torch.multinomial(probs, 1, True).cpu().numpy()[:, 0] + + if logits_random.shape[0] == num_seq: + return res_random + + res = np.empty((num_seq,), dtype=np.int32) + res[mask_random] = res_random + + if logits_greedy.shape[0] > 0: + res[mask_greedy] = res_greedy + + return res + + +def prepare_inputs( + sequence_ids, + all_token_ids, + prompt_lens, + all_slot_mappings, + all_decode_block_tables, + sliding_window, + is_prefill, +): + block_tables = [] + seq_lens = [] + input_ids = [] + slot_mapping = [] + positions = [] + max_num_blocks_per_seq = 0 + indices_within_window = [] + start_idx = 0 + + for i, (sequence_id, token_ids) in enumerate(zip(sequence_ids, all_token_ids)): + if is_prefill: + input_ids += token_ids + prompt_len = len(token_ids) + seq_lens.append(prompt_len) + positions += range(prompt_len) + slot_mapping += all_slot_mappings[sequence_id] + + if sliding_window: + indices_within_window += range( + start_idx + max(0, prompt_len - sliding_window), + start_idx + prompt_len, + ) + start_idx += prompt_len + + else: + input_ids.append(token_ids[-1]) + seq_len = prompt_lens[i] + len(token_ids) + positions.append(seq_len - 1) + block_table = all_decode_block_tables[sequence_id] + max_num_blocks_per_seq = max(max_num_blocks_per_seq, len(block_table)) + block_tables.append(block_table.get_blocks()) + slot_mapping.append(all_slot_mappings[sequence_id][-1]) + + if sliding_window: + seq_lens.append(min(seq_len, sliding_window)) + else: + seq_lens.append(seq_len) + + def to_torch(arr, torch_dtype): + return torch.tensor(arr, dtype=torch_dtype, device="cuda") + + input_ids = to_torch(input_ids, torch.int) + positions = to_torch(positions, torch.int) + seq_lens = to_torch(seq_lens, torch.int) + slot_mapping = to_torch(slot_mapping, torch.int) + + if is_prefill and sliding_window: + indices_within_window = to_torch(indices_within_window, torch.int) + else: + indices_within_window = None + + if not is_prefill: + + def _pad_to_max(x: List[int], max_len: int) -> List[int]: + return x + [0] * (max_len - len(x)) + + padded_block_tables = [ + _pad_to_max(block_table, max_num_blocks_per_seq) + for block_table in block_tables + ] + block_tables = to_torch(padded_block_tables, torch.int) + else: + block_tables = None + + return ( + input_ids, + positions, + seq_lens, + slot_mapping, + indices_within_window, + block_tables, + ) diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index 64c291aa1c..fc1de355da 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -1,602 +1,26 @@ -import math -import os -from typing import List, Union, Optional +from typing import Union from pathlib import Path - import structlog -import numpy as np -import torch -import tvm -from tvm import relax -from tvm.runtime import disco as di from .base import get_model_artifact_config -from .paged_cache_manager import KVCache, CacheManager +from .paged_cache_manager import CacheManager from .tokenizer import HfTokenizerModule, ConversationTemplate, Tokenizer -from ..engine import ( - SamplingType, - MLCServeEngineConfig, - SamplingParams, - SequenceId, - PROMPT_SEQEUNCE_INDEX, - get_prompt_sequence_id, -) +from .tvm_model import init_tvm_model + +from ..engine import MLCServeEngineConfig from ..engine.model_module import ( DecodeRequest, PrefillRequest, TextGenerationResult, + TextGenerator, ) from ..engine.model_module import ModelModule LOG = structlog.stdlib.get_logger(__name__) -def _apply_top_p_top_k(logits, top_ps, top_ks): - p = torch.tensor(top_ps, dtype=logits.dtype, device=logits.device) - k = torch.tensor(top_ks, dtype=torch.int, device=logits.device) - logits_sort, logits_idx = logits.sort(dim=-1, descending=True) - - # Apply top-p. - probs_sort = logits_sort.softmax(dim=-1) - probs_sum = probs_sort.cumsum(dim=-1) - top_p_mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1) - logits_sort[top_p_mask] = -float("inf") - - # Apply top-k. - # Create a mask for the top-k elements. - top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device) - top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1) - top_k_mask = top_k_mask >= k.unsqueeze(dim=1) - logits_sort[top_k_mask] = -float("inf") - - # Re-sort the probabilities. - logits = torch.gather(logits_sort, dim=-1, index=torch.argsort(logits_idx, dim=-1)) - return logits - - -def sample( - logits: Union[tvm.nd.NDArray, torch.Tensor], - sampling_params: List[SamplingParams], - vocab_size: int, - check_safety=False, -) -> Optional[np.ndarray]: - def _is_safe_to_sample(prob_like): - return ( - torch.sum(torch.isnan(prob_like) | torch.isinf(prob_like) | (prob_like < 0)) - == 0 - ) - - logits = torch.from_dlpack(logits) - num_seq = len(sampling_params) - - mask_random = torch.tensor( - [p.sampling_type == SamplingType.RANDOM for p in sampling_params], - dtype=torch.bool, - ) - mask_greedy = torch.logical_not(mask_random) - - logits_greedy = logits[mask_greedy] - - if logits_greedy.shape[0] > 0: - res_greedy = torch.argmax(logits_greedy, -1).cpu().numpy() - - if logits_greedy.shape[0] == num_seq: - return res_greedy - - temperatures = [] - top_ps = [] - top_ks = [] - divide_by_temperature = False - do_top_p = False - do_top_k = False - - for i in range(num_seq): - param = sampling_params[i] - freq = param.appeared_tokens_freq - - if param.sampling_type == SamplingType.RANDOM: - temperatures.append(param.temperature) - top_ps.append(param.top_p) - top_ks.append(param.top_k if param.top_k != -1 else vocab_size) - - divide_by_temperature |= temperatures[-1] != 1.0 - do_top_p |= top_ps[-1] < 1.0 - do_top_k |= top_ks[-1] != vocab_size - - # TODO(vvchernov): need to strictly define order of using penalties and logit bias or - # prohibit simultaneous using of them. At the latter case it can be LogitProcessor - if (not param.presence_penalty == 0.0 or not param.frequency_penalty == 0) and bool(freq): - index = torch.from_numpy(np.array(list(freq.keys()))).to(device=logits.device) - src = torch.from_numpy(np.array(list(freq.values()))).type_as(logits).to(device=logits.device) - logits[i][index] -= src * param.frequency_penalty + param.presence_penalty - - if not param.repetition_penalty == 1.0 and bool(freq): - index = torch.from_numpy(np.array(list(freq.keys()))).to(device=logits.device) - logits[i][index] /= param.repetition_penalty - - if param.logit_bias: - logits[i][param.logit_bias_index] += torch.Tensor(param.logit_bias_value).type_as(logits).to(device=logits.device) - - logits_random = logits[mask_random] - - if divide_by_temperature: - t = torch.tensor(temperatures, dtype=logits.dtype, device=logits.device) - logits_random.div_(t.unsqueeze(dim=1)) - - if do_top_p or do_top_k: - logits_random = _apply_top_p_top_k(logits_random, top_ps, top_ks) - - probs = torch.softmax(logits_random, dim=-1) - - if check_safety and not _is_safe_to_sample(probs): - return None - - res_random = torch.multinomial(probs, 1, True).cpu().numpy()[:, 0] - - if logits_random.shape[0] == num_seq: - return res_random - - res = np.empty((num_seq,), dtype=np.int32) - res[mask_random] = res_random - - if logits_greedy.shape[0] > 0: - res[mask_greedy] = res_greedy - - return res - - -def load_disco_module(artifact_path, lib_path, num_shards): - sess = di.ProcessSession(num_workers=num_shards, entrypoint="tvm.exec.disco_worker") - devices = range(num_shards) - sess.init_ccl("nccl", *devices) - module = sess.load_vm_module(lib_path) - - loader_create = sess.get_global_func("runtime.disco.ShardLoader") - metadata_path = os.path.join(artifact_path, "params", "ndarray-cache.json") - with open(metadata_path, "r", encoding="utf-8") as f: - ndarray_cache_metadata = f.read() - - loader = loader_create(metadata_path, ndarray_cache_metadata, "", module) - loader_load = sess.get_global_func("runtime.disco.ShardLoaderLoadAllPresharded") - params = loader_load(loader) - - return module, params, sess - - -def copy_to_worker_0(sess: di.Session, host_array): - x_array = sess.empty(host_array.shape, host_array.dtype) - sess.copy_to_worker_0(host_array, x_array) - return x_array - - -def broadcast_from_worker_0(sess: di.Session, src, shape, dtype): - dst = sess.empty(shape, dtype) - sess.broadcast_from_worker0(src, dst) - return dst - - -def get_tvm_model(config, dev): - LOG.info(f"Loading parameters from {config.model_artifact_path}.") - lib_path = os.path.join(config.model_artifact_path, config.library_name) - - if config.num_shards == 1: - ex = tvm.runtime.load_module(lib_path) - vm = relax.VirtualMachine(ex, dev) - - from tvm.contrib import tvmjs # pylint: disable=import-outside-toplevel - _params, _meta = tvmjs.load_ndarray_cache(f"{config.model_artifact_path}/params", dev) - params = [] - for i in range(_meta["ParamSize"]): - params.append(_params[f"param_{i}"]) - - return vm.module, params, None - - return load_disco_module(config.model_artifact_path, lib_path, config.num_shards) - - -def _prepare_inputs( - sequence_ids, - all_token_ids, - prompt_lens, - all_slot_mappings, - all_decode_block_tables, - sliding_window, - is_prefill, -): - block_tables = [] - seq_lens = [] - input_ids = [] - slot_mapping = [] - positions = [] - max_num_blocks_per_seq = 0 - indices_within_window = [] - start_idx = 0 - - for i, (sequence_id, token_ids) in enumerate(zip(sequence_ids, all_token_ids)): - if is_prefill: - input_ids += token_ids - prompt_len = len(token_ids) - seq_lens.append(prompt_len) - positions += range(prompt_len) - slot_mapping += all_slot_mappings[sequence_id] - - if sliding_window: - indices_within_window += range( - start_idx + max(0, prompt_len - sliding_window), - start_idx + prompt_len, - ) - start_idx += prompt_len - - else: - input_ids.append(token_ids[-1]) - seq_len = prompt_lens[i] + len(token_ids) - positions.append(seq_len - 1) - block_table = all_decode_block_tables[sequence_id] - max_num_blocks_per_seq = max(max_num_blocks_per_seq, len(block_table)) - block_tables.append(block_table.get_blocks()) - slot_mapping.append(all_slot_mappings[sequence_id][-1]) - - if sliding_window: - seq_lens.append(min(seq_len, sliding_window)) - else: - seq_lens.append(seq_len) - - def to_ndarray_via_torch(arr, torch_dtype): - return tvm.nd.from_dlpack(torch.tensor(arr, dtype=torch_dtype, device="cuda")) - - input_ids = to_ndarray_via_torch(input_ids, torch.int) - positions = to_ndarray_via_torch(positions, torch.int) - seq_lens = to_ndarray_via_torch(seq_lens, torch.int) - slot_mapping = to_ndarray_via_torch(slot_mapping, torch.int) - - if is_prefill and sliding_window: - indices_within_window = to_ndarray_via_torch(indices_within_window, torch.int) - else: - indices_within_window = None - - if not is_prefill: - - def _pad_to_max(x: List[int], max_len: int) -> List[int]: - return x + [0] * (max_len - len(x)) - - padded_block_tables = [ - _pad_to_max(block_table, max_num_blocks_per_seq) - for block_table in block_tables - ] - block_tables = to_ndarray_via_torch(padded_block_tables, torch.int) - else: - block_tables = None - - return ( - input_ids, - positions, - seq_lens, - slot_mapping, - indices_within_window, - block_tables, - ) - - -class Model: - def __init__( - self, - config, - dev, - ): - self.mod, self.params, self.disco_session = get_tvm_model(config, dev) - self.dev = dev - self.vocab_size = config.vocab_size - self.sliding_window = config.sliding_window - self.num_shards = config.num_shards - - if self.sliding_window: - self.block_sliding_window = self.sliding_window // CacheManager.block_size - else: - self.block_sliding_window = None - - if self.disco_session: - self.copy_cache_blocks_func = self.disco_session.get_global_func( - "tvm.contrib.vllm.copy_blocks" - ) - else: - self.copy_cache_blocks_func = tvm.get_global_func( - "tvm.contrib.vllm.copy_blocks" - ) - - def get_used_memory(self): - if self.disco_session: - params = self.params.debug_get_from_remote(0) - - get_used_memory_func = self.disco_session.get_global_func( - "vm.memory_manager.get_used_memory" - ) - # For Disco, we explicitly query the device 0. - peak_memory = get_used_memory_func( - tvm.device("cuda", 0) - ).debug_get_from_remote(0) - - # TODO: temp hack to switch the VM allocator to eager recycling mode on all devices - for i in range(1, self.num_shards): - get_used_memory_func(tvm.device("cuda", i)).debug_get_from_remote(i) - else: - params = self.params - - get_used_memory_func = tvm.get_global_func( - "vm.memory_manager.get_used_memory" - ) - peak_memory = get_used_memory_func(self.dev) - - param_bytes = sum( - math.prod(param.shape) * np.dtype(param.dtype).itemsize for param in params - ) - - return peak_memory + param_bytes - - def profile_memory_usage(self, seq_lens): - input_ids = [0] * sum(seq_lens) - positions = [] - - for s in seq_lens: - positions += range(s) - - input_ids = tvm.nd.array(np.array(input_ids, dtype="int32"), self.dev) - positions = tvm.nd.array(np.array(positions, dtype="int32"), self.dev) - seq_lens = tvm.nd.array(np.array(seq_lens, dtype="int32"), self.dev) - - if self.disco_session: - input_ids = copy_to_worker_0(self.disco_session, input_ids) - positions = copy_to_worker_0(self.disco_session, positions) - seq_lens = copy_to_worker_0(self.disco_session, seq_lens) - - self.mod["evaluate"](input_ids, positions, seq_lens, self.params) - - return self.get_used_memory() - - def generate( - self, - requests: Union[List[PrefillRequest], List[DecodeRequest]], - cache: KVCache, - ) -> List[TextGenerationResult]: - if len(requests) == 0: - return [] - - is_prefill = isinstance(requests[0], PrefillRequest) - - all_token_ids = [] - sampling_params = [] - sequence_ids = [] - prompt_lens = [] - num_sequences = [] - - for request in requests: - if isinstance(request, PrefillRequest): - sequence_ids.append(get_prompt_sequence_id(request.request_id)) - num_sequences.append(request.num_sequence) - else: - sequence_ids.append(request.sequence_id) - prompt_lens.append(request.prompt_token_counts) - - all_token_ids.append(request.token_ids) - sampling_params.append(request.sampling_params) - - ( - input_ids, - positions, - seq_lens, - slot_mapping, - indices_within_window, - block_tables, - ) = _prepare_inputs( - sequence_ids, - all_token_ids, - prompt_lens, - cache.slot_mappings, - cache.decode_block_tables, - self.sliding_window, - is_prefill, - ) - - input_shape = input_ids.shape - - if self.disco_session: - input_ids = copy_to_worker_0(self.disco_session, input_ids) - positions = copy_to_worker_0(self.disco_session, positions) - seq_lens = copy_to_worker_0(self.disco_session, seq_lens) - slot_mapping = copy_to_worker_0(self.disco_session, slot_mapping) - - if is_prefill: - torch.cuda.nvtx.range_push(f"forward prefill {input_shape}") - - if self.sliding_window: - if self.disco_session: - indices_within_window = copy_to_worker_0( - self.disco_session, indices_within_window - ) - - out = self.mod["prefill"]( - input_ids, - positions, - seq_lens, - cache.cache_blocks, - slot_mapping, - indices_within_window, - self.params, - ) - else: - out = self.mod["prefill"]( - input_ids, - positions, - seq_lens, - cache.cache_blocks, - slot_mapping, - self.params, - ) - - if self.disco_session: - logits, _ = out.debug_get_from_remote(0) - else: - logits = out[ - 0 - ] # Ignore returned KV cache since it is updated in-place anyway. - else: - torch.cuda.nvtx.range_push(f"forward decode {input_shape}") - - if self.disco_session: - block_tables = copy_to_worker_0(self.disco_session, block_tables) - - out = self.mod["decode"]( - input_ids, - positions, - seq_lens, - cache.cache_blocks, - slot_mapping, - block_tables, - self.params, - ) - - if self.disco_session: - logits, _ = out.debug_get_from_remote(0) - else: - logits = out[0] - - torch.cuda.synchronize() - torch.cuda.nvtx.range_pop() - - if is_prefill and cache.pending_copy_from_to: - block_mapping = tvm.nd.array( - np.array(cache.pending_copy_from_to, dtype="int64") - ) - - if self.disco_session: - block_mapping = broadcast_from_worker_0( - self.disco_session, - copy_to_worker_0(self.disco_session, block_mapping), - block_mapping.shape, - "int64", - ) - - self.copy_cache_blocks_func(cache.cache_blocks, block_mapping) - cache.pending_copy_from_to = [] - - try: - next_tokens = sample(logits, sampling_params, self.vocab_size) - assert next_tokens is not None - outputs = [] - for i, (sequence_id, new_token) in enumerate( - zip(sequence_ids, next_tokens) - ): - if not new_token in requests[i].sampling_params.appeared_tokens_freq: - requests[i].sampling_params.appeared_tokens_freq[new_token] = 0 - requests[i].sampling_params.appeared_tokens_freq[new_token] += 1 - if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: - for seq_id in range(num_sequences[i]): - outputs.append( - TextGenerationResult( - sequence_id=SequenceId(sequence_id.request_id, seq_id), - generated_tokens=[new_token], - error=None, - ) - ) - else: - outputs.append( - TextGenerationResult( - sequence_id=sequence_id, - generated_tokens=[new_token], - error=None, - ) - ) - - return outputs - except RuntimeError: - # Fallback to per-token sampling in case some logits values are corrupted. - outputs = [] - err_msg = ( - "Error from sampling: probability tensor contains either `inf`, `nan`" - " or element < 0" - ) - - for i, (sequence_id, logits_per_token, sampling_param) in enumerate( - zip(sequence_ids, torch.from_dlpack(logits), sampling_params) - ): - maybe_new_token = sample( - torch.unsqueeze(logits_per_token, 0), - [sampling_param], - self.vocab_size, - check_safety=True, - ) - - if maybe_new_token is not None: - new_token = maybe_new_token[0] - if not new_token in requests[i].sampling_params.appeared_tokens_freq: - requests[i].sampling_params.appeared_tokens_freq[new_token] = 0 - requests[i].sampling_params.appeared_tokens_freq[new_token] += 1 - if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: - for seq_id in range(num_sequences[i]): - outputs.append( - TextGenerationResult( - sequence_id=SequenceId( - sequence_id.request_id, seq_id - ), - generated_tokens=[new_token], # type: ignore - error=None, - ) - ) - else: - outputs.append( - TextGenerationResult( - sequence_id=sequence_id, - generated_tokens=[new_token], # type: ignore - error=None, - ) - ) - else: - if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: - for seq_id in range(num_sequences[i]): - outputs.append( - TextGenerationResult( - sequence_id=SequenceId( - sequence_id.request_id, seq_id - ), - generated_tokens=[], - error=err_msg, - ) - ) - else: - outputs.append( - TextGenerationResult( - sequence_id=sequence_id, - generated_tokens=[], - error=err_msg, - ) - ) - - return outputs - - -def get_gpu_memory(gpu: int = 0) -> int: - return torch.cuda.get_device_properties(gpu).total_memory - - -def get_num_cache_blocks( - model, - seq_lens, - num_layers, - num_kv_heads, - head_size, - gpu_memory_utilization=0.9, # the default used by vllm -): - used_memory_bytes = model.profile_memory_usage(seq_lens) - cache_block_size = CacheManager.get_cache_block_size( - num_layers, num_kv_heads, head_size - ) - total_vram = get_gpu_memory() - return int( - (total_vram * gpu_memory_utilization - used_memory_bytes) // cache_block_size - ) - - class PagedCacheModelTextGenerator: - def __init__(self, model: Model): + def __init__(self, model: TextGenerator): self.model = model def generate( @@ -628,69 +52,10 @@ def __init__( ): model_artifact_config = get_model_artifact_config(model_artifact_path) - dev = tvm.device("cuda", 0) - - model = Model(model_artifact_config, dev) - - if model_artifact_config.num_shards > 1: - model.disco_session.sync_worker_0() - - num_kv_heads = ( - model_artifact_config.num_key_value_heads - // model_artifact_config.num_shards - ) - head_size = ( - model_artifact_config.hidden_size - // model_artifact_config.num_attention_heads - ) - - if engine_config.max_num_batched_tokens > 0: - LOG.info("Running memory profiling.") - num_blocks = get_num_cache_blocks( - model, - [engine_config.max_input_len] * engine_config.max_num_sequences, - model_artifact_config.num_hidden_layers, - num_kv_heads, - head_size, - ) + if engine_config.model_type == "tvm": + model, cache_manager = init_tvm_model(model_artifact_config, engine_config) else: - num_blocks = 500 - - num_cache_slots = num_blocks * CacheManager.block_size - - if num_cache_slots <= engine_config.max_num_batched_tokens: - raise RuntimeError( - f"max_num_batched_tokens = {engine_config.max_num_batched_tokens} but" - f" only {num_blocks} cache blocks can be allocated. The number of" - f" available cache slots is {num_cache_slots}, not enough for" - f" {engine_config.max_num_batched_tokens} tokens. Try reducing" - " --max_input_len or --max_num_sequences." - ) - - LOG.info(f"Using {num_blocks} cache blocks.") - - if model.disco_session: - init_cache_func = model.disco_session.get_global_func( - "tvm.contrib.vllm.allocate_kv_cache" - ) - else: - init_cache_func = tvm.get_global_func("tvm.contrib.vllm.allocate_kv_cache") - - cache_blocks = init_cache_func( - head_size, - model_artifact_config.num_hidden_layers, - num_kv_heads, - CacheManager.block_size, - num_blocks, - ) - - cache_manager = CacheManager( - cache_blocks, - num_blocks, - model_artifact_config.sliding_window, - ) - - LOG.info("Allocated KV cache blocks.") + raise RuntimeError(f"Unknown model type {engine_config.model_type}") self.engine_config = engine_config self.model_artifact_config = model_artifact_config diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py new file mode 100644 index 0000000000..c639bbf459 --- /dev/null +++ b/serve/mlc_serve/model/tvm_model.py @@ -0,0 +1,500 @@ +import math +import os +from typing import List, Union, Tuple + +import structlog +import numpy as np +import torch +import tvm +from tvm import relax +from tvm.runtime import disco as di + +from .base import ModelArtifactConfig +from .paged_cache_manager import KVCache, CacheManager +from .model_common import ( + sample, + prepare_inputs, + get_num_cache_blocks, +) + +from ..engine import ( + SequenceId, + PROMPT_SEQEUNCE_INDEX, + get_prompt_sequence_id, + MLCServeEngineConfig, +) +from ..engine.model_module import ( + DecodeRequest, + PrefillRequest, + TextGenerationResult, + TextGenerator, +) + +LOG = structlog.stdlib.get_logger(__name__) + + +def load_disco_module(artifact_path, lib_path, num_shards): + sess = di.ProcessSession(num_workers=num_shards, entrypoint="tvm.exec.disco_worker") + devices = range(num_shards) + sess.init_ccl("nccl", *devices) + module = sess.load_vm_module(lib_path) + + loader_create = sess.get_global_func("runtime.disco.ShardLoader") + metadata_path = os.path.join(artifact_path, "params", "ndarray-cache.json") + with open(metadata_path, "r", encoding="utf-8") as f: + ndarray_cache_metadata = f.read() + + loader = loader_create(metadata_path, ndarray_cache_metadata, "", module) + loader_load = sess.get_global_func("runtime.disco.ShardLoaderLoadAllPresharded") + params = loader_load(loader) + + return module, params, sess + + +def copy_to_worker_0(sess: di.Session, host_array): + x_array = sess.empty(host_array.shape, host_array.dtype) + sess.copy_to_worker_0(host_array, x_array) + return x_array + + +def broadcast_from_worker_0(sess: di.Session, src, shape, dtype): + dst = sess.empty(shape, dtype) + sess.broadcast_from_worker0(src, dst) + return dst + + +def get_tvm_model(config, dev): + LOG.info(f"Loading parameters from {config.model_artifact_path}.") + lib_path = os.path.join(config.model_artifact_path, config.library_name) + + if config.num_shards == 1: + ex = tvm.runtime.load_module(lib_path) + vm = relax.VirtualMachine(ex, dev) + + from tvm.contrib import tvmjs # pylint: disable=import-outside-toplevel + + _params, _meta = tvmjs.load_ndarray_cache( + f"{config.model_artifact_path}/params", dev + ) + params = [] + for i in range(_meta["ParamSize"]): + params.append(_params[f"param_{i}"]) + + return vm.module, params, None + + return load_disco_module(config.model_artifact_path, lib_path, config.num_shards) + + +def _prepare_inputs( + sequence_ids, + all_token_ids, + prompt_lens, + all_slot_mappings, + all_decode_block_tables, + sliding_window, + is_prefill, +): + ( + input_ids, + positions, + seq_lens, + slot_mapping, + indices_within_window, + block_tables, + ) = prepare_inputs( + sequence_ids, + all_token_ids, + prompt_lens, + all_slot_mappings, + all_decode_block_tables, + sliding_window, + is_prefill, + ) + + for inp in [ + input_ids, + positions, + seq_lens, + slot_mapping, + indices_within_window, + block_tables, + ]: + if inp: + inp = tvm.nd.from_dlpack(inp) + + return ( + input_ids, + positions, + seq_lens, + slot_mapping, + indices_within_window, + block_tables, + ) + + +class Model: + def __init__( + self, + config, + dev, + ): + self.mod, self.params, self.disco_session = get_tvm_model(config, dev) + self.dev = dev + self.vocab_size = config.vocab_size + self.sliding_window = config.sliding_window + self.num_shards = config.num_shards + + if self.sliding_window: + self.block_sliding_window = self.sliding_window // CacheManager.block_size + else: + self.block_sliding_window = None + + if self.disco_session: + self.copy_cache_blocks_func = self.disco_session.get_global_func( + "tvm.contrib.vllm.copy_blocks" + ) + else: + self.copy_cache_blocks_func = tvm.get_global_func( + "tvm.contrib.vllm.copy_blocks" + ) + + def get_used_memory(self): + if self.disco_session: + params = self.params.debug_get_from_remote(0) + + get_used_memory_func = self.disco_session.get_global_func( + "vm.memory_manager.get_used_memory" + ) + # For Disco, we explicitly query the device 0. + peak_memory = get_used_memory_func( + tvm.device("cuda", 0) + ).debug_get_from_remote(0) + + # TODO: temp hack to switch the VM allocator to eager recycling mode on all devices + for i in range(1, self.num_shards): + get_used_memory_func(tvm.device("cuda", i)).debug_get_from_remote(i) + else: + params = self.params + + get_used_memory_func = tvm.get_global_func( + "vm.memory_manager.get_used_memory" + ) + peak_memory = get_used_memory_func(self.dev) + + param_bytes = sum( + math.prod(param.shape) * np.dtype(param.dtype).itemsize for param in params + ) + + return peak_memory + param_bytes + + def profile_memory_usage(self, seq_lens): + input_ids = [0] * sum(seq_lens) + positions = [] + + for s in seq_lens: + positions += range(s) + + input_ids = tvm.nd.array(np.array(input_ids, dtype="int32"), self.dev) + positions = tvm.nd.array(np.array(positions, dtype="int32"), self.dev) + seq_lens = tvm.nd.array(np.array(seq_lens, dtype="int32"), self.dev) + + if self.disco_session: + input_ids = copy_to_worker_0(self.disco_session, input_ids) + positions = copy_to_worker_0(self.disco_session, positions) + seq_lens = copy_to_worker_0(self.disco_session, seq_lens) + + self.mod["evaluate"](input_ids, positions, seq_lens, self.params) + + return self.get_used_memory() + + def generate( + self, + requests: Union[List[PrefillRequest], List[DecodeRequest]], + cache: KVCache, + ) -> List[TextGenerationResult]: + if len(requests) == 0: + return [] + + is_prefill = isinstance(requests[0], PrefillRequest) + + all_token_ids = [] + sampling_params = [] + sequence_ids = [] + prompt_lens = [] + num_sequences = [] + + for request in requests: + if isinstance(request, PrefillRequest): + sequence_ids.append(get_prompt_sequence_id(request.request_id)) + num_sequences.append(request.num_sequence) + else: + sequence_ids.append(request.sequence_id) + prompt_lens.append(request.prompt_token_counts) + + all_token_ids.append(request.token_ids) + sampling_params.append(request.sampling_params) + + ( + input_ids, + positions, + seq_lens, + slot_mapping, + indices_within_window, + block_tables, + ) = _prepare_inputs( + sequence_ids, + all_token_ids, + prompt_lens, + cache.slot_mappings, + cache.decode_block_tables, + self.sliding_window, + is_prefill, + ) + + input_shape = input_ids.shape + + if self.disco_session: + input_ids = copy_to_worker_0(self.disco_session, input_ids) + positions = copy_to_worker_0(self.disco_session, positions) + seq_lens = copy_to_worker_0(self.disco_session, seq_lens) + slot_mapping = copy_to_worker_0(self.disco_session, slot_mapping) + + if is_prefill: + torch.cuda.nvtx.range_push(f"forward prefill {input_shape}") + + if self.sliding_window: + if self.disco_session: + indices_within_window = copy_to_worker_0( + self.disco_session, indices_within_window + ) + + out = self.mod["prefill"]( + input_ids, + positions, + seq_lens, + cache.cache_blocks, + slot_mapping, + indices_within_window, + self.params, + ) + else: + out = self.mod["prefill"]( + input_ids, + positions, + seq_lens, + cache.cache_blocks, + slot_mapping, + self.params, + ) + + if self.disco_session: + logits, _ = out.debug_get_from_remote(0) + else: + logits = out[ + 0 + ] # Ignore returned KV cache since it is updated in-place anyway. + else: + torch.cuda.nvtx.range_push(f"forward decode {input_shape}") + + if self.disco_session: + block_tables = copy_to_worker_0(self.disco_session, block_tables) + + out = self.mod["decode"]( + input_ids, + positions, + seq_lens, + cache.cache_blocks, + slot_mapping, + block_tables, + self.params, + ) + + if self.disco_session: + logits, _ = out.debug_get_from_remote(0) + else: + logits = out[0] + + torch.cuda.synchronize() + torch.cuda.nvtx.range_pop() + + if is_prefill and cache.pending_copy_from_to: + block_mapping = tvm.nd.array( + np.array(cache.pending_copy_from_to, dtype="int64") + ) + + if self.disco_session: + block_mapping = broadcast_from_worker_0( + self.disco_session, + copy_to_worker_0(self.disco_session, block_mapping), + block_mapping.shape, + "int64", + ) + + self.copy_cache_blocks_func(cache.cache_blocks, block_mapping) + cache.pending_copy_from_to = [] + + try: + next_tokens = sample(logits, sampling_params, self.vocab_size) + assert next_tokens is not None + outputs = [] + for i, (sequence_id, new_token) in enumerate( + zip(sequence_ids, next_tokens) + ): + if not new_token in requests[i].sampling_params.appeared_tokens_freq: + requests[i].sampling_params.appeared_tokens_freq[new_token] = 0 + requests[i].sampling_params.appeared_tokens_freq[new_token] += 1 + if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: + for seq_id in range(num_sequences[i]): + outputs.append( + TextGenerationResult( + sequence_id=SequenceId(sequence_id.request_id, seq_id), + generated_tokens=[new_token], + error=None, + ) + ) + else: + outputs.append( + TextGenerationResult( + sequence_id=sequence_id, + generated_tokens=[new_token], + error=None, + ) + ) + + return outputs + except RuntimeError: + # Fallback to per-token sampling in case some logits values are corrupted. + outputs = [] + err_msg = ( + "Error from sampling: probability tensor contains either `inf`, `nan`" + " or element < 0" + ) + + for i, (sequence_id, logits_per_token, sampling_param) in enumerate( + zip(sequence_ids, torch.from_dlpack(logits), sampling_params) + ): + maybe_new_token = sample( + torch.unsqueeze(logits_per_token, 0), + [sampling_param], + self.vocab_size, + check_safety=True, + ) + + if maybe_new_token is not None: + new_token = maybe_new_token[0] + if ( + not new_token + in requests[i].sampling_params.appeared_tokens_freq + ): + requests[i].sampling_params.appeared_tokens_freq[new_token] = 0 + requests[i].sampling_params.appeared_tokens_freq[new_token] += 1 + if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: + for seq_id in range(num_sequences[i]): + outputs.append( + TextGenerationResult( + sequence_id=SequenceId( + sequence_id.request_id, seq_id + ), + generated_tokens=[new_token], # type: ignore + error=None, + ) + ) + else: + outputs.append( + TextGenerationResult( + sequence_id=sequence_id, + generated_tokens=[new_token], # type: ignore + error=None, + ) + ) + else: + if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: + for seq_id in range(num_sequences[i]): + outputs.append( + TextGenerationResult( + sequence_id=SequenceId( + sequence_id.request_id, seq_id + ), + generated_tokens=[], + error=err_msg, + ) + ) + else: + outputs.append( + TextGenerationResult( + sequence_id=sequence_id, + generated_tokens=[], + error=err_msg, + ) + ) + + return outputs + + +def init_tvm_model( + model_artifact_config: ModelArtifactConfig, engine_config: MLCServeEngineConfig +) -> Tuple[TextGenerator, CacheManager]: + dev = tvm.device("cuda", 0) + + model = Model(model_artifact_config, dev) + + if model_artifact_config.num_shards > 1: + model.disco_session.sync_worker_0() + + num_kv_heads = ( + model_artifact_config.num_key_value_heads // model_artifact_config.num_shards + ) + head_size = ( + model_artifact_config.hidden_size // model_artifact_config.num_attention_heads + ) + + if engine_config.max_num_batched_tokens > 0: + LOG.info("Running memory profiling.") + num_blocks = get_num_cache_blocks( + model, + [engine_config.max_input_len] * engine_config.max_num_sequences, + model_artifact_config.num_hidden_layers, + num_kv_heads, + head_size, + ) + else: + num_blocks = 500 + + num_cache_slots = num_blocks * CacheManager.block_size + + if num_cache_slots <= engine_config.max_num_batched_tokens: + raise RuntimeError( + f"max_num_batched_tokens = {engine_config.max_num_batched_tokens} but" + f" only {num_blocks} cache blocks can be allocated. The number of" + f" available cache slots is {num_cache_slots}, not enough for" + f" {engine_config.max_num_batched_tokens} tokens. Try reducing" + " --max_input_len or --max_num_sequences." + ) + + LOG.info(f"Using {num_blocks} cache blocks.") + + if model.disco_session: + init_cache_func = model.disco_session.get_global_func( + "tvm.contrib.vllm.allocate_kv_cache" + ) + else: + init_cache_func = tvm.get_global_func("tvm.contrib.vllm.allocate_kv_cache") + + cache_blocks = init_cache_func( + head_size, + model_artifact_config.num_hidden_layers, + num_kv_heads, + CacheManager.block_size, + num_blocks, + ) + + cache_manager = CacheManager( + cache_blocks, + num_blocks, + model_artifact_config.sliding_window, + ) + + LOG.info("Allocated KV cache blocks.") + + # TODO(masahi): Make mypy understand that model confirms to TextGenerator Protocol. + return model, cache_manager # type: ignore From 7a84f15aaa1b332aba7c852253656c6d598b604e Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 11 Jan 2024 20:22:02 +0000 Subject: [PATCH 02/73] fix --- serve/mlc_serve/model/tvm_model.py | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index c639bbf459..76e4451dc4 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -111,22 +111,16 @@ def _prepare_inputs( is_prefill, ) - for inp in [ - input_ids, - positions, - seq_lens, - slot_mapping, - indices_within_window, - block_tables, - ]: - if inp: - inp = tvm.nd.from_dlpack(inp) + if block_tables is not None: + block_tables = tvm.nd.from_dlpack(block_tables) + if indices_within_window is not None: + indices_within_window = tvm.nd.from_dlpack(indices_within_window) return ( - input_ids, - positions, - seq_lens, - slot_mapping, + tvm.nd.from_dlpack(input_ids), + tvm.nd.from_dlpack(positions), + tvm.nd.from_dlpack(seq_lens), + tvm.nd.from_dlpack(slot_mapping), indices_within_window, block_tables, ) From f454b7b4952d9c92aaabde1c2888685fdcac3bbe Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 11 Jan 2024 20:51:17 +0000 Subject: [PATCH 03/73] Remove engine config change for now --- serve/mlc_serve/engine/base.py | 1 - serve/mlc_serve/model/paged_cache_model.py | 6 ++---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/serve/mlc_serve/engine/base.py b/serve/mlc_serve/engine/base.py index 9fed40d046..bf678268fe 100644 --- a/serve/mlc_serve/engine/base.py +++ b/serve/mlc_serve/engine/base.py @@ -30,7 +30,6 @@ class MLCServeEngineConfig: min_decode_steps: int = 32 max_decode_steps: int = 48 init_timeout: int = 120 - model_type: str = "tvm" # "tvm", "torch" @classmethod def _from_json(config_cls, json_obj: Dict[Any, Any]): diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index fc1de355da..a7f11750eb 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -52,10 +52,8 @@ def __init__( ): model_artifact_config = get_model_artifact_config(model_artifact_path) - if engine_config.model_type == "tvm": - model, cache_manager = init_tvm_model(model_artifact_config, engine_config) - else: - raise RuntimeError(f"Unknown model type {engine_config.model_type}") + # TODO(masahi): Make the model type configurable. + model, cache_manager = init_tvm_model(model_artifact_config, engine_config) self.engine_config = engine_config self.model_artifact_config = model_artifact_config From afde741390309c7919cac8dbcfc825902999778e Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 12 Jan 2024 05:59:35 +0000 Subject: [PATCH 04/73] make mypy happy with TextGenerator impl by Model --- serve/mlc_serve/engine/model_module.py | 2 +- serve/mlc_serve/model/tvm_model.py | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/serve/mlc_serve/engine/model_module.py b/serve/mlc_serve/engine/model_module.py index 71d10c8e07..79b77e93a3 100644 --- a/serve/mlc_serve/engine/model_module.py +++ b/serve/mlc_serve/engine/model_module.py @@ -118,7 +118,7 @@ class TextGenerator(Protocol): def generate( self, requests: Sequence[Union[PrefillRequest, DecodeRequest]], - kv_cache: KVCache, + kv_cache, ) -> List[TextGenerationResult]: """ A unified entrypoint for text generation. diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index 76e4451dc4..cb5683a5c9 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -1,6 +1,6 @@ import math import os -from typing import List, Union, Tuple +from typing import List, Union, Tuple, Sequence import structlog import numpy as np @@ -203,7 +203,7 @@ def profile_memory_usage(self, seq_lens): def generate( self, - requests: Union[List[PrefillRequest], List[DecodeRequest]], + requests: Sequence[Union[PrefillRequest, DecodeRequest]], cache: KVCache, ) -> List[TextGenerationResult]: if len(requests) == 0: @@ -490,5 +490,4 @@ def init_tvm_model( LOG.info("Allocated KV cache blocks.") - # TODO(masahi): Make mypy understand that model confirms to TextGenerator Protocol. - return model, cache_manager # type: ignore + return model, cache_manager From c49ef45f6dde01dfbd19e78b5d1925f06e8120b5 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 11 Jan 2024 20:50:09 +0000 Subject: [PATCH 05/73] stub --- serve/mlc_serve/model/torch_model.py | 256 +++++++++++++++++++++++++++ 1 file changed, 256 insertions(+) create mode 100644 serve/mlc_serve/model/torch_model.py diff --git a/serve/mlc_serve/model/torch_model.py b/serve/mlc_serve/model/torch_model.py new file mode 100644 index 0000000000..521633ac15 --- /dev/null +++ b/serve/mlc_serve/model/torch_model.py @@ -0,0 +1,256 @@ +import math +import os +from typing import List, Union, Tuple + +import structlog +import numpy as np +import torch +import tvm + +from .base import ModelArtifactConfig +from .paged_cache_manager import KVCache, CacheManager +from .model_common import ( + sample, + prepare_inputs, + get_num_cache_blocks, +) + +from ..engine import ( + SequenceId, + PROMPT_SEQEUNCE_INDEX, + get_prompt_sequence_id, + MLCServeEngineConfig, +) +from ..engine.model_module import ( + DecodeRequest, + PrefillRequest, + TextGenerationResult, + TextGenerator, +) + +LOG = structlog.stdlib.get_logger(__name__) + + +class Model: + def __init__( + self, + config, + ): + self.vocab_size = config.vocab_size + self.sliding_window = config.sliding_window + self.num_shards = config.num_shards + + if self.sliding_window: + self.block_sliding_window = self.sliding_window // CacheManager.block_size + else: + self.block_sliding_window = None + + def get_used_memory(self): + return 0 + + def profile_memory_usage(self, seq_lens): + return self.get_used_memory() + + def generate( + self, + requests: Union[List[PrefillRequest], List[DecodeRequest]], + cache: KVCache, + ) -> List[TextGenerationResult]: + if len(requests) == 0: + return [] + + is_prefill = isinstance(requests[0], PrefillRequest) + + all_token_ids = [] + sampling_params = [] + sequence_ids = [] + prompt_lens = [] + num_sequences = [] + + for request in requests: + if isinstance(request, PrefillRequest): + sequence_ids.append(get_prompt_sequence_id(request.request_id)) + num_sequences.append(request.num_sequence) + else: + sequence_ids.append(request.sequence_id) + prompt_lens.append(request.prompt_token_counts) + + all_token_ids.append(request.token_ids) + sampling_params.append(request.sampling_params) + + ( + input_ids, + positions, + seq_lens, + slot_mapping, + indices_within_window, + block_tables, + ) = prepare_inputs( + sequence_ids, + all_token_ids, + prompt_lens, + cache.slot_mappings, + cache.decode_block_tables, + self.sliding_window, + is_prefill, + ) + + input_shape = input_ids.shape + + if is_prefill: + torch.cuda.nvtx.range_push(f"forward prefill {input_shape}") + else: + torch.cuda.nvtx.range_push(f"forward decode {input_shape}") + + torch.cuda.synchronize() + torch.cuda.nvtx.range_pop() + + logits = None + + try: + next_tokens = sample(logits, sampling_params, self.vocab_size) + assert next_tokens is not None + outputs = [] + for i, (sequence_id, new_token) in enumerate( + zip(sequence_ids, next_tokens) + ): + if not new_token in requests[i].sampling_params.appeared_tokens_freq: + requests[i].sampling_params.appeared_tokens_freq[new_token] = 0 + requests[i].sampling_params.appeared_tokens_freq[new_token] += 1 + if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: + for seq_id in range(num_sequences[i]): + outputs.append( + TextGenerationResult( + sequence_id=SequenceId(sequence_id.request_id, seq_id), + generated_tokens=[new_token], + error=None, + ) + ) + else: + outputs.append( + TextGenerationResult( + sequence_id=sequence_id, + generated_tokens=[new_token], + error=None, + ) + ) + + return outputs + except RuntimeError: + # Fallback to per-token sampling in case some logits values are corrupted. + outputs = [] + err_msg = ( + "Error from sampling: probability tensor contains either `inf`, `nan`" + " or element < 0" + ) + + for i, (sequence_id, logits_per_token, sampling_param) in enumerate( + zip(sequence_ids, torch.from_dlpack(logits), sampling_params) + ): + maybe_new_token = sample( + torch.unsqueeze(logits_per_token, 0), + [sampling_param], + self.vocab_size, + check_safety=True, + ) + + if maybe_new_token is not None: + new_token = maybe_new_token[0] + if ( + not new_token + in requests[i].sampling_params.appeared_tokens_freq + ): + requests[i].sampling_params.appeared_tokens_freq[new_token] = 0 + requests[i].sampling_params.appeared_tokens_freq[new_token] += 1 + if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: + for seq_id in range(num_sequences[i]): + outputs.append( + TextGenerationResult( + sequence_id=SequenceId( + sequence_id.request_id, seq_id + ), + generated_tokens=[new_token], # type: ignore + error=None, + ) + ) + else: + outputs.append( + TextGenerationResult( + sequence_id=sequence_id, + generated_tokens=[new_token], # type: ignore + error=None, + ) + ) + else: + if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: + for seq_id in range(num_sequences[i]): + outputs.append( + TextGenerationResult( + sequence_id=SequenceId( + sequence_id.request_id, seq_id + ), + generated_tokens=[], + error=err_msg, + ) + ) + else: + outputs.append( + TextGenerationResult( + sequence_id=sequence_id, + generated_tokens=[], + error=err_msg, + ) + ) + + return outputs + + +def init_torch_model( + model_artifact_config: ModelArtifactConfig, engine_config: MLCServeEngineConfig +) -> Tuple[TextGenerator, CacheManager]: + model = Model(model_artifact_config) + + num_kv_heads = ( + model_artifact_config.num_key_value_heads // model_artifact_config.num_shards + ) + head_size = ( + model_artifact_config.hidden_size // model_artifact_config.num_attention_heads + ) + + if engine_config.max_num_batched_tokens > 0: + LOG.info("Running memory profiling.") + num_blocks = get_num_cache_blocks( + model, + [engine_config.max_input_len] * engine_config.max_num_sequences, + model_artifact_config.num_hidden_layers, + num_kv_heads, + head_size, + ) + else: + num_blocks = 500 + + num_cache_slots = num_blocks * CacheManager.block_size + + if num_cache_slots <= engine_config.max_num_batched_tokens: + raise RuntimeError( + f"max_num_batched_tokens = {engine_config.max_num_batched_tokens} but" + f" only {num_blocks} cache blocks can be allocated. The number of" + f" available cache slots is {num_cache_slots}, not enough for" + f" {engine_config.max_num_batched_tokens} tokens. Try reducing" + " --max_input_len or --max_num_sequences." + ) + + LOG.info(f"Using {num_blocks} cache blocks.") + + cache_blocks = None + + cache_manager = CacheManager( + cache_blocks, + num_blocks, + model_artifact_config.sliding_window, + ) + + LOG.info("Allocated KV cache blocks.") + + # TODO(masahi): Make mypy understand that model confirms to TextGenerator Protocol. + return model, cache_manager # type: ignore From d9ac72f4a564f357cbb480db43b2a2e097c07415 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 11 Jan 2024 22:16:08 +0000 Subject: [PATCH 06/73] wip --- serve/mlc_serve/engine/base.py | 1 + serve/mlc_serve/model/paged_cache_model.py | 14 +++++--- serve/mlc_serve/model/torch_model.py | 38 +++++++++++++++++----- 3 files changed, 40 insertions(+), 13 deletions(-) diff --git a/serve/mlc_serve/engine/base.py b/serve/mlc_serve/engine/base.py index bf678268fe..9fed40d046 100644 --- a/serve/mlc_serve/engine/base.py +++ b/serve/mlc_serve/engine/base.py @@ -30,6 +30,7 @@ class MLCServeEngineConfig: min_decode_steps: int = 32 max_decode_steps: int = 48 init_timeout: int = 120 + model_type: str = "tvm" # "tvm", "torch" @classmethod def _from_json(config_cls, json_obj: Dict[Any, Any]): diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index a7f11750eb..88541ef454 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -6,6 +6,7 @@ from .paged_cache_manager import CacheManager from .tokenizer import HfTokenizerModule, ConversationTemplate, Tokenizer from .tvm_model import init_tvm_model +from .torch_model import init_torch_model from ..engine import MLCServeEngineConfig from ..engine.model_module import ( @@ -50,10 +51,15 @@ def __init__( model_artifact_path: Path, engine_config: MLCServeEngineConfig, ): - model_artifact_config = get_model_artifact_config(model_artifact_path) - - # TODO(masahi): Make the model type configurable. - model, cache_manager = init_tvm_model(model_artifact_config, engine_config) + if engine_config.model_type == "tvm": + model_artifact_config = get_model_artifact_config(model_artifact_path) + model, cache_manager = init_tvm_model(model_artifact_config, engine_config) + elif engine_config.model_type == "torch": + model, cache_manager, model_artifact_config = init_torch_model( + model_artifact_path, engine_config + ) + else: + raise RuntimeError(f"Unknown model type {engine_config.model_type}") self.engine_config = engine_config self.model_artifact_config = model_artifact_config diff --git a/serve/mlc_serve/model/torch_model.py b/serve/mlc_serve/model/torch_model.py index 521633ac15..c32dc6ccee 100644 --- a/serve/mlc_serve/model/torch_model.py +++ b/serve/mlc_serve/model/torch_model.py @@ -5,7 +5,8 @@ import structlog import numpy as np import torch -import tvm + +from vllm.model_executor.models.llama import LlamaForCausalLM from .base import ModelArtifactConfig from .paged_cache_manager import KVCache, CacheManager @@ -36,6 +37,8 @@ def __init__( self, config, ): + self.pt_model = LlamaForCausalLM(config) + self.vocab_size = config.vocab_size self.sliding_window = config.sliding_window self.num_shards = config.num_shards @@ -206,15 +209,20 @@ def generate( def init_torch_model( - model_artifact_config: ModelArtifactConfig, engine_config: MLCServeEngineConfig -) -> Tuple[TextGenerator, CacheManager]: - model = Model(model_artifact_config) + model_path, engine_config: MLCServeEngineConfig +) -> Tuple[TextGenerator, CacheManager, ModelArtifactConfig]: + from transformers import AutoConfig + hf_config = AutoConfig.from_pretrained(model_path) + + model = Model(hf_config) + # TODO + num_shards = 1 num_kv_heads = ( - model_artifact_config.num_key_value_heads // model_artifact_config.num_shards + hf_config.num_key_value_heads // num_shards ) head_size = ( - model_artifact_config.hidden_size // model_artifact_config.num_attention_heads + hf_config.hidden_size // hf_config.num_attention_heads ) if engine_config.max_num_batched_tokens > 0: @@ -222,7 +230,7 @@ def init_torch_model( num_blocks = get_num_cache_blocks( model, [engine_config.max_input_len] * engine_config.max_num_sequences, - model_artifact_config.num_hidden_layers, + hf_config.num_hidden_layers, num_kv_heads, head_size, ) @@ -247,10 +255,22 @@ def init_torch_model( cache_manager = CacheManager( cache_blocks, num_blocks, - model_artifact_config.sliding_window, + hf_config.sliding_window, ) LOG.info("Allocated KV cache blocks.") + artifact_config = ModelArtifactConfig( + model_artifact_path=model_path, + num_shards=1, + quantization=None, + max_context_length=hf_config.max_position_embeddings, # TODO, + vocab_size=hf_config.vocab_size, + sliding_window=hf_config.sliding_window, + num_key_value_heads=hf_config.num_key_value_heads, + num_attention_heads=hf_config.nnum_attention_heads, + num_hidden_layers=hf_config.num_hidden_layers, + hidden_size=hf_config.hidden_size, + ) # TODO(masahi): Make mypy understand that model confirms to TextGenerator Protocol. - return model, cache_manager # type: ignore + return model, cache_manager, artifact_config # type: ignore From acbf825ad81230535ecb18a1954226f9a9797495 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 11 Jan 2024 23:50:45 +0000 Subject: [PATCH 07/73] wip --- serve/mlc_serve/model/paged_cache_model.py | 3 +- serve/mlc_serve/model/tokenizer.py | 5 +- serve/mlc_serve/model/torch_model.py | 78 +++++++++++++++++----- serve/tests/test_engine.py | 7 ++ 4 files changed, 73 insertions(+), 20 deletions(-) diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index 88541ef454..df5dbf6256 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -54,10 +54,12 @@ def __init__( if engine_config.model_type == "tvm": model_artifact_config = get_model_artifact_config(model_artifact_path) model, cache_manager = init_tvm_model(model_artifact_config, engine_config) + tokenizer_module = HfTokenizerModule(model_artifact_path.joinpath("model")) elif engine_config.model_type == "torch": model, cache_manager, model_artifact_config = init_torch_model( model_artifact_path, engine_config ) + tokenizer_module = HfTokenizerModule(model_artifact_path) else: raise RuntimeError(f"Unknown model type {engine_config.model_type}") @@ -66,7 +68,6 @@ def __init__( self.text_generator = PagedCacheModelTextGenerator(model) self.cache_manager = cache_manager - tokenizer_module = HfTokenizerModule(model_artifact_path) self.tokenizer = tokenizer_module.tokenizer self.conversation_template = tokenizer_module.conversation_template diff --git a/serve/mlc_serve/model/tokenizer.py b/serve/mlc_serve/model/tokenizer.py index 23e0c044a5..ca7dce48e1 100644 --- a/serve/mlc_serve/model/tokenizer.py +++ b/serve/mlc_serve/model/tokenizer.py @@ -46,10 +46,9 @@ def apply(self, messages: list[ChatMessage]) -> str: class HfTokenizerModule: - def __init__(self, model_artifact_path: Path): + def __init__(self, tokenizer_path: Path): hf_tokenizer = AutoTokenizer.from_pretrained( - model_artifact_path.joinpath("model"), - trust_remote_code=False, + tokenizer_path, trust_remote_code=False ) self.tokenizer = Tokenizer(hf_tokenizer) self.conversation_template = ConversationTemplate(hf_tokenizer) diff --git a/serve/mlc_serve/model/torch_model.py b/serve/mlc_serve/model/torch_model.py index c32dc6ccee..0524fbf905 100644 --- a/serve/mlc_serve/model/torch_model.py +++ b/serve/mlc_serve/model/torch_model.py @@ -35,10 +35,10 @@ class Model: def __init__( self, + pt_model, config, ): - self.pt_model = LlamaForCausalLM(config) - + self.pt_model = pt_model self.vocab_size = config.vocab_size self.sliding_window = config.sliding_window self.num_shards = config.num_shards @@ -208,13 +208,46 @@ def generate( return outputs +def init_cache_blocks(head_size, num_layers, num_heads, block_size, num_gpu_blocks): + element_size = 2 + x = 16 // element_size + + key_block_shape = (num_heads, head_size // x, block_size, x) + value_block_shape = (num_heads, head_size, block_size) + + gpu_cache = [] + for _ in range(num_layers): + key_blocks = torch.empty( + size=(num_gpu_blocks, *key_block_shape), + dtype=torch.float16, + device="cuda", + ) + value_blocks = torch.empty( + size=(num_gpu_blocks, *value_block_shape), + dtype=torch.float16, + device="cuda", + ) + gpu_cache.append((key_blocks, value_blocks)) + return gpu_cache + + def init_torch_model( model_path, engine_config: MLCServeEngineConfig ) -> Tuple[TextGenerator, CacheManager, ModelArtifactConfig]: from transformers import AutoConfig + from vllm.model_executor.utils import set_random_seed + from vllm.model_executor.parallel_utils.parallel_state import ( + initialize_model_parallel, + ) + + torch.distributed.init_process_group( + backend="nccl", world_size=1, rank=0, init_method="tcp://localhost:59157" + ) + initialize_model_parallel(1, 1) + + print("model_path", model_path) hf_config = AutoConfig.from_pretrained(model_path) - model = Model(hf_config) # TODO num_shards = 1 @@ -225,6 +258,28 @@ def init_torch_model( hf_config.hidden_size // hf_config.num_attention_heads ) + if not hasattr(hf_config, "sliding_window"): + hf_config.sliding_window = None + + hf_config.num_shards = num_shards + + artifact_config = ModelArtifactConfig( + model_artifact_path=model_path, + num_shards=1, + quantization=None, + max_context_length=hf_config.max_position_embeddings, # TODO, + vocab_size=hf_config.vocab_size, + sliding_window=hf_config.sliding_window, + num_key_value_heads=num_kv_heads, + num_attention_heads=hf_config.num_attention_heads, + num_hidden_layers=hf_config.num_hidden_layers, + hidden_size=hf_config.hidden_size, + ) + + pt_model = LlamaForCausalLM(hf_config) + pt_model.load_weights(model_path, None, "auto", None) + model = Model(pt_model, hf_config) + if engine_config.max_num_batched_tokens > 0: LOG.info("Running memory profiling.") num_blocks = get_num_cache_blocks( @@ -250,7 +305,10 @@ def init_torch_model( LOG.info(f"Using {num_blocks} cache blocks.") - cache_blocks = None + cache_blocks = init_cache_blocks( + head_size, hf_config.num_hidden_layers, + hf_config.num_attention_heads, CacheManager.block_size, num_blocks + ) cache_manager = CacheManager( cache_blocks, @@ -260,17 +318,5 @@ def init_torch_model( LOG.info("Allocated KV cache blocks.") - artifact_config = ModelArtifactConfig( - model_artifact_path=model_path, - num_shards=1, - quantization=None, - max_context_length=hf_config.max_position_embeddings, # TODO, - vocab_size=hf_config.vocab_size, - sliding_window=hf_config.sliding_window, - num_key_value_heads=hf_config.num_key_value_heads, - num_attention_heads=hf_config.nnum_attention_heads, - num_hidden_layers=hf_config.num_hidden_layers, - hidden_size=hf_config.hidden_size, - ) # TODO(masahi): Make mypy understand that model confirms to TextGenerator Protocol. return model, cache_manager, artifact_config # type: ignore diff --git a/serve/tests/test_engine.py b/serve/tests/test_engine.py index 4355fec6cd..1842a2044a 100644 --- a/serve/tests/test_engine.py +++ b/serve/tests/test_engine.py @@ -1,3 +1,4 @@ +import os import argparse import json import random @@ -17,6 +18,11 @@ def _test(args: argparse.Namespace): + model_type = "tvm" + + if not os.path.exists(args.model_artifact_path.joinpath("mlc-model-config.json")): + model_type = "torch" + engine_config = get_engine_config( { "use_staging_engine": args.use_staging_engine, @@ -24,6 +30,7 @@ def _test(args: argparse.Namespace): "max_input_len": args.max_input_len, "min_decode_steps": args.min_decode_steps, "max_decode_steps": args.max_decode_steps, + "model_type": model_type, } ) From fef750fa1bc597b19f3ebfd5d654b2e07c7d586b Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 12 Jan 2024 00:18:10 +0000 Subject: [PATCH 08/73] wip --- serve/mlc_serve/model/torch_model.py | 150 ++++++++++++++++++++++++++- 1 file changed, 145 insertions(+), 5 deletions(-) diff --git a/serve/mlc_serve/model/torch_model.py b/serve/mlc_serve/model/torch_model.py index 0524fbf905..2031b27fe1 100644 --- a/serve/mlc_serve/model/torch_model.py +++ b/serve/mlc_serve/model/torch_model.py @@ -7,12 +7,15 @@ import torch from vllm.model_executor.models.llama import LlamaForCausalLM +from vllm.sequence import SequenceGroupMetadata, SequenceData +from vllm.model_executor import InputMetadata +from vllm.sampling_params import SamplingParams from .base import ModelArtifactConfig from .paged_cache_manager import KVCache, CacheManager from .model_common import ( sample, - prepare_inputs, + # prepare_inputs, get_num_cache_blocks, ) @@ -32,6 +35,107 @@ LOG = structlog.stdlib.get_logger(__name__) +def prepare_inputs( + seq_group_metadata_list: List[SequenceGroupMetadata], + block_size, +) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]: + seq_groups: List[Tuple[List[int], SamplingParams]] = [] + input_tokens: List[int] = [] + input_positions: List[int] = [] + slot_mapping: List[int] = [] + + prompt_lens: List[int] = [] + for seq_group_metadata in seq_group_metadata_list: + if not seq_group_metadata.is_prompt: + continue + + seq_ids = list(seq_group_metadata.seq_data.keys()) + sampling_params = seq_group_metadata.sampling_params + seq_groups.append((seq_ids, sampling_params)) + + seq_id = seq_ids[0] + + seq_data = seq_group_metadata.seq_data[seq_id] + prompt_tokens = seq_data.get_token_ids() + prompt_len = len(prompt_tokens) + prompt_lens.append(prompt_len) + + input_tokens.extend(prompt_tokens) + input_positions.extend(range(len(prompt_tokens))) + + if seq_group_metadata.block_tables is None: + slot_mapping.extend([0] * prompt_len) + continue + + block_table = seq_group_metadata.block_tables[seq_id] + for i in range(prompt_len): + block_number = block_table[i // block_size] + block_offset = i % block_size + slot = block_number * block_size + block_offset + slot_mapping.append(slot) + + max_context_len = 0 + max_num_blocks_per_seq = 0 + context_lens: List[int] = [] + generation_block_tables: List[List[int]] = [] + for seq_group_metadata in seq_group_metadata_list: + if seq_group_metadata.is_prompt: + continue + + seq_ids = list(seq_group_metadata.seq_data.keys()) + sampling_params = seq_group_metadata.sampling_params + seq_groups.append((seq_ids, sampling_params)) + + for seq_id in seq_ids: + seq_data = seq_group_metadata.seq_data[seq_id] + generation_token = seq_data.get_last_token_id() + input_tokens.append(generation_token) + + context_len = seq_data.get_len() + position = context_len - 1 + input_positions.append(position) + + block_table = seq_group_metadata.block_tables[seq_id] + generation_block_tables.append(list(block_table)) + + max_context_len = max(max_context_len, context_len) + max_num_blocks_per_seq = max(max_num_blocks_per_seq, len(block_table)) + context_lens.append(context_len) + + block_number = block_table[position // block_size] + block_offset = position % block_size + slot = block_number * block_size + block_offset + slot_mapping.append(slot) + + def _pad_to_max(x: List[int], max_len: int) -> List[int]: + return x + [0] * (max_len - len(x)) + + tokens_tensor = torch.cuda.LongTensor(input_tokens) + positions_tensor = torch.cuda.LongTensor(input_positions) + slot_mapping_tensor = torch.cuda.IntTensor(slot_mapping) + context_lens_tensor = torch.cuda.IntTensor(context_lens) + padded_block_tables = [ + _pad_to_max(block_table, max_num_blocks_per_seq) + for block_table in generation_block_tables + ] + block_tables_tensor = torch.cuda.IntTensor(padded_block_tables) + + seq_data = {} + for seq_group_metadata in seq_group_metadata_list: + seq_data.update(seq_group_metadata.seq_data) + + input_metadata = InputMetadata( + seq_groups=seq_groups, + seq_data=seq_data, + prompt_lens=prompt_lens, + slot_mapping=slot_mapping_tensor, + context_lens=context_lens_tensor, + max_context_len=max_context_len, + block_tables=block_tables_tensor, + ) + return tokens_tensor, positions_tensor, input_metadata + + class Model: def __init__( self, @@ -42,17 +146,51 @@ def __init__( self.vocab_size = config.vocab_size self.sliding_window = config.sliding_window self.num_shards = config.num_shards + self.num_hidden_layers = config.num_hidden_layers if self.sliding_window: self.block_sliding_window = self.sliding_window // CacheManager.block_size else: self.block_sliding_window = None - def get_used_memory(self): - return 0 - def profile_memory_usage(self, seq_lens): - return self.get_used_memory() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + seqs: List[SequenceGroupMetadata] = [] + sampling_params = SamplingParams(top_p=0.99) + + for i, seq_len in enumerate(seq_lens): + seq_data = SequenceData([0] * seq_len) + seq = SequenceGroupMetadata( + request_id=str(i), + is_prompt=True, + seq_data={i: seq_data}, + sampling_params=sampling_params, + block_tables=None, + ) + seqs.append(seq) + + input_ids, positions, input_metadata = prepare_inputs( + seqs, CacheManager.block_size, + ) + + kv_caches = [(None, None)] * self.num_hidden_layers + + with torch.no_grad(): + self.pt_model.forward( + input_ids, positions, kv_caches, input_metadata, + cache_events=None, + ) + + torch.cuda.synchronize() + + peak_memory = torch.cuda.max_memory_allocated() + print("peak memory", peak_memory / 1e9) + + torch.cuda.empty_cache() + + return peak_memory def generate( self, @@ -318,5 +456,7 @@ def init_torch_model( LOG.info("Allocated KV cache blocks.") + assert False + # TODO(masahi): Make mypy understand that model confirms to TextGenerator Protocol. return model, cache_manager, artifact_config # type: ignore From 25a567ed345a2972130f92be351140e2f092eb7e Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 12 Jan 2024 00:51:06 +0000 Subject: [PATCH 09/73] PT model memory profiling works --- serve/mlc_serve/model/paged_cache_model.py | 2 +- serve/mlc_serve/model/torch_model.py | 44 +++++++++++----------- serve/mlc_serve/model/tvm_model.py | 2 + serve/tests/test_engine.py | 2 +- 4 files changed, 27 insertions(+), 23 deletions(-) diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index df5dbf6256..68270ef2c5 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -5,8 +5,8 @@ from .base import get_model_artifact_config from .paged_cache_manager import CacheManager from .tokenizer import HfTokenizerModule, ConversationTemplate, Tokenizer -from .tvm_model import init_tvm_model from .torch_model import init_torch_model +from .tvm_model import init_tvm_model from ..engine import MLCServeEngineConfig from ..engine.model_module import ( diff --git a/serve/mlc_serve/model/torch_model.py b/serve/mlc_serve/model/torch_model.py index 2031b27fe1..2b57dbfd3a 100644 --- a/serve/mlc_serve/model/torch_model.py +++ b/serve/mlc_serve/model/torch_model.py @@ -6,10 +6,14 @@ import numpy as np import torch +from transformers import AutoConfig from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.sequence import SequenceGroupMetadata, SequenceData from vllm.model_executor import InputMetadata from vllm.sampling_params import SamplingParams +from vllm.model_executor.parallel_utils.parallel_state import ( + initialize_model_parallel, +) from .base import ModelArtifactConfig from .paged_cache_manager import KVCache, CacheManager @@ -172,14 +176,18 @@ def profile_memory_usage(self, seq_lens): seqs.append(seq) input_ids, positions, input_metadata = prepare_inputs( - seqs, CacheManager.block_size, + seqs, + CacheManager.block_size, ) kv_caches = [(None, None)] * self.num_hidden_layers with torch.no_grad(): self.pt_model.forward( - input_ids, positions, kv_caches, input_metadata, + input_ids, + positions, + kv_caches, + input_metadata, cache_events=None, ) @@ -372,12 +380,6 @@ def init_cache_blocks(head_size, num_layers, num_heads, block_size, num_gpu_bloc def init_torch_model( model_path, engine_config: MLCServeEngineConfig ) -> Tuple[TextGenerator, CacheManager, ModelArtifactConfig]: - from transformers import AutoConfig - from vllm.model_executor.utils import set_random_seed - from vllm.model_executor.parallel_utils.parallel_state import ( - initialize_model_parallel, - ) - torch.distributed.init_process_group( backend="nccl", world_size=1, rank=0, init_method="tcp://localhost:59157" ) @@ -389,12 +391,8 @@ def init_torch_model( # TODO num_shards = 1 - num_kv_heads = ( - hf_config.num_key_value_heads // num_shards - ) - head_size = ( - hf_config.hidden_size // hf_config.num_attention_heads - ) + num_kv_heads = hf_config.num_key_value_heads // num_shards + head_size = hf_config.hidden_size // hf_config.num_attention_heads if not hasattr(hf_config, "sliding_window"): hf_config.sliding_window = None @@ -405,7 +403,7 @@ def init_torch_model( model_artifact_path=model_path, num_shards=1, quantization=None, - max_context_length=hf_config.max_position_embeddings, # TODO, + max_context_length=hf_config.max_position_embeddings, # TODO, vocab_size=hf_config.vocab_size, sliding_window=hf_config.sliding_window, num_key_value_heads=num_kv_heads, @@ -414,8 +412,11 @@ def init_torch_model( hidden_size=hf_config.hidden_size, ) - pt_model = LlamaForCausalLM(hf_config) - pt_model.load_weights(model_path, None, "auto", None) + with torch.device("cuda"): + torch.set_default_dtype(torch.float16) + pt_model = LlamaForCausalLM(hf_config) + pt_model.load_weights(model_path, None, "auto", None) + model = Model(pt_model, hf_config) if engine_config.max_num_batched_tokens > 0: @@ -444,8 +445,11 @@ def init_torch_model( LOG.info(f"Using {num_blocks} cache blocks.") cache_blocks = init_cache_blocks( - head_size, hf_config.num_hidden_layers, - hf_config.num_attention_heads, CacheManager.block_size, num_blocks + head_size, + hf_config.num_hidden_layers, + hf_config.num_attention_heads, + CacheManager.block_size, + num_blocks, ) cache_manager = CacheManager( @@ -456,7 +460,5 @@ def init_torch_model( LOG.info("Allocated KV cache blocks.") - assert False - # TODO(masahi): Make mypy understand that model confirms to TextGenerator Protocol. return model, cache_manager, artifact_config # type: ignore diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index cb5683a5c9..825e06e2b4 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -179,6 +179,8 @@ def get_used_memory(self): math.prod(param.shape) * np.dtype(param.dtype).itemsize for param in params ) + print("peak memory:", (peak_memory + param_bytes) / 1e9) + return peak_memory + param_bytes def profile_memory_usage(self, seq_lens): diff --git a/serve/tests/test_engine.py b/serve/tests/test_engine.py index 1842a2044a..4f05db3e22 100644 --- a/serve/tests/test_engine.py +++ b/serve/tests/test_engine.py @@ -20,7 +20,7 @@ def _test(args: argparse.Namespace): model_type = "tvm" - if not os.path.exists(args.model_artifact_path.joinpath("mlc-model-config.json")): + if not os.path.exists(args.model_artifact_path.joinpath("build_config.json")): model_type = "torch" engine_config = get_engine_config( From 3d06f682637587fa4ff86c946132d773b8711b69 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 12 Jan 2024 01:32:11 +0000 Subject: [PATCH 10/73] get rid of vllm prepare_inputs --- serve/mlc_serve/model/torch_model.py | 262 ++++++--------------------- 1 file changed, 52 insertions(+), 210 deletions(-) diff --git a/serve/mlc_serve/model/torch_model.py b/serve/mlc_serve/model/torch_model.py index 2b57dbfd3a..3bf6dd3af1 100644 --- a/serve/mlc_serve/model/torch_model.py +++ b/serve/mlc_serve/model/torch_model.py @@ -1,9 +1,6 @@ -import math -import os from typing import List, Union, Tuple import structlog -import numpy as np import torch from transformers import AutoConfig @@ -18,8 +15,7 @@ from .base import ModelArtifactConfig from .paged_cache_manager import KVCache, CacheManager from .model_common import ( - sample, - # prepare_inputs, + prepare_inputs, get_num_cache_blocks, ) @@ -39,107 +35,6 @@ LOG = structlog.stdlib.get_logger(__name__) -def prepare_inputs( - seq_group_metadata_list: List[SequenceGroupMetadata], - block_size, -) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]: - seq_groups: List[Tuple[List[int], SamplingParams]] = [] - input_tokens: List[int] = [] - input_positions: List[int] = [] - slot_mapping: List[int] = [] - - prompt_lens: List[int] = [] - for seq_group_metadata in seq_group_metadata_list: - if not seq_group_metadata.is_prompt: - continue - - seq_ids = list(seq_group_metadata.seq_data.keys()) - sampling_params = seq_group_metadata.sampling_params - seq_groups.append((seq_ids, sampling_params)) - - seq_id = seq_ids[0] - - seq_data = seq_group_metadata.seq_data[seq_id] - prompt_tokens = seq_data.get_token_ids() - prompt_len = len(prompt_tokens) - prompt_lens.append(prompt_len) - - input_tokens.extend(prompt_tokens) - input_positions.extend(range(len(prompt_tokens))) - - if seq_group_metadata.block_tables is None: - slot_mapping.extend([0] * prompt_len) - continue - - block_table = seq_group_metadata.block_tables[seq_id] - for i in range(prompt_len): - block_number = block_table[i // block_size] - block_offset = i % block_size - slot = block_number * block_size + block_offset - slot_mapping.append(slot) - - max_context_len = 0 - max_num_blocks_per_seq = 0 - context_lens: List[int] = [] - generation_block_tables: List[List[int]] = [] - for seq_group_metadata in seq_group_metadata_list: - if seq_group_metadata.is_prompt: - continue - - seq_ids = list(seq_group_metadata.seq_data.keys()) - sampling_params = seq_group_metadata.sampling_params - seq_groups.append((seq_ids, sampling_params)) - - for seq_id in seq_ids: - seq_data = seq_group_metadata.seq_data[seq_id] - generation_token = seq_data.get_last_token_id() - input_tokens.append(generation_token) - - context_len = seq_data.get_len() - position = context_len - 1 - input_positions.append(position) - - block_table = seq_group_metadata.block_tables[seq_id] - generation_block_tables.append(list(block_table)) - - max_context_len = max(max_context_len, context_len) - max_num_blocks_per_seq = max(max_num_blocks_per_seq, len(block_table)) - context_lens.append(context_len) - - block_number = block_table[position // block_size] - block_offset = position % block_size - slot = block_number * block_size + block_offset - slot_mapping.append(slot) - - def _pad_to_max(x: List[int], max_len: int) -> List[int]: - return x + [0] * (max_len - len(x)) - - tokens_tensor = torch.cuda.LongTensor(input_tokens) - positions_tensor = torch.cuda.LongTensor(input_positions) - slot_mapping_tensor = torch.cuda.IntTensor(slot_mapping) - context_lens_tensor = torch.cuda.IntTensor(context_lens) - padded_block_tables = [ - _pad_to_max(block_table, max_num_blocks_per_seq) - for block_table in generation_block_tables - ] - block_tables_tensor = torch.cuda.IntTensor(padded_block_tables) - - seq_data = {} - for seq_group_metadata in seq_group_metadata_list: - seq_data.update(seq_group_metadata.seq_data) - - input_metadata = InputMetadata( - seq_groups=seq_groups, - seq_data=seq_data, - prompt_lens=prompt_lens, - slot_mapping=slot_mapping_tensor, - context_lens=context_lens_tensor, - max_context_len=max_context_len, - block_tables=block_tables_tensor, - ) - return tokens_tensor, positions_tensor, input_metadata - - class Model: def __init__( self, @@ -161,23 +56,35 @@ def profile_memory_usage(self, seq_lens): torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() - seqs: List[SequenceGroupMetadata] = [] sampling_params = SamplingParams(top_p=0.99) - for i, seq_len in enumerate(seq_lens): - seq_data = SequenceData([0] * seq_len) - seq = SequenceGroupMetadata( - request_id=str(i), - is_prompt=True, - seq_data={i: seq_data}, - sampling_params=sampling_params, - block_tables=None, - ) - seqs.append(seq) + seq_groups: List[Tuple[List[int], SamplingParams]] = [] + seq_data = {} + input_tokens: List[int] = [] + input_positions: List[int] = [] + slot_mapping: List[int] = [] - input_ids, positions, input_metadata = prepare_inputs( - seqs, - CacheManager.block_size, + for i, seq_len in enumerate(seq_lens): + seq_groups.append(([i], sampling_params)) + prompt_tokens = [0] * seq_len + seq_data[i] = SequenceData(prompt_tokens) + + input_tokens.extend(prompt_tokens) + input_positions.extend(range(seq_len)) + slot_mapping.extend([0] * seq_len) + + input_ids = torch.cuda.LongTensor(input_tokens) + positions = torch.cuda.LongTensor(input_positions) + slot_mapping_tensor = torch.cuda.IntTensor(slot_mapping) + + input_metadata = InputMetadata( + seq_groups=seq_groups, + seq_data=seq_data, + prompt_lens=seq_lens, + slot_mapping=slot_mapping_tensor, + context_lens=torch.cuda.IntTensor([]), + max_context_len=0, + block_tables=torch.cuda.IntTensor([]), ) kv_caches = [(None, None)] * self.num_hidden_layers @@ -251,107 +158,39 @@ def generate( else: torch.cuda.nvtx.range_push(f"forward decode {input_shape}") + # TODO(masahi): Do sampling outside of model + next_tokens = None + torch.cuda.synchronize() torch.cuda.nvtx.range_pop() - logits = None - - try: - next_tokens = sample(logits, sampling_params, self.vocab_size) - assert next_tokens is not None - outputs = [] - for i, (sequence_id, new_token) in enumerate( - zip(sequence_ids, next_tokens) - ): - if not new_token in requests[i].sampling_params.appeared_tokens_freq: - requests[i].sampling_params.appeared_tokens_freq[new_token] = 0 - requests[i].sampling_params.appeared_tokens_freq[new_token] += 1 - if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: - for seq_id in range(num_sequences[i]): - outputs.append( - TextGenerationResult( - sequence_id=SequenceId(sequence_id.request_id, seq_id), - generated_tokens=[new_token], - error=None, - ) - ) - else: + assert next_tokens is not None + + outputs = [] + + for i, (sequence_id, new_token) in enumerate(zip(sequence_ids, next_tokens)): + if not new_token in requests[i].sampling_params.appeared_tokens_freq: + requests[i].sampling_params.appeared_tokens_freq[new_token] = 0 + requests[i].sampling_params.appeared_tokens_freq[new_token] += 1 + if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: + for seq_id in range(num_sequences[i]): outputs.append( TextGenerationResult( - sequence_id=sequence_id, + sequence_id=SequenceId(sequence_id.request_id, seq_id), generated_tokens=[new_token], error=None, ) ) - - return outputs - except RuntimeError: - # Fallback to per-token sampling in case some logits values are corrupted. - outputs = [] - err_msg = ( - "Error from sampling: probability tensor contains either `inf`, `nan`" - " or element < 0" - ) - - for i, (sequence_id, logits_per_token, sampling_param) in enumerate( - zip(sequence_ids, torch.from_dlpack(logits), sampling_params) - ): - maybe_new_token = sample( - torch.unsqueeze(logits_per_token, 0), - [sampling_param], - self.vocab_size, - check_safety=True, + else: + outputs.append( + TextGenerationResult( + sequence_id=sequence_id, + generated_tokens=[new_token], + error=None, + ) ) - if maybe_new_token is not None: - new_token = maybe_new_token[0] - if ( - not new_token - in requests[i].sampling_params.appeared_tokens_freq - ): - requests[i].sampling_params.appeared_tokens_freq[new_token] = 0 - requests[i].sampling_params.appeared_tokens_freq[new_token] += 1 - if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: - for seq_id in range(num_sequences[i]): - outputs.append( - TextGenerationResult( - sequence_id=SequenceId( - sequence_id.request_id, seq_id - ), - generated_tokens=[new_token], # type: ignore - error=None, - ) - ) - else: - outputs.append( - TextGenerationResult( - sequence_id=sequence_id, - generated_tokens=[new_token], # type: ignore - error=None, - ) - ) - else: - if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: - for seq_id in range(num_sequences[i]): - outputs.append( - TextGenerationResult( - sequence_id=SequenceId( - sequence_id.request_id, seq_id - ), - generated_tokens=[], - error=err_msg, - ) - ) - else: - outputs.append( - TextGenerationResult( - sequence_id=sequence_id, - generated_tokens=[], - error=err_msg, - ) - ) - - return outputs + return outputs def init_cache_blocks(head_size, num_layers, num_heads, block_size, num_gpu_blocks): @@ -460,5 +299,8 @@ def init_torch_model( LOG.info("Allocated KV cache blocks.") + import ipdb + + ipdb.set_trace() # TODO(masahi): Make mypy understand that model confirms to TextGenerator Protocol. return model, cache_manager, artifact_config # type: ignore From 3cafc8bbb7dc303c1bc0a47d117bf03e21dd9e1c Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 12 Jan 2024 02:38:54 +0000 Subject: [PATCH 11/73] wip --- serve/mlc_serve/model/torch_model.py | 53 ++++++++++++++++++++++++---- 1 file changed, 47 insertions(+), 6 deletions(-) diff --git a/serve/mlc_serve/model/torch_model.py b/serve/mlc_serve/model/torch_model.py index 3bf6dd3af1..ce29907be0 100644 --- a/serve/mlc_serve/model/torch_model.py +++ b/serve/mlc_serve/model/torch_model.py @@ -5,7 +5,7 @@ from transformers import AutoConfig from vllm.model_executor.models.llama import LlamaForCausalLM -from vllm.sequence import SequenceGroupMetadata, SequenceData +from vllm.sequence import SequenceData from vllm.model_executor import InputMetadata from vllm.sampling_params import SamplingParams from vllm.model_executor.parallel_utils.parallel_state import ( @@ -24,6 +24,7 @@ PROMPT_SEQEUNCE_INDEX, get_prompt_sequence_id, MLCServeEngineConfig, + SamplingParams as MLCSamplingParams, ) from ..engine.model_module import ( DecodeRequest, @@ -35,6 +36,16 @@ LOG = structlog.stdlib.get_logger(__name__) +def convert_sampling_params(mlc_params: MLCSamplingParams) -> SamplingParams: + return SamplingParams( + presence_penalty=mlc_params.presence_penalty, + frequency_penalty=mlc_params.frequency_penalty, + temperature=mlc_params.temperature, + top_p=mlc_params.top_p, + top_k=mlc_params.top_k, + ) + + class Model: def __init__( self, @@ -122,6 +133,8 @@ def generate( sequence_ids = [] prompt_lens = [] num_sequences = [] + seq_groups: List[Tuple[List[SequenceId], SamplingParams]] = [] + seq_data = {} for request in requests: if isinstance(request, PrefillRequest): @@ -134,12 +147,17 @@ def generate( all_token_ids.append(request.token_ids) sampling_params.append(request.sampling_params) + seq_data[sequence_ids[-1]] = SequenceData(request.token_ids) + seq_groups.append( + ([sequence_ids[-1]], convert_sampling_params(request.sampling_params)) + ) + ( input_ids, positions, seq_lens, slot_mapping, - indices_within_window, + _, block_tables, ) = prepare_inputs( sequence_ids, @@ -151,6 +169,24 @@ def generate( is_prefill, ) + if block_tables is None: + block_tables = torch.cuda.IntTensor([]) + context_lens = torch.cuda.IntTensor([]) + max_context_len = 0 + else: + context_lens = seq_lens, + max_context_len = torch.max(seq_lens) + + input_metadata = InputMetadata( + seq_groups=seq_groups, + seq_data=seq_data, + prompt_lens=prompt_lens, + slot_mapping=slot_mapping, + context_lens=context_lens, + max_context_len=max_context_len, + block_tables=block_tables, + ) + input_shape = input_ids.shape if is_prefill: @@ -159,7 +195,15 @@ def generate( torch.cuda.nvtx.range_push(f"forward decode {input_shape}") # TODO(masahi): Do sampling outside of model - next_tokens = None + with torch.no_grad(): + next_tokens = self.pt_model.forward( + input_ids, + positions, + cache.cache_blocks, + input_metadata, + cache_events=None, + ) + print(next_tokens) torch.cuda.synchronize() torch.cuda.nvtx.range_pop() @@ -299,8 +343,5 @@ def init_torch_model( LOG.info("Allocated KV cache blocks.") - import ipdb - - ipdb.set_trace() # TODO(masahi): Make mypy understand that model confirms to TextGenerator Protocol. return model, cache_manager, artifact_config # type: ignore From 34f77ef8a8969a4343b48aa8c3272c6576462036 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 12 Jan 2024 02:50:04 +0000 Subject: [PATCH 12/73] model runs but nan output --- serve/mlc_serve/model/model_common.py | 5 +++-- serve/mlc_serve/model/torch_model.py | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/serve/mlc_serve/model/model_common.py b/serve/mlc_serve/model/model_common.py index e72dd0dbfe..215345c218 100644 --- a/serve/mlc_serve/model/model_common.py +++ b/serve/mlc_serve/model/model_common.py @@ -174,6 +174,7 @@ def prepare_inputs( all_decode_block_tables, sliding_window, is_prefill, + torch_ids_type=torch.int ): block_tables = [] seq_lens = [] @@ -216,8 +217,8 @@ def prepare_inputs( def to_torch(arr, torch_dtype): return torch.tensor(arr, dtype=torch_dtype, device="cuda") - input_ids = to_torch(input_ids, torch.int) - positions = to_torch(positions, torch.int) + input_ids = to_torch(input_ids, torch_ids_type) + positions = to_torch(positions, torch_ids_type) seq_lens = to_torch(seq_lens, torch.int) slot_mapping = to_torch(slot_mapping, torch.int) diff --git a/serve/mlc_serve/model/torch_model.py b/serve/mlc_serve/model/torch_model.py index ce29907be0..08a6782e0a 100644 --- a/serve/mlc_serve/model/torch_model.py +++ b/serve/mlc_serve/model/torch_model.py @@ -167,6 +167,7 @@ def generate( cache.decode_block_tables, self.sliding_window, is_prefill, + torch.long, ) if block_tables is None: From afb4d4f830ca0445c88100e6675b1c5d55e62b0d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 12 Jan 2024 06:21:22 +0000 Subject: [PATCH 13/73] mypy improvement --- serve/mlc_serve/model/torch_model.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/serve/mlc_serve/model/torch_model.py b/serve/mlc_serve/model/torch_model.py index 08a6782e0a..2b6bd1d2ce 100644 --- a/serve/mlc_serve/model/torch_model.py +++ b/serve/mlc_serve/model/torch_model.py @@ -1,4 +1,4 @@ -from typing import List, Union, Tuple +from typing import List, Union, Tuple, Sequence import structlog import torch @@ -120,7 +120,7 @@ def profile_memory_usage(self, seq_lens): def generate( self, - requests: Union[List[PrefillRequest], List[DecodeRequest]], + requests: Sequence[Union[PrefillRequest, DecodeRequest]], cache: KVCache, ) -> List[TextGenerationResult]: if len(requests) == 0: @@ -344,5 +344,4 @@ def init_torch_model( LOG.info("Allocated KV cache blocks.") - # TODO(masahi): Make mypy understand that model confirms to TextGenerator Protocol. - return model, cache_manager, artifact_config # type: ignore + return model, cache_manager, artifact_config From e7212a5b24a9edb61117de58db0584fc55db08f2 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 12 Jan 2024 10:08:28 +0000 Subject: [PATCH 14/73] runs e2e but the result is garbage --- serve/mlc_serve/model/torch_model.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/serve/mlc_serve/model/torch_model.py b/serve/mlc_serve/model/torch_model.py index 2b6bd1d2ce..91a6ac7c96 100644 --- a/serve/mlc_serve/model/torch_model.py +++ b/serve/mlc_serve/model/torch_model.py @@ -140,6 +140,7 @@ def generate( if isinstance(request, PrefillRequest): sequence_ids.append(get_prompt_sequence_id(request.request_id)) num_sequences.append(request.num_sequence) + prompt_lens.append(len(request.token_ids)) else: sequence_ids.append(request.sequence_id) prompt_lens.append(request.prompt_token_counts) @@ -170,13 +171,18 @@ def generate( torch.long, ) + input_shape = input_ids.shape + if block_tables is None: + torch.cuda.nvtx.range_push(f"forward prefill {input_shape}") block_tables = torch.cuda.IntTensor([]) context_lens = torch.cuda.IntTensor([]) max_context_len = 0 else: - context_lens = seq_lens, + torch.cuda.nvtx.range_push(f"forward decode {input_shape}") + context_lens = seq_lens max_context_len = torch.max(seq_lens) + prompt_lens = [] input_metadata = InputMetadata( seq_groups=seq_groups, @@ -188,23 +194,19 @@ def generate( block_tables=block_tables, ) - input_shape = input_ids.shape - - if is_prefill: - torch.cuda.nvtx.range_push(f"forward prefill {input_shape}") - else: - torch.cuda.nvtx.range_push(f"forward decode {input_shape}") - # TODO(masahi): Do sampling outside of model with torch.no_grad(): - next_tokens = self.pt_model.forward( + outs = self.pt_model.forward( input_ids, positions, cache.cache_blocks, input_metadata, cache_events=None, ) - print(next_tokens) + + next_tokens = [] + for samples in outs: + next_tokens.append(samples[0].output_token) torch.cuda.synchronize() torch.cuda.nvtx.range_pop() From f27e3b348f6c83231726dbc42142fc7c81a3bd2e Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 12 Jan 2024 20:37:59 +0000 Subject: [PATCH 15/73] working --- serve/mlc_serve/model/model_common.py | 11 ++++++++++- serve/mlc_serve/model/torch_model.py | 1 + 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/serve/mlc_serve/model/model_common.py b/serve/mlc_serve/model/model_common.py index 215345c218..c65c0da16e 100644 --- a/serve/mlc_serve/model/model_common.py +++ b/serve/mlc_serve/model/model_common.py @@ -166,6 +166,10 @@ def _is_safe_to_sample(prob_like): return res +def _pad_to_alignment(x: List[int], multiple_of: int) -> List[int]: + return x + [0] * ((-len(x)) % multiple_of) + + def prepare_inputs( sequence_ids, all_token_ids, @@ -174,7 +178,8 @@ def prepare_inputs( all_decode_block_tables, sliding_window, is_prefill, - torch_ids_type=torch.int + torch_ids_type=torch.int, + align=None, ): block_tables = [] seq_lens = [] @@ -217,6 +222,10 @@ def prepare_inputs( def to_torch(arr, torch_dtype): return torch.tensor(arr, dtype=torch_dtype, device="cuda") + if align: + input_ids = _pad_to_alignment(input_ids, multiple_of=align) + positions = _pad_to_alignment(positions, multiple_of=align) + input_ids = to_torch(input_ids, torch_ids_type) positions = to_torch(positions, torch_ids_type) seq_lens = to_torch(seq_lens, torch.int) diff --git a/serve/mlc_serve/model/torch_model.py b/serve/mlc_serve/model/torch_model.py index 91a6ac7c96..b8bbf1e0cc 100644 --- a/serve/mlc_serve/model/torch_model.py +++ b/serve/mlc_serve/model/torch_model.py @@ -169,6 +169,7 @@ def generate( self.sliding_window, is_prefill, torch.long, + align=8, ) input_shape = input_ids.shape From 2316e37b9bff9db0adce008d5c02c2cf31c187a1 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 12 Jan 2024 20:56:54 +0000 Subject: [PATCH 16/73] minor --- serve/mlc_serve/model/torch_model.py | 2 +- serve/mlc_serve/model/tvm_model.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/serve/mlc_serve/model/torch_model.py b/serve/mlc_serve/model/torch_model.py index b8bbf1e0cc..2518f8dc7a 100644 --- a/serve/mlc_serve/model/torch_model.py +++ b/serve/mlc_serve/model/torch_model.py @@ -202,7 +202,7 @@ def generate( positions, cache.cache_blocks, input_metadata, - cache_events=None, + cache_events=None, # TODO: what to do about this? ) next_tokens = [] diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index 825e06e2b4..cb5683a5c9 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -179,8 +179,6 @@ def get_used_memory(self): math.prod(param.shape) * np.dtype(param.dtype).itemsize for param in params ) - print("peak memory:", (peak_memory + param_bytes) / 1e9) - return peak_memory + param_bytes def profile_memory_usage(self, seq_lens): From 9b985e8c330ad3f943777b9089e3a6404e210da3 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 12 Jan 2024 23:56:08 +0000 Subject: [PATCH 17/73] do sampling by mlc function --- serve/mlc_serve/model/torch_model.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/serve/mlc_serve/model/torch_model.py b/serve/mlc_serve/model/torch_model.py index 2518f8dc7a..f08477bda9 100644 --- a/serve/mlc_serve/model/torch_model.py +++ b/serve/mlc_serve/model/torch_model.py @@ -4,6 +4,8 @@ import torch from transformers import AutoConfig + +from vllm.model_executor.layers.sampler import get_logits from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.sequence import SequenceData from vllm.model_executor import InputMetadata @@ -15,6 +17,7 @@ from .base import ModelArtifactConfig from .paged_cache_manager import KVCache, CacheManager from .model_common import ( + sample, prepare_inputs, get_num_cache_blocks, ) @@ -195,24 +198,27 @@ def generate( block_tables=block_tables, ) - # TODO(masahi): Do sampling outside of model with torch.no_grad(): - outs = self.pt_model.forward( + hidden_states = self.pt_model.model( input_ids, positions, cache.cache_blocks, input_metadata, - cache_events=None, # TODO: what to do about this? + # No need for this until parallel sampling is supported. + cache_events=None, ) - next_tokens = [] - for samples in outs: - next_tokens.append(samples[0].output_token) + logits = get_logits( + self.pt_model.lm_head.weight, + hidden_states, + input_metadata, + self.vocab_size, + ) - torch.cuda.synchronize() - torch.cuda.nvtx.range_pop() + torch.cuda.synchronize() + torch.cuda.nvtx.range_pop() - assert next_tokens is not None + next_tokens = sample(logits, sampling_params, self.vocab_size) outputs = [] From 4d73e638b8f1a685b805d5073e183157440e27a1 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 13 Jan 2024 08:18:51 +0000 Subject: [PATCH 18/73] merge fix --- serve/mlc_serve/model/model_common.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/serve/mlc_serve/model/model_common.py b/serve/mlc_serve/model/model_common.py index c65c0da16e..8f1f716d7d 100644 --- a/serve/mlc_serve/model/model_common.py +++ b/serve/mlc_serve/model/model_common.py @@ -71,6 +71,7 @@ def _is_safe_to_sample(prob_like): == 0 ) + torch.cuda.nvtx.range_push(f"sample {logits.shape}") logits = torch.from_dlpack(logits) num_seq = len(sampling_params) @@ -86,6 +87,7 @@ def _is_safe_to_sample(prob_like): res_greedy = torch.argmax(logits_greedy, -1).cpu().numpy() if logits_greedy.shape[0] == num_seq: + torch.cuda.nvtx.range_pop() return res_greedy temperatures = [] @@ -150,11 +152,13 @@ def _is_safe_to_sample(prob_like): probs = torch.softmax(logits_random, dim=-1) if check_safety and not _is_safe_to_sample(probs): + torch.cuda.nvtx.range_pop() return None res_random = torch.multinomial(probs, 1, True).cpu().numpy()[:, 0] if logits_random.shape[0] == num_seq: + torch.cuda.nvtx.range_pop() return res_random res = np.empty((num_seq,), dtype=np.int32) @@ -163,6 +167,7 @@ def _is_safe_to_sample(prob_like): if logits_greedy.shape[0] > 0: res[mask_greedy] = res_greedy + torch.cuda.nvtx.range_pop() return res From 15a0d3b9a468b9309e587ecc2effd4c804c34f34 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 13 Jan 2024 09:05:05 +0000 Subject: [PATCH 19/73] wip parallel sampling --- serve/mlc_serve/model/torch_model.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/serve/mlc_serve/model/torch_model.py b/serve/mlc_serve/model/torch_model.py index f08477bda9..7bd9c0b6a2 100644 --- a/serve/mlc_serve/model/torch_model.py +++ b/serve/mlc_serve/model/torch_model.py @@ -1,4 +1,5 @@ from typing import List, Union, Tuple, Sequence +from collections import defaultdict import structlog import torch @@ -136,25 +137,33 @@ def generate( sequence_ids = [] prompt_lens = [] num_sequences = [] - seq_groups: List[Tuple[List[SequenceId], SamplingParams]] = [] seq_data = {} + seq_group_sequence_ids = defaultdict(list) + seq_group_sampling_params = {} for request in requests: if isinstance(request, PrefillRequest): sequence_ids.append(get_prompt_sequence_id(request.request_id)) num_sequences.append(request.num_sequence) prompt_lens.append(len(request.token_ids)) + seq_group_sequence_ids[request.request_id].append(sequence_ids[-1]) + seq_group_sampling_params[request.request_id] = request.sampling_params else: sequence_ids.append(request.sequence_id) prompt_lens.append(request.prompt_token_counts) + req_id = request.sequence_id.request_id + seq_group_sequence_ids[req_id].append(request.sequence_id) + seq_group_sampling_params[req_id] = request.sampling_params all_token_ids.append(request.token_ids) sampling_params.append(request.sampling_params) seq_data[sequence_ids[-1]] = SequenceData(request.token_ids) - seq_groups.append( - ([sequence_ids[-1]], convert_sampling_params(request.sampling_params)) - ) + + seq_groups: List[Tuple[List[SequenceId], SamplingParams]] = [] + + for req_id, seq_ids in seq_group_sequence_ids.items(): + seq_groups.append((seq_ids, seq_group_sampling_params[req_id])) ( input_ids, @@ -218,6 +227,8 @@ def generate( torch.cuda.synchronize() torch.cuda.nvtx.range_pop() + print("logits.shape", logits.shape) + next_tokens = sample(logits, sampling_params, self.vocab_size) outputs = [] From 959019da01d1af42272c5d0808d950a4ae523fb2 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 13 Jan 2024 09:10:35 +0000 Subject: [PATCH 20/73] fix test --- serve/tests/test_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/serve/tests/test_engine.py b/serve/tests/test_engine.py index 4f05db3e22..f23e5f04aa 100644 --- a/serve/tests/test_engine.py +++ b/serve/tests/test_engine.py @@ -36,7 +36,7 @@ def _test(args: argparse.Namespace): if args.use_staging_engine: engine = StagingInferenceEngine( - tokenizer_module=HfTokenizerModule(args.model_artifact_path), + tokenizer_module=HfTokenizerModule(args.model_artifact_path.joinpath("model")), model_module_loader=PagedCacheModelModule, model_module_loader_kwargs={ "model_artifact_path": args.model_artifact_path, From b6050d9cc69ce0962df4b01dbb0903bbcd09dac7 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 15 Jan 2024 08:05:13 +0000 Subject: [PATCH 21/73] wip --- serve/mlc_serve/model/torch_model.py | 36 ++++++++++++++++++++-------- serve/mlc_serve/model/tvm_model.py | 10 ++++++++ serve/tests/test_engine.py | 8 +++---- 3 files changed, 40 insertions(+), 14 deletions(-) diff --git a/serve/mlc_serve/model/torch_model.py b/serve/mlc_serve/model/torch_model.py index 7bd9c0b6a2..361c3e5e49 100644 --- a/serve/mlc_serve/model/torch_model.py +++ b/serve/mlc_serve/model/torch_model.py @@ -147,13 +147,13 @@ def generate( num_sequences.append(request.num_sequence) prompt_lens.append(len(request.token_ids)) seq_group_sequence_ids[request.request_id].append(sequence_ids[-1]) - seq_group_sampling_params[request.request_id] = request.sampling_params + seq_group_sampling_params[request.request_id] = convert_sampling_params(request.sampling_params) else: sequence_ids.append(request.sequence_id) prompt_lens.append(request.prompt_token_counts) req_id = request.sequence_id.request_id seq_group_sequence_ids[req_id].append(request.sequence_id) - seq_group_sampling_params[req_id] = request.sampling_params + seq_group_sampling_params[req_id] = convert_sampling_params(request.sampling_params) all_token_ids.append(request.token_ids) sampling_params.append(request.sampling_params) @@ -208,6 +208,18 @@ def generate( ) with torch.no_grad(): + # outs = self.pt_model.forward( + # input_ids, + # positions, + # cache.cache_blocks, + # input_metadata, + # cache_events=None, # TODO: what to do about this? + # ) + + # next_tokens = [] + # for samples in outs: + # next_tokens.append(samples[0].output_token) + hidden_states = self.pt_model.model( input_ids, positions, @@ -217,20 +229,24 @@ def generate( cache_events=None, ) - logits = get_logits( - self.pt_model.lm_head.weight, - hidden_states, - input_metadata, - self.vocab_size, - ) + if hidden_states.shape[0] != len( + input_metadata.prompt_lens + ) and hidden_states.shape[0] != len(input_metadata.context_lens): + logits = get_logits( + self.pt_model.lm_head.weight, + hidden_states, + input_metadata, + self.vocab_size, + ) + + next_tokens = sample(logits, sampling_params, self.vocab_size) torch.cuda.synchronize() torch.cuda.nvtx.range_pop() print("logits.shape", logits.shape) - next_tokens = sample(logits, sampling_params, self.vocab_size) - + print("next tokens", next_tokens) outputs = [] for i, (sequence_id, new_token) in enumerate(zip(sequence_ids, next_tokens)): diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index cb5683a5c9..d90a5b695b 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -245,6 +245,10 @@ def generate( is_prefill, ) + print("input_ids", input_ids) + print("positions", positions) + print("block_tables", block_tables) + input_shape = input_ids.shape if self.disco_session: @@ -328,7 +332,13 @@ def generate( cache.pending_copy_from_to = [] try: + # if is_prefill: + # import numpy as np + # print("save logits") + # np.save("logits_tvm.npy", logits.numpy()) + next_tokens = sample(logits, sampling_params, self.vocab_size) + print("next tokens", next_tokens) assert next_tokens is not None outputs = [] for i, (sequence_id, new_token) in enumerate( diff --git a/serve/tests/test_engine.py b/serve/tests/test_engine.py index f23e5f04aa..6cbc3ff039 100644 --- a/serve/tests/test_engine.py +++ b/serve/tests/test_engine.py @@ -66,7 +66,7 @@ def _test(args: argparse.Namespace): sampling_params_choices = [sampling_params_random] elif args.use_random_sampling: # This tests different sampling types in the same batch - sampling_params_choices = [sampling_params_random, sampling_params_greedy] + sampling_params_choices = [sampling_params_random] else: sampling_params_choices = [sampling_params_greedy] @@ -76,9 +76,9 @@ def _test(args: argparse.Namespace): else: prompts = [ "Hello, my name is", - "The capital of France is", - "The president of the United States is a powerful man. But he can also be", - "The future of AI is full of promise. But we need to carefully", + # "The capital of France is", + # "The president of the United States is a powerful man. But he can also be", + # "The future of AI is full of promise. But we need to carefully", ] for i, prompt in enumerate(prompts): From ff8eb279430323b00e4a416500affcef92e39a07 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 15 Jan 2024 08:49:17 +0000 Subject: [PATCH 22/73] fix --- serve/tests/test_engine.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/serve/tests/test_engine.py b/serve/tests/test_engine.py index 6cbc3ff039..e3644c34c1 100644 --- a/serve/tests/test_engine.py +++ b/serve/tests/test_engine.py @@ -35,8 +35,13 @@ def _test(args: argparse.Namespace): ) if args.use_staging_engine: + if model_type == "tvm": + tokenizer_path = args.model_artifact_path.joinpath("model") + else: + tokenizer_path = args.model_artifact_path + engine = StagingInferenceEngine( - tokenizer_module=HfTokenizerModule(args.model_artifact_path.joinpath("model")), + tokenizer_module=HfTokenizerModule(tokenizer_path), model_module_loader=PagedCacheModelModule, model_module_loader_kwargs={ "model_artifact_path": args.model_artifact_path, From 8696df5fee6b7801c6ac62dee0bacd29ac69ef29 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 16 Jan 2024 05:54:50 +0000 Subject: [PATCH 23/73] wip --- serve/mlc_serve/model/torch_model.py | 79 ++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) diff --git a/serve/mlc_serve/model/torch_model.py b/serve/mlc_serve/model/torch_model.py index 361c3e5e49..aaa77c073c 100644 --- a/serve/mlc_serve/model/torch_model.py +++ b/serve/mlc_serve/model/torch_model.py @@ -14,6 +14,7 @@ from vllm.model_executor.parallel_utils.parallel_state import ( initialize_model_parallel, ) +from vllm.engine.ray_utils import RayWorker, ray from .base import ModelArtifactConfig from .paged_cache_manager import KVCache, CacheManager @@ -297,6 +298,77 @@ def init_cache_blocks(head_size, num_layers, num_heads, block_size, num_gpu_bloc return gpu_cache +def run_workers( + workers, + method: str, + *args, + get_all_outputs: bool = False, + **kwargs, +): + """Runs the given method on all workers.""" + from functools import partial + + all_outputs = [] + for worker in workers: + executor = partial(worker.execute_method.remote, method) + + output = executor(*args, **kwargs) + all_outputs.append(output) + + all_outputs = ray.get(all_outputs) + + if get_all_outputs: + return all_outputs + + # Make sure all workers have the same results. + output = all_outputs[0] + for other_output in all_outputs[1:]: + assert output == other_output + + return output + + +def init_workers_ray(placement_group: "PlacementGroup"): + import copy + from vllm.worker.worker import Worker # pylint: disable=import-outside-toplevel + from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy + from ray.air.util.torch_dist import init_torch_dist_process_group + + workers: List[Worker] = [] + for bundle in placement_group.bundle_specs: + if not bundle.get("GPU", 0): + continue + worker = ray.remote( + num_cpus=0, + num_gpus=1, + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=placement_group, + placement_group_capture_child_tasks=True), + **ray_remote_kwargs, + )(RayWorker).remote(False) + + workers.append(worker) + + # Initialize torch distributed process group for the workers. + init_torch_dist_process_group(self.workers, backend="nccl") + model_config = copy.deepcopy(self.model_config) + parallel_config = copy.deepcopy(self.parallel_config) + scheduler_config = copy.deepcopy(self.scheduler_config) + run_workers("init_worker", + get_all_outputs=True, + worker_init_fn=lambda: Worker( + model_config, + parallel_config, + scheduler_config, + None, + None, + )) + run_workers( + "init_model", + get_all_outputs=True, + ) + + def init_torch_model( model_path, engine_config: MLCServeEngineConfig ) -> Tuple[TextGenerator, CacheManager, ModelArtifactConfig]: @@ -311,6 +383,13 @@ def init_torch_model( # TODO num_shards = 1 + if num_shards > 1: + from vllm.config import ParallelConfig + from vllm.engine.ray_utils import initialize_cluster + parallel_config = ParallelConfig(1, num_shards, True) + _, placement_group = initialize_cluster(parallel_config) + init_workers_ray(placement_group) + num_kv_heads = hf_config.num_key_value_heads // num_shards head_size = hf_config.hidden_size // hf_config.num_attention_heads From 0af3a7030bf470d787cced61d2b2bb9368643096 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 16 Jan 2024 19:25:13 +0000 Subject: [PATCH 24/73] wip --- serve/mlc_serve/model/torch_model.py | 142 ++++++++++++++++++++++----- 1 file changed, 120 insertions(+), 22 deletions(-) diff --git a/serve/mlc_serve/model/torch_model.py b/serve/mlc_serve/model/torch_model.py index aaa77c073c..d762c43af2 100644 --- a/serve/mlc_serve/model/torch_model.py +++ b/serve/mlc_serve/model/torch_model.py @@ -6,6 +6,9 @@ from transformers import AutoConfig +import ray +from ray.air.util.torch_dist import TorchDistributedWorker + from vllm.model_executor.layers.sampler import get_logits from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.sequence import SequenceData @@ -14,7 +17,6 @@ from vllm.model_executor.parallel_utils.parallel_state import ( initialize_model_parallel, ) -from vllm.engine.ray_utils import RayWorker, ray from .base import ModelArtifactConfig from .paged_cache_manager import KVCache, CacheManager @@ -298,6 +300,117 @@ def init_cache_blocks(head_size, num_layers, num_heads, block_size, num_gpu_bloc return gpu_cache +def _init_distributed_environment( + parallel_config: ParallelConfig, + rank: int, + distributed_init_method: Optional[str] = None, +) -> None: + """Initialize the distributed environment.""" + if torch.distributed.is_initialized(): + torch_world_size = torch.distributed.get_world_size() + if torch_world_size != parallel_config.world_size: + raise RuntimeError( + "torch.distributed is already initialized but the torch world " + "size does not match parallel_config.world_size " + f"({torch_world_size} vs. {parallel_config.world_size}).") + elif not distributed_init_method: + raise ValueError( + "distributed_init_method must be set if torch.distributed " + "is not already initialized") + else: + torch.distributed.init_process_group( + backend="nccl", + world_size=parallel_config.world_size, + rank=rank, + init_method=distributed_init_method, + ) + + # A small all_reduce for warmup. + torch.distributed.all_reduce(torch.zeros(1).cuda()) + initialize_model_parallel(parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size) + + +class RayWorker(TorchDistributedWorker): + def __init__(self) -> None: + self.model = None + self.cache_manager = None + + def init_worker(self, hf_config, model_path, engine_config): + # os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) + # # Env vars will be set by Ray. + # self.rank = self.rank if self.rank is not None else int( + # os.getenv("RANK", "-1")) + # local_rank = int(os.getenv("LOCAL_RANK", "0")) + # self.device = torch.device(f"cuda:{local_rank}") + # if self.rank < 0: + # raise ValueError("Invalid or unspecified rank.") + # torch.cuda.set_device(self.device) + + # # Initialize the distributed environment. + # _init_distributed_environment(self.parallel_config, self.rank, + # self.distributed_init_method) + + # # Initialize the model. + # set_random_seed(self.model_config.seed) + + with torch.device("cuda"): + torch.set_default_dtype(torch.float16) + pt_model = LlamaForCausalLM(hf_config) + pt_model.load_weights(model_path, None, "auto", None) + + self.model = Model(pt_model, hf_config) + + head_size = hf_config.hidden_size // hf_config.num_attention_heads + num_kv_heads = hf_config.num_key_value_heads // num_shards + + if engine_config.max_num_batched_tokens > 0: + LOG.info("Running memory profiling.") + num_blocks = get_num_cache_blocks( + self.model, + [engine_config.max_input_len] * engine_config.max_num_sequences, + hf_config.num_hidden_layers, + num_kv_heads, + head_size, + ) + else: + num_blocks = 500 + + num_cache_slots = num_blocks * CacheManager.block_size + + if num_cache_slots <= engine_config.max_num_batched_tokens: + raise RuntimeError( + f"max_num_batched_tokens = {engine_config.max_num_batched_tokens} but" + f" only {num_blocks} cache blocks can be allocated. The number of" + f" available cache slots is {num_cache_slots}, not enough for" + f" {engine_config.max_num_batched_tokens} tokens. Try reducing" + " --max_input_len or --max_num_sequences." + ) + + LOG.info(f"Using {num_blocks} cache blocks.") + + cache_blocks = init_cache_blocks( + head_size, + hf_config.num_hidden_layers, + hf_config.num_attention_heads, + CacheManager.block_size, + num_blocks, + ) + + self.cache_manager = CacheManager( + cache_blocks, + num_blocks, + hf_config.sliding_window, + ) + + def __getattr__(self, name): + return getattr(self.worker, name) + + def execute_method(self, method, *args, **kwargs): + executor = getattr(self, method) + return executor(*args, **kwargs) + + def run_workers( workers, method: str, @@ -328,9 +441,8 @@ def run_workers( return output -def init_workers_ray(placement_group: "PlacementGroup"): +def init_workers_ray(placement_group: "PlacementGroup", parallel_config): import copy - from vllm.worker.worker import Worker # pylint: disable=import-outside-toplevel from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from ray.air.util.torch_dist import init_torch_dist_process_group @@ -345,28 +457,14 @@ def init_workers_ray(placement_group: "PlacementGroup"): placement_group=placement_group, placement_group_capture_child_tasks=True), **ray_remote_kwargs, - )(RayWorker).remote(False) + )(RayWorker).remote() workers.append(worker) # Initialize torch distributed process group for the workers. - init_torch_dist_process_group(self.workers, backend="nccl") - model_config = copy.deepcopy(self.model_config) - parallel_config = copy.deepcopy(self.parallel_config) - scheduler_config = copy.deepcopy(self.scheduler_config) - run_workers("init_worker", - get_all_outputs=True, - worker_init_fn=lambda: Worker( - model_config, - parallel_config, - scheduler_config, - None, - None, - )) - run_workers( - "init_model", - get_all_outputs=True, - ) + init_torch_dist_process_group(workers, backend="nccl") + parallel_config = copy.deepcopy(parallel_config) + run_workers("init_worker") # TODO def init_torch_model( @@ -388,7 +486,7 @@ def init_torch_model( from vllm.engine.ray_utils import initialize_cluster parallel_config = ParallelConfig(1, num_shards, True) _, placement_group = initialize_cluster(parallel_config) - init_workers_ray(placement_group) + init_workers_ray(placement_group, parallel_config) num_kv_heads = hf_config.num_key_value_heads // num_shards head_size = hf_config.hidden_size // hf_config.num_attention_heads From 90ffccdd9d20f790bd3a8464c8c6066df44664bb Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 18 Jan 2024 10:13:16 +0000 Subject: [PATCH 25/73] wip --- serve/mlc_serve/engine/model_module.py | 2 +- serve/mlc_serve/model/torch_model.py | 377 +++++++++---------------- 2 files changed, 140 insertions(+), 239 deletions(-) diff --git a/serve/mlc_serve/engine/model_module.py b/serve/mlc_serve/engine/model_module.py index 79b77e93a3..0dc4dd26a0 100644 --- a/serve/mlc_serve/engine/model_module.py +++ b/serve/mlc_serve/engine/model_module.py @@ -62,7 +62,7 @@ def get_cache(self) -> KVCache: The returned value should be passed to Executor.generate_text. """ - def allocate(self, request_id: RequestId, num_tokens: int, num_sequnces: int): + def allocate(self, request_id: RequestId, num_tokens: int, num_sequences: int): """ Allocate cache space for request, raise error if there is no space. """ diff --git a/serve/mlc_serve/model/torch_model.py b/serve/mlc_serve/model/torch_model.py index d762c43af2..6321a2c446 100644 --- a/serve/mlc_serve/model/torch_model.py +++ b/serve/mlc_serve/model/torch_model.py @@ -1,3 +1,4 @@ +import time from typing import List, Union, Tuple, Sequence from collections import defaultdict @@ -6,17 +7,17 @@ from transformers import AutoConfig -import ray -from ray.air.util.torch_dist import TorchDistributedWorker - from vllm.model_executor.layers.sampler import get_logits from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.sequence import SequenceData from vllm.model_executor import InputMetadata from vllm.sampling_params import SamplingParams -from vllm.model_executor.parallel_utils.parallel_state import ( - initialize_model_parallel, -) + +import multiprocessing +import rpyc +from rpyc.utils.classic import obtain +from rpyc.utils.server import ThreadedServer +from concurrent.futures import ThreadPoolExecutor from .base import ModelArtifactConfig from .paged_cache_manager import KVCache, CacheManager @@ -53,6 +54,102 @@ def convert_sampling_params(mlc_params: MLCSamplingParams) -> SamplingParams: ) +def init_cache_blocks(head_size, num_layers, num_heads, block_size, num_gpu_blocks): + element_size = 2 + x = 16 // element_size + + key_block_shape = (num_heads, head_size // x, block_size, x) + value_block_shape = (num_heads, head_size, block_size) + + gpu_cache = [] + for _ in range(num_layers): + key_blocks = torch.empty( + size=(num_gpu_blocks, *key_block_shape), + dtype=torch.float16, + device="cuda", + ) + value_blocks = torch.empty( + size=(num_gpu_blocks, *value_block_shape), + dtype=torch.float16, + device="cuda", + ) + gpu_cache.append((key_blocks, value_blocks)) + return gpu_cache + + +class ModelRpcServer(rpyc.Service): + def exposed_init_model(self, tp_rank: int): + pass + + def profile_memory_usage(self, seq_lens): + return 0 + + def exposed_generate( + self, requests: Sequence[Union[PrefillRequest, DecodeRequest]], cache: KVCache + ) -> List[TextGenerationResult]: + return [] + + +def start_model_process(port): + def _init_service(port): + t = ThreadedServer( + ModelRpcServer(), + port=port, + protocol_config={"allow_pickle": True, "sync_request_timeout": 600}, + ) + t.start() + + proc = multiprocessing.Process(target=_init_service, args=(port,)) + proc.start() + time.sleep(1) + + repeat_count = 0 + while repeat_count < 20: + try: + con = rpyc.connect( + "localhost", + port, + config={"allow_pickle": True, "sync_request_timeout": 600}, + ) + break + except ConnectionRefusedError: + time.sleep(1) + repeat_count += 1 + if repeat_count == 20: + raise RuntimeError("init rpc env error!") + + assert proc.is_alive() + return con.root, proc + + +class ModelRpcClient: + def __init__(self, num_shards): + # TODO: Init torch.distributed + + with ThreadPoolExecutor(num_shards) as executor: + ports = [3000 + i for i in range(num_shards)] + rets = executor.map(start_model_process, ports) + + self.model_servers = [x[0] for x in rets] + self.procs = [x[1] for x in rets] + + def init_model(i): + return self.model_servers[i].init_model(i) + + _ = [obtain(x) for x in executor.map(init_model, range(num_shards))] + + def _func( + requests: Sequence[Union[PrefillRequest, DecodeRequest]], cache: KVCache + ) -> List[TextGenerationResult]: + def generate(i): + return self.model_servers[i].generate(requests, cache) + + res = [obtain(x) for x in executor.map(generate, range(num_shards)] + return obtain(res[0].value) + + self.generate = _func + + class Model: def __init__( self, @@ -150,13 +247,17 @@ def generate( num_sequences.append(request.num_sequence) prompt_lens.append(len(request.token_ids)) seq_group_sequence_ids[request.request_id].append(sequence_ids[-1]) - seq_group_sampling_params[request.request_id] = convert_sampling_params(request.sampling_params) + seq_group_sampling_params[request.request_id] = convert_sampling_params( + request.sampling_params + ) else: sequence_ids.append(request.sequence_id) prompt_lens.append(request.prompt_token_counts) req_id = request.sequence_id.request_id seq_group_sequence_ids[req_id].append(request.sequence_id) - seq_group_sampling_params[req_id] = convert_sampling_params(request.sampling_params) + seq_group_sampling_params[req_id] = convert_sampling_params( + request.sampling_params + ) all_token_ids.append(request.token_ids) sampling_params.append(request.sampling_params) @@ -277,97 +378,49 @@ def generate( return outputs -def init_cache_blocks(head_size, num_layers, num_heads, block_size, num_gpu_blocks): - element_size = 2 - x = 16 // element_size - - key_block_shape = (num_heads, head_size // x, block_size, x) - value_block_shape = (num_heads, head_size, block_size) - - gpu_cache = [] - for _ in range(num_layers): - key_blocks = torch.empty( - size=(num_gpu_blocks, *key_block_shape), - dtype=torch.float16, - device="cuda", - ) - value_blocks = torch.empty( - size=(num_gpu_blocks, *value_block_shape), - dtype=torch.float16, - device="cuda", - ) - gpu_cache.append((key_blocks, value_blocks)) - return gpu_cache - - -def _init_distributed_environment( - parallel_config: ParallelConfig, - rank: int, - distributed_init_method: Optional[str] = None, -) -> None: - """Initialize the distributed environment.""" - if torch.distributed.is_initialized(): - torch_world_size = torch.distributed.get_world_size() - if torch_world_size != parallel_config.world_size: - raise RuntimeError( - "torch.distributed is already initialized but the torch world " - "size does not match parallel_config.world_size " - f"({torch_world_size} vs. {parallel_config.world_size}).") - elif not distributed_init_method: - raise ValueError( - "distributed_init_method must be set if torch.distributed " - "is not already initialized") - else: - torch.distributed.init_process_group( - backend="nccl", - world_size=parallel_config.world_size, - rank=rank, - init_method=distributed_init_method, - ) - - # A small all_reduce for warmup. - torch.distributed.all_reduce(torch.zeros(1).cuda()) - initialize_model_parallel(parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size) +def init_torch_model( + model_path, engine_config: MLCServeEngineConfig +) -> Tuple[TextGenerator, CacheManager, ModelArtifactConfig]: + hf_config = AutoConfig.from_pretrained(model_path) + # TODO + num_shards = 1 -class RayWorker(TorchDistributedWorker): - def __init__(self) -> None: - self.model = None - self.cache_manager = None + num_kv_heads = hf_config.num_key_value_heads // num_shards + head_size = hf_config.hidden_size // hf_config.num_attention_heads - def init_worker(self, hf_config, model_path, engine_config): - # os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) - # # Env vars will be set by Ray. - # self.rank = self.rank if self.rank is not None else int( - # os.getenv("RANK", "-1")) - # local_rank = int(os.getenv("LOCAL_RANK", "0")) - # self.device = torch.device(f"cuda:{local_rank}") - # if self.rank < 0: - # raise ValueError("Invalid or unspecified rank.") - # torch.cuda.set_device(self.device) + if not hasattr(hf_config, "sliding_window"): + hf_config.sliding_window = None - # # Initialize the distributed environment. - # _init_distributed_environment(self.parallel_config, self.rank, - # self.distributed_init_method) + hf_config.num_shards = num_shards - # # Initialize the model. - # set_random_seed(self.model_config.seed) + artifact_config = ModelArtifactConfig( + model_artifact_path=model_path, + num_shards=1, + quantization=None, + max_context_length=hf_config.max_position_embeddings, # TODO, + vocab_size=hf_config.vocab_size, + sliding_window=hf_config.sliding_window, + num_key_value_heads=num_kv_heads, + num_attention_heads=hf_config.num_attention_heads, + num_hidden_layers=hf_config.num_hidden_layers, + hidden_size=hf_config.hidden_size, + ) + if num_shareds > 1: + model = ModelRpcClient() + else: with torch.device("cuda"): torch.set_default_dtype(torch.float16) pt_model = LlamaForCausalLM(hf_config) pt_model.load_weights(model_path, None, "auto", None) - self.model = Model(pt_model, hf_config) - - head_size = hf_config.hidden_size // hf_config.num_attention_heads - num_kv_heads = hf_config.num_key_value_heads // num_shards + model = Model(pt_model, hf_config) if engine_config.max_num_batched_tokens > 0: LOG.info("Running memory profiling.") num_blocks = get_num_cache_blocks( - self.model, + model, [engine_config.max_input_len] * engine_config.max_num_sequences, hf_config.num_hidden_layers, num_kv_heads, @@ -397,164 +450,12 @@ def init_worker(self, hf_config, model_path, engine_config): num_blocks, ) - self.cache_manager = CacheManager( + cache_manager = CacheManager( cache_blocks, num_blocks, hf_config.sliding_window, ) - def __getattr__(self, name): - return getattr(self.worker, name) - - def execute_method(self, method, *args, **kwargs): - executor = getattr(self, method) - return executor(*args, **kwargs) - - -def run_workers( - workers, - method: str, - *args, - get_all_outputs: bool = False, - **kwargs, -): - """Runs the given method on all workers.""" - from functools import partial - - all_outputs = [] - for worker in workers: - executor = partial(worker.execute_method.remote, method) - - output = executor(*args, **kwargs) - all_outputs.append(output) - - all_outputs = ray.get(all_outputs) - - if get_all_outputs: - return all_outputs - - # Make sure all workers have the same results. - output = all_outputs[0] - for other_output in all_outputs[1:]: - assert output == other_output - - return output - - -def init_workers_ray(placement_group: "PlacementGroup", parallel_config): - import copy - from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy - from ray.air.util.torch_dist import init_torch_dist_process_group - - workers: List[Worker] = [] - for bundle in placement_group.bundle_specs: - if not bundle.get("GPU", 0): - continue - worker = ray.remote( - num_cpus=0, - num_gpus=1, - scheduling_strategy=PlacementGroupSchedulingStrategy( - placement_group=placement_group, - placement_group_capture_child_tasks=True), - **ray_remote_kwargs, - )(RayWorker).remote() - - workers.append(worker) - - # Initialize torch distributed process group for the workers. - init_torch_dist_process_group(workers, backend="nccl") - parallel_config = copy.deepcopy(parallel_config) - run_workers("init_worker") # TODO - - -def init_torch_model( - model_path, engine_config: MLCServeEngineConfig -) -> Tuple[TextGenerator, CacheManager, ModelArtifactConfig]: - torch.distributed.init_process_group( - backend="nccl", world_size=1, rank=0, init_method="tcp://localhost:59157" - ) - initialize_model_parallel(1, 1) - - print("model_path", model_path) - hf_config = AutoConfig.from_pretrained(model_path) - - # TODO - num_shards = 1 - - if num_shards > 1: - from vllm.config import ParallelConfig - from vllm.engine.ray_utils import initialize_cluster - parallel_config = ParallelConfig(1, num_shards, True) - _, placement_group = initialize_cluster(parallel_config) - init_workers_ray(placement_group, parallel_config) - - num_kv_heads = hf_config.num_key_value_heads // num_shards - head_size = hf_config.hidden_size // hf_config.num_attention_heads - - if not hasattr(hf_config, "sliding_window"): - hf_config.sliding_window = None - - hf_config.num_shards = num_shards - - artifact_config = ModelArtifactConfig( - model_artifact_path=model_path, - num_shards=1, - quantization=None, - max_context_length=hf_config.max_position_embeddings, # TODO, - vocab_size=hf_config.vocab_size, - sliding_window=hf_config.sliding_window, - num_key_value_heads=num_kv_heads, - num_attention_heads=hf_config.num_attention_heads, - num_hidden_layers=hf_config.num_hidden_layers, - hidden_size=hf_config.hidden_size, - ) - - with torch.device("cuda"): - torch.set_default_dtype(torch.float16) - pt_model = LlamaForCausalLM(hf_config) - pt_model.load_weights(model_path, None, "auto", None) - - model = Model(pt_model, hf_config) - - if engine_config.max_num_batched_tokens > 0: - LOG.info("Running memory profiling.") - num_blocks = get_num_cache_blocks( - model, - [engine_config.max_input_len] * engine_config.max_num_sequences, - hf_config.num_hidden_layers, - num_kv_heads, - head_size, - ) - else: - num_blocks = 500 - - num_cache_slots = num_blocks * CacheManager.block_size - - if num_cache_slots <= engine_config.max_num_batched_tokens: - raise RuntimeError( - f"max_num_batched_tokens = {engine_config.max_num_batched_tokens} but" - f" only {num_blocks} cache blocks can be allocated. The number of" - f" available cache slots is {num_cache_slots}, not enough for" - f" {engine_config.max_num_batched_tokens} tokens. Try reducing" - " --max_input_len or --max_num_sequences." - ) - - LOG.info(f"Using {num_blocks} cache blocks.") - - cache_blocks = init_cache_blocks( - head_size, - hf_config.num_hidden_layers, - hf_config.num_attention_heads, - CacheManager.block_size, - num_blocks, - ) - - cache_manager = CacheManager( - cache_blocks, - num_blocks, - hf_config.sliding_window, - ) - - LOG.info("Allocated KV cache blocks.") + LOG.info("Allocated KV cache blocks.") return model, cache_manager, artifact_config From 32686d81b6f362a982a80dd56e197327f6f20d6d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 18 Jan 2024 10:31:58 +0000 Subject: [PATCH 26/73] attach cache_blocks to model --- serve/mlc_serve/model/paged_cache_manager.py | 5 +---- serve/mlc_serve/model/tvm_model.py | 13 ++++++++----- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/serve/mlc_serve/model/paged_cache_manager.py b/serve/mlc_serve/model/paged_cache_manager.py index 7cd24f5182..5b1bb8498f 100644 --- a/serve/mlc_serve/model/paged_cache_manager.py +++ b/serve/mlc_serve/model/paged_cache_manager.py @@ -104,10 +104,8 @@ def replace_head_prompt_block_with(self, new_block): class KVCache: def __init__( self, - cache_blocks, block_size, ): - self.cache_blocks = cache_blocks self.block_size = block_size # SequenceId -> list[int] @@ -135,13 +133,12 @@ def get_cache_block_size(num_layers, num_heads, head_size): def __init__( self, - cache_blocks, # This can be any type num_blocks: int, sliding_window: Optional[int] = None, ): self.num_blocks = num_blocks self.free_blocks = list(range(num_blocks)) - self.kv_cache = KVCache(cache_blocks, self.block_size) + self.kv_cache = KVCache(self.block_size) self.token_counts = dict[SequenceId, int]() if sliding_window: diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index d90a5b695b..f6c51875f9 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -152,6 +152,8 @@ def __init__( "tvm.contrib.vllm.copy_blocks" ) + self.cache_blocks = None + def get_used_memory(self): if self.disco_session: params = self.params.debug_get_from_remote(0) @@ -270,7 +272,7 @@ def generate( input_ids, positions, seq_lens, - cache.cache_blocks, + self.cache_blocks, slot_mapping, indices_within_window, self.params, @@ -280,7 +282,7 @@ def generate( input_ids, positions, seq_lens, - cache.cache_blocks, + self.cache_blocks, slot_mapping, self.params, ) @@ -301,7 +303,7 @@ def generate( input_ids, positions, seq_lens, - cache.cache_blocks, + self.cache_blocks, slot_mapping, block_tables, self.params, @@ -328,7 +330,7 @@ def generate( "int64", ) - self.copy_cache_blocks_func(cache.cache_blocks, block_mapping) + self.copy_cache_blocks_func(self.cache_blocks, block_mapping) cache.pending_copy_from_to = [] try: @@ -492,8 +494,9 @@ def init_tvm_model( num_blocks, ) + model.cache_blocks = cache_blocks + cache_manager = CacheManager( - cache_blocks, num_blocks, model_artifact_config.sliding_window, ) From de2631b468aa258b87ec99e56bacb42639989598 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 18 Jan 2024 18:57:31 +0000 Subject: [PATCH 27/73] change get_num_cache_blocks signature --- serve/mlc_serve/model/model_common.py | 4 +--- serve/mlc_serve/model/tvm_model.py | 9 ++++----- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/serve/mlc_serve/model/model_common.py b/serve/mlc_serve/model/model_common.py index 8f1f716d7d..7b0b9c383e 100644 --- a/serve/mlc_serve/model/model_common.py +++ b/serve/mlc_serve/model/model_common.py @@ -19,14 +19,12 @@ def get_gpu_memory(gpu: int = 0) -> int: def get_num_cache_blocks( - model, - seq_lens, + used_memory_bytes, num_layers, num_kv_heads, head_size, gpu_memory_utilization=0.9, # the default used by vllm ): - used_memory_bytes = model.profile_memory_usage(seq_lens) cache_block_size = CacheManager.get_cache_block_size( num_layers, num_kv_heads, head_size ) diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index f6c51875f9..7becbbb0a8 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -456,9 +456,10 @@ def init_tvm_model( if engine_config.max_num_batched_tokens > 0: LOG.info("Running memory profiling.") + seq_lens = [engine_config.max_input_len] * engine_config.max_num_sequences + used_memory_bytes = model.profile_memory_usage(seq_lens) num_blocks = get_num_cache_blocks( - model, - [engine_config.max_input_len] * engine_config.max_num_sequences, + used_memory_bytes, model_artifact_config.num_hidden_layers, num_kv_heads, head_size, @@ -486,7 +487,7 @@ def init_tvm_model( else: init_cache_func = tvm.get_global_func("tvm.contrib.vllm.allocate_kv_cache") - cache_blocks = init_cache_func( + model.cache_blocks = init_cache_func( head_size, model_artifact_config.num_hidden_layers, num_kv_heads, @@ -494,8 +495,6 @@ def init_tvm_model( num_blocks, ) - model.cache_blocks = cache_blocks - cache_manager = CacheManager( num_blocks, model_artifact_config.sliding_window, From 618ca62a49dc4449c50ba3a3ef14075f8e6e84f5 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 18 Jan 2024 19:55:07 +0000 Subject: [PATCH 28/73] wip --- serve/mlc_serve/model/torch_model.py | 309 ++++++++++++++++++++++++--- 1 file changed, 280 insertions(+), 29 deletions(-) diff --git a/serve/mlc_serve/model/torch_model.py b/serve/mlc_serve/model/torch_model.py index 6321a2c446..a97abe3952 100644 --- a/serve/mlc_serve/model/torch_model.py +++ b/serve/mlc_serve/model/torch_model.py @@ -1,4 +1,5 @@ import time +import os from typing import List, Union, Tuple, Sequence from collections import defaultdict @@ -13,7 +14,8 @@ from vllm.model_executor import InputMetadata from vllm.sampling_params import SamplingParams -import multiprocessing +import torch.multiprocessing as multiprocessing + import rpyc from rpyc.utils.classic import obtain from rpyc.utils.server import ThreadedServer @@ -78,27 +80,270 @@ def init_cache_blocks(head_size, num_layers, num_heads, block_size, num_gpu_bloc class ModelRpcServer(rpyc.Service): - def exposed_init_model(self, tp_rank: int): - pass + def exposed_init_model( + self, tp_rank, num_shards, model_path, hf_config, engine_config + ): + # torch.distributed.all_reduce does not free the input tensor until + # the synchronization point. This causes the memory usage to grow + # as the number of all_reduce calls increases. This env var disables + # this behavior. + # Related issue: + # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + + torch.cuda.set_device(tp_rank) + torch.distributed.init_process_group( + backend="nccl", + world_size=num_shards, + rank=tp_rank, + # init_method=f"tcp://127.0.0.1:{self.nccl_port}", + ) + + # A small all_reduce for warmup. + torch.distributed.all_reduce(torch.zeros(1).cuda()) + + with torch.device("cuda"): + torch.set_default_dtype(torch.float16) + pt_model = LlamaForCausalLM(hf_config) + pt_model.load_weights(model_path, None, "auto", None) + + num_kv_heads = hf_config.num_key_value_heads // num_shards + head_size = hf_config.hidden_size // hf_config.num_attention_heads + + if engine_config.max_num_batched_tokens > 0: + LOG.info("Running memory profiling.") + seq_lens = ( + [engine_config.max_input_len] * engine_config.max_num_sequences, + ) + used_memory_bytes = self.profile_memory_usage(seq_lens) + + num_blocks = get_num_cache_blocks( + used_memory_bytes, + hf_config.num_hidden_layers, + num_kv_heads, + head_size, + ) + else: + num_blocks = 500 + + self.cache_blocks = init_cache_blocks( + head_size, + hf_config.num_hidden_layers, + hf_config.num_attention_heads, + CacheManager.block_size, + num_blocks, + ) + + return num_blocks def profile_memory_usage(self, seq_lens): - return 0 + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + sampling_params = SamplingParams(top_p=0.99) + + seq_groups: List[Tuple[List[int], SamplingParams]] = [] + seq_data = {} + input_tokens: List[int] = [] + input_positions: List[int] = [] + slot_mapping: List[int] = [] + + for i, seq_len in enumerate(seq_lens): + seq_groups.append(([i], sampling_params)) + prompt_tokens = [0] * seq_len + seq_data[i] = SequenceData(prompt_tokens) + + input_tokens.extend(prompt_tokens) + input_positions.extend(range(seq_len)) + slot_mapping.extend([0] * seq_len) + + input_ids = torch.cuda.LongTensor(input_tokens) + positions = torch.cuda.LongTensor(input_positions) + slot_mapping_tensor = torch.cuda.IntTensor(slot_mapping) + + input_metadata = InputMetadata( + seq_groups=seq_groups, + seq_data=seq_data, + prompt_lens=seq_lens, + slot_mapping=slot_mapping_tensor, + context_lens=torch.cuda.IntTensor([]), + max_context_len=0, + block_tables=torch.cuda.IntTensor([]), + ) + + kv_caches = [(None, None)] * self.num_hidden_layers + + with torch.no_grad(): + self.pt_model.forward( + input_ids, + positions, + kv_caches, + input_metadata, + cache_events=None, + ) + + torch.cuda.synchronize() + + peak_memory = torch.cuda.max_memory_allocated() + print("peak memory", peak_memory / 1e9) + + torch.cuda.empty_cache() + + return peak_memory def exposed_generate( self, requests: Sequence[Union[PrefillRequest, DecodeRequest]], cache: KVCache ) -> List[TextGenerationResult]: - return [] + requests = obtain(requests) + cache = obtain(cache) + if len(requests) == 0: + return [] -def start_model_process(port): - def _init_service(port): - t = ThreadedServer( - ModelRpcServer(), - port=port, - protocol_config={"allow_pickle": True, "sync_request_timeout": 600}, + is_prefill = isinstance(requests[0], PrefillRequest) + + all_token_ids = [] + sampling_params = [] + sequence_ids = [] + prompt_lens = [] + num_sequences = [] + seq_data = {} + seq_group_sequence_ids = defaultdict(list) + seq_group_sampling_params = {} + + for request in requests: + if isinstance(request, PrefillRequest): + sequence_ids.append(get_prompt_sequence_id(request.request_id)) + num_sequences.append(request.num_sequence) + prompt_lens.append(len(request.token_ids)) + seq_group_sequence_ids[request.request_id].append(sequence_ids[-1]) + seq_group_sampling_params[request.request_id] = convert_sampling_params( + request.sampling_params + ) + else: + sequence_ids.append(request.sequence_id) + prompt_lens.append(request.prompt_token_counts) + req_id = request.sequence_id.request_id + seq_group_sequence_ids[req_id].append(request.sequence_id) + seq_group_sampling_params[req_id] = convert_sampling_params( + request.sampling_params + ) + + all_token_ids.append(request.token_ids) + sampling_params.append(request.sampling_params) + + seq_data[sequence_ids[-1]] = SequenceData(request.token_ids) + + seq_groups: List[Tuple[List[SequenceId], SamplingParams]] = [] + + for req_id, seq_ids in seq_group_sequence_ids.items(): + seq_groups.append((seq_ids, seq_group_sampling_params[req_id])) + + ( + input_ids, + positions, + seq_lens, + slot_mapping, + _, + block_tables, + ) = prepare_inputs( + sequence_ids, + all_token_ids, + prompt_lens, + cache.slot_mappings, + cache.decode_block_tables, + self.sliding_window, + is_prefill, + torch.long, + align=8, + ) + + input_shape = input_ids.shape + + if block_tables is None: + torch.cuda.nvtx.range_push(f"forward prefill {input_shape}") + block_tables = torch.cuda.IntTensor([]) + context_lens = torch.cuda.IntTensor([]) + max_context_len = 0 + else: + torch.cuda.nvtx.range_push(f"forward decode {input_shape}") + context_lens = seq_lens + max_context_len = torch.max(seq_lens) + prompt_lens = [] + + input_metadata = InputMetadata( + seq_groups=seq_groups, + seq_data=seq_data, + prompt_lens=prompt_lens, + slot_mapping=slot_mapping, + context_lens=context_lens, + max_context_len=max_context_len, + block_tables=block_tables, ) - t.start() + with torch.no_grad(): + hidden_states = self.pt_model.model( + input_ids, + positions, + self.cache_blocks, + input_metadata, + # No need for this until parallel sampling is supported. + cache_events=None, + ) + + if hidden_states.shape[0] != len( + input_metadata.prompt_lens + ) and hidden_states.shape[0] != len(input_metadata.context_lens): + logits = get_logits( + self.pt_model.lm_head.weight, + hidden_states, + input_metadata, + self.vocab_size, + ) + + next_tokens = sample(logits, sampling_params, self.vocab_size) + + torch.cuda.synchronize() + torch.cuda.nvtx.range_pop() + + outputs = [] + + for i, (sequence_id, new_token) in enumerate(zip(sequence_ids, next_tokens)): + if not new_token in requests[i].sampling_params.appeared_tokens_freq: + requests[i].sampling_params.appeared_tokens_freq[new_token] = 0 + requests[i].sampling_params.appeared_tokens_freq[new_token] += 1 + if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: + for seq_id in range(num_sequences[i]): + outputs.append( + TextGenerationResult( + sequence_id=SequenceId(sequence_id.request_id, seq_id), + generated_tokens=[new_token], + error=None, + ) + ) + else: + outputs.append( + TextGenerationResult( + sequence_id=sequence_id, + generated_tokens=[new_token], + error=None, + ) + ) + + return outputs + + +def _init_service(port): + t = ThreadedServer( + ModelRpcServer(), + port=port, + protocol_config={"allow_pickle": True, "sync_request_timeout": 600}, + ) + t.start() + + +def start_model_process(port): + multiprocessing.set_start_method('spawn', force=True) proc = multiprocessing.Process(target=_init_service, args=(port,)) proc.start() time.sleep(1) @@ -123,20 +368,21 @@ def _init_service(port): class ModelRpcClient: - def __init__(self, num_shards): - # TODO: Init torch.distributed - + def __init__(self, num_shards, model_path, hf_config, engine_config): with ThreadPoolExecutor(num_shards) as executor: ports = [3000 + i for i in range(num_shards)] rets = executor.map(start_model_process, ports) + print("started processes") self.model_servers = [x[0] for x in rets] self.procs = [x[1] for x in rets] def init_model(i): - return self.model_servers[i].init_model(i) + return self.model_servers[i].init_model(i, num_shards, model_path, hf_config, engine_config) + + rets = [obtain(x) for x in executor.map(init_model, range(num_shards))] - _ = [obtain(x) for x in executor.map(init_model, range(num_shards))] + self.num_blocks = rets[0] def _func( requests: Sequence[Union[PrefillRequest, DecodeRequest]], cache: KVCache @@ -144,11 +390,14 @@ def _func( def generate(i): return self.model_servers[i].generate(requests, cache) - res = [obtain(x) for x in executor.map(generate, range(num_shards)] + res = [obtain(x) for x in executor.map(generate, range(num_shards))] return obtain(res[0].value) self.generate = _func + def get_num_cache_blocks(self): + return self.num_blocks + class Model: def __init__( @@ -167,6 +416,8 @@ def __init__( else: self.block_sliding_window = None + self.cache_blocks = None + def profile_memory_usage(self, seq_lens): torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() @@ -327,7 +578,7 @@ def generate( hidden_states = self.pt_model.model( input_ids, positions, - cache.cache_blocks, + self.cache_blocks, input_metadata, # No need for this until parallel sampling is supported. cache_events=None, @@ -384,7 +635,7 @@ def init_torch_model( hf_config = AutoConfig.from_pretrained(model_path) # TODO - num_shards = 1 + num_shards = 2 num_kv_heads = hf_config.num_key_value_heads // num_shards head_size = hf_config.hidden_size // hf_config.num_attention_heads @@ -407,8 +658,9 @@ def init_torch_model( hidden_size=hf_config.hidden_size, ) - if num_shareds > 1: - model = ModelRpcClient() + if num_shards > 1: + model = ModelRpcClient(num_shards, model_path, hf_config, engine_config) + num_blocks = model.get_num_cache_blocks() else: with torch.device("cuda"): torch.set_default_dtype(torch.float16) @@ -442,7 +694,7 @@ def init_torch_model( LOG.info(f"Using {num_blocks} cache blocks.") - cache_blocks = init_cache_blocks( + model.cache_blocks = init_cache_blocks( head_size, hf_config.num_hidden_layers, hf_config.num_attention_heads, @@ -450,12 +702,11 @@ def init_torch_model( num_blocks, ) - cache_manager = CacheManager( - cache_blocks, - num_blocks, - hf_config.sliding_window, - ) - LOG.info("Allocated KV cache blocks.") + cache_manager = CacheManager( + num_blocks, + hf_config.sliding_window, + ) + return model, cache_manager, artifact_config From 9ce2f47d66724f7ed1c172c8f1b00d1eeebcd93a Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 19 Jan 2024 06:36:27 +0000 Subject: [PATCH 29/73] wip --- serve/mlc_serve/model/torch_model.py | 52 ++++++++++++++++++++++++---- serve/tests/test_engine.py | 3 ++ 2 files changed, 48 insertions(+), 7 deletions(-) diff --git a/serve/mlc_serve/model/torch_model.py b/serve/mlc_serve/model/torch_model.py index a97abe3952..343216830d 100644 --- a/serve/mlc_serve/model/torch_model.py +++ b/serve/mlc_serve/model/torch_model.py @@ -13,6 +13,9 @@ from vllm.sequence import SequenceData from vllm.model_executor import InputMetadata from vllm.sampling_params import SamplingParams +from vllm.model_executor.parallel_utils.parallel_state import ( + initialize_model_parallel, +) import torch.multiprocessing as multiprocessing @@ -83,6 +86,22 @@ class ModelRpcServer(rpyc.Service): def exposed_init_model( self, tp_rank, num_shards, model_path, hf_config, engine_config ): + hf_config = obtain(hf_config) + engine_config = obtain(engine_config) + model_path = obtain(model_path) + + self.vocab_size = hf_config.vocab_size + self.sliding_window = hf_config.sliding_window + self.num_shards = hf_config.num_shards + self.num_hidden_layers = hf_config.num_hidden_layers + + if self.sliding_window: + self.block_sliding_window = self.sliding_window // CacheManager.block_size + else: + self.block_sliding_window = None + + self.cache_blocks = None + # torch.distributed.all_reduce does not free the input tensor until # the synchronization point. This causes the memory usage to grow # as the number of all_reduce calls increases. This env var disables @@ -90,21 +109,33 @@ def exposed_init_model( # Related issue: # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + print("exposed init_model", tp_rank) + print("set device", tp_rank) torch.cuda.set_device(tp_rank) + print("done", tp_rank) + + os.environ["MASTER_ADDR"] = str("127.0.0.1") + os.environ["MASTER_PORT"] = str(4000) + + print("init_process_group", tp_rank, flush=True) torch.distributed.init_process_group( backend="nccl", world_size=num_shards, rank=tp_rank, - # init_method=f"tcp://127.0.0.1:{self.nccl_port}", + # init_method=f"tcp://127.0.0.1:{nccl_port}", ) + initialize_model_parallel(num_shards, 1) # A small all_reduce for warmup. + print("warm up", tp_rank) torch.distributed.all_reduce(torch.zeros(1).cuda()) with torch.device("cuda"): + print("loading model") torch.set_default_dtype(torch.float16) pt_model = LlamaForCausalLM(hf_config) + print("loading weights") pt_model.load_weights(model_path, None, "auto", None) num_kv_heads = hf_config.num_key_value_heads // num_shards @@ -113,7 +144,7 @@ def exposed_init_model( if engine_config.max_num_batched_tokens > 0: LOG.info("Running memory profiling.") seq_lens = ( - [engine_config.max_input_len] * engine_config.max_num_sequences, + [engine_config.max_input_len] * engine_config.max_num_sequences ) used_memory_bytes = self.profile_memory_usage(seq_lens) @@ -343,9 +374,10 @@ def _init_service(port): def start_model_process(port): - multiprocessing.set_start_method('spawn', force=True) proc = multiprocessing.Process(target=_init_service, args=(port,)) proc.start() + # proc = multiprocessing.spawn(_init_service, (port,), nprocs + time.sleep(1) repeat_count = 0 @@ -364,15 +396,15 @@ def start_model_process(port): raise RuntimeError("init rpc env error!") assert proc.is_alive() + print("start process") return con.root, proc class ModelRpcClient: def __init__(self, num_shards, model_path, hf_config, engine_config): with ThreadPoolExecutor(num_shards) as executor: - ports = [3000 + i for i in range(num_shards)] + ports = [3010 + i for i in range(num_shards)] # TODO rets = executor.map(start_model_process, ports) - print("started processes") self.model_servers = [x[0] for x in rets] self.procs = [x[1] for x in rets] @@ -662,6 +694,11 @@ def init_torch_model( model = ModelRpcClient(num_shards, model_path, hf_config, engine_config) num_blocks = model.get_num_cache_blocks() else: + torch.distributed.init_process_group( + backend="nccl", world_size=1, rank=0, init_method="tcp://localhost:59157" + ) + initialize_model_parallel(1, 1) + with torch.device("cuda"): torch.set_default_dtype(torch.float16) pt_model = LlamaForCausalLM(hf_config) @@ -671,9 +708,10 @@ def init_torch_model( if engine_config.max_num_batched_tokens > 0: LOG.info("Running memory profiling.") + seq_lens = [engine_config.max_input_len] * engine_config.max_num_sequences + used_memory_bytes = model.profile_memory_usage(seq_lens) num_blocks = get_num_cache_blocks( - model, - [engine_config.max_input_len] * engine_config.max_num_sequences, + used_memory_bytes, hf_config.num_hidden_layers, num_kv_heads, head_size, diff --git a/serve/tests/test_engine.py b/serve/tests/test_engine.py index e3644c34c1..6290fe1f53 100644 --- a/serve/tests/test_engine.py +++ b/serve/tests/test_engine.py @@ -138,6 +138,9 @@ def _test(args: argparse.Namespace): if __name__ == "__main__": + import torch + torch.multiprocessing.set_start_method('spawn') + parser = get_default_mlc_serve_argparser("test engine") parser.add_argument("--long-prompt", action="store_true") parser.add_argument("--use-random-sampling", action="store_true") From c14c0e964131a2bcc3b3565598d7354da388f88a Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 19 Jan 2024 08:21:37 +0000 Subject: [PATCH 30/73] wip --- serve/mlc_serve/model/torch_model.py | 152 +++++++++++++-------------- serve/tests/test_engine.py | 9 +- 2 files changed, 77 insertions(+), 84 deletions(-) diff --git a/serve/mlc_serve/model/torch_model.py b/serve/mlc_serve/model/torch_model.py index 343216830d..eabc236d3d 100644 --- a/serve/mlc_serve/model/torch_model.py +++ b/serve/mlc_serve/model/torch_model.py @@ -82,6 +82,62 @@ def init_cache_blocks(head_size, num_layers, num_heads, block_size, num_gpu_bloc return gpu_cache +def profile_memory_usage(pt_model, seq_lens, num_hidden_layers): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + sampling_params = SamplingParams(top_p=0.99) + + seq_groups: List[Tuple[List[int], SamplingParams]] = [] + seq_data = {} + input_tokens: List[int] = [] + input_positions: List[int] = [] + slot_mapping: List[int] = [] + + for i, seq_len in enumerate(seq_lens): + seq_groups.append(([i], sampling_params)) + prompt_tokens = [0] * seq_len + seq_data[i] = SequenceData(prompt_tokens) + + input_tokens.extend(prompt_tokens) + input_positions.extend(range(seq_len)) + slot_mapping.extend([0] * seq_len) + + input_ids = torch.cuda.LongTensor(input_tokens) + positions = torch.cuda.LongTensor(input_positions) + slot_mapping_tensor = torch.cuda.IntTensor(slot_mapping) + + input_metadata = InputMetadata( + seq_groups=seq_groups, + seq_data=seq_data, + prompt_lens=seq_lens, + slot_mapping=slot_mapping_tensor, + context_lens=torch.cuda.IntTensor([]), + max_context_len=0, + block_tables=torch.cuda.IntTensor([]), + ) + + kv_caches = [(None, None)] * num_hidden_layers + + with torch.no_grad(): + pt_model.forward( + input_ids, + positions, + kv_caches, + input_metadata, + cache_events=None, + ) + + torch.cuda.synchronize() + + peak_memory = torch.cuda.max_memory_allocated() + print("peak memory", peak_memory / 1e9) + + torch.cuda.empty_cache() + + return peak_memory + + class ModelRpcServer(rpyc.Service): def exposed_init_model( self, tp_rank, num_shards, model_path, hf_config, engine_config @@ -109,34 +165,26 @@ def exposed_init_model( # Related issue: # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" - print("exposed init_model", tp_rank) - print("set device", tp_rank) torch.cuda.set_device(tp_rank) - print("done", tp_rank) os.environ["MASTER_ADDR"] = str("127.0.0.1") - os.environ["MASTER_PORT"] = str(4000) + os.environ["MASTER_PORT"] = str(4000) # TODO port - print("init_process_group", tp_rank, flush=True) torch.distributed.init_process_group( backend="nccl", world_size=num_shards, rank=tp_rank, - # init_method=f"tcp://127.0.0.1:{nccl_port}", ) - initialize_model_parallel(num_shards, 1) + initialize_model_parallel(num_shards) # A small all_reduce for warmup. - print("warm up", tp_rank) torch.distributed.all_reduce(torch.zeros(1).cuda()) with torch.device("cuda"): - print("loading model") torch.set_default_dtype(torch.float16) - pt_model = LlamaForCausalLM(hf_config) - print("loading weights") - pt_model.load_weights(model_path, None, "auto", None) + self.pt_model = LlamaForCausalLM(hf_config) + self.pt_model.load_weights(model_path, None, "auto", None) num_kv_heads = hf_config.num_key_value_heads // num_shards head_size = hf_config.hidden_size // hf_config.num_attention_heads @@ -146,7 +194,7 @@ def exposed_init_model( seq_lens = ( [engine_config.max_input_len] * engine_config.max_num_sequences ) - used_memory_bytes = self.profile_memory_usage(seq_lens) + used_memory_bytes = profile_memory_usage(self.pt_model, seq_lens, self.num_hidden_layers) num_blocks = get_num_cache_blocks( used_memory_bytes, @@ -157,71 +205,18 @@ def exposed_init_model( else: num_blocks = 500 + LOG.info(f"Using {num_blocks} cache blocks.") + self.cache_blocks = init_cache_blocks( head_size, hf_config.num_hidden_layers, - hf_config.num_attention_heads, + num_kv_heads, CacheManager.block_size, num_blocks, ) return num_blocks - def profile_memory_usage(self, seq_lens): - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() - - sampling_params = SamplingParams(top_p=0.99) - - seq_groups: List[Tuple[List[int], SamplingParams]] = [] - seq_data = {} - input_tokens: List[int] = [] - input_positions: List[int] = [] - slot_mapping: List[int] = [] - - for i, seq_len in enumerate(seq_lens): - seq_groups.append(([i], sampling_params)) - prompt_tokens = [0] * seq_len - seq_data[i] = SequenceData(prompt_tokens) - - input_tokens.extend(prompt_tokens) - input_positions.extend(range(seq_len)) - slot_mapping.extend([0] * seq_len) - - input_ids = torch.cuda.LongTensor(input_tokens) - positions = torch.cuda.LongTensor(input_positions) - slot_mapping_tensor = torch.cuda.IntTensor(slot_mapping) - - input_metadata = InputMetadata( - seq_groups=seq_groups, - seq_data=seq_data, - prompt_lens=seq_lens, - slot_mapping=slot_mapping_tensor, - context_lens=torch.cuda.IntTensor([]), - max_context_len=0, - block_tables=torch.cuda.IntTensor([]), - ) - - kv_caches = [(None, None)] * self.num_hidden_layers - - with torch.no_grad(): - self.pt_model.forward( - input_ids, - positions, - kv_caches, - input_metadata, - cache_events=None, - ) - - torch.cuda.synchronize() - - peak_memory = torch.cuda.max_memory_allocated() - print("peak memory", peak_memory / 1e9) - - torch.cuda.empty_cache() - - return peak_memory - def exposed_generate( self, requests: Sequence[Union[PrefillRequest, DecodeRequest]], cache: KVCache ) -> List[TextGenerationResult]: @@ -376,7 +371,6 @@ def _init_service(port): def start_model_process(port): proc = multiprocessing.Process(target=_init_service, args=(port,)) proc.start() - # proc = multiprocessing.spawn(_init_service, (port,), nprocs time.sleep(1) @@ -403,7 +397,7 @@ def start_model_process(port): class ModelRpcClient: def __init__(self, num_shards, model_path, hf_config, engine_config): with ThreadPoolExecutor(num_shards) as executor: - ports = [3010 + i for i in range(num_shards)] # TODO + ports = [3000 + i for i in range(num_shards)] # TODO port rets = executor.map(start_model_process, ports) self.model_servers = [x[0] for x in rets] @@ -416,16 +410,17 @@ def init_model(i): self.num_blocks = rets[0] - def _func( - requests: Sequence[Union[PrefillRequest, DecodeRequest]], cache: KVCache - ) -> List[TextGenerationResult]: - def generate(i): - return self.model_servers[i].generate(requests, cache) + def _func( + requests: Sequence[Union[PrefillRequest, DecodeRequest]], cache: KVCache + ) -> List[TextGenerationResult]: + def generate(i): + return self.model_servers[i].generate(requests, cache) + with ThreadPoolExecutor(num_shards) as executor: res = [obtain(x) for x in executor.map(generate, range(num_shards))] - return obtain(res[0].value) + return obtain(res[0]) - self.generate = _func + self.generate = _func def get_num_cache_blocks(self): return self.num_blocks @@ -691,6 +686,7 @@ def init_torch_model( ) if num_shards > 1: + torch.multiprocessing.set_start_method('spawn') model = ModelRpcClient(num_shards, model_path, hf_config, engine_config) num_blocks = model.get_num_cache_blocks() else: diff --git a/serve/tests/test_engine.py b/serve/tests/test_engine.py index 6290fe1f53..c2e63fa56e 100644 --- a/serve/tests/test_engine.py +++ b/serve/tests/test_engine.py @@ -81,9 +81,9 @@ def _test(args: argparse.Namespace): else: prompts = [ "Hello, my name is", - # "The capital of France is", - # "The president of the United States is a powerful man. But he can also be", - # "The future of AI is full of promise. But we need to carefully", + "The capital of France is", + "The president of the United States is a powerful man. But he can also be", + "The future of AI is full of promise. But we need to carefully", ] for i, prompt in enumerate(prompts): @@ -138,9 +138,6 @@ def _test(args: argparse.Namespace): if __name__ == "__main__": - import torch - torch.multiprocessing.set_start_method('spawn') - parser = get_default_mlc_serve_argparser("test engine") parser.add_argument("--long-prompt", action="store_true") parser.add_argument("--use-random-sampling", action="store_true") From 43284406d51239c186eeb09f1e14739091669e66 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 19 Jan 2024 10:46:47 +0000 Subject: [PATCH 31/73] refactor --- serve/mlc_serve/model/torch_model.py | 648 +++++++++------------------ 1 file changed, 216 insertions(+), 432 deletions(-) diff --git a/serve/mlc_serve/model/torch_model.py b/serve/mlc_serve/model/torch_model.py index eabc236d3d..eb8e78f44c 100644 --- a/serve/mlc_serve/model/torch_model.py +++ b/serve/mlc_serve/model/torch_model.py @@ -131,13 +131,195 @@ def profile_memory_usage(pt_model, seq_lens, num_hidden_layers): torch.cuda.synchronize() peak_memory = torch.cuda.max_memory_allocated() - print("peak memory", peak_memory / 1e9) torch.cuda.empty_cache() return peak_memory +def profile_and_init_cache( + pt_model, num_hidden_layers, hf_config, engine_config, num_shards +): + num_kv_heads = hf_config.num_key_value_heads // num_shards + head_size = hf_config.hidden_size // hf_config.num_attention_heads + + if engine_config.max_num_batched_tokens > 0: + LOG.info("Running memory profiling.") + seq_lens = [engine_config.max_input_len] * engine_config.max_num_sequences + used_memory_bytes = profile_memory_usage(pt_model, seq_lens, num_hidden_layers) + + num_blocks = get_num_cache_blocks( + used_memory_bytes, + hf_config.num_hidden_layers, + num_kv_heads, + head_size, + ) + else: + num_blocks = 500 + + cache_blocks = init_cache_blocks( + head_size, + hf_config.num_hidden_layers, + num_kv_heads, + CacheManager.block_size, + num_blocks, + ) + + return cache_blocks, num_blocks + + +def load_model(hf_config, model_path): + with torch.device("cuda"): + torch.set_default_dtype(torch.float16) + model = LlamaForCausalLM(hf_config) + model.load_weights(model_path, None, "auto", None) + return model + + +def generate( + requests: Sequence[Union[PrefillRequest, DecodeRequest]], + cache: KVCache, + pt_model, + cache_blocks, + sliding_window, + vocab_size, +) -> List[TextGenerationResult]: + if len(requests) == 0: + return [] + + is_prefill = isinstance(requests[0], PrefillRequest) + + all_token_ids = [] + sampling_params = [] + sequence_ids = [] + prompt_lens = [] + num_sequences = [] + seq_data = {} + seq_group_sequence_ids = defaultdict(list) + seq_group_sampling_params = {} + + for request in requests: + if isinstance(request, PrefillRequest): + sequence_ids.append(get_prompt_sequence_id(request.request_id)) + num_sequences.append(request.num_sequence) + prompt_lens.append(len(request.token_ids)) + seq_group_sequence_ids[request.request_id].append(sequence_ids[-1]) + seq_group_sampling_params[request.request_id] = convert_sampling_params( + request.sampling_params + ) + else: + sequence_ids.append(request.sequence_id) + prompt_lens.append(request.prompt_token_counts) + req_id = request.sequence_id.request_id + seq_group_sequence_ids[req_id].append(request.sequence_id) + seq_group_sampling_params[req_id] = convert_sampling_params( + request.sampling_params + ) + + all_token_ids.append(request.token_ids) + sampling_params.append(request.sampling_params) + + seq_data[sequence_ids[-1]] = SequenceData(request.token_ids) + + seq_groups: List[Tuple[List[SequenceId], SamplingParams]] = [] + + for req_id, seq_ids in seq_group_sequence_ids.items(): + seq_groups.append((seq_ids, seq_group_sampling_params[req_id])) + + ( + input_ids, + positions, + seq_lens, + slot_mapping, + _, + block_tables, + ) = prepare_inputs( + sequence_ids, + all_token_ids, + prompt_lens, + cache.slot_mappings, + cache.decode_block_tables, + sliding_window, + is_prefill, + torch.long, + align=8, + ) + + input_shape = input_ids.shape + + if block_tables is None: + torch.cuda.nvtx.range_push(f"forward prefill {input_shape}") + block_tables = torch.cuda.IntTensor([]) + context_lens = torch.cuda.IntTensor([]) + max_context_len = 0 + else: + torch.cuda.nvtx.range_push(f"forward decode {input_shape}") + context_lens = seq_lens + max_context_len = torch.max(seq_lens) + prompt_lens = [] + + input_metadata = InputMetadata( + seq_groups=seq_groups, + seq_data=seq_data, + prompt_lens=prompt_lens, + slot_mapping=slot_mapping, + context_lens=context_lens, + max_context_len=max_context_len, + block_tables=block_tables, + ) + + with torch.no_grad(): + hidden_states = pt_model.model( + input_ids, + positions, + cache_blocks, + input_metadata, + # No need for this until parallel sampling is supported. + cache_events=None, + ) + + if hidden_states.shape[0] != len( + input_metadata.prompt_lens + ) and hidden_states.shape[0] != len(input_metadata.context_lens): + logits = get_logits( + pt_model.lm_head.weight, + hidden_states, + input_metadata, + vocab_size, + ) + + next_tokens = sample(logits, sampling_params, vocab_size) + + torch.cuda.synchronize() + torch.cuda.nvtx.range_pop() + + outputs = [] + + for i, (sequence_id, new_token) in enumerate(zip(sequence_ids, next_tokens)): + if not new_token in requests[i].sampling_params.appeared_tokens_freq: + requests[i].sampling_params.appeared_tokens_freq[new_token] = 0 + requests[i].sampling_params.appeared_tokens_freq[new_token] += 1 + if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: + for seq_id in range(num_sequences[i]): + outputs.append( + TextGenerationResult( + sequence_id=SequenceId(sequence_id.request_id, seq_id), + generated_tokens=[new_token], + error=None, + ) + ) + else: + outputs.append( + TextGenerationResult( + sequence_id=sequence_id, + generated_tokens=[new_token], + error=None, + ) + ) + + return outputs + + class ModelRpcServer(rpyc.Service): def exposed_init_model( self, tp_rank, num_shards, model_path, hf_config, engine_config @@ -148,15 +330,6 @@ def exposed_init_model( self.vocab_size = hf_config.vocab_size self.sliding_window = hf_config.sliding_window - self.num_shards = hf_config.num_shards - self.num_hidden_layers = hf_config.num_hidden_layers - - if self.sliding_window: - self.block_sliding_window = self.sliding_window // CacheManager.block_size - else: - self.block_sliding_window = None - - self.cache_blocks = None # torch.distributed.all_reduce does not free the input tensor until # the synchronization point. This causes the memory usage to grow @@ -181,40 +354,18 @@ def exposed_init_model( # A small all_reduce for warmup. torch.distributed.all_reduce(torch.zeros(1).cuda()) - with torch.device("cuda"): - torch.set_default_dtype(torch.float16) - self.pt_model = LlamaForCausalLM(hf_config) - self.pt_model.load_weights(model_path, None, "auto", None) + self.pt_model = load_model(hf_config, model_path) - num_kv_heads = hf_config.num_key_value_heads // num_shards - head_size = hf_config.hidden_size // hf_config.num_attention_heads - - if engine_config.max_num_batched_tokens > 0: - LOG.info("Running memory profiling.") - seq_lens = ( - [engine_config.max_input_len] * engine_config.max_num_sequences - ) - used_memory_bytes = profile_memory_usage(self.pt_model, seq_lens, self.num_hidden_layers) - - num_blocks = get_num_cache_blocks( - used_memory_bytes, - hf_config.num_hidden_layers, - num_kv_heads, - head_size, - ) - else: - num_blocks = 500 - - LOG.info(f"Using {num_blocks} cache blocks.") - - self.cache_blocks = init_cache_blocks( - head_size, + self.cache_blocks, num_blocks = profile_and_init_cache( + self.pt_model, hf_config.num_hidden_layers, - num_kv_heads, - CacheManager.block_size, - num_blocks, + hf_config, + engine_config, + num_shards, ) + LOG.info(f"Using {num_blocks} cache blocks.") + return num_blocks def exposed_generate( @@ -222,142 +373,15 @@ def exposed_generate( ) -> List[TextGenerationResult]: requests = obtain(requests) cache = obtain(cache) - - if len(requests) == 0: - return [] - - is_prefill = isinstance(requests[0], PrefillRequest) - - all_token_ids = [] - sampling_params = [] - sequence_ids = [] - prompt_lens = [] - num_sequences = [] - seq_data = {} - seq_group_sequence_ids = defaultdict(list) - seq_group_sampling_params = {} - - for request in requests: - if isinstance(request, PrefillRequest): - sequence_ids.append(get_prompt_sequence_id(request.request_id)) - num_sequences.append(request.num_sequence) - prompt_lens.append(len(request.token_ids)) - seq_group_sequence_ids[request.request_id].append(sequence_ids[-1]) - seq_group_sampling_params[request.request_id] = convert_sampling_params( - request.sampling_params - ) - else: - sequence_ids.append(request.sequence_id) - prompt_lens.append(request.prompt_token_counts) - req_id = request.sequence_id.request_id - seq_group_sequence_ids[req_id].append(request.sequence_id) - seq_group_sampling_params[req_id] = convert_sampling_params( - request.sampling_params - ) - - all_token_ids.append(request.token_ids) - sampling_params.append(request.sampling_params) - - seq_data[sequence_ids[-1]] = SequenceData(request.token_ids) - - seq_groups: List[Tuple[List[SequenceId], SamplingParams]] = [] - - for req_id, seq_ids in seq_group_sequence_ids.items(): - seq_groups.append((seq_ids, seq_group_sampling_params[req_id])) - - ( - input_ids, - positions, - seq_lens, - slot_mapping, - _, - block_tables, - ) = prepare_inputs( - sequence_ids, - all_token_ids, - prompt_lens, - cache.slot_mappings, - cache.decode_block_tables, + return generate( + requests, + cache, + self.pt_model, + self.cache_blocks, self.sliding_window, - is_prefill, - torch.long, - align=8, + self.vocab_size, ) - input_shape = input_ids.shape - - if block_tables is None: - torch.cuda.nvtx.range_push(f"forward prefill {input_shape}") - block_tables = torch.cuda.IntTensor([]) - context_lens = torch.cuda.IntTensor([]) - max_context_len = 0 - else: - torch.cuda.nvtx.range_push(f"forward decode {input_shape}") - context_lens = seq_lens - max_context_len = torch.max(seq_lens) - prompt_lens = [] - - input_metadata = InputMetadata( - seq_groups=seq_groups, - seq_data=seq_data, - prompt_lens=prompt_lens, - slot_mapping=slot_mapping, - context_lens=context_lens, - max_context_len=max_context_len, - block_tables=block_tables, - ) - - with torch.no_grad(): - hidden_states = self.pt_model.model( - input_ids, - positions, - self.cache_blocks, - input_metadata, - # No need for this until parallel sampling is supported. - cache_events=None, - ) - - if hidden_states.shape[0] != len( - input_metadata.prompt_lens - ) and hidden_states.shape[0] != len(input_metadata.context_lens): - logits = get_logits( - self.pt_model.lm_head.weight, - hidden_states, - input_metadata, - self.vocab_size, - ) - - next_tokens = sample(logits, sampling_params, self.vocab_size) - - torch.cuda.synchronize() - torch.cuda.nvtx.range_pop() - - outputs = [] - - for i, (sequence_id, new_token) in enumerate(zip(sequence_ids, next_tokens)): - if not new_token in requests[i].sampling_params.appeared_tokens_freq: - requests[i].sampling_params.appeared_tokens_freq[new_token] = 0 - requests[i].sampling_params.appeared_tokens_freq[new_token] += 1 - if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: - for seq_id in range(num_sequences[i]): - outputs.append( - TextGenerationResult( - sequence_id=SequenceId(sequence_id.request_id, seq_id), - generated_tokens=[new_token], - error=None, - ) - ) - else: - outputs.append( - TextGenerationResult( - sequence_id=sequence_id, - generated_tokens=[new_token], - error=None, - ) - ) - - return outputs - def _init_service(port): t = ThreadedServer( @@ -390,7 +414,6 @@ def start_model_process(port): raise RuntimeError("init rpc env error!") assert proc.is_alive() - print("start process") return con.root, proc @@ -404,7 +427,9 @@ def __init__(self, num_shards, model_path, hf_config, engine_config): self.procs = [x[1] for x in rets] def init_model(i): - return self.model_servers[i].init_model(i, num_shards, model_path, hf_config, engine_config) + return self.model_servers[i].init_model( + i, num_shards, model_path, hf_config, engine_config + ) rets = [obtain(x) for x in executor.map(init_model, range(num_shards))] @@ -422,9 +447,6 @@ def generate(i): self.generate = _func - def get_num_cache_blocks(self): - return self.num_blocks - class Model: def __init__( @@ -435,226 +457,22 @@ def __init__( self.pt_model = pt_model self.vocab_size = config.vocab_size self.sliding_window = config.sliding_window - self.num_shards = config.num_shards - self.num_hidden_layers = config.num_hidden_layers - - if self.sliding_window: - self.block_sliding_window = self.sliding_window // CacheManager.block_size - else: - self.block_sliding_window = None - self.cache_blocks = None - def profile_memory_usage(self, seq_lens): - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() - - sampling_params = SamplingParams(top_p=0.99) - - seq_groups: List[Tuple[List[int], SamplingParams]] = [] - seq_data = {} - input_tokens: List[int] = [] - input_positions: List[int] = [] - slot_mapping: List[int] = [] - - for i, seq_len in enumerate(seq_lens): - seq_groups.append(([i], sampling_params)) - prompt_tokens = [0] * seq_len - seq_data[i] = SequenceData(prompt_tokens) - - input_tokens.extend(prompt_tokens) - input_positions.extend(range(seq_len)) - slot_mapping.extend([0] * seq_len) - - input_ids = torch.cuda.LongTensor(input_tokens) - positions = torch.cuda.LongTensor(input_positions) - slot_mapping_tensor = torch.cuda.IntTensor(slot_mapping) - - input_metadata = InputMetadata( - seq_groups=seq_groups, - seq_data=seq_data, - prompt_lens=seq_lens, - slot_mapping=slot_mapping_tensor, - context_lens=torch.cuda.IntTensor([]), - max_context_len=0, - block_tables=torch.cuda.IntTensor([]), - ) - - kv_caches = [(None, None)] * self.num_hidden_layers - - with torch.no_grad(): - self.pt_model.forward( - input_ids, - positions, - kv_caches, - input_metadata, - cache_events=None, - ) - - torch.cuda.synchronize() - - peak_memory = torch.cuda.max_memory_allocated() - print("peak memory", peak_memory / 1e9) - - torch.cuda.empty_cache() - - return peak_memory - def generate( self, requests: Sequence[Union[PrefillRequest, DecodeRequest]], cache: KVCache, ) -> List[TextGenerationResult]: - if len(requests) == 0: - return [] - - is_prefill = isinstance(requests[0], PrefillRequest) - - all_token_ids = [] - sampling_params = [] - sequence_ids = [] - prompt_lens = [] - num_sequences = [] - seq_data = {} - seq_group_sequence_ids = defaultdict(list) - seq_group_sampling_params = {} - - for request in requests: - if isinstance(request, PrefillRequest): - sequence_ids.append(get_prompt_sequence_id(request.request_id)) - num_sequences.append(request.num_sequence) - prompt_lens.append(len(request.token_ids)) - seq_group_sequence_ids[request.request_id].append(sequence_ids[-1]) - seq_group_sampling_params[request.request_id] = convert_sampling_params( - request.sampling_params - ) - else: - sequence_ids.append(request.sequence_id) - prompt_lens.append(request.prompt_token_counts) - req_id = request.sequence_id.request_id - seq_group_sequence_ids[req_id].append(request.sequence_id) - seq_group_sampling_params[req_id] = convert_sampling_params( - request.sampling_params - ) - - all_token_ids.append(request.token_ids) - sampling_params.append(request.sampling_params) - - seq_data[sequence_ids[-1]] = SequenceData(request.token_ids) - - seq_groups: List[Tuple[List[SequenceId], SamplingParams]] = [] - - for req_id, seq_ids in seq_group_sequence_ids.items(): - seq_groups.append((seq_ids, seq_group_sampling_params[req_id])) - - ( - input_ids, - positions, - seq_lens, - slot_mapping, - _, - block_tables, - ) = prepare_inputs( - sequence_ids, - all_token_ids, - prompt_lens, - cache.slot_mappings, - cache.decode_block_tables, + return generate( + requests, + cache, + self.pt_model, + self.cache_blocks, self.sliding_window, - is_prefill, - torch.long, - align=8, - ) - - input_shape = input_ids.shape - - if block_tables is None: - torch.cuda.nvtx.range_push(f"forward prefill {input_shape}") - block_tables = torch.cuda.IntTensor([]) - context_lens = torch.cuda.IntTensor([]) - max_context_len = 0 - else: - torch.cuda.nvtx.range_push(f"forward decode {input_shape}") - context_lens = seq_lens - max_context_len = torch.max(seq_lens) - prompt_lens = [] - - input_metadata = InputMetadata( - seq_groups=seq_groups, - seq_data=seq_data, - prompt_lens=prompt_lens, - slot_mapping=slot_mapping, - context_lens=context_lens, - max_context_len=max_context_len, - block_tables=block_tables, + self.vocab_size, ) - with torch.no_grad(): - # outs = self.pt_model.forward( - # input_ids, - # positions, - # cache.cache_blocks, - # input_metadata, - # cache_events=None, # TODO: what to do about this? - # ) - - # next_tokens = [] - # for samples in outs: - # next_tokens.append(samples[0].output_token) - - hidden_states = self.pt_model.model( - input_ids, - positions, - self.cache_blocks, - input_metadata, - # No need for this until parallel sampling is supported. - cache_events=None, - ) - - if hidden_states.shape[0] != len( - input_metadata.prompt_lens - ) and hidden_states.shape[0] != len(input_metadata.context_lens): - logits = get_logits( - self.pt_model.lm_head.weight, - hidden_states, - input_metadata, - self.vocab_size, - ) - - next_tokens = sample(logits, sampling_params, self.vocab_size) - - torch.cuda.synchronize() - torch.cuda.nvtx.range_pop() - - print("logits.shape", logits.shape) - - print("next tokens", next_tokens) - outputs = [] - - for i, (sequence_id, new_token) in enumerate(zip(sequence_ids, next_tokens)): - if not new_token in requests[i].sampling_params.appeared_tokens_freq: - requests[i].sampling_params.appeared_tokens_freq[new_token] = 0 - requests[i].sampling_params.appeared_tokens_freq[new_token] += 1 - if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: - for seq_id in range(num_sequences[i]): - outputs.append( - TextGenerationResult( - sequence_id=SequenceId(sequence_id.request_id, seq_id), - generated_tokens=[new_token], - error=None, - ) - ) - else: - outputs.append( - TextGenerationResult( - sequence_id=sequence_id, - generated_tokens=[new_token], - error=None, - ) - ) - - return outputs - def init_torch_model( model_path, engine_config: MLCServeEngineConfig @@ -665,13 +483,10 @@ def init_torch_model( num_shards = 2 num_kv_heads = hf_config.num_key_value_heads // num_shards - head_size = hf_config.hidden_size // hf_config.num_attention_heads if not hasattr(hf_config, "sliding_window"): hf_config.sliding_window = None - hf_config.num_shards = num_shards - artifact_config = ModelArtifactConfig( model_artifact_path=model_path, num_shards=1, @@ -686,57 +501,26 @@ def init_torch_model( ) if num_shards > 1: - torch.multiprocessing.set_start_method('spawn') + torch.multiprocessing.set_start_method("spawn") model = ModelRpcClient(num_shards, model_path, hf_config, engine_config) - num_blocks = model.get_num_cache_blocks() + num_blocks = model.num_blocks else: torch.distributed.init_process_group( - backend="nccl", world_size=1, rank=0, init_method="tcp://localhost:59157" + backend="nccl", + world_size=1, + rank=0, + init_method="tcp://localhost:59157", # port ) initialize_model_parallel(1, 1) - with torch.device("cuda"): - torch.set_default_dtype(torch.float16) - pt_model = LlamaForCausalLM(hf_config) - pt_model.load_weights(model_path, None, "auto", None) - + pt_model = load_model(hf_config, model_path) model = Model(pt_model, hf_config) - if engine_config.max_num_batched_tokens > 0: - LOG.info("Running memory profiling.") - seq_lens = [engine_config.max_input_len] * engine_config.max_num_sequences - used_memory_bytes = model.profile_memory_usage(seq_lens) - num_blocks = get_num_cache_blocks( - used_memory_bytes, - hf_config.num_hidden_layers, - num_kv_heads, - head_size, - ) - else: - num_blocks = 500 - - num_cache_slots = num_blocks * CacheManager.block_size - - if num_cache_slots <= engine_config.max_num_batched_tokens: - raise RuntimeError( - f"max_num_batched_tokens = {engine_config.max_num_batched_tokens} but" - f" only {num_blocks} cache blocks can be allocated. The number of" - f" available cache slots is {num_cache_slots}, not enough for" - f" {engine_config.max_num_batched_tokens} tokens. Try reducing" - " --max_input_len or --max_num_sequences." - ) - - LOG.info(f"Using {num_blocks} cache blocks.") - - model.cache_blocks = init_cache_blocks( - head_size, - hf_config.num_hidden_layers, - hf_config.num_attention_heads, - CacheManager.block_size, - num_blocks, + model.cache_blocks, num_blocks = profile_and_init_cache( + pt_model, hf_config.num_hidden_layers, hf_config, engine_config, num_shards ) - LOG.info("Allocated KV cache blocks.") + LOG.info("Allocated KV cache blocks.") cache_manager = CacheManager( num_blocks, From 9fb6358e65fb10ff05e0341ce7d6b0abf80d2bbb Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 19 Jan 2024 20:43:00 +0000 Subject: [PATCH 32/73] update for qwen --- serve/mlc_serve/engine/engine_common.py | 2 +- serve/mlc_serve/model/tokenizer.py | 2 +- serve/mlc_serve/model/torch_model.py | 21 +++++++++++++++++---- 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index 1bc252b48c..e8b3e801f2 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -105,7 +105,7 @@ def detokenize_incrementally( prefix_begin_offset = generation_sequence.prefix_begin_offset prefix_end_offset = generation_sequence.prefix_end_offset - assert tokenizer.is_fast + # assert tokenizer.is_fast prefix_text = tokenizer.convert_tokens_to_string( output_tokens[prefix_begin_offset:prefix_end_offset] diff --git a/serve/mlc_serve/model/tokenizer.py b/serve/mlc_serve/model/tokenizer.py index ca7dce48e1..bdde9ca6cb 100644 --- a/serve/mlc_serve/model/tokenizer.py +++ b/serve/mlc_serve/model/tokenizer.py @@ -48,7 +48,7 @@ def apply(self, messages: list[ChatMessage]) -> str: class HfTokenizerModule: def __init__(self, tokenizer_path: Path): hf_tokenizer = AutoTokenizer.from_pretrained( - tokenizer_path, trust_remote_code=False + tokenizer_path, trust_remote_code=True ) self.tokenizer = Tokenizer(hf_tokenizer) self.conversation_template = ConversationTemplate(hf_tokenizer) diff --git a/serve/mlc_serve/model/torch_model.py b/serve/mlc_serve/model/torch_model.py index eb8e78f44c..8186c4aaeb 100644 --- a/serve/mlc_serve/model/torch_model.py +++ b/serve/mlc_serve/model/torch_model.py @@ -10,6 +10,7 @@ from vllm.model_executor.layers.sampler import get_logits from vllm.model_executor.models.llama import LlamaForCausalLM +from vllm.model_executor.models.qwen import QWenLMHeadModel from vllm.sequence import SequenceData from vllm.model_executor import InputMetadata from vllm.sampling_params import SamplingParams @@ -157,6 +158,8 @@ def profile_and_init_cache( else: num_blocks = 500 + LOG.info(f"Using {num_blocks} cache blocks.") + cache_blocks = init_cache_blocks( head_size, hf_config.num_hidden_layers, @@ -172,6 +175,7 @@ def load_model(hf_config, model_path): with torch.device("cuda"): torch.set_default_dtype(torch.float16) model = LlamaForCausalLM(hf_config) + # model = QWenLMHeadModel(hf_config) # requires tiktoken package model.load_weights(model_path, None, "auto", None) return model @@ -269,6 +273,14 @@ def generate( ) with torch.no_grad(): + # hidden_states = pt_model.transformer( + # input_ids, + # positions, + # cache_blocks, + # input_metadata, + # # No need for this until parallel sampling is supported. + # cache_events=None, + # ) hidden_states = pt_model.model( input_ids, positions, @@ -364,8 +376,6 @@ def exposed_init_model( num_shards, ) - LOG.info(f"Using {num_blocks} cache blocks.") - return num_blocks def exposed_generate( @@ -477,16 +487,19 @@ def generate( def init_torch_model( model_path, engine_config: MLCServeEngineConfig ) -> Tuple[TextGenerator, CacheManager, ModelArtifactConfig]: - hf_config = AutoConfig.from_pretrained(model_path) + hf_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) # TODO num_shards = 2 - num_kv_heads = hf_config.num_key_value_heads // num_shards + if not hasattr(hf_config, "num_key_value_heads") and hasattr(hf_config, "num_attention_heads"): + hf_config.num_key_value_heads = hf_config.num_attention_heads if not hasattr(hf_config, "sliding_window"): hf_config.sliding_window = None + num_kv_heads = hf_config.num_key_value_heads // num_shards + artifact_config = ModelArtifactConfig( model_artifact_path=model_path, num_shards=1, From 7d89811e7b3f359c392d998591f96d0a16a7e2ec Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 19 Jan 2024 20:49:21 +0000 Subject: [PATCH 33/73] mergei fix --- serve/mlc_serve/model/torch_model.py | 3 +-- serve/mlc_serve/model/tvm_model.py | 2 +- serve/mlc_serve/utils.py | 1 + 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/serve/mlc_serve/model/torch_model.py b/serve/mlc_serve/model/torch_model.py index 8186c4aaeb..f3fce59037 100644 --- a/serve/mlc_serve/model/torch_model.py +++ b/serve/mlc_serve/model/torch_model.py @@ -146,9 +146,8 @@ def profile_and_init_cache( if engine_config.max_num_batched_tokens > 0: LOG.info("Running memory profiling.") - seq_lens = [engine_config.max_input_len] * engine_config.max_num_sequences + seq_lens = [1] * engine_config.max_num_batched_tokens used_memory_bytes = profile_memory_usage(pt_model, seq_lens, num_hidden_layers) - num_blocks = get_num_cache_blocks( used_memory_bytes, hf_config.num_hidden_layers, diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index 40eaa1944f..f6255acdc5 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -456,7 +456,7 @@ def init_tvm_model( if engine_config.max_num_batched_tokens > 0: LOG.info("Running memory profiling.") - seq_lens = [1] * engine_config.max_num_batched_tokens, + seq_lens = [1] * engine_config.max_num_batched_tokens used_memory_bytes = model.profile_memory_usage(seq_lens) num_blocks = get_num_cache_blocks( used_memory_bytes, diff --git a/serve/mlc_serve/utils.py b/serve/mlc_serve/utils.py index ca689a0b60..d293bf4058 100644 --- a/serve/mlc_serve/utils.py +++ b/serve/mlc_serve/utils.py @@ -24,6 +24,7 @@ def get_default_mlc_serve_argparser(description="", allow_override=False): parser.add_argument("--artifact-path", type=str, default="dist") parser.add_argument("--use-sync-engine", action="store_true") parser.add_argument("--max-num-batched-tokens", type=int, default=4096) + parser.add_argument("--num-sequences-to-sample", type=int, default=1) parser.add_argument("--min-decode-steps", type=int, default=32) parser.add_argument("--max-decode-steps", type=int, default=56) parser.add_argument("--debug-logging", action="store_true") From 0bbc41a0615349a6ed78dd260e5baf036de9c069 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 19 Jan 2024 20:53:10 +0000 Subject: [PATCH 34/73] clean --- serve/mlc_serve/model/tvm_model.py | 11 +---------- serve/tests/test_engine.py | 2 +- 2 files changed, 2 insertions(+), 11 deletions(-) diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index f6255acdc5..487617af82 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -247,10 +247,6 @@ def generate( is_prefill, ) - print("input_ids", input_ids) - print("positions", positions) - print("block_tables", block_tables) - input_shape = input_ids.shape if self.disco_session: @@ -334,13 +330,8 @@ def generate( cache.pending_copy_from_to = [] try: - # if is_prefill: - # import numpy as np - # print("save logits") - # np.save("logits_tvm.npy", logits.numpy()) - next_tokens = sample(logits, sampling_params, self.vocab_size) - print("next tokens", next_tokens) + assert next_tokens is not None outputs = [] for i, (sequence_id, new_token) in enumerate( diff --git a/serve/tests/test_engine.py b/serve/tests/test_engine.py index d88a896858..5582d6c815 100644 --- a/serve/tests/test_engine.py +++ b/serve/tests/test_engine.py @@ -70,7 +70,7 @@ def _test(args: argparse.Namespace): sampling_params_choices = [sampling_params_random] elif args.use_random_sampling: # This tests different sampling types in the same batch - sampling_params_choices = [sampling_params_random] + sampling_params_choices = [sampling_params_random, sampling_params_greedy] else: sampling_params_choices = [sampling_params_greedy] From 08a63cac465ff8d40bbe11a458807e565c4f8996 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 19 Jan 2024 22:58:54 +0000 Subject: [PATCH 35/73] KV cache refactor to decouple cache blocks and metadata about them --- serve/mlc_serve/model/paged_cache_manager.py | 71 +++++++++++--------- serve/mlc_serve/model/tvm_model.py | 17 ++--- serve/mlc_serve/utils.py | 1 + 3 files changed, 48 insertions(+), 41 deletions(-) diff --git a/serve/mlc_serve/model/paged_cache_manager.py b/serve/mlc_serve/model/paged_cache_manager.py index 7cd24f5182..7193c0eadb 100644 --- a/serve/mlc_serve/model/paged_cache_manager.py +++ b/serve/mlc_serve/model/paged_cache_manager.py @@ -101,13 +101,11 @@ def replace_head_prompt_block_with(self, new_block): self.prompt_shared = False -class KVCache: +class KVCacheInfo: def __init__( self, - cache_blocks, block_size, ): - self.cache_blocks = cache_blocks self.block_size = block_size # SequenceId -> list[int] @@ -135,18 +133,17 @@ def get_cache_block_size(num_layers, num_heads, head_size): def __init__( self, - cache_blocks, # This can be any type num_blocks: int, sliding_window: Optional[int] = None, ): self.num_blocks = num_blocks self.free_blocks = list(range(num_blocks)) - self.kv_cache = KVCache(cache_blocks, self.block_size) + self.kv_cache_info = KVCacheInfo(self.block_size) self.token_counts = dict[SequenceId, int]() if sliding_window: - assert sliding_window % self.kv_cache.block_size == 0 - self.block_sliding_window = sliding_window // self.kv_cache.block_size + assert sliding_window % self.kv_cache_info.block_size == 0 + self.block_sliding_window = sliding_window // self.kv_cache_info.block_size else: self.block_sliding_window = None @@ -160,20 +157,20 @@ def set_size(self, sequence_ids: List[SequenceId], target_sizes: List[int]): num_needed_block = min(num_needed_block, self.block_sliding_window) if size == 0: - if id in self.kv_cache.prompt_block_tables: - self.free_blocks.extend(self.kv_cache.prompt_block_tables[id]) - del self.kv_cache.prompt_block_tables[id] - elif id in self.kv_cache.decode_block_tables: + if id in self.kv_cache_info.prompt_block_tables: + self.free_blocks.extend(self.kv_cache_info.prompt_block_tables[id]) + del self.kv_cache_info.prompt_block_tables[id] + elif id in self.kv_cache_info.decode_block_tables: self.free_blocks.extend( - self.kv_cache.decode_block_tables[id].decode_blocks + self.kv_cache_info.decode_block_tables[id].decode_blocks ) - del self.kv_cache.decode_block_tables[id] + del self.kv_cache_info.decode_block_tables[id] - if id in self.kv_cache.slot_mappings: - del self.kv_cache.slot_mappings[id] + if id in self.kv_cache_info.slot_mappings: + del self.kv_cache_info.slot_mappings[id] - elif id in self.kv_cache.decode_block_tables: - decode_block_table = self.kv_cache.decode_block_tables[id] + elif id in self.kv_cache_info.decode_block_tables: + decode_block_table = self.kv_cache_info.decode_block_tables[id] if len(decode_block_table) < num_needed_block: # Need to allocate a new block for this request @@ -218,40 +215,42 @@ def get_block_circular_index(token_pos): block_offset = pos % self.block_size slot = block_number * self.block_size + block_offset - self.kv_cache.slot_mappings[id].append(slot) + self.kv_cache_info.slot_mappings[id].append(slot) - elif id not in self.kv_cache.prompt_block_tables: + elif id not in self.kv_cache_info.prompt_block_tables: assert ( len(self.free_blocks) >= num_needed_block ), "Not enough free blocks." for _ in range(num_needed_block): - self.kv_cache.prompt_block_tables[id].append(self.free_blocks.pop()) + self.kv_cache_info.prompt_block_tables[id].append( + self.free_blocks.pop() + ) for block_idx in range(math.floor(size / self.block_size)): if self.block_sliding_window: block_idx %= self.block_sliding_window - block_number = self.kv_cache.prompt_block_tables[id][block_idx] + block_number = self.kv_cache_info.prompt_block_tables[id][block_idx] slots = [ block_number * self.block_size + block_offset for block_offset in range(self.block_size) ] - self.kv_cache.slot_mappings[id] += slots + self.kv_cache_info.slot_mappings[id] += slots - for i in range(len(self.kv_cache.slot_mappings[id]), size): + for i in range(len(self.kv_cache_info.slot_mappings[id]), size): block_idx = i // self.block_size if self.block_sliding_window: block_idx %= self.block_sliding_window - block_number = self.kv_cache.prompt_block_tables[id][block_idx] + block_number = self.kv_cache_info.prompt_block_tables[id][block_idx] block_offset = i % self.block_size slot = block_number * self.block_size + block_offset - self.kv_cache.slot_mappings[id].append(slot) + self.kv_cache_info.slot_mappings[id].append(slot) def get_cache(self): - return self.kv_cache + return self.kv_cache_info def allocate(self, request_id: RequestId, num_tokens: int, num_sequences: int): """ @@ -267,7 +266,7 @@ def allocate(self, request_id: RequestId, num_tokens: int, num_sequences: int): if self.sliding_window: last_block_partially_shared &= num_tokens < self.sliding_window - prompt_blocks = self.kv_cache.prompt_block_tables[prompt_seq_id] + prompt_blocks = self.kv_cache_info.prompt_block_tables[prompt_seq_id] assert prompt_blocks prompt_shared = num_sequences > 1 @@ -277,7 +276,9 @@ def allocate(self, request_id: RequestId, num_tokens: int, num_sequences: int): self.token_counts[decode_seq_id] = num_tokens if not last_block_partially_shared: - self.kv_cache.decode_block_tables[decode_seq_id] = DecodeBlockTable( + self.kv_cache_info.decode_block_tables[ + decode_seq_id + ] = DecodeBlockTable( prompt_blocks, num_tokens, self.block_size, @@ -286,8 +287,10 @@ def allocate(self, request_id: RequestId, num_tokens: int, num_sequences: int): ) else: if i < num_sequences: - # Need to copy the last block in self.kv_cache.block_tables[prompt_seq_id] - self.kv_cache.decode_block_tables[decode_seq_id] = DecodeBlockTable( + # Need to copy the last block in self.kv_cache_info.block_tables[prompt_seq_id] + self.kv_cache_info.decode_block_tables[ + decode_seq_id + ] = DecodeBlockTable( prompt_blocks[:-1], num_tokens, self.block_size, @@ -295,16 +298,18 @@ def allocate(self, request_id: RequestId, num_tokens: int, num_sequences: int): prompt_shared, ) last_block_copy = self.free_blocks.pop() - self.kv_cache.decode_block_tables[decode_seq_id].append( + self.kv_cache_info.decode_block_tables[decode_seq_id].append( last_block_copy ) - self.kv_cache.pending_copy_from_to.extend( + self.kv_cache_info.pending_copy_from_to.extend( [prompt_blocks[-1], last_block_copy] ) else: # The last sequence can directly overwrite the last block without copying it, # since other sequences have its own copy of the last block. - self.kv_cache.decode_block_tables[decode_seq_id] = DecodeBlockTable( + self.kv_cache_info.decode_block_tables[ + decode_seq_id + ] = DecodeBlockTable( prompt_blocks, num_tokens, self.block_size, diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index e854206877..cfb528ac43 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -10,7 +10,7 @@ from tvm.runtime import disco as di from .base import ModelArtifactConfig -from .paged_cache_manager import KVCache, CacheManager +from .paged_cache_manager import KVCacheInfo, CacheManager from .model_common import ( sample, prepare_inputs, @@ -152,6 +152,8 @@ def __init__( "tvm.contrib.vllm.copy_blocks" ) + self.cache_blocks = None + def get_used_memory(self): if self.disco_session: params = self.params.debug_get_from_remote(0) @@ -204,7 +206,7 @@ def profile_memory_usage(self, seq_lens): def generate( self, requests: Sequence[Union[PrefillRequest, DecodeRequest]], - cache: KVCache, + cache: KVCacheInfo, ) -> List[TextGenerationResult]: if len(requests) == 0: return [] @@ -266,7 +268,7 @@ def generate( input_ids, positions, seq_lens, - cache.cache_blocks, + self.cache_blocks, slot_mapping, indices_within_window, self.params, @@ -276,7 +278,7 @@ def generate( input_ids, positions, seq_lens, - cache.cache_blocks, + self.cache_blocks, slot_mapping, self.params, ) @@ -297,7 +299,7 @@ def generate( input_ids, positions, seq_lens, - cache.cache_blocks, + self.cache_blocks, slot_mapping, block_tables, self.params, @@ -324,7 +326,7 @@ def generate( "int64", ) - self.copy_cache_blocks_func(cache.cache_blocks, block_mapping) + self.copy_cache_blocks_func(self.cache_blocks, block_mapping) cache.pending_copy_from_to = [] try: @@ -474,7 +476,7 @@ def init_tvm_model( else: init_cache_func = tvm.get_global_func("tvm.contrib.vllm.allocate_kv_cache") - cache_blocks = init_cache_func( + model.cache_blocks = init_cache_func( head_size, model_artifact_config.num_hidden_layers, num_kv_heads, @@ -483,7 +485,6 @@ def init_tvm_model( ) cache_manager = CacheManager( - cache_blocks, num_blocks, model_artifact_config.sliding_window, ) diff --git a/serve/mlc_serve/utils.py b/serve/mlc_serve/utils.py index ca689a0b60..d293bf4058 100644 --- a/serve/mlc_serve/utils.py +++ b/serve/mlc_serve/utils.py @@ -24,6 +24,7 @@ def get_default_mlc_serve_argparser(description="", allow_override=False): parser.add_argument("--artifact-path", type=str, default="dist") parser.add_argument("--use-sync-engine", action="store_true") parser.add_argument("--max-num-batched-tokens", type=int, default=4096) + parser.add_argument("--num-sequences-to-sample", type=int, default=1) parser.add_argument("--min-decode-steps", type=int, default=32) parser.add_argument("--max-decode-steps", type=int, default=56) parser.add_argument("--debug-logging", action="store_true") From e098d0b8b90f34992093ec499e465dfb26281095 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 20 Jan 2024 08:23:16 +0900 Subject: [PATCH 36/73] update for KV refactor --- serve/mlc_serve/model/torch_model.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/serve/mlc_serve/model/torch_model.py b/serve/mlc_serve/model/torch_model.py index f3fce59037..695c86dc69 100644 --- a/serve/mlc_serve/model/torch_model.py +++ b/serve/mlc_serve/model/torch_model.py @@ -26,7 +26,7 @@ from concurrent.futures import ThreadPoolExecutor from .base import ModelArtifactConfig -from .paged_cache_manager import KVCache, CacheManager +from .paged_cache_manager import KVCacheInfo, CacheManager from .model_common import ( sample, prepare_inputs, @@ -181,7 +181,7 @@ def load_model(hf_config, model_path): def generate( requests: Sequence[Union[PrefillRequest, DecodeRequest]], - cache: KVCache, + cache_info: KVCacheInfo, pt_model, cache_blocks, sliding_window, @@ -240,8 +240,8 @@ def generate( sequence_ids, all_token_ids, prompt_lens, - cache.slot_mappings, - cache.decode_block_tables, + cache_info.slot_mappings, + cache_info.decode_block_tables, sliding_window, is_prefill, torch.long, @@ -378,7 +378,7 @@ def exposed_init_model( return num_blocks def exposed_generate( - self, requests: Sequence[Union[PrefillRequest, DecodeRequest]], cache: KVCache + self, requests: Sequence[Union[PrefillRequest, DecodeRequest]], cache: KVCacheInfo ) -> List[TextGenerationResult]: requests = obtain(requests) cache = obtain(cache) @@ -471,7 +471,7 @@ def __init__( def generate( self, requests: Sequence[Union[PrefillRequest, DecodeRequest]], - cache: KVCache, + cache: KVCacheInfo, ) -> List[TextGenerationResult]: return generate( requests, From 02b7c1ba06a30e1d6ac11beedff78fba3520f56c Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 20 Jan 2024 03:00:50 +0000 Subject: [PATCH 37/73] updated for the latest vllm --- serve/mlc_serve/model/model_common.py | 79 ++++++++++------ serve/mlc_serve/model/torch_model.py | 125 ++++++++++++-------------- 2 files changed, 109 insertions(+), 95 deletions(-) diff --git a/serve/mlc_serve/model/model_common.py b/serve/mlc_serve/model/model_common.py index 7b0b9c383e..9d18bfdfe2 100644 --- a/serve/mlc_serve/model/model_common.py +++ b/serve/mlc_serve/model/model_common.py @@ -169,10 +169,6 @@ def _is_safe_to_sample(prob_like): return res -def _pad_to_alignment(x: List[int], multiple_of: int) -> List[int]: - return x + [0] * ((-len(x)) % multiple_of) - - def prepare_inputs( sequence_ids, all_token_ids, @@ -181,25 +177,37 @@ def prepare_inputs( all_decode_block_tables, sliding_window, is_prefill, - torch_ids_type=torch.int, - align=None, + for_vllm=False, ): + if for_vllm: + torch_int_dtype = torch.long + else: + torch_int_dtype = torch.int + block_tables = [] seq_lens = [] input_ids = [] slot_mapping = [] positions = [] - max_num_blocks_per_seq = 0 indices_within_window = [] start_idx = 0 + max_prompt_len = -1 + max_context_len = -1 for i, (sequence_id, token_ids) in enumerate(zip(sequence_ids, all_token_ids)): if is_prefill: - input_ids += token_ids prompt_len = len(token_ids) seq_lens.append(prompt_len) - positions += range(prompt_len) - slot_mapping += all_slot_mappings[sequence_id] + max_prompt_len = max(max_prompt_len, prompt_len) + + if for_vllm: + input_ids.append(token_ids) + positions.append(list(range(prompt_len))) + slot_mapping.append(all_slot_mappings[sequence_id]) + else: + input_ids += token_ids + positions += range(prompt_len) + slot_mapping += all_slot_mappings[sequence_id] if sliding_window: indices_within_window += range( @@ -209,30 +217,50 @@ def prepare_inputs( start_idx += prompt_len else: - input_ids.append(token_ids[-1]) seq_len = prompt_lens[i] + len(token_ids) - positions.append(seq_len - 1) + + if for_vllm: + input_ids.append([token_ids[-1]]) + positions.append([seq_len - 1]) + slot_mapping.append([all_slot_mappings[sequence_id][-1]]) + else: + input_ids.append(token_ids[-1]) + positions.append(seq_len - 1) + slot_mapping.append(all_slot_mappings[sequence_id][-1]) + block_table = all_decode_block_tables[sequence_id] - max_num_blocks_per_seq = max(max_num_blocks_per_seq, len(block_table)) block_tables.append(block_table.get_blocks()) - slot_mapping.append(all_slot_mappings[sequence_id][-1]) if sliding_window: seq_lens.append(min(seq_len, sliding_window)) else: seq_lens.append(seq_len) + max_context_len = max(max_context_len, seq_lens[-1]) + def to_torch(arr, torch_dtype): return torch.tensor(arr, dtype=torch_dtype, device="cuda") - if align: - input_ids = _pad_to_alignment(input_ids, multiple_of=align) - positions = _pad_to_alignment(positions, multiple_of=align) + def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]: + assert len(x) <= max_len + return x + [pad] * (max_len - len(x)) - input_ids = to_torch(input_ids, torch_ids_type) - positions = to_torch(positions, torch_ids_type) + def _do_pad( + x: List[List[int]], + max_len: int, + pad: int, + ) -> List[List[int]]: + return [_pad_to_max(x_i, max_len, pad) for x_i in x] + + if for_vllm and is_prefill: + input_ids = _do_pad(input_ids, max_prompt_len, 0) + positions = _do_pad(positions, max_prompt_len, 0) + slot_mapping = _do_pad(slot_mapping, max_prompt_len, -1) + + input_ids = to_torch(input_ids, torch_int_dtype) + positions = to_torch(positions, torch_int_dtype) seq_lens = to_torch(seq_lens, torch.int) - slot_mapping = to_torch(slot_mapping, torch.int) + slot_mapping = to_torch(slot_mapping, torch_int_dtype) if is_prefill and sliding_window: indices_within_window = to_torch(indices_within_window, torch.int) @@ -240,14 +268,11 @@ def to_torch(arr, torch_dtype): indices_within_window = None if not is_prefill: + max_block_table_len = ( + max_context_len + CacheManager.block_size - 1 + ) // CacheManager.block_size - def _pad_to_max(x: List[int], max_len: int) -> List[int]: - return x + [0] * (max_len - len(x)) - - padded_block_tables = [ - _pad_to_max(block_table, max_num_blocks_per_seq) - for block_table in block_tables - ] + padded_block_tables = _do_pad(block_tables, max_block_table_len, 0) block_tables = to_torch(padded_block_tables, torch.int) else: block_tables = None diff --git a/serve/mlc_serve/model/torch_model.py b/serve/mlc_serve/model/torch_model.py index 695c86dc69..f72c69df1b 100644 --- a/serve/mlc_serve/model/torch_model.py +++ b/serve/mlc_serve/model/torch_model.py @@ -1,7 +1,6 @@ import time import os from typing import List, Union, Tuple, Sequence -from collections import defaultdict import structlog import torch @@ -11,8 +10,7 @@ from vllm.model_executor.layers.sampler import get_logits from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.qwen import QWenLMHeadModel -from vllm.sequence import SequenceData -from vllm.model_executor import InputMetadata +from vllm.model_executor import InputMetadata, SamplingMetadata from vllm.sampling_params import SamplingParams from vllm.model_executor.parallel_utils.parallel_state import ( initialize_model_parallel, @@ -87,35 +85,32 @@ def profile_memory_usage(pt_model, seq_lens, num_hidden_layers): torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() - sampling_params = SamplingParams(top_p=0.99) - - seq_groups: List[Tuple[List[int], SamplingParams]] = [] - seq_data = {} - input_tokens: List[int] = [] - input_positions: List[int] = [] - slot_mapping: List[int] = [] + input_tokens: List[List[int]] = [] + input_positions: List[List[int]] = [] + slot_mapping: List[List[int]] = [] for i, seq_len in enumerate(seq_lens): - seq_groups.append(([i], sampling_params)) prompt_tokens = [0] * seq_len - seq_data[i] = SequenceData(prompt_tokens) - input_tokens.extend(prompt_tokens) - input_positions.extend(range(seq_len)) - slot_mapping.extend([0] * seq_len) + input_tokens.append(prompt_tokens) + input_positions.append(list(range(seq_len))) + slot_mapping.append([0] * seq_len) input_ids = torch.cuda.LongTensor(input_tokens) positions = torch.cuda.LongTensor(input_positions) - slot_mapping_tensor = torch.cuda.IntTensor(slot_mapping) + slot_mapping_tensor = torch.cuda.LongTensor(slot_mapping) + prompt_lens_tensor = torch.cuda.LongTensor(seq_lens) input_metadata = InputMetadata( - seq_groups=seq_groups, - seq_data=seq_data, - prompt_lens=seq_lens, + is_prompt=True, slot_mapping=slot_mapping_tensor, - context_lens=torch.cuda.IntTensor([]), + prompt_lens=prompt_lens_tensor, + max_seq_len=None, + start_loc=None, max_context_len=0, + context_lens=torch.cuda.IntTensor([]), block_tables=torch.cuda.IntTensor([]), + use_cuda_graph=False, ) kv_caches = [(None, None)] * num_hidden_layers @@ -126,7 +121,6 @@ def profile_memory_usage(pt_model, seq_lens, num_hidden_layers): positions, kv_caches, input_metadata, - cache_events=None, ) torch.cuda.synchronize() @@ -197,37 +191,28 @@ def generate( sequence_ids = [] prompt_lens = [] num_sequences = [] - seq_data = {} - seq_group_sequence_ids = defaultdict(list) - seq_group_sampling_params = {} for request in requests: if isinstance(request, PrefillRequest): sequence_ids.append(get_prompt_sequence_id(request.request_id)) num_sequences.append(request.num_sequence) prompt_lens.append(len(request.token_ids)) - seq_group_sequence_ids[request.request_id].append(sequence_ids[-1]) - seq_group_sampling_params[request.request_id] = convert_sampling_params( - request.sampling_params - ) else: sequence_ids.append(request.sequence_id) prompt_lens.append(request.prompt_token_counts) - req_id = request.sequence_id.request_id - seq_group_sequence_ids[req_id].append(request.sequence_id) - seq_group_sampling_params[req_id] = convert_sampling_params( - request.sampling_params - ) all_token_ids.append(request.token_ids) sampling_params.append(request.sampling_params) - seq_data[sequence_ids[-1]] = SequenceData(request.token_ids) + selected_token_indices: List[int] = [] - seq_groups: List[Tuple[List[SequenceId], SamplingParams]] = [] + if is_prefill: + max_prompt_len = max(prompt_lens) + seq_start = 0 - for req_id, seq_ids in seq_group_sequence_ids.items(): - seq_groups.append((seq_ids, seq_group_sampling_params[req_id])) + for prompt_len in prompt_lens: + selected_token_indices.append(seq_start + prompt_len - 1) + seq_start += max_prompt_len ( input_ids, @@ -244,8 +229,7 @@ def generate( cache_info.decode_block_tables, sliding_window, is_prefill, - torch.long, - align=8, + for_vllm=True, ) input_shape = input_ids.shape @@ -261,43 +245,44 @@ def generate( max_context_len = torch.max(seq_lens) prompt_lens = [] + prompt_lens = torch.cuda.LongTensor(prompt_lens) + input_metadata = InputMetadata( - seq_groups=seq_groups, - seq_data=seq_data, - prompt_lens=prompt_lens, + is_prompt=is_prefill, slot_mapping=slot_mapping, - context_lens=context_lens, + prompt_lens=prompt_lens, + max_seq_len=None, + start_loc=None, max_context_len=max_context_len, + context_lens=context_lens, block_tables=block_tables, + use_cuda_graph=False, + ) + + sampling_metadata = SamplingMetadata( + seq_groups=None, + seq_data=None, + prompt_lens=prompt_lens, + selected_token_indices=torch.tensor( + selected_token_indices, dtype=torch.long, device="cuda" + ), + categorized_sample_indices=None, ) with torch.no_grad(): - # hidden_states = pt_model.transformer( - # input_ids, - # positions, - # cache_blocks, - # input_metadata, - # # No need for this until parallel sampling is supported. - # cache_events=None, - # ) hidden_states = pt_model.model( input_ids, positions, cache_blocks, input_metadata, - # No need for this until parallel sampling is supported. - cache_events=None, ) - if hidden_states.shape[0] != len( - input_metadata.prompt_lens - ) and hidden_states.shape[0] != len(input_metadata.context_lens): - logits = get_logits( - pt_model.lm_head.weight, - hidden_states, - input_metadata, - vocab_size, - ) + logits = get_logits( + pt_model.lm_head.weight, + hidden_states, + sampling_metadata, + vocab_size, + ) next_tokens = sample(logits, sampling_params, vocab_size) @@ -378,7 +363,9 @@ def exposed_init_model( return num_blocks def exposed_generate( - self, requests: Sequence[Union[PrefillRequest, DecodeRequest]], cache: KVCacheInfo + self, + requests: Sequence[Union[PrefillRequest, DecodeRequest]], + cache: KVCacheInfo, ) -> List[TextGenerationResult]: requests = obtain(requests) cache = obtain(cache) @@ -429,7 +416,7 @@ def start_model_process(port): class ModelRpcClient: def __init__(self, num_shards, model_path, hf_config, engine_config): with ThreadPoolExecutor(num_shards) as executor: - ports = [3000 + i for i in range(num_shards)] # TODO port + ports = [3010 + i for i in range(num_shards)] # TODO port rets = executor.map(start_model_process, ports) self.model_servers = [x[0] for x in rets] @@ -445,7 +432,7 @@ def init_model(i): self.num_blocks = rets[0] def _func( - requests: Sequence[Union[PrefillRequest, DecodeRequest]], cache: KVCache + requests: Sequence[Union[PrefillRequest, DecodeRequest]], cache: KVCacheInfo ) -> List[TextGenerationResult]: def generate(i): return self.model_servers[i].generate(requests, cache) @@ -489,9 +476,11 @@ def init_torch_model( hf_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) # TODO - num_shards = 2 + num_shards = 1 - if not hasattr(hf_config, "num_key_value_heads") and hasattr(hf_config, "num_attention_heads"): + if not hasattr(hf_config, "num_key_value_heads") and hasattr( + hf_config, "num_attention_heads" + ): hf_config.num_key_value_heads = hf_config.num_attention_heads if not hasattr(hf_config, "sliding_window"): @@ -521,7 +510,7 @@ def init_torch_model( backend="nccl", world_size=1, rank=0, - init_method="tcp://localhost:59157", # port + init_method="tcp://localhost:59157", # TODO port ) initialize_model_parallel(1, 1) From c470c36aa3b03a934c0fcf528e3bf3e9707b815d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 20 Jan 2024 20:45:16 +0000 Subject: [PATCH 38/73] qwen and phi supported --- serve/mlc_serve/model/torch_model.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/serve/mlc_serve/model/torch_model.py b/serve/mlc_serve/model/torch_model.py index f72c69df1b..008aadd4a9 100644 --- a/serve/mlc_serve/model/torch_model.py +++ b/serve/mlc_serve/model/torch_model.py @@ -10,6 +10,7 @@ from vllm.model_executor.layers.sampler import get_logits from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.qwen import QWenLMHeadModel +from vllm.model_executor.models.phi import PhiForCausalLM from vllm.model_executor import InputMetadata, SamplingMetadata from vllm.sampling_params import SamplingParams from vllm.model_executor.parallel_utils.parallel_state import ( @@ -165,10 +166,20 @@ def profile_and_init_cache( def load_model(hf_config, model_path): + model_map = { + "LlamaForCausalLM": LlamaForCausalLM, + "PhiForCausalLM": PhiForCausalLM, + "QWenLMHeadModel": QWenLMHeadModel, # requires tiktoken package + } + + arch = hf_config.architectures[0] + + if arch not in model_map: + raise RuntimeError(f"Unsupported model: {arch}") + with torch.device("cuda"): torch.set_default_dtype(torch.float16) - model = LlamaForCausalLM(hf_config) - # model = QWenLMHeadModel(hf_config) # requires tiktoken package + model = model_map[arch](hf_config) model.load_weights(model_path, None, "auto", None) return model From a4612da0d4e1ab3ef00ea1a440233d2b096f32f9 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 20 Jan 2024 21:27:12 +0000 Subject: [PATCH 39/73] Make num_shards configuable via engine config --- serve/mlc_serve/engine/base.py | 1 + serve/mlc_serve/model/torch_model.py | 53 ++++++++++++++-------------- serve/mlc_serve/utils.py | 1 + serve/tests/test_engine.py | 3 ++ 4 files changed, 32 insertions(+), 26 deletions(-) diff --git a/serve/mlc_serve/engine/base.py b/serve/mlc_serve/engine/base.py index 504f131e54..c0a4801cf9 100644 --- a/serve/mlc_serve/engine/base.py +++ b/serve/mlc_serve/engine/base.py @@ -24,6 +24,7 @@ class MLCServeEngineConfig: max_decode_steps: int = 48 init_timeout: int = 120 model_type: str = "tvm" # "tvm", "torch" + num_shards: Optional[int] = None # Need to be specified for if model_type is "torch" @classmethod def _from_json(config_cls, json_obj: Dict[Any, Any]): diff --git a/serve/mlc_serve/model/torch_model.py b/serve/mlc_serve/model/torch_model.py index 008aadd4a9..89756b0b89 100644 --- a/serve/mlc_serve/model/torch_model.py +++ b/serve/mlc_serve/model/torch_model.py @@ -1,6 +1,7 @@ import time import os from typing import List, Union, Tuple, Sequence +from pathlib import Path import structlog import torch @@ -12,7 +13,6 @@ from vllm.model_executor.models.qwen import QWenLMHeadModel from vllm.model_executor.models.phi import PhiForCausalLM from vllm.model_executor import InputMetadata, SamplingMetadata -from vllm.sampling_params import SamplingParams from vllm.model_executor.parallel_utils.parallel_state import ( initialize_model_parallel, ) @@ -37,7 +37,6 @@ PROMPT_SEQEUNCE_INDEX, get_prompt_sequence_id, MLCServeEngineConfig, - SamplingParams as MLCSamplingParams, ) from ..engine.model_module import ( DecodeRequest, @@ -49,16 +48,6 @@ LOG = structlog.stdlib.get_logger(__name__) -def convert_sampling_params(mlc_params: MLCSamplingParams) -> SamplingParams: - return SamplingParams( - presence_penalty=mlc_params.presence_penalty, - frequency_penalty=mlc_params.frequency_penalty, - temperature=mlc_params.temperature, - top_p=mlc_params.top_p, - top_k=mlc_params.top_k, - ) - - def init_cache_blocks(head_size, num_layers, num_heads, block_size, num_gpu_blocks): element_size = 2 x = 16 // element_size @@ -90,7 +79,7 @@ def profile_memory_usage(pt_model, seq_lens, num_hidden_layers): input_positions: List[List[int]] = [] slot_mapping: List[List[int]] = [] - for i, seq_len in enumerate(seq_lens): + for seq_len in seq_lens: prompt_tokens = [0] * seq_len input_tokens.append(prompt_tokens) @@ -329,8 +318,13 @@ def generate( class ModelRpcServer(rpyc.Service): def exposed_init_model( - self, tp_rank, num_shards, model_path, hf_config, engine_config - ): + self, + tp_rank: int, + num_shards: int, + model_path: Path, + hf_config: AutoConfig, + engine_config: MLCServeEngineConfig, + ) -> int: hf_config = obtain(hf_config) engine_config = obtain(engine_config) model_path = obtain(model_path) @@ -425,7 +419,13 @@ def start_model_process(port): class ModelRpcClient: - def __init__(self, num_shards, model_path, hf_config, engine_config): + def __init__( + self, + num_shards: int, + model_path: Path, + hf_config: AutoConfig, + engine_config: MLCServeEngineConfig, + ): with ThreadPoolExecutor(num_shards) as executor: ports = [3010 + i for i in range(num_shards)] # TODO port rets = executor.map(start_model_process, ports) @@ -438,9 +438,8 @@ def init_model(i): i, num_shards, model_path, hf_config, engine_config ) - rets = [obtain(x) for x in executor.map(init_model, range(num_shards))] - - self.num_blocks = rets[0] + rets = executor.map(init_model, range(num_shards)) + self.num_blocks = obtain(list(rets)[0]) def _func( requests: Sequence[Union[PrefillRequest, DecodeRequest]], cache: KVCacheInfo @@ -459,7 +458,7 @@ class Model: def __init__( self, pt_model, - config, + config: AutoConfig, ): self.pt_model = pt_model self.vocab_size = config.vocab_size @@ -482,12 +481,14 @@ def generate( def init_torch_model( - model_path, engine_config: MLCServeEngineConfig + model_path: Path, engine_config: MLCServeEngineConfig ) -> Tuple[TextGenerator, CacheManager, ModelArtifactConfig]: hf_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) - # TODO - num_shards = 1 + if engine_config.num_shards is None: + raise RuntimeError("num_shards needs to be specifed for PyTorch models.") + + num_shards = engine_config.num_shards if not hasattr(hf_config, "num_key_value_heads") and hasattr( hf_config, "num_attention_heads" @@ -500,8 +501,8 @@ def init_torch_model( num_kv_heads = hf_config.num_key_value_heads // num_shards artifact_config = ModelArtifactConfig( - model_artifact_path=model_path, - num_shards=1, + model_artifact_path=str(model_path), + num_shards=num_shards, quantization=None, max_context_length=hf_config.max_position_embeddings, # TODO, vocab_size=hf_config.vocab_size, @@ -528,7 +529,7 @@ def init_torch_model( pt_model = load_model(hf_config, model_path) model = Model(pt_model, hf_config) - model.cache_blocks, num_blocks = profile_and_init_cache( + model.cache_blocks, num_blocks = profile_and_init_cache( # type: ignore pt_model, hf_config.num_hidden_layers, hf_config, engine_config, num_shards ) diff --git a/serve/mlc_serve/utils.py b/serve/mlc_serve/utils.py index d293bf4058..d8c37ea521 100644 --- a/serve/mlc_serve/utils.py +++ b/serve/mlc_serve/utils.py @@ -29,6 +29,7 @@ def get_default_mlc_serve_argparser(description="", allow_override=False): parser.add_argument("--max-decode-steps", type=int, default=56) parser.add_argument("--debug-logging", action="store_true") parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--num-shards", type=int, default=1) # Needed for PT models return parser diff --git a/serve/tests/test_engine.py b/serve/tests/test_engine.py index 5582d6c815..c5473950e3 100644 --- a/serve/tests/test_engine.py +++ b/serve/tests/test_engine.py @@ -19,9 +19,11 @@ def _test(args: argparse.Namespace): model_type = "tvm" + num_shards = None if not os.path.exists(args.model_artifact_path.joinpath("build_config.json")): model_type = "torch" + num_shards = args.num_shards engine_config = get_engine_config( { @@ -30,6 +32,7 @@ def _test(args: argparse.Namespace): "min_decode_steps": args.min_decode_steps, "max_decode_steps": args.max_decode_steps, "model_type": model_type, + "num_shards": num_shards, } ) From 4564bd07306281a351618130e57a2db067884e8a Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 20 Jan 2024 21:40:08 +0000 Subject: [PATCH 40/73] unify Model and ModelRpcClient classes --- serve/mlc_serve/model/torch_model.py | 137 +++++++++++++++------------ 1 file changed, 74 insertions(+), 63 deletions(-) diff --git a/serve/mlc_serve/model/torch_model.py b/serve/mlc_serve/model/torch_model.py index 89756b0b89..394ad3ebbe 100644 --- a/serve/mlc_serve/model/torch_model.py +++ b/serve/mlc_serve/model/torch_model.py @@ -123,14 +123,18 @@ def profile_memory_usage(pt_model, seq_lens, num_hidden_layers): def profile_and_init_cache( - pt_model, num_hidden_layers, hf_config, engine_config, num_shards + pt_model, + hf_config, + num_shards, + max_num_batched_tokens, ): num_kv_heads = hf_config.num_key_value_heads // num_shards + num_hidden_layers = hf_config.num_hidden_layers head_size = hf_config.hidden_size // hf_config.num_attention_heads - if engine_config.max_num_batched_tokens > 0: + if max_num_batched_tokens > 0: LOG.info("Running memory profiling.") - seq_lens = [1] * engine_config.max_num_batched_tokens + seq_lens = [1] * max_num_batched_tokens used_memory_bytes = profile_memory_usage(pt_model, seq_lens, num_hidden_layers) num_blocks = get_num_cache_blocks( used_memory_bytes, @@ -151,6 +155,8 @@ def profile_and_init_cache( num_blocks, ) + LOG.info("Allocated KV cache blocks.") + return cache_blocks, num_blocks @@ -359,10 +365,9 @@ def exposed_init_model( self.cache_blocks, num_blocks = profile_and_init_cache( self.pt_model, - hf_config.num_hidden_layers, hf_config, - engine_config, num_shards, + engine_config.max_num_batched_tokens, ) return num_blocks @@ -421,13 +426,16 @@ def start_model_process(port): class ModelRpcClient: def __init__( self, - num_shards: int, model_path: Path, hf_config: AutoConfig, engine_config: MLCServeEngineConfig, ): - with ThreadPoolExecutor(num_shards) as executor: - ports = [3010 + i for i in range(num_shards)] # TODO port + assert engine_config.num_shards is not None + + self.num_shards = engine_config.num_shards + + with ThreadPoolExecutor(self.num_shards) as executor: + ports = [3010 + i for i in range(self.num_shards)] # TODO port rets = executor.map(start_model_process, ports) self.model_servers = [x[0] for x in rets] @@ -435,49 +443,74 @@ def __init__( def init_model(i): return self.model_servers[i].init_model( - i, num_shards, model_path, hf_config, engine_config + i, self.num_shards, model_path, hf_config, engine_config ) - rets = executor.map(init_model, range(num_shards)) + rets = executor.map(init_model, range(self.num_shards)) self.num_blocks = obtain(list(rets)[0]) - def _func( - requests: Sequence[Union[PrefillRequest, DecodeRequest]], cache: KVCacheInfo - ) -> List[TextGenerationResult]: - def generate(i): - return self.model_servers[i].generate(requests, cache) - - with ThreadPoolExecutor(num_shards) as executor: - res = [obtain(x) for x in executor.map(generate, range(num_shards))] - return obtain(res[0]) + def generate( + self, + requests: Sequence[Union[PrefillRequest, DecodeRequest]], + cache: KVCacheInfo, + ) -> List[TextGenerationResult]: + def _generate(i): + return self.model_servers[i].generate(requests, cache) - self.generate = _func + with ThreadPoolExecutor(self.num_shards) as executor: + res = [obtain(x) for x in executor.map(_generate, range(self.num_shards))] + return obtain(res[0]) class Model: def __init__( self, - pt_model, - config: AutoConfig, + model_path: Path, + hf_config: AutoConfig, + engine_config: MLCServeEngineConfig, ): - self.pt_model = pt_model - self.vocab_size = config.vocab_size - self.sliding_window = config.sliding_window - self.cache_blocks = None + if engine_config.num_shards and engine_config.num_shards > 1: + torch.multiprocessing.set_start_method("spawn") + self.model_rpc = ModelRpcClient(model_path, hf_config, engine_config) + self.num_blocks = self.model_rpc.num_blocks + self.cache_blocks = None + else: + torch.distributed.init_process_group( + backend="nccl", + world_size=1, + rank=0, + init_method="tcp://localhost:59157", # TODO port + ) + initialize_model_parallel(1, 1) + + self.pt_model = load_model(hf_config, model_path) + self.cache_blocks, self.num_blocks = profile_and_init_cache( + self.pt_model, + hf_config, + 1, + engine_config.max_num_batched_tokens, + ) + self.model_rpc = None + + self.vocab_size = hf_config.vocab_size + self.sliding_window = hf_config.sliding_window def generate( self, requests: Sequence[Union[PrefillRequest, DecodeRequest]], cache: KVCacheInfo, ) -> List[TextGenerationResult]: - return generate( - requests, - cache, - self.pt_model, - self.cache_blocks, - self.sliding_window, - self.vocab_size, - ) + if self.model_rpc is None: + return generate( + requests, + cache, + self.pt_model, + self.cache_blocks, + self.sliding_window, + self.vocab_size, + ) + + return self.model_rpc.generate(requests, cache) def init_torch_model( @@ -485,11 +518,6 @@ def init_torch_model( ) -> Tuple[TextGenerator, CacheManager, ModelArtifactConfig]: hf_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) - if engine_config.num_shards is None: - raise RuntimeError("num_shards needs to be specifed for PyTorch models.") - - num_shards = engine_config.num_shards - if not hasattr(hf_config, "num_key_value_heads") and hasattr( hf_config, "num_attention_heads" ): @@ -498,7 +526,10 @@ def init_torch_model( if not hasattr(hf_config, "sliding_window"): hf_config.sliding_window = None - num_kv_heads = hf_config.num_key_value_heads // num_shards + if engine_config.num_shards is None: + raise RuntimeError("num_shards needs to be specifed for PyTorch models.") + + num_shards = engine_config.num_shards artifact_config = ModelArtifactConfig( model_artifact_path=str(model_path), @@ -507,36 +538,16 @@ def init_torch_model( max_context_length=hf_config.max_position_embeddings, # TODO, vocab_size=hf_config.vocab_size, sliding_window=hf_config.sliding_window, - num_key_value_heads=num_kv_heads, + num_key_value_heads=hf_config.num_key_value_heads // num_shards, num_attention_heads=hf_config.num_attention_heads, num_hidden_layers=hf_config.num_hidden_layers, hidden_size=hf_config.hidden_size, ) - if num_shards > 1: - torch.multiprocessing.set_start_method("spawn") - model = ModelRpcClient(num_shards, model_path, hf_config, engine_config) - num_blocks = model.num_blocks - else: - torch.distributed.init_process_group( - backend="nccl", - world_size=1, - rank=0, - init_method="tcp://localhost:59157", # TODO port - ) - initialize_model_parallel(1, 1) - - pt_model = load_model(hf_config, model_path) - model = Model(pt_model, hf_config) - - model.cache_blocks, num_blocks = profile_and_init_cache( # type: ignore - pt_model, hf_config.num_hidden_layers, hf_config, engine_config, num_shards - ) - - LOG.info("Allocated KV cache blocks.") + model = Model(model_path, hf_config, engine_config) cache_manager = CacheManager( - num_blocks, + model.num_blocks, hf_config.sliding_window, ) From 196026ce84b4c00cbca9eb0ac1e71b471023685d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 20 Jan 2024 22:15:55 +0000 Subject: [PATCH 41/73] support PT model in server --- serve/mlc_serve/model/model_common.py | 18 +++++----- serve/mlc_serve/model/torch_model.py | 4 +-- serve/mlc_serve/run.py | 49 ++++++++++++--------------- serve/tests/test_engine.py | 4 +++ 4 files changed, 36 insertions(+), 39 deletions(-) diff --git a/serve/mlc_serve/model/model_common.py b/serve/mlc_serve/model/model_common.py index 9d18bfdfe2..b188a51b06 100644 --- a/serve/mlc_serve/model/model_common.py +++ b/serve/mlc_serve/model/model_common.py @@ -238,25 +238,25 @@ def prepare_inputs( max_context_len = max(max_context_len, seq_lens[-1]) - def to_torch(arr, torch_dtype): - return torch.tensor(arr, dtype=torch_dtype, device="cuda") - - def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]: - assert len(x) <= max_len - return x + [pad] * (max_len - len(x)) - def _do_pad( x: List[List[int]], max_len: int, - pad: int, + pad_val: int, ) -> List[List[int]]: - return [_pad_to_max(x_i, max_len, pad) for x_i in x] + def _pad_to_max(x: List[int], max_len: int, pad_val: int) -> List[int]: + assert len(x) <= max_len + return x + [pad_val] * (max_len - len(x)) + + return [_pad_to_max(x_i, max_len, pad_val) for x_i in x] if for_vllm and is_prefill: input_ids = _do_pad(input_ids, max_prompt_len, 0) positions = _do_pad(positions, max_prompt_len, 0) slot_mapping = _do_pad(slot_mapping, max_prompt_len, -1) + def to_torch(arr, torch_dtype): + return torch.tensor(arr, dtype=torch_dtype, device="cuda") + input_ids = to_torch(input_ids, torch_int_dtype) positions = to_torch(positions, torch_int_dtype) seq_lens = to_torch(seq_lens, torch.int) diff --git a/serve/mlc_serve/model/torch_model.py b/serve/mlc_serve/model/torch_model.py index 394ad3ebbe..2dc9db395c 100644 --- a/serve/mlc_serve/model/torch_model.py +++ b/serve/mlc_serve/model/torch_model.py @@ -455,6 +455,7 @@ def generate( cache: KVCacheInfo, ) -> List[TextGenerationResult]: def _generate(i): + # This calls ModelRpcServer.exposed_generate(...) via RPC. return self.model_servers[i].generate(requests, cache) with ThreadPoolExecutor(self.num_shards) as executor: @@ -470,10 +471,9 @@ def __init__( engine_config: MLCServeEngineConfig, ): if engine_config.num_shards and engine_config.num_shards > 1: - torch.multiprocessing.set_start_method("spawn") self.model_rpc = ModelRpcClient(model_path, hf_config, engine_config) self.num_blocks = self.model_rpc.num_blocks - self.cache_blocks = None + self.cache_blocks = None # Owned by each remote shard else: torch.distributed.init_process_group( backend="nccl", diff --git a/serve/mlc_serve/run.py b/serve/mlc_serve/run.py index 78a25f5e10..102e68a4f1 100644 --- a/serve/mlc_serve/run.py +++ b/serve/mlc_serve/run.py @@ -2,7 +2,6 @@ import tempfile import os import uvicorn -from pathlib import Path from .api import create_app from .engine import AsyncEngineConnector, get_engine_config @@ -10,34 +9,15 @@ from .engine.sync_engine import SynchronousInferenceEngine from .model.paged_cache_model import HfTokenizerModule, PagedCacheModelModule from .logging_utils import configure_logging +from .utils import get_default_mlc_serve_argparser, postproc_mlc_serve_args def parse_args(): - # Example - # python build.py --model vicuna-v1-7b --quantization q4f16_ft --use-cache=0 --max-seq-len 768 --batched - # python tests/python/test_batched.py --local-id vicuna-v1-7b-q4f16_ft - # - # For Disco: - # python build.py --model vicuna-v1-7b --quantization q0f16 --use-cache=0 --max-seq-len 768 --batched --build-model-only --num-shards 2 - # python build.py --model vicuna-v1-7b --quantization q0f16 --use-cache=0 --max-seq-len 768 --batched --convert-weight-only - # /opt/bin/cuda-reserve.py --num-gpus 2 python -m mlc_serve --local-id vicuna-v1-7b-q0f16 --num-shards 2 - # - # Profile the gpu memory usage, and use the maximum number of cache blocks possible: - # /opt/bin/cuda-reserve.py --num-gpus 2 python -m mlc_serve --local-id vicuna-v1-7b-q0f16 --num-shards 2 --max-num-batched-tokens 2560 - - # TODO(@sunggg): replace this with `utils.get_default_mlc_serve_argparser` - # Since this will require the change in ollm side as well, revisit this after octocalm. - args = argparse.ArgumentParser() + args = get_default_mlc_serve_argparser("MLC serve") args.add_argument("--host", type=str, default="127.0.0.1") args.add_argument("--port", type=int, default=8000) - args.add_argument("--local-id", type=str, required=True) - args.add_argument("--artifact-path", type=str, default="dist") - args.add_argument("--use-staging-engine", action="store_true") - args.add_argument("--max-num-batched-tokens", type=int, default=4096) - args.add_argument("--min-decode-steps", type=int, default=12) - args.add_argument("--max-decode-steps", type=int, default=16) - args.add_argument("--debug-logging", action="store_true") parsed = args.parse_args() + postproc_mlc_serve_args(parsed) return parsed @@ -52,10 +32,16 @@ def create_engine( | `ndarray-cache.json` is especially important for Disco. |- model/ : stores info from hf model cards such as max context length and tokenizer """ - model_artifact_path = Path(os.path.join(args.artifact_path, args.local_id)) - if not os.path.exists(model_artifact_path): + if not os.path.exists(args.model_artifact_path): raise Exception(f"Invalid local id: {args.local_id}") + model_type = "tvm" + num_shards = None + + if not os.path.exists(args.model_artifact_path.joinpath("build_config.json")): + model_type = "torch" + num_shards = args.num_shards + # Set the engine config engine_config = get_engine_config( { @@ -63,23 +49,30 @@ def create_engine( "max_num_batched_tokens": args.max_num_batched_tokens, "min_decode_steps": args.min_decode_steps, "max_decode_steps": args.max_decode_steps, + "model_type": model_type, + "num_shards": num_shards, } ) # TODO(yelite, masahi): Protocol subtyping is not working if args.use_staging_engine: + if model_type == "tvm": + tokenizer_path = args.model_artifact_path.joinpath("model") + else: + tokenizer_path = args.model_artifact_path + return StagingInferenceEngine( - tokenizer_module=HfTokenizerModule(model_artifact_path), + tokenizer_module=HfTokenizerModule(tokenizer_path), model_module_loader=PagedCacheModelModule, # type: ignore model_module_loader_kwargs={ - "model_artifact_path": model_artifact_path, + "model_artifact_path": args.model_artifact_path, "engine_config": engine_config, }, ) else: return SynchronousInferenceEngine( PagedCacheModelModule( - model_artifact_path=model_artifact_path, + model_artifact_path=args.model_artifact_path, engine_config=engine_config, ) # type: ignore ) diff --git a/serve/tests/test_engine.py b/serve/tests/test_engine.py index c5473950e3..5fc51cd82a 100644 --- a/serve/tests/test_engine.py +++ b/serve/tests/test_engine.py @@ -25,6 +25,10 @@ def _test(args: argparse.Namespace): model_type = "torch" num_shards = args.num_shards + if num_shards > 1: + import torch + torch.multiprocessing.set_start_method("spawn") + engine_config = get_engine_config( { "use_staging_engine": args.use_staging_engine, From b5026546053029eb11f7d25c9495e4a8bd1b76e7 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 21 Jan 2024 11:42:36 +0000 Subject: [PATCH 42/73] properly allocate port --- serve/mlc_serve/model/torch_model.py | 46 +++++++++++++++++++++++----- serve/mlc_serve/utils.py | 1 - 2 files changed, 39 insertions(+), 8 deletions(-) diff --git a/serve/mlc_serve/model/torch_model.py b/serve/mlc_serve/model/torch_model.py index 2dc9db395c..4bc3fb23b4 100644 --- a/serve/mlc_serve/model/torch_model.py +++ b/serve/mlc_serve/model/torch_model.py @@ -1,6 +1,7 @@ import time import os -from typing import List, Union, Tuple, Sequence +import socket +from typing import List, Union, Sequence, Tuple from pathlib import Path import structlog @@ -330,6 +331,7 @@ def exposed_init_model( model_path: Path, hf_config: AutoConfig, engine_config: MLCServeEngineConfig, + master_port: int, ) -> int: hf_config = obtain(hf_config) engine_config = obtain(engine_config) @@ -349,7 +351,7 @@ def exposed_init_model( torch.cuda.set_device(tp_rank) os.environ["MASTER_ADDR"] = str("127.0.0.1") - os.environ["MASTER_PORT"] = str(4000) # TODO port + os.environ["MASTER_PORT"] = str(master_port) torch.distributed.init_process_group( backend="nccl", @@ -429,21 +431,29 @@ def __init__( model_path: Path, hf_config: AutoConfig, engine_config: MLCServeEngineConfig, + ports: List[int], ): assert engine_config.num_shards is not None self.num_shards = engine_config.num_shards + rpc_ports = ports[: self.num_shards] + master_port = ports[-1] + with ThreadPoolExecutor(self.num_shards) as executor: - ports = [3010 + i for i in range(self.num_shards)] # TODO port - rets = executor.map(start_model_process, ports) + rets = executor.map(start_model_process, rpc_ports) self.model_servers = [x[0] for x in rets] self.procs = [x[1] for x in rets] def init_model(i): return self.model_servers[i].init_model( - i, self.num_shards, model_path, hf_config, engine_config + i, + self.num_shards, + model_path, + hf_config, + engine_config, + master_port, ) rets = executor.map(init_model, range(self.num_shards)) @@ -463,6 +473,23 @@ def _generate(i): return obtain(res[0]) +# Taken from sgl-project/sglang +def alloc_usable_network_port(num): + port_list = [] + for port in range(10000, 65536): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + s.bind(("", port)) + port_list.append(port) + except socket.error: + pass + + if len(port_list) == num: + break + + return port_list + + class Model: def __init__( self, @@ -471,15 +498,20 @@ def __init__( engine_config: MLCServeEngineConfig, ): if engine_config.num_shards and engine_config.num_shards > 1: - self.model_rpc = ModelRpcClient(model_path, hf_config, engine_config) + num_needed_ports = engine_config.num_shards + 1 + ports = alloc_usable_network_port(num_needed_ports) + assert len(ports) == num_needed_ports, "Not enough ports available." + self.model_rpc = ModelRpcClient(model_path, hf_config, engine_config, ports) self.num_blocks = self.model_rpc.num_blocks self.cache_blocks = None # Owned by each remote shard else: + ports = alloc_usable_network_port(1) + assert len(ports) == 1 torch.distributed.init_process_group( backend="nccl", world_size=1, rank=0, - init_method="tcp://localhost:59157", # TODO port + init_method=f"tcp://localhost:{ports[0]}", ) initialize_model_parallel(1, 1) diff --git a/serve/mlc_serve/utils.py b/serve/mlc_serve/utils.py index 4dee6f96d6..d8c37ea521 100644 --- a/serve/mlc_serve/utils.py +++ b/serve/mlc_serve/utils.py @@ -23,7 +23,6 @@ def get_default_mlc_serve_argparser(description="", allow_override=False): parser.add_argument("--local-id", type=str, required=True) parser.add_argument("--artifact-path", type=str, default="dist") parser.add_argument("--use-sync-engine", action="store_true") - parser.add_argument("--num-sequences-to-sample", type=int, default=1) parser.add_argument("--max-num-batched-tokens", type=int, default=4096) parser.add_argument("--num-sequences-to-sample", type=int, default=1) parser.add_argument("--min-decode-steps", type=int, default=32) From 686780c2669fa0a2a11ca91fb5f9e29b012251ac Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 22 Jan 2024 00:25:28 +0000 Subject: [PATCH 43/73] refactor engine creation --- serve/mlc_serve/utils.py | 28 +++++++++++++++++++---- serve/tests/test_engine.py | 47 ++------------------------------------ 2 files changed, 25 insertions(+), 50 deletions(-) diff --git a/serve/mlc_serve/utils.py b/serve/mlc_serve/utils.py index d8c37ea521..c2024ffe36 100644 --- a/serve/mlc_serve/utils.py +++ b/serve/mlc_serve/utils.py @@ -48,19 +48,37 @@ def postproc_mlc_serve_args(args): def create_mlc_engine(args: argparse.Namespace): + model_type = "tvm" + num_shards = None + + if not os.path.exists(args.model_artifact_path.joinpath("build_config.json")): + model_type = "torch" + num_shards = args.num_shards + + if num_shards > 1: + import torch + torch.multiprocessing.set_start_method("spawn") + engine_config = get_engine_config( { "use_staging_engine": args.use_staging_engine, "max_num_batched_tokens": args.max_num_batched_tokens, "min_decode_steps": args.min_decode_steps, "max_decode_steps": args.max_decode_steps, + "model_type": model_type, + "num_shards": num_shards, } ) # TODO(@team): There is a type mismatch in the definition. Let's fix this when have time. if args.use_staging_engine: - engine = StagingInferenceEngine( # type: ignore - tokenizer_module=HfTokenizerModule(args.model_artifact_path), + if model_type == "tvm": + tokenizer_path = args.model_artifact_path.joinpath("model") + else: + tokenizer_path = args.model_artifact_path + + engine = StagingInferenceEngine( + tokenizer_module=HfTokenizerModule(tokenizer_path), model_module_loader=PagedCacheModelModule, # type: ignore model_module_loader_kwargs={ "model_artifact_path": args.model_artifact_path, @@ -69,10 +87,10 @@ def create_mlc_engine(args: argparse.Namespace): ) engine.start() else: - engine = SynchronousInferenceEngine( # type: ignore - PagedCacheModelModule( # type: ignore + engine = SynchronousInferenceEngine( + PagedCacheModelModule( model_artifact_path=args.model_artifact_path, engine_config=engine_config, - ) + ) # type: ignore ) return engine diff --git a/serve/tests/test_engine.py b/serve/tests/test_engine.py index 5fc51cd82a..8d419395fd 100644 --- a/serve/tests/test_engine.py +++ b/serve/tests/test_engine.py @@ -14,54 +14,11 @@ from mlc_serve.engine.staging_engine import StagingInferenceEngine from mlc_serve.engine.sync_engine import SynchronousInferenceEngine from mlc_serve.model.paged_cache_model import HfTokenizerModule, PagedCacheModelModule -from mlc_serve.utils import get_default_mlc_serve_argparser, postproc_mlc_serve_args +from mlc_serve.utils import get_default_mlc_serve_argparser, postproc_mlc_serve_args, create_mlc_engine def _test(args: argparse.Namespace): - model_type = "tvm" - num_shards = None - - if not os.path.exists(args.model_artifact_path.joinpath("build_config.json")): - model_type = "torch" - num_shards = args.num_shards - - if num_shards > 1: - import torch - torch.multiprocessing.set_start_method("spawn") - - engine_config = get_engine_config( - { - "use_staging_engine": args.use_staging_engine, - "max_num_batched_tokens": args.max_num_batched_tokens, - "min_decode_steps": args.min_decode_steps, - "max_decode_steps": args.max_decode_steps, - "model_type": model_type, - "num_shards": num_shards, - } - ) - - if args.use_staging_engine: - if model_type == "tvm": - tokenizer_path = args.model_artifact_path.joinpath("model") - else: - tokenizer_path = args.model_artifact_path - - engine = StagingInferenceEngine( - tokenizer_module=HfTokenizerModule(tokenizer_path), - model_module_loader=PagedCacheModelModule, - model_module_loader_kwargs={ - "model_artifact_path": args.model_artifact_path, - "engine_config": engine_config, - }, - ) - engine.start() - else: - engine = SynchronousInferenceEngine( - PagedCacheModelModule( - model_artifact_path=args.model_artifact_path, - engine_config=engine_config, - ) - ) + engine = create_mlc_engine(args) sampling_params_greedy = SamplingParams( temperature=0.0, From bebd7b25e87dc466850fc7d920ac07b78b2fb3b3 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 22 Jan 2024 07:47:05 +0000 Subject: [PATCH 44/73] fix sync point --- serve/mlc_serve/model/torch_model.py | 4 ++-- serve/mlc_serve/utils.py | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/serve/mlc_serve/model/torch_model.py b/serve/mlc_serve/model/torch_model.py index 4bc3fb23b4..811a58ff36 100644 --- a/serve/mlc_serve/model/torch_model.py +++ b/serve/mlc_serve/model/torch_model.py @@ -291,11 +291,11 @@ def generate( vocab_size, ) - next_tokens = sample(logits, sampling_params, vocab_size) - torch.cuda.synchronize() torch.cuda.nvtx.range_pop() + next_tokens = sample(logits, sampling_params, vocab_size) + outputs = [] for i, (sequence_id, new_token) in enumerate(zip(sequence_ids, next_tokens)): diff --git a/serve/mlc_serve/utils.py b/serve/mlc_serve/utils.py index c2024ffe36..69e2b9754b 100644 --- a/serve/mlc_serve/utils.py +++ b/serve/mlc_serve/utils.py @@ -55,6 +55,8 @@ def create_mlc_engine(args: argparse.Namespace): model_type = "torch" num_shards = args.num_shards + assert num_shards is not None, "--num-shards needs to be provided for PT models." + if num_shards > 1: import torch torch.multiprocessing.set_start_method("spawn") From ed46b5eb89ea8a9cfd3da042e4f82209f26e4c05 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 22 Jan 2024 19:54:07 +0000 Subject: [PATCH 45/73] do not create executor at each step --- serve/mlc_serve/model/torch_model.py | 38 +++++++++++++++------------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/serve/mlc_serve/model/torch_model.py b/serve/mlc_serve/model/torch_model.py index 811a58ff36..b8f75e03cd 100644 --- a/serve/mlc_serve/model/torch_model.py +++ b/serve/mlc_serve/model/torch_model.py @@ -439,25 +439,28 @@ def __init__( rpc_ports = ports[: self.num_shards] master_port = ports[-1] + self.executor = ThreadPoolExecutor(self.num_shards) - with ThreadPoolExecutor(self.num_shards) as executor: - rets = executor.map(start_model_process, rpc_ports) + rets = self.executor.map(start_model_process, rpc_ports) - self.model_servers = [x[0] for x in rets] - self.procs = [x[1] for x in rets] + self.model_servers = [x[0] for x in rets] + self.procs = [x[1] for x in rets] - def init_model(i): - return self.model_servers[i].init_model( - i, - self.num_shards, - model_path, - hf_config, - engine_config, - master_port, - ) + def init_model(i): + return self.model_servers[i].init_model( + i, + self.num_shards, + model_path, + hf_config, + engine_config, + master_port, + ) + + rets = self.executor.map(init_model, range(self.num_shards)) + self.num_blocks = obtain(list(rets)[0]) - rets = executor.map(init_model, range(self.num_shards)) - self.num_blocks = obtain(list(rets)[0]) + def __del__(self): + self.executor.shutdown() def generate( self, @@ -468,9 +471,8 @@ def _generate(i): # This calls ModelRpcServer.exposed_generate(...) via RPC. return self.model_servers[i].generate(requests, cache) - with ThreadPoolExecutor(self.num_shards) as executor: - res = [obtain(x) for x in executor.map(_generate, range(self.num_shards))] - return obtain(res[0]) + res = [obtain(x) for x in self.executor.map(_generate, range(self.num_shards))] + return obtain(res[0]) # Taken from sgl-project/sglang From 62918ddf97194316e09dd3949ebc0af80d7573e3 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 22 Jan 2024 22:25:44 +0000 Subject: [PATCH 46/73] remove dup obtain calls --- serve/mlc_serve/model/torch_model.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/serve/mlc_serve/model/torch_model.py b/serve/mlc_serve/model/torch_model.py index b8f75e03cd..fa86560b62 100644 --- a/serve/mlc_serve/model/torch_model.py +++ b/serve/mlc_serve/model/torch_model.py @@ -221,6 +221,8 @@ def generate( selected_token_indices.append(seq_start + prompt_len - 1) seq_start += max_prompt_len + torch.cuda.nvtx.range_push(f"Prepare input") + ( input_ids, positions, @@ -239,6 +241,8 @@ def generate( for_vllm=True, ) + torch.cuda.nvtx.range_pop() + input_shape = input_ids.shape if block_tables is None: @@ -379,8 +383,10 @@ def exposed_generate( requests: Sequence[Union[PrefillRequest, DecodeRequest]], cache: KVCacheInfo, ) -> List[TextGenerationResult]: + torch.cuda.nvtx.range_push(f"Obtain input") requests = obtain(requests) cache = obtain(cache) + torch.cuda.nvtx.range_pop() return generate( requests, cache, @@ -472,7 +478,7 @@ def _generate(i): return self.model_servers[i].generate(requests, cache) res = [obtain(x) for x in self.executor.map(_generate, range(self.num_shards))] - return obtain(res[0]) + return res[0] # Taken from sgl-project/sglang From e1445173fd2b70b5e7ef630db0602031e1a70c15 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 1 Feb 2024 18:12:46 +0000 Subject: [PATCH 47/73] fix --- serve/mlc_serve/model/torch_model.py | 5 ++++- serve/mlc_serve/model/tvm_model.py | 3 +-- serve/mlc_serve/utils.py | 1 + 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/serve/mlc_serve/model/torch_model.py b/serve/mlc_serve/model/torch_model.py index fa86560b62..a0965b63bd 100644 --- a/serve/mlc_serve/model/torch_model.py +++ b/serve/mlc_serve/model/torch_model.py @@ -31,6 +31,7 @@ sample, prepare_inputs, get_num_cache_blocks, + get_logprob_infos, ) from ..engine import ( @@ -298,7 +299,7 @@ def generate( torch.cuda.synchronize() torch.cuda.nvtx.range_pop() - next_tokens = sample(logits, sampling_params, vocab_size) + next_tokens, logprob_infos = sample(logits, sampling_params, vocab_size) outputs = [] @@ -313,6 +314,7 @@ def generate( sequence_id=SequenceId(sequence_id.request_id, seq_id), generated_tokens=[new_token], error=None, + logprob_info=get_logprob_infos(i, logprob_infos), ) ) else: @@ -321,6 +323,7 @@ def generate( sequence_id=sequence_id, generated_tokens=[new_token], error=None, + logprob_info=get_logprob_infos(i, logprob_infos), ) ) diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index 6112a77191..fb5315cd2c 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -1,6 +1,6 @@ import math import os -from typing import List, Optional, Union, Tuple, Sequence +from typing import List, Union, Tuple, Sequence import structlog import numpy as np @@ -20,7 +20,6 @@ from ..engine import ( PROMPT_SEQEUNCE_INDEX, - RawLogprobsInfos, SequenceId, get_prompt_sequence_id, MLCServeEngineConfig, diff --git a/serve/mlc_serve/utils.py b/serve/mlc_serve/utils.py index 79d7b1e629..834a03cda6 100644 --- a/serve/mlc_serve/utils.py +++ b/serve/mlc_serve/utils.py @@ -23,6 +23,7 @@ def get_default_mlc_serve_argparser(description="", allow_override=False): parser.add_argument("--local-id", type=str, required=True) parser.add_argument("--artifact-path", type=str, default="dist") parser.add_argument("--use-sync-engine", action="store_true") + parser.add_argument("--num-sequences-to-sample", type=int, default=1) parser.add_argument("--max-num-batched-tokens", type=int, default=4096) parser.add_argument("--min-decode-steps", type=int, default=32) parser.add_argument("--max-decode-steps", type=int, default=56) From dfbf35959acbd83c877d4f65580bb33dd287d272 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 2 Feb 2024 11:01:23 +0000 Subject: [PATCH 48/73] use sample_from_logits --- serve/mlc_serve/model/torch_model.py | 34 ++-------------------------- 1 file changed, 2 insertions(+), 32 deletions(-) diff --git a/serve/mlc_serve/model/torch_model.py b/serve/mlc_serve/model/torch_model.py index a0965b63bd..c45ff39c24 100644 --- a/serve/mlc_serve/model/torch_model.py +++ b/serve/mlc_serve/model/torch_model.py @@ -28,10 +28,10 @@ from .base import ModelArtifactConfig from .paged_cache_manager import KVCacheInfo, CacheManager from .model_common import ( - sample, prepare_inputs, get_num_cache_blocks, get_logprob_infos, + sample_from_logits, ) from ..engine import ( @@ -195,7 +195,6 @@ def generate( is_prefill = isinstance(requests[0], PrefillRequest) all_token_ids = [] - sampling_params = [] sequence_ids = [] prompt_lens = [] num_sequences = [] @@ -210,7 +209,6 @@ def generate( prompt_lens.append(request.prompt_token_counts) all_token_ids.append(request.token_ids) - sampling_params.append(request.sampling_params) selected_token_indices: List[int] = [] @@ -299,35 +297,7 @@ def generate( torch.cuda.synchronize() torch.cuda.nvtx.range_pop() - next_tokens, logprob_infos = sample(logits, sampling_params, vocab_size) - - outputs = [] - - for i, (sequence_id, new_token) in enumerate(zip(sequence_ids, next_tokens)): - if not new_token in requests[i].sampling_params.appeared_tokens_freq: - requests[i].sampling_params.appeared_tokens_freq[new_token] = 0 - requests[i].sampling_params.appeared_tokens_freq[new_token] += 1 - if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: - for seq_id in range(num_sequences[i]): - outputs.append( - TextGenerationResult( - sequence_id=SequenceId(sequence_id.request_id, seq_id), - generated_tokens=[new_token], - error=None, - logprob_info=get_logprob_infos(i, logprob_infos), - ) - ) - else: - outputs.append( - TextGenerationResult( - sequence_id=sequence_id, - generated_tokens=[new_token], - error=None, - logprob_info=get_logprob_infos(i, logprob_infos), - ) - ) - - return outputs + return sample_from_logits(logits, sequence_ids, requests, vocab_size) class ModelRpcServer(rpyc.Service): From 04da3bb647445819e364b4a20db2588c68395b68 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 6 Feb 2024 08:23:19 +0000 Subject: [PATCH 49/73] enable TCP NoDelay option to fix slow socket recv issue --- serve/mlc_serve/model/torch_model.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/serve/mlc_serve/model/torch_model.py b/serve/mlc_serve/model/torch_model.py index c45ff39c24..dbb555fedf 100644 --- a/serve/mlc_serve/model/torch_model.py +++ b/serve/mlc_serve/model/torch_model.py @@ -23,6 +23,8 @@ import rpyc from rpyc.utils.classic import obtain from rpyc.utils.server import ThreadedServer +from rpyc.core.service import VoidService +from rpyc.core.stream import SocketStream from concurrent.futures import ThreadPoolExecutor from .base import ModelArtifactConfig @@ -379,6 +381,11 @@ def _init_service(port): t.start() +def connect_rpyc(host, port, config={}): + s = SocketStream.connect(host, port, ipv6=False, keepalive=False, nodelay=True) + return rpyc.connect_stream(s, VoidService, config) + + def start_model_process(port): proc = multiprocessing.Process(target=_init_service, args=(port,)) proc.start() @@ -388,7 +395,7 @@ def start_model_process(port): repeat_count = 0 while repeat_count < 20: try: - con = rpyc.connect( + con = connect_rpyc( "localhost", port, config={"allow_pickle": True, "sync_request_timeout": 600}, From 5dfecb2b34e3e17905a52044474468ba436f2923 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 6 Feb 2024 10:08:41 +0000 Subject: [PATCH 50/73] Replace TCP with Unix domain socket --- serve/mlc_serve/model/torch_model.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/serve/mlc_serve/model/torch_model.py b/serve/mlc_serve/model/torch_model.py index dbb555fedf..010c28e56b 100644 --- a/serve/mlc_serve/model/torch_model.py +++ b/serve/mlc_serve/model/torch_model.py @@ -1,5 +1,6 @@ import time import os +import tempfile import socket from typing import List, Union, Sequence, Tuple from pathlib import Path @@ -25,6 +26,8 @@ from rpyc.utils.server import ThreadedServer from rpyc.core.service import VoidService from rpyc.core.stream import SocketStream +from rpyc.utils.factory import unix_connect + from concurrent.futures import ThreadPoolExecutor from .base import ModelArtifactConfig @@ -372,10 +375,10 @@ def exposed_generate( ) -def _init_service(port): +def _init_service(socket_path): t = ThreadedServer( ModelRpcServer(), - port=port, + socket_path=socket_path, protocol_config={"allow_pickle": True, "sync_request_timeout": 600}, ) t.start() @@ -386,24 +389,24 @@ def connect_rpyc(host, port, config={}): return rpyc.connect_stream(s, VoidService, config) -def start_model_process(port): - proc = multiprocessing.Process(target=_init_service, args=(port,)) +def start_model_process(socket_path): + proc = multiprocessing.Process(target=_init_service, args=(socket_path,)) proc.start() time.sleep(1) repeat_count = 0 + while repeat_count < 20: try: - con = connect_rpyc( - "localhost", - port, - config={"allow_pickle": True, "sync_request_timeout": 600}, + con = unix_connect( + socket_path, config={"allow_pickle": True, "sync_request_timeout": 600} ) break - except ConnectionRefusedError: + except FileNotFoundError: time.sleep(1) repeat_count += 1 + if repeat_count == 20: raise RuntimeError("init rpc env error!") @@ -423,11 +426,11 @@ def __init__( self.num_shards = engine_config.num_shards - rpc_ports = ports[: self.num_shards] master_port = ports[-1] self.executor = ThreadPoolExecutor(self.num_shards) + self.socket_paths = [tempfile.mktemp() for _ in range(self.num_shards)] - rets = self.executor.map(start_model_process, rpc_ports) + rets = self.executor.map(start_model_process, self.socket_paths) self.model_servers = [x[0] for x in rets] self.procs = [x[1] for x in rets] From e4bbad951b901a7e3fe392641f7fcafe91fb1439 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 6 Feb 2024 19:28:06 +0000 Subject: [PATCH 51/73] clean and add note on RPC overhead --- serve/mlc_serve/model/torch_model.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/serve/mlc_serve/model/torch_model.py b/serve/mlc_serve/model/torch_model.py index 010c28e56b..8d5547d534 100644 --- a/serve/mlc_serve/model/torch_model.py +++ b/serve/mlc_serve/model/torch_model.py @@ -24,8 +24,6 @@ import rpyc from rpyc.utils.classic import obtain from rpyc.utils.server import ThreadedServer -from rpyc.core.service import VoidService -from rpyc.core.stream import SocketStream from rpyc.utils.factory import unix_connect from concurrent.futures import ThreadPoolExecutor @@ -35,13 +33,10 @@ from .model_common import ( prepare_inputs, get_num_cache_blocks, - get_logprob_infos, sample_from_logits, ) from ..engine import ( - SequenceId, - PROMPT_SEQEUNCE_INDEX, get_prompt_sequence_id, MLCServeEngineConfig, ) @@ -361,6 +356,9 @@ def exposed_generate( requests: Sequence[Union[PrefillRequest, DecodeRequest]], cache: KVCacheInfo, ) -> List[TextGenerationResult]: + # TODO(masahi): Currently, obtaining inputs is the bottleneck. + # We should switch to the architecture used by Disco and vLLM as of + # https://github.com/vllm-project/vllm/pull/2221 torch.cuda.nvtx.range_push(f"Obtain input") requests = obtain(requests) cache = obtain(cache) @@ -384,11 +382,6 @@ def _init_service(socket_path): t.start() -def connect_rpyc(host, port, config={}): - s = SocketStream.connect(host, port, ipv6=False, keepalive=False, nodelay=True) - return rpyc.connect_stream(s, VoidService, config) - - def start_model_process(socket_path): proc = multiprocessing.Process(target=_init_service, args=(socket_path,)) proc.start() @@ -489,7 +482,7 @@ def __init__( engine_config: MLCServeEngineConfig, ): if engine_config.num_shards and engine_config.num_shards > 1: - num_needed_ports = engine_config.num_shards + 1 + num_needed_ports = 1 # For torch distributed master port ports = alloc_usable_network_port(num_needed_ports) assert len(ports) == num_needed_ports, "Not enough ports available." self.model_rpc = ModelRpcClient(model_path, hf_config, engine_config, ports) From ee9cdc91148f7ac849d07d221d7ebadfc88807b8 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 6 Feb 2024 19:50:03 +0000 Subject: [PATCH 52/73] clean --- serve/tests/test_engine.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/serve/tests/test_engine.py b/serve/tests/test_engine.py index 8d419395fd..e64f830f7d 100644 --- a/serve/tests/test_engine.py +++ b/serve/tests/test_engine.py @@ -1,4 +1,3 @@ -import os import argparse import json import random @@ -9,11 +8,7 @@ DebugOptions, SamplingParams, StoppingCriteria, - get_engine_config, ) -from mlc_serve.engine.staging_engine import StagingInferenceEngine -from mlc_serve.engine.sync_engine import SynchronousInferenceEngine -from mlc_serve.model.paged_cache_model import HfTokenizerModule, PagedCacheModelModule from mlc_serve.utils import get_default_mlc_serve_argparser, postproc_mlc_serve_args, create_mlc_engine From 2071749f9014a6a4d7219a3ae6b42cc1e331ba04 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 6 Feb 2024 21:09:19 +0000 Subject: [PATCH 53/73] RPC process join works --- .../mlc_serve/engine/staging_engine_worker.py | 1 + serve/mlc_serve/model/torch_model.py | 25 +++++++++++++++---- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/serve/mlc_serve/engine/staging_engine_worker.py b/serve/mlc_serve/engine/staging_engine_worker.py index da4731a96f..908f452721 100644 --- a/serve/mlc_serve/engine/staging_engine_worker.py +++ b/serve/mlc_serve/engine/staging_engine_worker.py @@ -371,6 +371,7 @@ def handle_command(): while True: cmd = command_queue.get() if isinstance(cmd, ShutdownCommand): + del worker.text_generator break elif isinstance(cmd, AddRequestsCommand): worker.add(cmd.request_states) diff --git a/serve/mlc_serve/model/torch_model.py b/serve/mlc_serve/model/torch_model.py index 8d5547d534..b6b5ee2b41 100644 --- a/serve/mlc_serve/model/torch_model.py +++ b/serve/mlc_serve/model/torch_model.py @@ -389,10 +389,11 @@ def start_model_process(socket_path): time.sleep(1) repeat_count = 0 + conn = None while repeat_count < 20: try: - con = unix_connect( + conn = unix_connect( socket_path, config={"allow_pickle": True, "sync_request_timeout": 600} ) break @@ -404,7 +405,7 @@ def start_model_process(socket_path): raise RuntimeError("init rpc env error!") assert proc.is_alive() - return con.root, proc + return conn, proc class ModelRpcClient: @@ -423,10 +424,14 @@ def __init__( self.executor = ThreadPoolExecutor(self.num_shards) self.socket_paths = [tempfile.mktemp() for _ in range(self.num_shards)] - rets = self.executor.map(start_model_process, self.socket_paths) + self.model_servers = [] + self.connections = [] + self.procs = [] - self.model_servers = [x[0] for x in rets] - self.procs = [x[1] for x in rets] + for conn, proc in self.executor.map(start_model_process, self.socket_paths): + self.model_servers.append(conn.root) + self.connections.append(conn) + self.procs.append(proc) def init_model(i): return self.model_servers[i].init_model( @@ -444,6 +449,13 @@ def init_model(i): def __del__(self): self.executor.shutdown() + for conn in self.connections: + conn.close() + + for proc in self.procs: + proc.terminate() + proc.join() + def generate( self, requests: Sequence[Union[PrefillRequest, DecodeRequest]], @@ -511,6 +523,9 @@ def __init__( self.vocab_size = hf_config.vocab_size self.sliding_window = hf_config.sliding_window + def __del__(self): + del self.model_rpc + def generate( self, requests: Sequence[Union[PrefillRequest, DecodeRequest]], From 15a90d05a05a857f055fad37e749ae3075d93701 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 6 Feb 2024 21:23:36 +0000 Subject: [PATCH 54/73] fix mypy --- serve/mlc_serve/model/torch_model.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/serve/mlc_serve/model/torch_model.py b/serve/mlc_serve/model/torch_model.py index b6b5ee2b41..743b5332dc 100644 --- a/serve/mlc_serve/model/torch_model.py +++ b/serve/mlc_serve/model/torch_model.py @@ -2,7 +2,7 @@ import os import tempfile import socket -from typing import List, Union, Sequence, Tuple +from typing import List, Tuple from pathlib import Path import structlog @@ -45,6 +45,7 @@ PrefillRequest, TextGenerationResult, TextGenerator, + RequestsType, ) LOG = structlog.stdlib.get_logger(__name__) @@ -182,7 +183,7 @@ def load_model(hf_config, model_path): def generate( - requests: Sequence[Union[PrefillRequest, DecodeRequest]], + requests: RequestsType, cache_info: KVCacheInfo, pt_model, cache_blocks, @@ -204,9 +205,11 @@ def generate( sequence_ids.append(get_prompt_sequence_id(request.request_id)) num_sequences.append(request.num_sequence) prompt_lens.append(len(request.token_ids)) - else: + elif isinstance(request, DecodeRequest): sequence_ids.append(request.sequence_id) prompt_lens.append(request.prompt_token_counts) + else: + raise RuntimeError(f"Unsupported request type {request}") all_token_ids.append(request.token_ids) @@ -353,7 +356,7 @@ def exposed_init_model( def exposed_generate( self, - requests: Sequence[Union[PrefillRequest, DecodeRequest]], + requests: RequestsType, cache: KVCacheInfo, ) -> List[TextGenerationResult]: # TODO(masahi): Currently, obtaining inputs is the bottleneck. @@ -458,7 +461,7 @@ def __del__(self): def generate( self, - requests: Sequence[Union[PrefillRequest, DecodeRequest]], + requests: RequestsType, cache: KVCacheInfo, ) -> List[TextGenerationResult]: def _generate(i): @@ -528,7 +531,7 @@ def __del__(self): def generate( self, - requests: Sequence[Union[PrefillRequest, DecodeRequest]], + requests: RequestsType, cache: KVCacheInfo, ) -> List[TextGenerationResult]: if self.model_rpc is None: From 90284fa3684a09d8795c14d1983815959f855128 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 13 Feb 2024 08:24:15 +0000 Subject: [PATCH 55/73] merge fix --- serve/mlc_serve/engine/staging_engine.py | 7 ++- .../mlc_serve/engine/staging_engine_worker.py | 2 +- serve/mlc_serve/model/model_common.py | 29 +++++----- serve/mlc_serve/model/sampler.py | 14 ++++- serve/mlc_serve/model/tokenizer.py | 2 +- serve/mlc_serve/model/torch_model.py | 56 +++++++++++++++---- serve/mlc_serve/model/tvm_model.py | 22 ++++++-- 7 files changed, 91 insertions(+), 41 deletions(-) diff --git a/serve/mlc_serve/engine/staging_engine.py b/serve/mlc_serve/engine/staging_engine.py index fb7e3f6c03..dd74209e86 100644 --- a/serve/mlc_serve/engine/staging_engine.py +++ b/serve/mlc_serve/engine/staging_engine.py @@ -60,9 +60,10 @@ def __init__( # TODO(@team): This is a temporary solution to expose model config to higher API layer. # Follow-up with the proper solution - self.model_artifact_config = get_model_artifact_config( - model_module_loader_kwargs["model_artifact_path"] - ) + # TODO fix for PT + # self.model_artifact_config = get_model_artifact_config( + # model_module_loader_kwargs["model_artifact_path"] + # ) self.tokenizer = tokenizer_module.tokenizer self.conversation_template = tokenizer_module.conversation_template diff --git a/serve/mlc_serve/engine/staging_engine_worker.py b/serve/mlc_serve/engine/staging_engine_worker.py index 73ee093d9a..84e7ce8b28 100644 --- a/serve/mlc_serve/engine/staging_engine_worker.py +++ b/serve/mlc_serve/engine/staging_engine_worker.py @@ -358,7 +358,7 @@ def run_generation_loop_worker( try: model_module = model_module_loader(**model_module_loader_kwargs) - LOG.info("Model is initalized.") + LOG.info("Model is initialized.") worker = GenerationLoopWorker(model_module=model_module) except: LOG.exception("An error raised in model initialization.") diff --git a/serve/mlc_serve/model/model_common.py b/serve/mlc_serve/model/model_common.py index 76044a04eb..c3dd207bb8 100644 --- a/serve/mlc_serve/model/model_common.py +++ b/serve/mlc_serve/model/model_common.py @@ -83,21 +83,17 @@ def sample_from_logits( requests: Sequence[RequestType], sampling_metadata: SamplingState, vocab_size: int, - copy_stream: torch.cuda.Stream, torch_dtype: torch.dtype, torch_dev: str, past_decode_tokens: List[List[int]], ) -> List[TextGenerationResult]: batch_size = logits.shape[0] assert batch_size == len(requests) + # Convert to torch tensors if logits are in tvm ndarray if isinstance(logits, tvm.nd.NDArray): logits = torch.from_dlpack(logits) - # synchronization point for sampling tensors - # wait until all the tensors are loaded on GPU - torch.cuda.current_stream().wait_stream(copy_stream) - # Logit processing for constraint sampling e.g., JSON Mode for i, (sequence_id, request) in enumerate(zip(sequence_ids, requests)): if request.sampling_params.logits_processor is not None: @@ -137,6 +133,7 @@ def sample_from_logits( " or element < 0" ) logits = torch.from_dlpack(logits) + for i in range(batch_size): sequence_id = sequence_ids[i] logits_per_token = logits[i] @@ -145,15 +142,13 @@ def sample_from_logits( # NOTE: Rerun the preparation for simplicity. # Assume this code path is taken rarely and the recomputation overhead is # marginal. - with torch.cuda.stream(copy_stream): - new_sampling_metadata = SamplingState.from_sampling_params( - [sampling_param], - [past_decode_tokens_per_request], - torch_dtype, - torch_dev, - vocab_size, - ) - torch.cuda.current_stream().wait_stream(copy_stream) + new_sampling_metadata = SamplingState.from_sampling_params( + [sampling_param], + [past_decode_tokens_per_request], + torch_dtype, + torch_dev, + vocab_size, + ) maybe_sampling_output: Optional[SamplingOutput] = sample( torch.unsqueeze(logits_per_token, 0), new_sampling_metadata, @@ -164,6 +159,7 @@ def sample_from_logits( logprob_info = maybe_sampling_output.logprob_infos[0] # Valid sample request = requests[i] + if maybe_sampling_output is not None: outputs.extend( prepare_textgen_result( @@ -195,6 +191,7 @@ def prepare_inputs( all_decode_block_tables, sliding_window, is_prefill, + block_size, num_decode_query_tokens=1, for_vllm=False, ): @@ -292,8 +289,8 @@ def to_torch(arr, torch_dtype): if not is_prefill: max_block_table_len = ( - max_context_len + CacheManager.block_size - 1 - ) // CacheManager.block_size + max_context_len + block_size - 1 + ) // block_size padded_block_tables = _do_pad(block_tables, max_block_table_len, 0) block_tables = to_torch(padded_block_tables, torch.int) diff --git a/serve/mlc_serve/model/sampler.py b/serve/mlc_serve/model/sampler.py index 0562121b8f..cdbf795454 100644 --- a/serve/mlc_serve/model/sampler.py +++ b/serve/mlc_serve/model/sampler.py @@ -441,15 +441,15 @@ def adjust_logits(logits, sampling_metadata, vocab_size): @dataclass class SamplingOutput: - next_tokens: list[int] - logprob_infos: list[Optional[RawLogprobsInfo]] + next_tokens: np.ndarray + logprob_infos: List[Optional[RawLogprobsInfo]] def sample( logits: torch.Tensor, sampling_metadata: SamplingState, check_safety: bool = False, -) -> SamplingOutput: +) -> Optional[SamplingOutput]: def _is_safe_to_sample(prob_like): return ( torch.sum(torch.isnan(prob_like) | torch.isinf(prob_like) | (prob_like < 0)) @@ -466,21 +466,26 @@ def _is_safe_to_sample(prob_like): ) next_tokens = np.empty((batch_size,), dtype=np.int64) + if sampling_metadata.has_greedy: res_greedy = torch.argmax(logits[mask_greedy_t], -1) np_mask_greedy = mask_greedy_t.cpu().numpy() next_tokens[np_mask_greedy] = res_greedy.cpu().numpy() probs_random = None + if sampling_metadata.has_random: probs_random = torch.softmax(logits[mask_random_t], dim=-1) + if check_safety and not _is_safe_to_sample(probs_random): return None + res_random = torch.multinomial(probs_random, 1, True)[:, 0] np_mask_random = mask_random_t.cpu().numpy() next_tokens[np_mask_random] = res_random.cpu().numpy() logprob_infos: List[Optional[RawLogprobsInfo]] = [None] * batch_size + if sampling_metadata.has_logprob: # If everything is random sampling, save one extra softmax if not sampling_metadata.has_greedy: @@ -497,6 +502,9 @@ def _is_safe_to_sample(prob_like): mask = sampling_metadata.sampling_tensors.mask_top_logprob top_tokens = all_top_tokens[mask] top_logprobs = all_top_logprobs[mask] + + logprobs = logprobs.to("cpu") + for idx, batch_idx in enumerate(sampling_metadata.logprob_batch_indices): next_token = next_tokens[batch_idx] assert sampling_metadata.sampling_params[batch_idx].logprobs diff --git a/serve/mlc_serve/model/tokenizer.py b/serve/mlc_serve/model/tokenizer.py index 37f08f4ae1..a91cd7b525 100644 --- a/serve/mlc_serve/model/tokenizer.py +++ b/serve/mlc_serve/model/tokenizer.py @@ -56,7 +56,7 @@ def apply(self, messages: list[ChatMessage]) -> str: class HfTokenizerModule: def __init__(self, tokenizer_path: Path): hf_tokenizer = AutoTokenizer.from_pretrained( - tokenizer_path, trust_remote_code=True + tokenizer_path, trust_remote_code=True, revision=None, tokenizer_revision=None, ) self.tokenizer = Tokenizer(hf_tokenizer) diff --git a/serve/mlc_serve/model/torch_model.py b/serve/mlc_serve/model/torch_model.py index 743b5332dc..2bd208f52e 100644 --- a/serve/mlc_serve/model/torch_model.py +++ b/serve/mlc_serve/model/torch_model.py @@ -2,7 +2,7 @@ import os import tempfile import socket -from typing import List, Tuple +from typing import List, Tuple, Sequence from pathlib import Path import structlog @@ -45,8 +45,10 @@ PrefillRequest, TextGenerationResult, TextGenerator, - RequestsType, + RequestType, ) +from .sampler import SamplingState + LOG = structlog.stdlib.get_logger(__name__) @@ -94,6 +96,10 @@ def profile_memory_usage(pt_model, seq_lens, num_hidden_layers): slot_mapping_tensor = torch.cuda.LongTensor(slot_mapping) prompt_lens_tensor = torch.cuda.LongTensor(seq_lens) + peak_memory = torch.cuda.max_memory_allocated() + + print("peak memory before:", peak_memory / 1e9) + input_metadata = InputMetadata( is_prompt=True, slot_mapping=slot_mapping_tensor, @@ -122,6 +128,8 @@ def profile_memory_usage(pt_model, seq_lens, num_hidden_layers): torch.cuda.empty_cache() + print("peak memory after:", peak_memory / 1e9) + return peak_memory @@ -135,12 +143,15 @@ def profile_and_init_cache( num_hidden_layers = hf_config.num_hidden_layers head_size = hf_config.hidden_size // hf_config.num_attention_heads + block_size = 16 + if max_num_batched_tokens > 0: LOG.info("Running memory profiling.") seq_lens = [1] * max_num_batched_tokens used_memory_bytes = profile_memory_usage(pt_model, seq_lens, num_hidden_layers) num_blocks = get_num_cache_blocks( used_memory_bytes, + block_size, hf_config.num_hidden_layers, num_kv_heads, head_size, @@ -154,7 +165,7 @@ def profile_and_init_cache( head_size, hf_config.num_hidden_layers, num_kv_heads, - CacheManager.block_size, + block_size, num_blocks, ) @@ -183,7 +194,7 @@ def load_model(hf_config, model_path): def generate( - requests: RequestsType, + requests: Sequence[RequestType], cache_info: KVCacheInfo, pt_model, cache_blocks, @@ -198,20 +209,23 @@ def generate( all_token_ids = [] sequence_ids = [] prompt_lens = [] - num_sequences = [] + sampling_params = [] + past_decode_tokens = [] for request in requests: if isinstance(request, PrefillRequest): sequence_ids.append(get_prompt_sequence_id(request.request_id)) - num_sequences.append(request.num_sequence) prompt_lens.append(len(request.token_ids)) + past_decode_tokens.append([vocab_size]) elif isinstance(request, DecodeRequest): sequence_ids.append(request.sequence_id) prompt_lens.append(request.prompt_token_counts) + past_decode_tokens.append([vocab_size, *request.token_ids]) else: raise RuntimeError(f"Unsupported request type {request}") all_token_ids.append(request.token_ids) + sampling_params.append(request.sampling_params) selected_token_indices: List[int] = [] @@ -240,6 +254,7 @@ def generate( cache_info.decode_block_tables, sliding_window, is_prefill, + cache_info.block_size, for_vllm=True, ) @@ -300,7 +315,24 @@ def generate( torch.cuda.synchronize() torch.cuda.nvtx.range_pop() - return sample_from_logits(logits, sequence_ids, requests, vocab_size) + sampling_metadata = SamplingState.from_sampling_params( + sampling_params, + past_decode_tokens, + torch.float32, + "cuda", + vocab_size, + ) + + return sample_from_logits( + logits, + sequence_ids, + requests, + sampling_metadata, + vocab_size, + torch.float32, + "cuda", + past_decode_tokens, + ) class ModelRpcServer(rpyc.Service): @@ -356,7 +388,7 @@ def exposed_init_model( def exposed_generate( self, - requests: RequestsType, + requests: Sequence[RequestType], cache: KVCacheInfo, ) -> List[TextGenerationResult]: # TODO(masahi): Currently, obtaining inputs is the bottleneck. @@ -461,7 +493,7 @@ def __del__(self): def generate( self, - requests: RequestsType, + requests: Sequence[RequestType], cache: KVCacheInfo, ) -> List[TextGenerationResult]: def _generate(i): @@ -527,11 +559,12 @@ def __init__( self.sliding_window = hf_config.sliding_window def __del__(self): - del self.model_rpc + if self.model_rpc: + del self.model_rpc def generate( self, - requests: RequestsType, + requests: Sequence[RequestType], cache: KVCacheInfo, ) -> List[TextGenerationResult]: if self.model_rpc is None: @@ -582,6 +615,7 @@ def init_torch_model( cache_manager = CacheManager( model.num_blocks, + 16, hf_config.sliding_window, ) diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index eb5781483f..418f148fc6 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -33,6 +33,7 @@ ) from .sampler import SamplingState + LOG = structlog.stdlib.get_logger(__name__) @@ -106,6 +107,7 @@ def _prepare_inputs( all_decode_block_tables, sliding_window, is_prefill, + block_size, num_decode_query_tokens=1, ): ( @@ -123,6 +125,7 @@ def _prepare_inputs( all_decode_block_tables, sliding_window, is_prefill, + block_size, num_decode_query_tokens, ) @@ -159,15 +162,13 @@ def __init__( self.num_shards = config.num_shards # TODO(@sunggg): Find a better way - if config.model_type == "llama": - self.torch_dtype = torch.float32 - elif config.model_type == "mistral" or config.model_type == "mixtral": + if config.model_type in ["llama", "mistral", "mixtral"]: self.torch_dtype = torch.float32 else: assert 0, f"{config.model_type} is NOT supported yet" self._copy_stream: torch.cuda.Stream = torch.cuda.Stream() - self.torch_dev: str = "cuda" + self.torch_dev = "cuda" if self.sliding_window: self.block_sliding_window = self.sliding_window // block_size @@ -314,13 +315,17 @@ def generate_multi_query( torch.cuda.nvtx.range_pop() last_query_logits = torch.from_dlpack(logits)[last_query_offsets] + + # synchronization point for sampling tensors + # wait until all the tensors are loaded on GPU + torch.cuda.current_stream().wait_stream(self._copy_stream) + return sample_from_logits( last_query_logits, sequence_ids, requests, sampling_metadata, self.vocab_size, - self._copy_stream, self.torch_dtype, self.torch_dev, past_decode_tokens, @@ -404,10 +409,12 @@ def generate( cache.decode_block_tables, self.sliding_window, is_prefill, + cache.block_size, num_decode_query_tokens, ) input_shape = input_ids.shape + if self.disco_session: input_ids = copy_to_worker_0(self.disco_session, input_ids) positions = copy_to_worker_0(self.disco_session, positions) @@ -498,13 +505,16 @@ def generate( # TODO(masahi, yelite): Proper logic for handling multi-query logits (speculative decoding). return [] + # synchronization point for sampling tensors + # wait until all the tensors are loaded on GPU + torch.cuda.current_stream().wait_stream(self._copy_stream) + return sample_from_logits( logits, sequence_ids, requests, sampling_metadata, self.vocab_size, - self._copy_stream, self.torch_dtype, self.torch_dev, past_decode_tokens, From 52ad1ad45ef6aed2208317f98d3f8818b5e95c87 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 13 Feb 2024 19:41:07 +0000 Subject: [PATCH 56/73] wip test fix --- serve/tests/unittest/test_engine_with_samplers.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/serve/tests/unittest/test_engine_with_samplers.py b/serve/tests/unittest/test_engine_with_samplers.py index b589d967a0..9d712bf9b4 100644 --- a/serve/tests/unittest/test_engine_with_samplers.py +++ b/serve/tests/unittest/test_engine_with_samplers.py @@ -12,7 +12,7 @@ from mlc_serve.engine.sync_engine import SynchronousInferenceEngine from mlc_serve.model.base import get_model_artifact_config from mlc_serve.model.paged_cache_model import HfTokenizerModule, PagedCacheModelModule -from mlc_serve.utils import get_default_mlc_serve_argparser, postproc_mlc_serve_args +from mlc_serve.utils import get_default_mlc_serve_argparser, postproc_mlc_serve_args, create_mlc_engine import random from pydantic import BaseModel from typing import List @@ -436,9 +436,8 @@ def _test_json_mode( max_num_batched_tokens = 2048 # Test staging engines - staging_engine = create_engine( - args.model_artifact_path, max_num_batched_tokens, use_staging_engine=True - ) + staging_engine = create_mlc_engine(args) + _test_max_tokens(staging_engine) _test_ignore_eos(staging_engine) # TODO (@sunggg): There is something stateful. @@ -452,9 +451,8 @@ def _test_json_mode( staging_engine.stop() # Test sync engines - sync_engine = create_engine( - args.model_artifact_path, max_num_batched_tokens, use_staging_engine=False - ) + sync_engine = create_mlc_engine(args) + _test_max_tokens(sync_engine) _test_ignore_eos(sync_engine) _test_stop(sync_engine) From 61b680ed65e99e4150493524e23b1e153a2a3309 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 13 Feb 2024 19:57:05 +0000 Subject: [PATCH 57/73] fix --- serve/tests/unittest/test_engine_with_samplers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/serve/tests/unittest/test_engine_with_samplers.py b/serve/tests/unittest/test_engine_with_samplers.py index 9d712bf9b4..4b90002a35 100644 --- a/serve/tests/unittest/test_engine_with_samplers.py +++ b/serve/tests/unittest/test_engine_with_samplers.py @@ -433,7 +433,7 @@ def _test_json_mode( parser = get_default_mlc_serve_argparser("test engine with samplers") args = parser.parse_args() postproc_mlc_serve_args(args) - max_num_batched_tokens = 2048 + args.max_num_batched_tokens = 2048 # Test staging engines staging_engine = create_mlc_engine(args) @@ -451,6 +451,7 @@ def _test_json_mode( staging_engine.stop() # Test sync engines + args.use_staging_engine = False sync_engine = create_mlc_engine(args) _test_max_tokens(sync_engine) From f128fe6251f11c1fbcf7bfe5552cd7edf3d1aef4 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 13 Feb 2024 20:56:05 +0000 Subject: [PATCH 58/73] Properly verify sampling params in api handler --- serve/mlc_serve/api/handler.py | 17 ++++++++--------- serve/mlc_serve/engine/sampling_params.py | 4 ++-- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/serve/mlc_serve/api/handler.py b/serve/mlc_serve/api/handler.py index fc2245b769..51ec94ffd3 100644 --- a/serve/mlc_serve/api/handler.py +++ b/serve/mlc_serve/api/handler.py @@ -49,14 +49,11 @@ def create_error_response(status_code: HTTPStatus, message: str) -> JSONResponse def _get_sampling_params( request: ChatCompletionRequest, model_artifact_config: ModelArtifactConfig ) -> SamplingParams: - sampling_params = SamplingParams( - # These params came from vllm - # TODO(amnalyshe): should they be put into mlc-llm batch serving ChatCompletionRequest? - # best_of=request.best_of, - # top_k=request.top_k, - # ignore_eos=request.ignore_eos, - # use_beam_search=request.use_beam_search, - ) + sampling_params = SamplingParams() + assert model_artifact_config.vocab_size is not None + sampling_params.vocab_size = model_artifact_config.vocab_size + + # Initialize optional parameters if request.presence_penalty is not None: sampling_params.presence_penalty = request.presence_penalty if request.frequency_penalty is not None: @@ -74,7 +71,9 @@ def _get_sampling_params( sampling_params.logprobs = request.logprobs if request.response_format and request.response_format.type == "json_object": sampling_params.json_schema = request.response_format.response_schema - sampling_params.vocab_size = model_artifact_config.vocab_size + + sampling_params.verify() + return sampling_params diff --git a/serve/mlc_serve/engine/sampling_params.py b/serve/mlc_serve/engine/sampling_params.py index ae96df7709..43c7040e6f 100644 --- a/serve/mlc_serve/engine/sampling_params.py +++ b/serve/mlc_serve/engine/sampling_params.py @@ -80,7 +80,7 @@ def __post_init__(self): if self.logit_bias: self.logit_bias_index = list(self.logit_bias.keys()) self.logit_bias_value = list(self.logit_bias.values()) - self._verify_args() + self.verify() if self.temperature < _SAMPLING_EPS: # Zero temperature means greedy sampling. self.top_p = 1.0 @@ -89,7 +89,7 @@ def __post_init__(self): if not self.logprobs: self.top_logprobs = 0 - def _verify_args(self) -> None: + def verify(self) -> None: if not -2.0 <= self.presence_penalty <= 2.0: raise ValueError( "presence_penalty must be in [-2, 2], got " f"{self.presence_penalty}." From 568583a663b9ee33f7a5ef1483eb75c8692cddcb Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 13 Feb 2024 21:23:46 +0000 Subject: [PATCH 59/73] Create model artifact config before module initialization --- serve/mlc_serve/engine/staging_engine.py | 4 +- serve/mlc_serve/model/base.py | 3 ++ serve/mlc_serve/model/paged_cache_model.py | 6 +-- serve/mlc_serve/run.py | 49 ++--------------- serve/mlc_serve/utils.py | 5 ++ serve/tests/test_engine.py | 37 +++---------- .../unittest/test_engine_with_samplers.py | 52 +++---------------- serve/tests/unittest/test_sync_engine.py | 20 +------ 8 files changed, 30 insertions(+), 146 deletions(-) diff --git a/serve/mlc_serve/engine/staging_engine.py b/serve/mlc_serve/engine/staging_engine.py index fb7e3f6c03..03cebe963c 100644 --- a/serve/mlc_serve/engine/staging_engine.py +++ b/serve/mlc_serve/engine/staging_engine.py @@ -60,9 +60,7 @@ def __init__( # TODO(@team): This is a temporary solution to expose model config to higher API layer. # Follow-up with the proper solution - self.model_artifact_config = get_model_artifact_config( - model_module_loader_kwargs["model_artifact_path"] - ) + self.model_artifact_config = model_module_loader_kwargs["model_artifact_config"] self.tokenizer = tokenizer_module.tokenizer self.conversation_template = tokenizer_module.conversation_template diff --git a/serve/mlc_serve/model/base.py b/serve/mlc_serve/model/base.py index d420334de1..e47cf61267 100644 --- a/serve/mlc_serve/model/base.py +++ b/serve/mlc_serve/model/base.py @@ -4,6 +4,7 @@ import json import inspect + # TODO(@sunggg): consider transition to something like Pydantic @dataclass class ModelArtifactConfig: @@ -32,11 +33,13 @@ def _from_json(config_cls, json_obj: dict): } ) + class AssetNotFound(Exception): def __init__(self, asset_path): self.asset_path = asset_path super().__init__(f"{self.asset_path} should exist. Did you build with `--enable-batching`?") + def get_model_artifact_config(model_artifact_path): json_object = {"model_artifact_path": model_artifact_path} for config_file_name in [ diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index 34755a8c36..c8da1382c4 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -2,7 +2,7 @@ import structlog from typing import Sequence, List -from .base import get_model_artifact_config +from .base import ModelArtifactConfig from .paged_cache_manager import CacheManager from .tokenizer import HfTokenizerModule, ConversationTemplate, Tokenizer from .tvm_model import init_tvm_model @@ -18,6 +18,7 @@ TextGenerator, ) + LOG = structlog.stdlib.get_logger(__name__) @@ -78,9 +79,8 @@ def __init__( self, model_artifact_path: Path, engine_config: MLCServeEngineConfig, + model_artifact_config: ModelArtifactConfig ): - model_artifact_config = get_model_artifact_config(model_artifact_path) - # TODO(masahi): Make the model type configurable. model, cache_manager = init_tvm_model(model_artifact_config, engine_config) diff --git a/serve/mlc_serve/run.py b/serve/mlc_serve/run.py index 50947912cb..e7e1fbe4bd 100644 --- a/serve/mlc_serve/run.py +++ b/serve/mlc_serve/run.py @@ -1,14 +1,10 @@ -import argparse import tempfile import os import uvicorn from .api import create_app -from .engine import AsyncEngineConnector, get_engine_config -from .engine.staging_engine import StagingInferenceEngine -from .engine.sync_engine import SynchronousInferenceEngine -from .model.paged_cache_model import HfTokenizerModule, PagedCacheModelModule -from .utils import get_default_mlc_serve_argparser, postproc_mlc_serve_args +from .engine import AsyncEngineConnector +from .utils import get_default_mlc_serve_argparser, postproc_mlc_serve_args, create_mlc_engine def parse_args(): @@ -20,52 +16,13 @@ def parse_args(): return args -def create_engine( - args: argparse.Namespace, -): - """ - `model_artifact_path` has the following structure - |- compiled artifact (.so) - |- `build_config.json`: stores compile-time info, such as `num_shards` and `quantization`. - |- params/ : stores weights in mlc format and `ndarray-cache.json`. - | `ndarray-cache.json` is especially important for Disco. - |- model/ : stores info from hf model cards such as max context length and tokenizer - """ - # Set the engine config - engine_config = get_engine_config( - { - "use_staging_engine": args.use_staging_engine, - "max_num_batched_tokens": args.max_num_batched_tokens, - "min_decode_steps": args.min_decode_steps, - "max_decode_steps": args.max_decode_steps, - } - ) - - if args.use_staging_engine: - return StagingInferenceEngine( - tokenizer_module=HfTokenizerModule(args.model_artifact_path), - model_module_loader=PagedCacheModelModule, - model_module_loader_kwargs={ - "model_artifact_path": args.model_artifact_path, - "engine_config": engine_config, - }, - ) - else: - return SynchronousInferenceEngine( - PagedCacheModelModule( - model_artifact_path=args.model_artifact_path, - engine_config=engine_config, - ) - ) - - def run_server(): args = parse_args() with tempfile.TemporaryDirectory() as temp_dir: os.environ["PROMETHEUS_MULTIPROC_DIR"] = temp_dir - engine = create_engine(args) + engine = create_mlc_engine(args) connector = AsyncEngineConnector(engine) app = create_app(connector) diff --git a/serve/mlc_serve/utils.py b/serve/mlc_serve/utils.py index 10c616ac18..cbbca2652a 100644 --- a/serve/mlc_serve/utils.py +++ b/serve/mlc_serve/utils.py @@ -10,6 +10,7 @@ from mlc_serve.logging_utils import configure_logging from mlc_serve.engine.staging_engine import StagingInferenceEngine from mlc_serve.engine.sync_engine import SynchronousInferenceEngine +from mlc_serve.model.base import get_model_artifact_config from mlc_serve.model.paged_cache_model import HfTokenizerModule, PagedCacheModelModule @@ -55,6 +56,7 @@ def create_mlc_engine(args: argparse.Namespace) -> InferenceEngine: "max_decode_steps": args.max_decode_steps, } ) + model_artifact_config = get_model_artifact_config(args.model_artifact_path) engine: InferenceEngine @@ -65,6 +67,7 @@ def create_mlc_engine(args: argparse.Namespace) -> InferenceEngine: model_module_loader_kwargs={ "model_artifact_path": args.model_artifact_path, "engine_config": engine_config, + "model_artifact_config": model_artifact_config }, ) engine.start() @@ -73,6 +76,8 @@ def create_mlc_engine(args: argparse.Namespace) -> InferenceEngine: PagedCacheModelModule( model_artifact_path=args.model_artifact_path, engine_config=engine_config, + model_artifact_config=model_artifact_config ) ) + return engine diff --git a/serve/tests/test_engine.py b/serve/tests/test_engine.py index fa8fec34fd..0cb313da9c 100644 --- a/serve/tests/test_engine.py +++ b/serve/tests/test_engine.py @@ -8,41 +8,16 @@ DebugOptions, SamplingParams, StoppingCriteria, - get_engine_config, ) -from mlc_serve.engine.staging_engine import StagingInferenceEngine -from mlc_serve.engine.sync_engine import SynchronousInferenceEngine -from mlc_serve.model.paged_cache_model import HfTokenizerModule, PagedCacheModelModule -from mlc_serve.utils import get_default_mlc_serve_argparser, postproc_mlc_serve_args +from mlc_serve.utils import ( + get_default_mlc_serve_argparser, + postproc_mlc_serve_args, + create_mlc_engine, +) def _test(args: argparse.Namespace): - engine_config = get_engine_config( - { - "use_staging_engine": args.use_staging_engine, - "max_num_batched_tokens": args.max_num_batched_tokens, - "min_decode_steps": args.min_decode_steps, - "max_decode_steps": args.max_decode_steps, - } - ) - - if args.use_staging_engine: - engine = StagingInferenceEngine( - tokenizer_module=HfTokenizerModule(args.model_artifact_path), - model_module_loader=PagedCacheModelModule, - model_module_loader_kwargs={ - "model_artifact_path": args.model_artifact_path, - "engine_config": engine_config, - }, - ) - engine.start() - else: - engine = SynchronousInferenceEngine( - PagedCacheModelModule( - model_artifact_path=args.model_artifact_path, - engine_config=engine_config, - ) - ) + engine = create_mlc_engine(args) sampling_params_greedy = SamplingParams( temperature=0.0, diff --git a/serve/tests/unittest/test_engine_with_samplers.py b/serve/tests/unittest/test_engine_with_samplers.py index b589d967a0..18feed5a9d 100644 --- a/serve/tests/unittest/test_engine_with_samplers.py +++ b/serve/tests/unittest/test_engine_with_samplers.py @@ -6,51 +6,14 @@ SamplingParams, StoppingCriteria, FinishReason, - get_engine_config, ) -from mlc_serve.engine.staging_engine import StagingInferenceEngine -from mlc_serve.engine.sync_engine import SynchronousInferenceEngine from mlc_serve.model.base import get_model_artifact_config -from mlc_serve.model.paged_cache_model import HfTokenizerModule, PagedCacheModelModule -from mlc_serve.utils import get_default_mlc_serve_argparser, postproc_mlc_serve_args +from mlc_serve.utils import get_default_mlc_serve_argparser, postproc_mlc_serve_args, create_mlc_engine import random from pydantic import BaseModel from typing import List -def create_engine( - model_artifact_path, - max_num_batched_tokens, - use_staging_engine, -): - engine_config = get_engine_config( - { - "use_staging_engine": use_staging_engine, - "max_num_batched_tokens": max_num_batched_tokens, - # Use defaults for "min_decode_steps", "max_decode_steps" - } - ) - - if use_staging_engine: - engine = StagingInferenceEngine( - tokenizer_module=HfTokenizerModule(model_artifact_path), - model_module_loader=PagedCacheModelModule, - model_module_loader_kwargs={ - "model_artifact_path": model_artifact_path, - "engine_config": engine_config, - }, - ) - engine.start() - else: - engine = SynchronousInferenceEngine( - PagedCacheModelModule( - model_artifact_path=model_artifact_path, - engine_config=engine_config, - ) - ) - return engine - - def create_request( idx, prompt, @@ -433,12 +396,11 @@ def _test_json_mode( parser = get_default_mlc_serve_argparser("test engine with samplers") args = parser.parse_args() postproc_mlc_serve_args(args) - max_num_batched_tokens = 2048 + args.max_num_batched_tokens = 2048 # Test staging engines - staging_engine = create_engine( - args.model_artifact_path, max_num_batched_tokens, use_staging_engine=True - ) + staging_engine = create_mlc_engine(args) + _test_max_tokens(staging_engine) _test_ignore_eos(staging_engine) # TODO (@sunggg): There is something stateful. @@ -452,9 +414,9 @@ def _test_json_mode( staging_engine.stop() # Test sync engines - sync_engine = create_engine( - args.model_artifact_path, max_num_batched_tokens, use_staging_engine=False - ) + args.use_staging_engine = False + sync_engine = create_mlc_engine(args) + _test_max_tokens(sync_engine) _test_ignore_eos(sync_engine) _test_stop(sync_engine) diff --git a/serve/tests/unittest/test_sync_engine.py b/serve/tests/unittest/test_sync_engine.py index 9b70b10988..a5852058f4 100644 --- a/serve/tests/unittest/test_sync_engine.py +++ b/serve/tests/unittest/test_sync_engine.py @@ -1,5 +1,4 @@ -from typing import Optional, Union - +from typing import Optional from mlc_serve.engine import ( ChatMessage, FinishReason, @@ -8,24 +7,9 @@ RequestOutput, SamplingParams, StoppingCriteria, - get_engine_config -) -from mlc_serve.model.base import ModelArtifactConfig -from mlc_serve.engine.model_module import ( - DecodeRequest, - KVCache, - PrefillRequest, - SequenceId, - TextGenerationResult, ) - from mlc_serve.engine.sync_engine import SynchronousInferenceEngine -from mlc_serve.engine.staging_engine import StagingInferenceEngine - -from mlc_serve.model.dummy_model import ( - DummyModelModule, - DummyTokenizerModule, -) +from mlc_serve.model.dummy_model import DummyModelModule def create_messages(prompt) -> list[ChatMessage]: From 762012d48fd3ee7108f9c4d09f0d411c3660a883 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 13 Feb 2024 21:49:21 +0000 Subject: [PATCH 60/73] fix engine start --- serve/mlc_serve/run.py | 2 +- serve/mlc_serve/utils.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/serve/mlc_serve/run.py b/serve/mlc_serve/run.py index e7e1fbe4bd..35e06be6ae 100644 --- a/serve/mlc_serve/run.py +++ b/serve/mlc_serve/run.py @@ -22,7 +22,7 @@ def run_server(): with tempfile.TemporaryDirectory() as temp_dir: os.environ["PROMETHEUS_MULTIPROC_DIR"] = temp_dir - engine = create_mlc_engine(args) + engine = create_mlc_engine(args, start_engine=False) connector = AsyncEngineConnector(engine) app = create_app(connector) diff --git a/serve/mlc_serve/utils.py b/serve/mlc_serve/utils.py index cbbca2652a..77ce740b10 100644 --- a/serve/mlc_serve/utils.py +++ b/serve/mlc_serve/utils.py @@ -47,7 +47,7 @@ def postproc_mlc_serve_args(args): random.seed(args.seed) -def create_mlc_engine(args: argparse.Namespace) -> InferenceEngine: +def create_mlc_engine(args: argparse.Namespace, start_engine=True) -> InferenceEngine: engine_config = get_engine_config( { "use_staging_engine": args.use_staging_engine, @@ -70,7 +70,9 @@ def create_mlc_engine(args: argparse.Namespace) -> InferenceEngine: "model_artifact_config": model_artifact_config }, ) - engine.start() + + if start_engine: + engine.start() else: engine = SynchronousInferenceEngine( PagedCacheModelModule( From dc5fb6e365dedd548a1a0e990cbb3f7002807d57 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 14 Feb 2024 05:01:57 +0000 Subject: [PATCH 61/73] fix --- serve/mlc_serve/model/base.py | 17 ++++++++++++++++ serve/mlc_serve/model/torch_model.py | 30 +++------------------------- serve/mlc_serve/utils.py | 21 +++++++++++++++++-- 3 files changed, 39 insertions(+), 29 deletions(-) diff --git a/serve/mlc_serve/model/base.py b/serve/mlc_serve/model/base.py index e47cf61267..63daf7895c 100644 --- a/serve/mlc_serve/model/base.py +++ b/serve/mlc_serve/model/base.py @@ -1,5 +1,8 @@ +from transformers import AutoConfig + from dataclasses import dataclass from typing import Optional +from pathlib import Path import os import json import inspect @@ -57,3 +60,17 @@ def get_model_artifact_config(model_artifact_path): json_object["paged_kv_cache_type"] = "vllm" return ModelArtifactConfig._from_json(json_object) + + +def get_hf_config(model_path: Path) -> AutoConfig: + hf_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + + if not hasattr(hf_config, "num_key_value_heads") and hasattr( + hf_config, "num_attention_heads" + ): + hf_config.num_key_value_heads = hf_config.num_attention_heads + + if not hasattr(hf_config, "sliding_window"): + hf_config.sliding_window = None + + return hf_config diff --git a/serve/mlc_serve/model/torch_model.py b/serve/mlc_serve/model/torch_model.py index 1347cb2271..a6d99292a3 100644 --- a/serve/mlc_serve/model/torch_model.py +++ b/serve/mlc_serve/model/torch_model.py @@ -28,7 +28,7 @@ from concurrent.futures import ThreadPoolExecutor -from .base import ModelArtifactConfig +from .base import ModelArtifactConfig, get_hf_config from .paged_cache_manager import KVCacheInfo, CacheManager from .model_common import ( prepare_inputs, @@ -582,35 +582,12 @@ def generate( def init_torch_model( model_path: Path, engine_config: MLCServeEngineConfig -) -> Tuple[TextGenerator, CacheManager, ModelArtifactConfig]: - hf_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) - - if not hasattr(hf_config, "num_key_value_heads") and hasattr( - hf_config, "num_attention_heads" - ): - hf_config.num_key_value_heads = hf_config.num_attention_heads - - if not hasattr(hf_config, "sliding_window"): - hf_config.sliding_window = None +) -> Tuple[TextGenerator, CacheManager]: + hf_config = get_hf_config(model_path) if engine_config.num_shards is None: raise RuntimeError("num_shards needs to be specifed for PyTorch models.") - num_shards = engine_config.num_shards - - artifact_config = ModelArtifactConfig( - model_artifact_path=str(model_path), - num_shards=num_shards, - quantization=None, - max_context_length=hf_config.max_position_embeddings, # TODO, - vocab_size=hf_config.vocab_size, - sliding_window=hf_config.sliding_window, - num_key_value_heads=hf_config.num_key_value_heads // num_shards, - num_attention_heads=hf_config.num_attention_heads, - num_hidden_layers=hf_config.num_hidden_layers, - hidden_size=hf_config.hidden_size, - ) - model = Model(model_path, hf_config, engine_config) cache_manager = CacheManager( @@ -619,5 +596,4 @@ def init_torch_model( hf_config.sliding_window, ) - # TODO return model, cache_manager diff --git a/serve/mlc_serve/utils.py b/serve/mlc_serve/utils.py index fd4a75d79c..a0de962c02 100644 --- a/serve/mlc_serve/utils.py +++ b/serve/mlc_serve/utils.py @@ -10,7 +10,7 @@ from mlc_serve.logging_utils import configure_logging from mlc_serve.engine.staging_engine import StagingInferenceEngine from mlc_serve.engine.sync_engine import SynchronousInferenceEngine -from mlc_serve.model.base import get_model_artifact_config +from mlc_serve.model.base import get_model_artifact_config, ModelArtifactConfig, get_hf_config from mlc_serve.model.paged_cache_model import HfTokenizerModule, PagedCacheModelModule @@ -72,7 +72,24 @@ def create_mlc_engine(args: argparse.Namespace, start_engine=True) -> InferenceE "num_shards": num_shards, } ) - model_artifact_config = get_model_artifact_config(args.model_artifact_path) + + if model_type == "tvm": + model_artifact_config = get_model_artifact_config(args.model_artifact_path) + else: + hf_config = get_hf_config(args.model_artifact_path) + + model_artifact_config = ModelArtifactConfig( + model_artifact_path=str(args.model_artifact_path), + num_shards=num_shards, + quantization=None, + max_context_length=hf_config.max_position_embeddings, + vocab_size=hf_config.vocab_size, + sliding_window=hf_config.sliding_window, + num_key_value_heads=hf_config.num_key_value_heads // num_shards, + num_attention_heads=hf_config.num_attention_heads, + num_hidden_layers=hf_config.num_hidden_layers, + hidden_size=hf_config.hidden_size, + ) engine: InferenceEngine From ebe0b4e030e423edbed4baaa7410192ce7b291f1 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 14 Feb 2024 05:27:49 +0000 Subject: [PATCH 62/73] black --- serve/mlc_serve/engine/engine_common.py | 2 -- serve/mlc_serve/model/tokenizer.py | 6 ++++-- serve/mlc_serve/model/torch_model.py | 10 +++++----- serve/mlc_serve/utils.py | 17 ++++++++++++----- 4 files changed, 21 insertions(+), 14 deletions(-) diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index 3aedbd7e03..fb387e9b4c 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -118,8 +118,6 @@ def detokenize_incrementally( prefix_begin_offset = generation_sequence.prefix_begin_offset prefix_end_offset = generation_sequence.prefix_end_offset - # assert tokenizer.is_fast - prefix_text = tokenizer.convert_tokens_to_string( output_tokens[prefix_begin_offset:prefix_end_offset] ) diff --git a/serve/mlc_serve/model/tokenizer.py b/serve/mlc_serve/model/tokenizer.py index cbcd3d6788..a0d6ec1789 100644 --- a/serve/mlc_serve/model/tokenizer.py +++ b/serve/mlc_serve/model/tokenizer.py @@ -48,8 +48,10 @@ def apply(self, messages: list[ChatMessage]) -> str: class HfTokenizerModule: def __init__(self, tokenizer_path: Path): hf_tokenizer = AutoTokenizer.from_pretrained( - tokenizer_path, trust_remote_code=True, - revision=None, tokenizer_revision=None, + tokenizer_path, + trust_remote_code=True, + revision=None, + tokenizer_revision=None, ) self.tokenizer = Tokenizer(hf_tokenizer) self.conversation_template = ConversationTemplate(hf_tokenizer) diff --git a/serve/mlc_serve/model/torch_model.py b/serve/mlc_serve/model/torch_model.py index a6d99292a3..c5c788ec82 100644 --- a/serve/mlc_serve/model/torch_model.py +++ b/serve/mlc_serve/model/torch_model.py @@ -14,6 +14,8 @@ from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.qwen import QWenLMHeadModel from vllm.model_executor.models.phi import PhiForCausalLM +from vllm.model_executor.models.mistral import MistralForCausalLM +from vllm.model_executor.models.mixtral import MixtralForCausalLM from vllm.model_executor import InputMetadata, SamplingMetadata from vllm.model_executor.parallel_utils.parallel_state import ( initialize_model_parallel, @@ -28,7 +30,7 @@ from concurrent.futures import ThreadPoolExecutor -from .base import ModelArtifactConfig, get_hf_config +from .base import get_hf_config from .paged_cache_manager import KVCacheInfo, CacheManager from .model_common import ( prepare_inputs, @@ -98,8 +100,6 @@ def profile_memory_usage(pt_model, seq_lens, num_hidden_layers): peak_memory = torch.cuda.max_memory_allocated() - print("peak memory before:", peak_memory / 1e9) - input_metadata = InputMetadata( is_prompt=True, slot_mapping=slot_mapping_tensor, @@ -128,8 +128,6 @@ def profile_memory_usage(pt_model, seq_lens, num_hidden_layers): torch.cuda.empty_cache() - print("peak memory after:", peak_memory / 1e9) - return peak_memory @@ -179,6 +177,8 @@ def load_model(hf_config, model_path): "LlamaForCausalLM": LlamaForCausalLM, "PhiForCausalLM": PhiForCausalLM, "QWenLMHeadModel": QWenLMHeadModel, # requires tiktoken package + "MistralForCausalLM": MistralForCausalLM, + "MixtralForCausalLM": MixtralForCausalLM, } arch = hf_config.architectures[0] diff --git a/serve/mlc_serve/utils.py b/serve/mlc_serve/utils.py index a0de962c02..578f5893b3 100644 --- a/serve/mlc_serve/utils.py +++ b/serve/mlc_serve/utils.py @@ -10,7 +10,11 @@ from mlc_serve.logging_utils import configure_logging from mlc_serve.engine.staging_engine import StagingInferenceEngine from mlc_serve.engine.sync_engine import SynchronousInferenceEngine -from mlc_serve.model.base import get_model_artifact_config, ModelArtifactConfig, get_hf_config +from mlc_serve.model.base import ( + get_model_artifact_config, + ModelArtifactConfig, + get_hf_config, +) from mlc_serve.model.paged_cache_model import HfTokenizerModule, PagedCacheModelModule @@ -30,7 +34,7 @@ def get_default_mlc_serve_argparser(description="", allow_override=False): parser.add_argument("--max-decode-steps", type=int, default=56) parser.add_argument("--debug-logging", action="store_true") parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--num-shards", type=int, default=1) # Needed for PT models + parser.add_argument("--num-shards", type=int, default=1) # Needed for PT models return parser @@ -56,10 +60,13 @@ def create_mlc_engine(args: argparse.Namespace, start_engine=True) -> InferenceE model_type = "torch" num_shards = args.num_shards - assert num_shards is not None, "--num-shards needs to be provided for PT models." + assert ( + num_shards is not None + ), "--num-shards needs to be provided for PT models." if num_shards > 1: import torch + torch.multiprocessing.set_start_method("spawn") engine_config = get_engine_config( @@ -105,7 +112,7 @@ def create_mlc_engine(args: argparse.Namespace, start_engine=True) -> InferenceE model_module_loader_kwargs={ "model_artifact_path": args.model_artifact_path, "engine_config": engine_config, - "model_artifact_config": model_artifact_config + "model_artifact_config": model_artifact_config, }, ) @@ -116,7 +123,7 @@ def create_mlc_engine(args: argparse.Namespace, start_engine=True) -> InferenceE PagedCacheModelModule( model_artifact_path=args.model_artifact_path, engine_config=engine_config, - model_artifact_config=model_artifact_config + model_artifact_config=model_artifact_config, ) ) From 4b2de70ce283c50db534ace5bafd8535abf4a5ef Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 14 Feb 2024 05:49:22 +0000 Subject: [PATCH 63/73] properly handle import failure --- serve/mlc_serve/model/torch_model.py | 176 ++++++++++++++------------- serve/pyproject.toml | 3 + 2 files changed, 97 insertions(+), 82 deletions(-) diff --git a/serve/mlc_serve/model/torch_model.py b/serve/mlc_serve/model/torch_model.py index c5c788ec82..5128cb16f1 100644 --- a/serve/mlc_serve/model/torch_model.py +++ b/serve/mlc_serve/model/torch_model.py @@ -4,31 +4,36 @@ import socket from typing import List, Tuple, Sequence from pathlib import Path +from concurrent.futures import ThreadPoolExecutor import structlog + import torch +import torch.multiprocessing as multiprocessing from transformers import AutoConfig -from vllm.model_executor.layers.sampler import get_logits -from vllm.model_executor.models.llama import LlamaForCausalLM -from vllm.model_executor.models.qwen import QWenLMHeadModel -from vllm.model_executor.models.phi import PhiForCausalLM -from vllm.model_executor.models.mistral import MistralForCausalLM -from vllm.model_executor.models.mixtral import MixtralForCausalLM -from vllm.model_executor import InputMetadata, SamplingMetadata -from vllm.model_executor.parallel_utils.parallel_state import ( - initialize_model_parallel, -) +try: + from vllm.model_executor.layers.sampler import get_logits + from vllm.model_executor.models.llama import LlamaForCausalLM + from vllm.model_executor.models.qwen import QWenLMHeadModel + from vllm.model_executor.models.phi import PhiForCausalLM + from vllm.model_executor.models.mistral import MistralForCausalLM + from vllm.model_executor.models.mixtral import MixtralForCausalLM + from vllm.model_executor import InputMetadata, SamplingMetadata + from vllm.model_executor.parallel_utils.parallel_state import ( + initialize_model_parallel, + ) + import rpyc + from rpyc.utils.classic import obtain + from rpyc.utils.server import ThreadedServer + from rpyc.utils.factory import unix_connect -import torch.multiprocessing as multiprocessing + support_torch_model = True -import rpyc -from rpyc.utils.classic import obtain -from rpyc.utils.server import ThreadedServer -from rpyc.utils.factory import unix_connect +except ImportError: + support_torch_model = False -from concurrent.futures import ThreadPoolExecutor from .base import get_hf_config from .paged_cache_manager import KVCacheInfo, CacheManager @@ -335,77 +340,79 @@ def generate( ) -class ModelRpcServer(rpyc.Service): - def exposed_init_model( - self, - tp_rank: int, - num_shards: int, - model_path: Path, - hf_config: AutoConfig, - engine_config: MLCServeEngineConfig, - master_port: int, - ) -> int: - hf_config = obtain(hf_config) - engine_config = obtain(engine_config) - model_path = obtain(model_path) - - self.vocab_size = hf_config.vocab_size - self.sliding_window = hf_config.sliding_window - - # torch.distributed.all_reduce does not free the input tensor until - # the synchronization point. This causes the memory usage to grow - # as the number of all_reduce calls increases. This env var disables - # this behavior. - # Related issue: - # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 - os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" - - torch.cuda.set_device(tp_rank) - - os.environ["MASTER_ADDR"] = str("127.0.0.1") - os.environ["MASTER_PORT"] = str(master_port) +if support_torch_model: + class ModelRpcServer(rpyc.Service): + def exposed_init_model( + self, + tp_rank: int, + num_shards: int, + model_path: Path, + hf_config: AutoConfig, + engine_config: MLCServeEngineConfig, + master_port: int, + ) -> int: + hf_config = obtain(hf_config) + engine_config = obtain(engine_config) + model_path = obtain(model_path) + + self.vocab_size = hf_config.vocab_size + self.sliding_window = hf_config.sliding_window + + # This was taken from vLLM + # torch.distributed.all_reduce does not free the input tensor until + # the synchronization point. This causes the memory usage to grow + # as the number of all_reduce calls increases. This env var disables + # this behavior. + # Related issue: + # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + + torch.cuda.set_device(tp_rank) + + os.environ["MASTER_ADDR"] = str("127.0.0.1") + os.environ["MASTER_PORT"] = str(master_port) - torch.distributed.init_process_group( - backend="nccl", - world_size=num_shards, - rank=tp_rank, - ) - initialize_model_parallel(num_shards) - - # A small all_reduce for warmup. - torch.distributed.all_reduce(torch.zeros(1).cuda()) + torch.distributed.init_process_group( + backend="nccl", + world_size=num_shards, + rank=tp_rank, + ) + initialize_model_parallel(num_shards) - self.pt_model = load_model(hf_config, model_path) + # A small all_reduce for warmup. + torch.distributed.all_reduce(torch.zeros(1).cuda()) - self.cache_blocks, num_blocks = profile_and_init_cache( - self.pt_model, - hf_config, - num_shards, - engine_config.max_num_batched_tokens, - ) + self.pt_model = load_model(hf_config, model_path) - return num_blocks + self.cache_blocks, num_blocks = profile_and_init_cache( + self.pt_model, + hf_config, + num_shards, + engine_config.max_num_batched_tokens, + ) - def exposed_generate( - self, - requests: Sequence[RequestType], - cache: KVCacheInfo, - ) -> List[TextGenerationResult]: - # TODO(masahi): Currently, obtaining inputs is the bottleneck. - # We should switch to the architecture used by Disco and vLLM as of - # https://github.com/vllm-project/vllm/pull/2221 - torch.cuda.nvtx.range_push(f"Obtain input") - requests = obtain(requests) - cache = obtain(cache) - torch.cuda.nvtx.range_pop() - return generate( - requests, - cache, - self.pt_model, - self.cache_blocks, - self.sliding_window, - self.vocab_size, - ) + return num_blocks + + def exposed_generate( + self, + requests: Sequence[RequestType], + cache: KVCacheInfo, + ) -> List[TextGenerationResult]: + # TODO(masahi): Currently, obtaining inputs is the bottleneck. + # We should switch to the architecture used by Disco and vLLM as of + # https://github.com/vllm-project/vllm/pull/2221 + torch.cuda.nvtx.range_push(f"Obtain input") + requests = obtain(requests) + cache = obtain(cache) + torch.cuda.nvtx.range_pop() + return generate( + requests, + cache, + self.pt_model, + self.cache_blocks, + self.sliding_window, + self.vocab_size, + ) def _init_service(socket_path): @@ -583,6 +590,11 @@ def generate( def init_torch_model( model_path: Path, engine_config: MLCServeEngineConfig ) -> Tuple[TextGenerator, CacheManager]: + if not support_torch_model: + raise RuntimeError("Running PyTorch models requires vLLM from " + "https://github.com/octoml/vllm/tree/for-mlc-serve installed. " + "Furthermore, rpyc is needed for multi-gpu support.") + hf_config = get_hf_config(model_path) if engine_config.num_shards is None: diff --git a/serve/pyproject.toml b/serve/pyproject.toml index f79af292c1..9daef2c049 100644 --- a/serve/pyproject.toml +++ b/serve/pyproject.toml @@ -10,6 +10,9 @@ fastapi = ">=0.103.1" pydantic = ">=1.8.0" prometheus-client = ">=0.18.0" outlines = "0.0.23" +rpyc = "5.3.1" +# TODO vllm dep +# https://github.com/octoml/vllm/tree/for-mlc-serve [tool.poetry.group.dev.dependencies] pytest = "^7.4.2" From f09d458f1bdfcf5dbdd02e44fd5dd007dd7ecd87 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 14 Feb 2024 06:05:50 +0000 Subject: [PATCH 64/73] add titoken dep --- serve/pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/serve/pyproject.toml b/serve/pyproject.toml index 9daef2c049..b75d16aa44 100644 --- a/serve/pyproject.toml +++ b/serve/pyproject.toml @@ -10,7 +10,9 @@ fastapi = ">=0.103.1" pydantic = ">=1.8.0" prometheus-client = ">=0.18.0" outlines = "0.0.23" +# PT model dep rpyc = "5.3.1" +tiktoken = "0.5.2" # for Qwen # TODO vllm dep # https://github.com/octoml/vllm/tree/for-mlc-serve From c9ac5bafaf3fa0e0cf4b23e98adf891a59c05832 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 14 Feb 2024 18:59:36 +0000 Subject: [PATCH 65/73] revert logprob change --- serve/mlc_serve/model/sampler.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/serve/mlc_serve/model/sampler.py b/serve/mlc_serve/model/sampler.py index cdbf795454..0917b7ae6a 100644 --- a/serve/mlc_serve/model/sampler.py +++ b/serve/mlc_serve/model/sampler.py @@ -503,8 +503,6 @@ def _is_safe_to_sample(prob_like): top_tokens = all_top_tokens[mask] top_logprobs = all_top_logprobs[mask] - logprobs = logprobs.to("cpu") - for idx, batch_idx in enumerate(sampling_metadata.logprob_batch_indices): next_token = next_tokens[batch_idx] assert sampling_metadata.sampling_params[batch_idx].logprobs From f1cf274d56aa941f83d260b0f3850c074749035d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 14 Feb 2024 19:04:14 +0000 Subject: [PATCH 66/73] restored tokenizer.is_fast assert but commented out --- serve/mlc_serve/engine/engine_common.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index fb387e9b4c..824c3c495c 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -118,6 +118,9 @@ def detokenize_incrementally( prefix_begin_offset = generation_sequence.prefix_begin_offset prefix_end_offset = generation_sequence.prefix_end_offset + # Some models, such as Qwen, uses a tokenizer from titoken package which fails this assert. + # assert tokenizer.is_fast + prefix_text = tokenizer.convert_tokens_to_string( output_tokens[prefix_begin_offset:prefix_end_offset] ) From 1336fb8195b81feaa2ee25f81f59bfce19b9b5bf Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 15 Feb 2024 19:43:48 +0000 Subject: [PATCH 67/73] fix vocab siz --- serve/benchmarks/benchmark_throughput.py | 1 + 1 file changed, 1 insertion(+) diff --git a/serve/benchmarks/benchmark_throughput.py b/serve/benchmarks/benchmark_throughput.py index 3cb6958e47..d3ab1587c9 100644 --- a/serve/benchmarks/benchmark_throughput.py +++ b/serve/benchmarks/benchmark_throughput.py @@ -142,6 +142,7 @@ def run_mlc(engine, requests, args) -> float: logprobs=args.sampling_setting["logprobs"], top_logprobs=args.sampling_setting["top_logprobs"], json_schema=args.sampling_setting["json_schema"], + vocab_size=engine.model_artifact_config.vocab_size, ), stopping_criteria=StoppingCriteria( max_tokens=args.num_output_tokens, stop_sequences=None From 6186ef2d9b1d6020a98e6a361ceb1db1f49d6383 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 15 Feb 2024 21:44:54 +0000 Subject: [PATCH 68/73] properly account for logits storage in memory profiling --- serve/mlc_serve/model/torch_model.py | 46 +++++++++++++++++++++++----- 1 file changed, 39 insertions(+), 7 deletions(-) diff --git a/serve/mlc_serve/model/torch_model.py b/serve/mlc_serve/model/torch_model.py index 5128cb16f1..1925e6f3f7 100644 --- a/serve/mlc_serve/model/torch_model.py +++ b/serve/mlc_serve/model/torch_model.py @@ -83,7 +83,7 @@ def init_cache_blocks(head_size, num_layers, num_heads, block_size, num_gpu_bloc return gpu_cache -def profile_memory_usage(pt_model, seq_lens, num_hidden_layers): +def profile_memory_usage(pt_model, seq_lens, num_hidden_layers, vocab_size): torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() @@ -98,12 +98,21 @@ def profile_memory_usage(pt_model, seq_lens, num_hidden_layers): input_positions.append(list(range(seq_len))) slot_mapping.append([0] * seq_len) + selected_token_indices: List[int] = [] + + max_prompt_len = max(seq_lens) + seq_start = 0 + + for prompt_len in seq_lens: + selected_token_indices.append(seq_start + prompt_len - 1) + seq_start += max_prompt_len + input_ids = torch.cuda.LongTensor(input_tokens) positions = torch.cuda.LongTensor(input_positions) slot_mapping_tensor = torch.cuda.LongTensor(slot_mapping) prompt_lens_tensor = torch.cuda.LongTensor(seq_lens) - peak_memory = torch.cuda.max_memory_allocated() + peak_memory_before = torch.cuda.max_memory_allocated() input_metadata = InputMetadata( is_prompt=True, @@ -117,19 +126,37 @@ def profile_memory_usage(pt_model, seq_lens, num_hidden_layers): use_cuda_graph=False, ) + sampling_metadata = SamplingMetadata( + seq_groups=None, + seq_data=None, + prompt_lens=seq_lens, + selected_token_indices=torch.tensor( + selected_token_indices, dtype=torch.long, device="cuda" + ), + categorized_sample_indices=None, + ) + kv_caches = [(None, None)] * num_hidden_layers with torch.no_grad(): - pt_model.forward( + hidden_states = pt_model.forward( input_ids, positions, kv_caches, input_metadata, ) + _ = get_logits( + pt_model.lm_head.weight, + hidden_states, + sampling_metadata, + vocab_size, + ) + torch.cuda.synchronize() peak_memory = torch.cuda.max_memory_allocated() + print(f"peak memory during profling: {(peak_memory - peak_memory_before) / 1e9} GB") torch.cuda.empty_cache() @@ -151,7 +178,9 @@ def profile_and_init_cache( if max_num_batched_tokens > 0: LOG.info("Running memory profiling.") seq_lens = [1] * max_num_batched_tokens - used_memory_bytes = profile_memory_usage(pt_model, seq_lens, num_hidden_layers) + used_memory_bytes = profile_memory_usage( + pt_model, seq_lens, num_hidden_layers, hf_config.vocab_size + ) num_blocks = get_num_cache_blocks( used_memory_bytes, block_size, @@ -341,6 +370,7 @@ def generate( if support_torch_model: + class ModelRpcServer(rpyc.Service): def exposed_init_model( self, @@ -591,9 +621,11 @@ def init_torch_model( model_path: Path, engine_config: MLCServeEngineConfig ) -> Tuple[TextGenerator, CacheManager]: if not support_torch_model: - raise RuntimeError("Running PyTorch models requires vLLM from " - "https://github.com/octoml/vllm/tree/for-mlc-serve installed. " - "Furthermore, rpyc is needed for multi-gpu support.") + raise RuntimeError( + "Running PyTorch models requires vLLM from " + "https://github.com/octoml/vllm/tree/for-mlc-serve installed. " + "Furthermore, rpyc is needed for multi-gpu support." + ) hf_config = get_hf_config(model_path) From aa4d477624c0b8769439742830779a4c432e667f Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 20 Feb 2024 20:01:12 +0000 Subject: [PATCH 69/73] merge fix --- serve/mlc_serve/model/torch_model.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/serve/mlc_serve/model/torch_model.py b/serve/mlc_serve/model/torch_model.py index 1925e6f3f7..b266d7c456 100644 --- a/serve/mlc_serve/model/torch_model.py +++ b/serve/mlc_serve/model/torch_model.py @@ -245,6 +245,7 @@ def generate( prompt_lens = [] sampling_params = [] past_decode_tokens = [] + prompt_masks = [] for request in requests: if isinstance(request, PrefillRequest): @@ -260,6 +261,7 @@ def generate( all_token_ids.append(request.token_ids) sampling_params.append(request.sampling_params) + prompt_masks.append(request.prompt_mask) selected_token_indices: List[int] = [] @@ -352,6 +354,7 @@ def generate( sampling_metadata = SamplingState.from_sampling_params( sampling_params, past_decode_tokens, + prompt_masks, torch.float32, "cuda", vocab_size, @@ -366,6 +369,7 @@ def generate( torch.float32, "cuda", past_decode_tokens, + prompt_masks, ) From 8bb96edebfb806d921396ba5ad2bbb211479a7ba Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 20 Feb 2024 20:15:54 +0000 Subject: [PATCH 70/73] validate num_shards in engine creation --- serve/mlc_serve/engine/base.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/serve/mlc_serve/engine/base.py b/serve/mlc_serve/engine/base.py index 835874d72f..de1fba57c0 100644 --- a/serve/mlc_serve/engine/base.py +++ b/serve/mlc_serve/engine/base.py @@ -59,6 +59,11 @@ def get_engine_config(dict_config): assert (engine_config.min_decode_steps > 0) and (engine_config.max_decode_steps > 0) assert engine_config.max_decode_steps > engine_config.min_decode_steps + if engine_config.model_type == "torch": + assert ( + engine_config.num_shards is not None + ), "num_shards in MLCServeEngineConfig needs to be provided for PT models." + return engine_config From cf0813de29abaf5e4784121fb2c5fccd66fc9585 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 20 Feb 2024 20:35:04 +0000 Subject: [PATCH 71/73] replace print with structlog --- serve/mlc_serve/model/torch_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/serve/mlc_serve/model/torch_model.py b/serve/mlc_serve/model/torch_model.py index b266d7c456..04381cd3b4 100644 --- a/serve/mlc_serve/model/torch_model.py +++ b/serve/mlc_serve/model/torch_model.py @@ -156,7 +156,7 @@ def profile_memory_usage(pt_model, seq_lens, num_hidden_layers, vocab_size): torch.cuda.synchronize() peak_memory = torch.cuda.max_memory_allocated() - print(f"peak memory during profling: {(peak_memory - peak_memory_before) / 1e9} GB") + LOG.info(f"peak memory during profling: {(peak_memory - peak_memory_before) / 1e9} GB") torch.cuda.empty_cache() From f7168516bce8a63a6406ce8c9c98f71b9132ba76 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 20 Feb 2024 20:44:35 +0000 Subject: [PATCH 72/73] add peak memory log for tvm as well --- serve/mlc_serve/model/tvm_model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index 1a85ace6c5..67fa7eefbd 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -253,6 +253,8 @@ def profile_memory_usage(self, seq_lens): vm_alloc_after = self.get_used_memory() + LOG.info(f"peak memory during profling: {(vm_alloc_after - vm_alloc_before) / 1e9} GB") + return self.get_param_nbytes() + (vm_alloc_after - vm_alloc_before) def generate_multi_query( From 992b1a0523b95b920a49e2094e265fc48f95fdb0 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 21 Feb 2024 20:31:31 +0000 Subject: [PATCH 73/73] add tokenizer.is_fast warning on creation --- serve/mlc_serve/engine/engine_common.py | 3 --- serve/mlc_serve/model/tokenizer.py | 8 ++++++++ 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index 9dbe8d3800..a381d72f69 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -134,9 +134,6 @@ def detokenize_incrementally( prefix_begin_offset = generation_sequence.prefix_begin_offset prefix_end_offset = generation_sequence.prefix_end_offset - # Some models, such as Qwen, uses a tokenizer from titoken package which fails this assert. - # assert tokenizer.is_fast - prefix_text = tokenizer.convert_tokens_to_string( output_tokens[prefix_begin_offset:prefix_end_offset] ) diff --git a/serve/mlc_serve/model/tokenizer.py b/serve/mlc_serve/model/tokenizer.py index a0d6ec1789..3e4a9a0551 100644 --- a/serve/mlc_serve/model/tokenizer.py +++ b/serve/mlc_serve/model/tokenizer.py @@ -1,8 +1,11 @@ +import structlog from typing import List from transformers import AutoTokenizer from ..engine import ChatMessage from pathlib import Path +LOG = structlog.stdlib.get_logger(__name__) + class Tokenizer: def __init__(self, hf_tokenizer, skip_special_tokens=True): @@ -55,3 +58,8 @@ def __init__(self, tokenizer_path: Path): ) self.tokenizer = Tokenizer(hf_tokenizer) self.conversation_template = ConversationTemplate(hf_tokenizer) + + if not self.tokenizer.is_fast: + LOG.warn("tokenizer.is_fast is false. Some models using an external tokenizer package, " + "such as QWen, might hit this condition but that does not imply that their " + "tokenizers are slow.")