Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate Flash-Decoding into engine #181

Merged
merged 30 commits into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
30e57a0
test stub
masahi Jan 29, 2024
6f3429a
wip
masahi Jan 29, 2024
97a4366
wip
masahi Jan 29, 2024
7279cb6
wip
masahi Jan 29, 2024
7348f0e
compiled
masahi Jan 29, 2024
b692376
wip
masahi Jan 29, 2024
1df6cac
fix
masahi Jan 29, 2024
8c8872c
fix
masahi Jan 29, 2024
6a8272f
wip, decode with flash decoding works
masahi Jan 29, 2024
487129c
all work
masahi Jan 29, 2024
8114197
add paged_kv_cache_type option
masahi Jan 30, 2024
2d6c81b
read kv_type from artifact
masahi Jan 30, 2024
67353b2
black
masahi Jan 30, 2024
b9e41e1
refactor attention backend
masahi Jan 30, 2024
910e31b
minor clean up
masahi Jan 30, 2024
ab910f2
Integrate flash-decoding into mlc-serve
masahi Jan 30, 2024
4c8a75b
remove --use-vllm-attention
masahi Jan 30, 2024
00e1d09
wip decode_multi_query integration
masahi Jan 30, 2024
5fbf671
temp handling for multi-query logits
masahi Jan 30, 2024
2eff7b0
remove tmp support for multi-query decode
masahi Jan 30, 2024
c51c2a4
Merge branch 'batch-serving' into flash-decoding-engine
masahi Jan 30, 2024
d7704e2
Merge branch 'batch-serving' into flash-decoding-engine
masahi Jan 31, 2024
404b305
typo
masahi Jan 31, 2024
d87506c
Merge branch 'batch-serving' into flash-decoding-engine
masahi Jan 31, 2024
99af3fb
Merge branch 'batch-serving' into flash-decoding-engine
masahi Feb 1, 2024
a003965
use block size 128 or 64 when possible
masahi Feb 1, 2024
780e244
Merge branch 'batch-serving' into flash-decoding-engine
masahi Feb 3, 2024
56d7a23
remove unused var
masahi Feb 3, 2024
1b976dc
Merge branch 'batch-serving' into flash-decoding-engine
masahi Feb 8, 2024
a028c7d
merge fix
masahi Feb 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 24 additions & 29 deletions mlc_llm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import json
import os
import pickle
import shutil
from dataclasses import asdict, dataclass, field, fields
from typing import Any, Dict, Optional

Expand Down Expand Up @@ -157,8 +156,9 @@ class BuildArgs:
pdb: bool
If set, drop into a pdb debugger on error.

use_vllm_attention: bool
Use vLLM paged KV cache and attention kernel, only relevant when enable_batching=True.
paged_kv_cache_type: str
The type of paged KV cache to use, only relevant when enable_batching=True.
Currently "vllm" and "flash-decoding" are supported.
"""
model: str = field(
default="auto",
Expand Down Expand Up @@ -392,19 +392,8 @@ class BuildArgs:
"action": "store_true",
},
)
# TODO(masahi): Remove the use of this option with paged_kv_cache_type
use_vllm_attention: bool = field(
default=False,
metadata={
"help": (
"Use vLLM paged KV cache and attention kernel, only relevant when "
"enable_batching=True."
),
"action": "store_true",
},
)
paged_kv_cache_type: str = field(
default="vllm",
default="",
metadata={"help": "The type of paged KV cache, either vllm or flash-decoding"},
)

Expand Down Expand Up @@ -462,12 +451,18 @@ def _parse_args(parsed) -> argparse.Namespace:
utils.parse_target(parsed)
utils.argparse_postproc_common(parsed)

if parsed.use_vllm_attention:
assert parsed.enable_batching, "--enable_batching is required for using vLLM attention."
assert parsed.target_kind == "cuda", "vLLM attention is only supported for CUDA."
assert tvm.get_global_func(
"tvm.contrib.vllm.single_query_cached_kv_attention", True
), "TVM needs to be built with -DUSE_VLLM=ON."
if parsed.paged_kv_cache_type in ["vllm", "flash-decoding"]:
assert parsed.enable_batching, "--enable_batching is required for using vLLM or Flash-Decoding."
assert parsed.target_kind == "cuda", "vLLM and Flash-Decoding are only supported for CUDA."

if parsed.paged_kv_cache_type == "vllm":
assert tvm.get_global_func(
"tvm.contrib.vllm.single_query_cached_kv_attention", True
), "TVM needs to be built with -DUSE_VLLM=ON to use vLLM."
elif parsed.paged_kv_cache_type == "flash-decoding":
assert tvm.get_global_func(
"tvm.contrib.flash_attn.flash_decoding_with_paged_kvcache", True
), "TVM needs to be built with -DUSE_CUTLASS=ON to use Flash-Ddecoding."
masahi marked this conversation as resolved.
Show resolved Hide resolved

model_name = [
parsed.model,
Expand Down Expand Up @@ -588,20 +583,20 @@ def mod_transform_before_build(
"decode",
]

if not args.use_vllm_attention:
model_names += [
"create_kv_cache",
"softmax_with_temperature",
"get_metadata",
]
else:
if args.paged_kv_cache_type in ["vllm", "flash-decoding"]:
# This is equivalent to prefill but without KV cache. It is used for
# determining the number of paged cache blocks that can be allocated.
model_names.append("evaluate")
model_names.append("evaluate_multi_query")

if args.paged_kv_cache_type == "flash-decoding":
model_names.append("decode_multi_query")
else:
model_names += [
"create_kv_cache",
"softmax_with_temperature",
"get_metadata",
]

if args.sep_embed:
model_names = ["embed", "prefill_with_embed"] + model_names[1:]
Expand Down Expand Up @@ -879,7 +874,7 @@ def build_model_from_args(args: argparse.Namespace):
"mixtral": llama,
}

if args.use_vllm_attention:
if args.paged_kv_cache_type in ["vllm", "flash-decoding"]:
model_generators["llama"] = llama_batched_vllm
model_generators["mistral"] = llama_batched_vllm
model_generators["mixtral"] = llama_batched_vllm
Expand Down
2 changes: 1 addition & 1 deletion serve/mlc_serve/engine/model_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class PrefillRequest:
class DecodeRequest:
sequence_id: SequenceId
prompt_token_counts: int
# All tokens for this request, including prompt
# Decoded tokens for this sequence
token_ids: List[int]
sampling_params: SamplingParams

Expand Down
4 changes: 4 additions & 0 deletions serve/mlc_serve/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class ModelArtifactConfig:
model_artifact_path: Optional[str] = None
num_shards: Optional[int] = None
quantization: Optional[str] = None
paged_kv_cache_type: Optional[str] = None
model_type: Optional[str] = None
library_name: Optional[str] = None
max_context_length: Optional[int] = None
Expand Down Expand Up @@ -49,4 +50,7 @@ def get_model_artifact_config(model_artifact_path):
with open(config_file_path, mode="rt", encoding="utf-8") as f:
json_object.update(json.load(f))

if not "paged_kv_cache_type" in json_object:
json_object["paged_kv_cache_type"] = "vllm"

return ModelArtifactConfig._from_json(json_object)
14 changes: 10 additions & 4 deletions serve/mlc_serve/model/model_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def get_gpu_memory(gpu: int = 0) -> int:

def get_num_cache_blocks(
model,
block_size,
seq_lens,
num_layers,
num_kv_heads,
Expand All @@ -28,7 +29,7 @@ def get_num_cache_blocks(
):
used_memory_bytes = model.profile_memory_usage(seq_lens)
cache_block_size = CacheManager.get_cache_block_size(
num_layers, num_kv_heads, head_size
block_size, num_layers, num_kv_heads, head_size
)
total_vram = get_gpu_memory()
return int(
Expand Down Expand Up @@ -185,6 +186,7 @@ def prepare_inputs(
all_decode_block_tables,
sliding_window,
is_prefill,
num_decode_query_tokens=1,
):
block_tables = []
seq_lens = []
Expand All @@ -211,13 +213,17 @@ def prepare_inputs(
start_idx += prompt_len

else:
input_ids.append(token_ids[-1])
seq_len = prompt_lens[i] + len(token_ids)
positions.append(seq_len - 1)
input_ids += token_ids[-num_decode_query_tokens:]

for i in range(num_decode_query_tokens):
positions.append(seq_len - (num_decode_query_tokens - i))

slot_mapping += all_slot_mappings[sequence_id][-num_decode_query_tokens:]

block_table = all_decode_block_tables[sequence_id]
max_num_blocks_per_seq = max(max_num_blocks_per_seq, len(block_table))
block_tables.append(block_table.get_blocks())
slot_mapping.append(all_slot_mappings[sequence_id][-1])

if sliding_window:
seq_lens.append(min(seq_len, sliding_window))
Expand Down
8 changes: 4 additions & 4 deletions serve/mlc_serve/model/paged_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,10 @@ def __init__(


class CacheManager:
block_size: int = 16

@staticmethod
def get_cache_block_size(num_layers, num_heads, head_size):
def get_cache_block_size(block_size, num_layers, num_heads, head_size):
# Taken from vllm/worker/cache_engine.py
key_cache_block = CacheManager.block_size * num_heads * head_size
key_cache_block = block_size * num_heads * head_size
value_cache_block = key_cache_block
total = num_layers * (key_cache_block + value_cache_block)
dtype_size = 2 # fp16
Expand All @@ -134,9 +132,11 @@ def get_cache_block_size(num_layers, num_heads, head_size):
def __init__(
self,
num_blocks: int,
block_size: int,
sliding_window: Optional[int] = None,
):
self.num_blocks = num_blocks
self.block_size = block_size
self.free_blocks = list(range(num_blocks))
self.kv_cache_info = KVCacheInfo(self.block_size)
self.token_counts = dict[SequenceId, int]()
Expand Down
Loading
Loading