From 52932008425acee5fa785c0dd805929ad162406d Mon Sep 17 00:00:00 2001 From: jayfeather9 Date: Tue, 18 Mar 2025 17:10:21 +0800 Subject: [PATCH 01/22] initial hicache support not finished --- .../router/dynamic_prompt/cache_controller.py | 230 ++++++++++++++++++ .../router/dynamic_prompt/hiradix_cache.py | 64 +++++ 2 files changed, 294 insertions(+) create mode 100644 lightllm/server/router/dynamic_prompt/cache_controller.py create mode 100644 lightllm/server/router/dynamic_prompt/hiradix_cache.py diff --git a/lightllm/server/router/dynamic_prompt/cache_controller.py b/lightllm/server/router/dynamic_prompt/cache_controller.py new file mode 100644 index 000000000..44fb79af1 --- /dev/null +++ b/lightllm/server/router/dynamic_prompt/cache_controller.py @@ -0,0 +1,230 @@ +import torch +import threading +import time +import json +from typing import Dict, List, Tuple, Optional, Set, Any +from queue import Queue +from lightllm.common.mem_manager import MemoryManager + +BLOCK_SIZE = 16384 + +def get_torch_tensor_size(tensor: torch.Tensor): + return tensor.nelement() * tensor.element_size() + +class CacheNode: + def __init__(self, parent=None, split_token_idx=None): + self.parent = parent # 父节点 + self.split_token_idx = split_token_idx # 从父节点分裂的位置 + self.children = {} # (token_id, split_position) -> (child_node, split_position) + self.cache_indices = [] # 存储kv cache在mem_manager中的索引 + self.token_ids = [] # 当前节点存储的token ids + self.hash = None # 存储在磁盘上的唯一标识 + + def serialize(self): + """将节点数据序列化为JSON""" + data = { + "children": {f"{k[0]}_{k[1]}": [c.hash, p] for k, (c, p) in self.children.items()}, + "cache_indices": self.cache_indices, + "token_ids": self.token_ids, + "split_token_idx": self.split_token_idx + } + return json.dumps(data) + + @classmethod + def deserialize(cls, data_str, parent=None): + """从JSON反序列化节点数据""" + data = json.loads(data_str) + node = cls(parent=parent, split_token_idx=data["split_token_idx"]) + node.cache_indices = data["cache_indices"] + node.token_ids = data["token_ids"] + # 子节点需要单独加载 + return node, {(int(k.split('_')[0]), int(k.split('_')[1])): (v[0], v[1]) for k, v in data["children"].items()} + + +class HiCacheController: + def __init__(self, mem_manager: MemoryManager): + self.mem_manager = mem_manager + self.service = None # 将由外部代码初始化 + + self.root = CacheNode() + self.root.hash = "root" + + self.node_cache = {self.root.hash: self.root} # hash -> node + self.read_queue = Queue() + self.write_queue = Queue() + + self.token_kvcache_size = None # 每个token的kvcache大小 + + # 启动后台线程处理读写任务 + self.running = True + self.poll_thread = threading.Thread(target=self._poll_tasks) + self.poll_thread.daemon = True + self.poll_thread.start() + + def reset(self): + """重置缓存控制器""" + self.running = False + self.poll_thread.join(timeout=1) + + self.root = CacheNode() + self.root.hash = "root" + self.node_cache = {self.root.hash: self.root} + + self.read_queue = Queue() + self.write_queue = Queue() + + self.running = True + self.poll_thread = threading.Thread(target=self._poll_tasks) + self.poll_thread.daemon = True + self.poll_thread.start() + + def _poll_tasks(self): + """轮询读写任务,检查是否完成""" + while self.running: + # 处理读任务 + pending_reads = [] + while not self.read_queue.empty(): + task = self.read_queue.get() + if task.ready(): + # TODO: 将读到的内容存入 memory manager 中 + pass + else: + pending_reads.append(task) + + for task in pending_reads: + self.read_queue.put(task) + + # 处理写任务 + pending_writes = [] + while not self.write_queue.empty(): + task = self.write_queue.get() + if not task.ready(): + pending_writes.append(task) + + for task in pending_writes: + self.write_queue.put(task) + + time.sleep(0.01) # 避免CPU过度使用 + + def _ensure_node_loaded(self, node_hash): + """确保节点已加载到内存中""" + if node_hash not in self.node_cache and node_hash != "root": + task = self.service.create(hashs=[node_hash], mode="r") + self.service.commit(task) + self.read_queue.put(task) + # 需要等待节点加载完成 + while not task.ready() or node_hash not in self.node_cache: + time.sleep(0.01) + + def _persist_node(self, node): + """将节点持久化到磁盘""" + if not node.hash: + # 为新节点生成hash + node.hash = f"node_{id(node)}_{time.time()}" + + # TODO: 将对应的kvcache写入磁盘 + task = self.service.create(hashs=[node.hash], mode="w") + self.service.commit(task) + self.write_queue.put(task) + self.node_cache[node.hash] = node + + def write(self, key: torch.Tensor, value: torch.Tensor): + """ + 写入token序列及其对应的KV缓存索引 + key: token_ids序列 + value: 对应的KV缓存索引 + """ + token_ids = key.cpu().tolist() + indices = value.cpu().tolist() + + # 首次计算每个token的kvcache大小 + if self.token_kvcache_size is None: + kvcache = self.mem_manager.to_kvcache(indices[:1]) # 计算单个token的kvcache + self.token_kvcache_size = get_torch_tensor_size(kvcache) + + current = self.root + position = 0 + relative_position = 0 + + while position < len(token_ids): + token_id = token_ids[position] + child_key = (token_id, relative_position) + + if child_key in current.children: + child_info = current.children[child_key] + assert isinstance(child_info[0], CacheNode) + child_hash = child_info[0].hash + self._ensure_node_loaded(child_hash) + current = self.node_cache[child_hash] + position += 1 + relative_position = 0 # next time relative pos is 0 + else: + # 计算当前节点剩余空间 + remaining_space = BLOCK_SIZE - len(current.cache_indices) * self.token_kvcache_size + + if self.token_kvcache_size <= remaining_space: + # 当前节点有足够空间 + current.token_ids.append(token_ids[position]) + current.cache_indices.append(indices[position]) + position += 1 + relative_position += 1 + self._persist_node(current) + else: + # 当前节点已满,需要创建新节点 + new_node = CacheNode(parent=current, split_token_idx=len(current.token_ids)) + + # 将token添加到新节点 + new_node.token_ids.append(token_ids[position]) + new_node.cache_indices.append(indices[position]) + position += 1 + relative_position = 0 # next time relative pos is 0, not affecting child_key + + # 建立父子关系 + current.children[child_key] = (new_node, len(current.cache_indices)) + + # 持久化 + self._persist_node(new_node) + self._persist_node(current) + + current = new_node + + # 确保最后修改的节点被持久化 + self._persist_node(current) + + def read(self, key: torch.Tensor) -> torch.Tensor: + """ + 读取token序列对应的KV缓存索引 + key: token_ids序列 + 返回: 对应的KV缓存索引 + """ + token_ids = key.cpu().tolist() + result_indices = [] + + current = self.root + position = 0 + relative_position = 0 + + while position < len(token_ids): + token_id = token_ids[position] + + # 检查当前节点的token + if relative_position < len(current.token_ids) and current.token_ids[relative_position] == token_id: + # TODO: 将读到的东西存到 result_indices 中 + position += 1 + relative_position += 1 + continue + + # 查找子节点 + child_key = (token_id, relative_position) + if child_key in current.children: + child_info = current.children[child_key] + assert isinstance(child_info[0], CacheNode) + child_hash = child_info[0].hash + self._ensure_node_loaded(child_hash) + current = self.node_cache[child_hash] + relative_position = 0 + else: + # 未找到匹配的路径 + return torch.tensor(result_indices, dtype=torch.int64) + + return torch.tensor(result_indices, dtype=torch.int64) diff --git a/lightllm/server/router/dynamic_prompt/hiradix_cache.py b/lightllm/server/router/dynamic_prompt/hiradix_cache.py new file mode 100644 index 000000000..151bcd53a --- /dev/null +++ b/lightllm/server/router/dynamic_prompt/hiradix_cache.py @@ -0,0 +1,64 @@ +import torch +from .cache_controller import HiCacheController +from .radix_cache import RadixCache, TreeNode, match +from typing import Tuple, Dict, Set, List +from lightllm.common.mem_manager import MemoryManager + + +class HiRadixCache(RadixCache): + def __init__(self, cache_controller: HiCacheController, unique_name, total_token_num, rank_in_node, mem_manager: MemoryManager = None): + super().__init__(unique_name, total_token_num, rank_in_node, mem_manager) + self.cache_controller = cache_controller + + def _insert_helper(self, node: TreeNode, key, value): + if node.is_leaf(): + self.evict_tree_set.discard(node) + + try: + first_key_id = key[0].item() + if first_key_id in node.children.keys(): + child: TreeNode = node.children[first_key_id] + prefix_len = match(key, child.token_id_key) + if prefix_len == len(key): + if child.is_leaf(): + self.evict_tree_set.discard(child) + child.update_time() + if child.is_leaf(): + self.evict_tree_set.add(child) + return prefix_len + + elif prefix_len < len(key) and prefix_len < len(child.token_id_key): + if child.is_leaf(): + self.evict_tree_set.discard(child) + + key = key[prefix_len:] + value = value[prefix_len:] + split_parent_node = child.split_node(prefix_len) + new_node = split_parent_node.add_and_return_new_child(key, value) + # update total token num + self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value) + + if split_parent_node.is_leaf(): + self.evict_tree_set.add(split_parent_node) + if new_node.is_leaf(): + self.evict_tree_set.add(new_node) + + if child.is_leaf(): + self.evict_tree_set.add(child) + return prefix_len + elif prefix_len < len(key) and prefix_len == len(child.token_id_key): + return prefix_len + self._insert_helper(child, key[prefix_len:], value[prefix_len:]) + else: + assert False, "can not run to here" + + else: + new_node = node.add_and_return_new_child(key, value) + # update total token num + self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value) + if new_node.is_leaf(): + self.evict_tree_set.add(new_node) + return 0 + finally: + node.update_time() + if node.is_leaf(): + self.evict_tree_set.add(node) From 013e107cb65ca29ca3edaace846d898d989e2b37 Mon Sep 17 00:00:00 2001 From: jayfeather9 Date: Tue, 18 Mar 2025 22:02:09 +0800 Subject: [PATCH 02/22] add debug outputs --- .../router/dynamic_prompt/cache_controller.py | 14 +- test/server/test_hicache.py | 156 ++++++++++++++++++ 2 files changed, 168 insertions(+), 2 deletions(-) create mode 100644 test/server/test_hicache.py diff --git a/lightllm/server/router/dynamic_prompt/cache_controller.py b/lightllm/server/router/dynamic_prompt/cache_controller.py index 44fb79af1..e14106923 100644 --- a/lightllm/server/router/dynamic_prompt/cache_controller.py +++ b/lightllm/server/router/dynamic_prompt/cache_controller.py @@ -87,7 +87,11 @@ def _poll_tasks(self): task = self.read_queue.get() if task.ready(): # TODO: 将读到的内容存入 memory manager 中 - pass + node_hash = task.hashs[0] + if node_hash in self.node_cache: + node = self.node_cache[node_hash] + node.cache_indices = self.mem_manager.store(node.cache_indices, task.value) + print(f"Node {node_hash} loaded with {len(node.cache_indices)} cache indices") else: pending_reads.append(task) @@ -118,12 +122,13 @@ def _ensure_node_loaded(self, node_hash): def _persist_node(self, node): """将节点持久化到磁盘""" + print(f"Persisting node {node.hash} with {len(node.token_ids)} tokens") if not node.hash: # 为新节点生成hash node.hash = f"node_{id(node)}_{time.time()}" # TODO: 将对应的kvcache写入磁盘 - task = self.service.create(hashs=[node.hash], mode="w") + task = self.service.create(hashs=[node.hash], value=self.mem_manager.to_kvcache(node.cache_indices), mode="w") self.service.commit(task) self.write_queue.put(task) self.node_cache[node.hash] = node @@ -141,6 +146,7 @@ def write(self, key: torch.Tensor, value: torch.Tensor): if self.token_kvcache_size is None: kvcache = self.mem_manager.to_kvcache(indices[:1]) # 计算单个token的kvcache self.token_kvcache_size = get_torch_tensor_size(kvcache) + print(f"Single token KV cache size: {self.token_kvcache_size} bytes, Block size: {BLOCK_SIZE}") current = self.root position = 0 @@ -148,9 +154,11 @@ def write(self, key: torch.Tensor, value: torch.Tensor): while position < len(token_ids): token_id = token_ids[position] + print(f"Writing token {token_id} at position {position}, current node has {len(current.token_ids)} tokens") child_key = (token_id, relative_position) if child_key in current.children: + print(f"Child key {child_key} found in current.children") child_info = current.children[child_key] assert isinstance(child_info[0], CacheNode) child_hash = child_info[0].hash @@ -172,6 +180,7 @@ def write(self, key: torch.Tensor, value: torch.Tensor): else: # 当前节点已满,需要创建新节点 new_node = CacheNode(parent=current, split_token_idx=len(current.token_ids)) + print(f"Creating new node at split position {new_node.split_token_idx}, parent hash: {current.hash}") # 将token添加到新节点 new_node.token_ids.append(token_ids[position]) @@ -206,6 +215,7 @@ def read(self, key: torch.Tensor) -> torch.Tensor: while position < len(token_ids): token_id = token_ids[position] + print(f"Reading token {token_id} at position {position}, current node has {len(current.token_ids)} tokens") # 检查当前节点的token if relative_position < len(current.token_ids) and current.token_ids[relative_position] == token_id: diff --git a/test/server/test_hicache.py b/test/server/test_hicache.py new file mode 100644 index 000000000..b65b4f9c7 --- /dev/null +++ b/test/server/test_hicache.py @@ -0,0 +1,156 @@ +# test_hicache.py +import torch +import time +import random +from threading import Thread, Event +from queue import Queue +from lightllm.server.router.dynamic_prompt.cache_controller import HiCacheController, CacheNode, BLOCK_SIZE + +class MockMemoryManager: + """模拟内存管理器,仅返回连续的索引值""" + def __init__(self): + self.current_idx = 0 + self.kvcache_store = {} + + def alloc(self, size): + indices = list(range(self.current_idx, self.current_idx + size)) + self.current_idx += size + self.store(indices, torch.tensor([[0] * 512 for _ in range(size)])) + return indices + + def to_kvcache(self, indices): + return torch.tensor([self.kvcache_store[idx].tolist() for idx in indices]) + + def store(self, indices, value): + for idx, val in zip(indices, value): + self.kvcache_store[idx] = val + + def free(self, indices): + for idx in indices: + del self.kvcache_store[idx] + +class MockTask: + def __init__(self, hashs, mode, value=None): + self.hashs = hashs + self.mode = mode + self._ready = Event() + self.data = value + + def ready(self): + return self._ready.is_set() + + def set_ready(self): + self._ready.set() + +class MockService: + def __init__(self): + self.tasks = Queue() + self.running = True + self.worker = Thread(target=self.process_tasks) + self.worker.daemon = True + self.worker.start() + + def process_tasks(self): + while self.running: + if not self.tasks.empty(): + task = self.tasks.get() + # 模拟随机延迟后完成任务 + delay = random.uniform(0.01, 0.1) + time.sleep(delay) + task.set_ready() + print(f"Task for {task.hashs} completed after {delay:.2f}s") + else: + time.sleep(0.01) + + def create(self, hashs, mode, value=None): + task = MockTask(hashs, mode, value) + self.tasks.put(task) + return task + + def commit(self, task): + pass # 在Mock中不需要实现 + + def shutdown(self): + self.running = False + self.worker.join() + +def setup(): + mem_manager = MockMemoryManager() + service = MockService() + hicache = HiCacheController(mem_manager) + hicache.service = service # 注入模拟服务 + + # 预先计算单token大小 + dummy_indices = mem_manager.alloc(1) + kvcache = mem_manager.to_kvcache(dummy_indices[:1]) + token_size = kvcache.nelement() * kvcache.element_size() + print(f"[TEST] Single token KV cache size: {token_size} bytes, Block size: {BLOCK_SIZE}") + + return mem_manager, service, hicache, token_size + +def test_basic_write_read(mem_manager, hicache, token_size): + # 计算每个块可容纳的token数量 + tokens_per_block = BLOCK_SIZE // token_size + print(f"[TEST] Each block can hold {tokens_per_block} tokens") + + # 生成测试数据:刚好占满一个块 + token_ids = list(range(tokens_per_block)) + indices = mem_manager.alloc(len(token_ids)) + + # 写入缓存 + hicache.write(torch.tensor(token_ids), torch.tensor(indices)) + + # 等待任务完成 + time.sleep(0.5) # 确保后台线程处理完成 + + # 读取验证 + result = hicache.read(torch.tensor(token_ids)) + assert result.tolist() == indices, f"Retrieved indices: {result.tolist()}, Expected indices: {indices}" + print(f"[TEST] Basic test passed. Retrieved indices: {result.tolist()}") + +def test_node_splitting(mem_manager, hicache, token_size): + tokens_per_block = BLOCK_SIZE // token_size + # 生成超过一个块的数据 + token_ids = list(range(tokens_per_block + 1)) + indices = mem_manager.alloc(len(token_ids)) + + hicache.write(torch.tensor(token_ids), torch.tensor(indices)) + time.sleep(0.5) + + # 验证根节点应该有子节点 + root = hicache.root + assert len(root.children) > 0 + print(f"\nRoot node has {len(root.children)} children") + + # 读取完整序列 + result = hicache.read(torch.tensor(token_ids)) + assert result.tolist() == indices + print(f"[TEST] Node splitting test passed. Retrieved indices: {result.tolist()}") + +def test_partial_read(mem_manager, hicache): + token_ids = [1,2,3,4,5] + indices = mem_manager.alloc(len(token_ids)) + hicache.write(torch.tensor(token_ids), torch.tensor(indices)) + time.sleep(0.2) + + # 查询存在的部分前缀 + result = hicache.read(torch.tensor([1,2,3])) + assert result.tolist() == indices[:3] + print(f"[TEST] Partial read result: {result.tolist()}") + + # 查询不存在的前缀 + result = hicache.read(torch.tensor([1,2,9])) + assert len(result) == 0 + print(f"[TEST] Non-existent prefix returned: {result.tolist()}") + +def main(): + mem_manager, service, hicache, token_size = setup() + try: + test_basic_write_read(mem_manager, hicache, token_size) + test_node_splitting(mem_manager, hicache, token_size) + test_partial_read(mem_manager, hicache) + finally: + service.shutdown() + +if __name__ == "__main__": + main() \ No newline at end of file From 7bffa94e2777cea92826b51fa5adfd4634bbf954 Mon Sep 17 00:00:00 2001 From: wufeiyang Date: Wed, 19 Mar 2025 11:27:02 +0800 Subject: [PATCH 03/22] fixed various bugs --- .../router/dynamic_prompt/cache_controller.py | 39 +++++---- test/server/test_hicache.py | 84 +++++++++++++++---- 2 files changed, 89 insertions(+), 34 deletions(-) diff --git a/lightllm/server/router/dynamic_prompt/cache_controller.py b/lightllm/server/router/dynamic_prompt/cache_controller.py index e14106923..ba22146f2 100644 --- a/lightllm/server/router/dynamic_prompt/cache_controller.py +++ b/lightllm/server/router/dynamic_prompt/cache_controller.py @@ -47,7 +47,7 @@ def __init__(self, mem_manager: MemoryManager): self.service = None # 将由外部代码初始化 self.root = CacheNode() - self.root.hash = "root" + self.root.hash = "root_" + str(time.time()) self.node_cache = {self.root.hash: self.root} # hash -> node self.read_queue = Queue() @@ -67,7 +67,7 @@ def reset(self): self.poll_thread.join(timeout=1) self.root = CacheNode() - self.root.hash = "root" + self.root.hash = "root_" + str(time.time()) self.node_cache = {self.root.hash: self.root} self.read_queue = Queue() @@ -85,14 +85,7 @@ def _poll_tasks(self): pending_reads = [] while not self.read_queue.empty(): task = self.read_queue.get() - if task.ready(): - # TODO: 将读到的内容存入 memory manager 中 - node_hash = task.hashs[0] - if node_hash in self.node_cache: - node = self.node_cache[node_hash] - node.cache_indices = self.mem_manager.store(node.cache_indices, task.value) - print(f"Node {node_hash} loaded with {len(node.cache_indices)} cache indices") - else: + if not task.ready(): pending_reads.append(task) for task in pending_reads: @@ -112,21 +105,30 @@ def _poll_tasks(self): def _ensure_node_loaded(self, node_hash): """确保节点已加载到内存中""" - if node_hash not in self.node_cache and node_hash != "root": + assert node_hash in self.node_cache, f"Node {node_hash} not found in cache" + assert node_hash[:4] != "root", "Cannot load root node" + if not self.mem_manager.exist(self.node_cache[node_hash].cache_indices): task = self.service.create(hashs=[node_hash], mode="r") self.service.commit(task) self.read_queue.put(task) # 需要等待节点加载完成 - while not task.ready() or node_hash not in self.node_cache: + while not task.ready(): time.sleep(0.01) + for node_hash, node_data in zip(task.hashs, task.data): + assert node_hash in self.node_cache + node = self.node_cache[node_hash] + node.cache_indices = self.mem_manager.store(node.cache_indices, node_data) + print(f"Node {node_hash} loaded with {len(node.cache_indices)} cache indices") + print(f"Node {node_hash} loaded to memory") def _persist_node(self, node): """将节点持久化到磁盘""" - print(f"Persisting node {node.hash} with {len(node.token_ids)} tokens") if not node.hash: # 为新节点生成hash node.hash = f"node_{id(node)}_{time.time()}" + print(f"Persisting node {node.hash} with {len(node.token_ids)} tokens") + # TODO: 将对应的kvcache写入磁盘 task = self.service.create(hashs=[node.hash], value=self.mem_manager.to_kvcache(node.cache_indices), mode="w") self.service.commit(task) @@ -155,6 +157,7 @@ def write(self, key: torch.Tensor, value: torch.Tensor): while position < len(token_ids): token_id = token_ids[position] print(f"Writing token {token_id} at position {position}, current node has {len(current.token_ids)} tokens") + print(f"relative_position: {relative_position}, node_hash: {current.hash}") child_key = (token_id, relative_position) if child_key in current.children: @@ -170,7 +173,8 @@ def write(self, key: torch.Tensor, value: torch.Tensor): # 计算当前节点剩余空间 remaining_space = BLOCK_SIZE - len(current.cache_indices) * self.token_kvcache_size - if self.token_kvcache_size <= remaining_space: + # root 不应存储任何内容 + if self.token_kvcache_size <= remaining_space and current != self.root: # 当前节点有足够空间 current.token_ids.append(token_ids[position]) current.cache_indices.append(indices[position]) @@ -189,7 +193,7 @@ def write(self, key: torch.Tensor, value: torch.Tensor): relative_position = 0 # next time relative pos is 0, not affecting child_key # 建立父子关系 - current.children[child_key] = (new_node, len(current.cache_indices)) + current.children[(token_id, new_node.split_token_idx)] = (new_node, new_node.split_token_idx) # 持久化 self._persist_node(new_node) @@ -206,6 +210,7 @@ def read(self, key: torch.Tensor) -> torch.Tensor: key: token_ids序列 返回: 对应的KV缓存索引 """ + print(f"Reading key: {key}") token_ids = key.cpu().tolist() result_indices = [] @@ -215,17 +220,19 @@ def read(self, key: torch.Tensor) -> torch.Tensor: while position < len(token_ids): token_id = token_ids[position] - print(f"Reading token {token_id} at position {position}, current node has {len(current.token_ids)} tokens") + print(f"Reading token {token_id} at position {position}, node total {len(current.token_ids)} tokens from node hash {current.hash}") # 检查当前节点的token if relative_position < len(current.token_ids) and current.token_ids[relative_position] == token_id: # TODO: 将读到的东西存到 result_indices 中 + result_indices.append(current.cache_indices[relative_position]) position += 1 relative_position += 1 continue # 查找子节点 child_key = (token_id, relative_position) + print(f"Looking for child {child_key} in node {current.hash}: {current.children}") if child_key in current.children: child_info = current.children[child_key] assert isinstance(child_info[0], CacheNode) diff --git a/test/server/test_hicache.py b/test/server/test_hicache.py index b65b4f9c7..82a7a44de 100644 --- a/test/server/test_hicache.py +++ b/test/server/test_hicache.py @@ -15,17 +15,25 @@ def __init__(self): def alloc(self, size): indices = list(range(self.current_idx, self.current_idx + size)) self.current_idx += size - self.store(indices, torch.tensor([[0] * 512 for _ in range(size)])) + self.store(indices, torch.tensor([[random.randint(0, 0xffff) for __ in range(512)] for _ in range(size)])) return indices def to_kvcache(self, indices): + assert all([idx in self.kvcache_store for idx in indices]), f"Not all of {indices} are not found in kvcache_store" return torch.tensor([self.kvcache_store[idx].tolist() for idx in indices]) def store(self, indices, value): - for idx, val in zip(indices, value): - self.kvcache_store[idx] = val + print(f"[TEST:MemManager] Storing {value.shape} at {indices}") + for idx, value_dim in zip(indices, range(value.shape[0])): + self.kvcache_store[idx] = value[value_dim] + print(f"[TEST:MemManager] Stored {value[value_dim].shape} at {idx}") + return indices + + def exist(self, indices): + return all([idx in self.kvcache_store for idx in indices]) def free(self, indices): + print(f"[TEST:MemManager] Freeing {indices}") for idx in indices: del self.kvcache_store[idx] @@ -45,7 +53,10 @@ def set_ready(self): class MockService: def __init__(self): self.tasks = Queue() + self.added_count = 0 + self.finished_count = 0 self.running = True + self.hash_data = {} self.worker = Thread(target=self.process_tasks) self.worker.daemon = True self.worker.start() @@ -53,20 +64,43 @@ def __init__(self): def process_tasks(self): while self.running: if not self.tasks.empty(): - task = self.tasks.get() # 模拟随机延迟后完成任务 delay = random.uniform(0.01, 0.1) time.sleep(delay) + task = self.tasks.get() + self.complete(task) task.set_ready() print(f"Task for {task.hashs} completed after {delay:.2f}s") else: time.sleep(0.01) + def complete(self, task): + if task.mode == "r": + assert all(hash in self.hash_data for hash in task.hashs) + task.data = torch.stack(list(self.hash_data[hash] for hash in task.hashs)) + elif task.mode == "w": + for hash, value in zip(task.hashs, task.data): + self.hash_data[hash] = value + self.finished_count += 1 + def create(self, hashs, mode, value=None): + assert mode in ["r", "w"] + if not isinstance(value, list): + value = [value] + assert len(value) == len(hashs) task = MockTask(hashs, mode, value) self.tasks.put(task) + self.added_count += 1 return task + def all_finished(self): + return self.tasks.empty() and self.added_count == self.finished_count + + def wait_till_all_finished(self): + time.sleep(1) + while not self.all_finished(): + time.sleep(0.01) + def commit(self, task): pass # 在Mock中不需要实现 @@ -80,6 +114,9 @@ def setup(): hicache = HiCacheController(mem_manager) hicache.service = service # 注入模拟服务 + indices = mem_manager.alloc(5) + print(mem_manager.to_kvcache(indices)) + # 预先计算单token大小 dummy_indices = mem_manager.alloc(1) kvcache = mem_manager.to_kvcache(dummy_indices[:1]) @@ -96,26 +133,32 @@ def test_basic_write_read(mem_manager, hicache, token_size): # 生成测试数据:刚好占满一个块 token_ids = list(range(tokens_per_block)) indices = mem_manager.alloc(len(token_ids)) + kvcache = mem_manager.to_kvcache(indices) + print(f"[TEST] Generated KV cache with shape: {kvcache.shape}, type: {kvcache.dtype}") # 写入缓存 hicache.write(torch.tensor(token_ids), torch.tensor(indices)) # 等待任务完成 - time.sleep(0.5) # 确保后台线程处理完成 + hicache.service.wait_till_all_finished() + + mem_manager.free(indices) # 读取验证 result = hicache.read(torch.tensor(token_ids)) - assert result.tolist() == indices, f"Retrieved indices: {result.tolist()}, Expected indices: {indices}" - print(f"[TEST] Basic test passed. Retrieved indices: {result.tolist()}") + result = mem_manager.to_kvcache(result.tolist()) + assert result.eq(kvcache).all(), f"Retrieved kvcache: {result}, Expected kvcache: {kvcache}" + print(f"[TEST] Basic test passed. Retrieved kvcache\n\n") def test_node_splitting(mem_manager, hicache, token_size): tokens_per_block = BLOCK_SIZE // token_size # 生成超过一个块的数据 - token_ids = list(range(tokens_per_block + 1)) + token_ids = list(range(12, 12 + tokens_per_block * 3 + 1)) indices = mem_manager.alloc(len(token_ids)) + kvcache = mem_manager.to_kvcache(indices) hicache.write(torch.tensor(token_ids), torch.tensor(indices)) - time.sleep(0.5) + hicache.service.wait_till_all_finished() # 验证根节点应该有子节点 root = hicache.root @@ -124,23 +167,28 @@ def test_node_splitting(mem_manager, hicache, token_size): # 读取完整序列 result = hicache.read(torch.tensor(token_ids)) - assert result.tolist() == indices - print(f"[TEST] Node splitting test passed. Retrieved indices: {result.tolist()}") + result = mem_manager.to_kvcache(result.tolist()) + assert result.eq(kvcache).all(), f"Retrieved kvcache: {result}, Expected kvcache: {kvcache}" + print(f"[TEST] Node splitting test passed. Retrieved kvcache: {result.shape}\n\n") def test_partial_read(mem_manager, hicache): - token_ids = [1,2,3,4,5] + token_ids = [97, 98, 99, 100, 101, 102] indices = mem_manager.alloc(len(token_ids)) + kvcache = mem_manager.to_kvcache(indices) hicache.write(torch.tensor(token_ids), torch.tensor(indices)) - time.sleep(0.2) + hicache.service.wait_till_all_finished() # 查询存在的部分前缀 - result = hicache.read(torch.tensor([1,2,3])) - assert result.tolist() == indices[:3] - print(f"[TEST] Partial read result: {result.tolist()}") + result = hicache.read(torch.tensor([97, 98, 99])) + result = mem_manager.to_kvcache(result.tolist()) + assert result.eq(kvcache[:3]).all() + print(f"[TEST] Partial read passed") # 查询不存在的前缀 - result = hicache.read(torch.tensor([1,2,9])) - assert len(result) == 0 + result = hicache.read(torch.tensor([97, 98, 100])) + assert len(result) == 2 + result = mem_manager.to_kvcache(result.tolist()) + assert result.eq(kvcache[:2]).all() print(f"[TEST] Non-existent prefix returned: {result.tolist()}") def main(): From 2c649ed0fe9336b356e3d7866b328c129968eba2 Mon Sep 17 00:00:00 2001 From: wufeiyang Date: Wed, 19 Mar 2025 17:12:10 +0800 Subject: [PATCH 04/22] add hiradix cache implement untested --- .../router/dynamic_prompt/cache_controller.py | 38 ++++++++- .../router/dynamic_prompt/hiradix_cache.py | 79 +++++++------------ 2 files changed, 64 insertions(+), 53 deletions(-) diff --git a/lightllm/server/router/dynamic_prompt/cache_controller.py b/lightllm/server/router/dynamic_prompt/cache_controller.py index ba22146f2..e0248fa65 100644 --- a/lightllm/server/router/dynamic_prompt/cache_controller.py +++ b/lightllm/server/router/dynamic_prompt/cache_controller.py @@ -204,6 +204,40 @@ def write(self, key: torch.Tensor, value: torch.Tensor): # 确保最后修改的节点被持久化 self._persist_node(current) + def readable_length(self, key: torch.Tensor) -> int: + """ + 计算key对应的KV缓存索引可读取的长度 + """ + token_ids = key.cpu().tolist() + current = self.root + position = 0 + relative_position = 0 + readable_count = 0 + + while position < len(token_ids): + token_id = token_ids[position] + + # 检查当前节点的token + if relative_position < len(current.token_ids) and current.token_ids[relative_position] == token_id: + readable_count += 1 + position += 1 + relative_position += 1 + continue + + # 查找子节点 + child_key = (token_id, relative_position) + if child_key in current.children: + child_info = current.children[child_key] + assert isinstance(child_info[0], CacheNode) + child_hash = child_info[0].hash + current = self.node_cache[child_hash] + relative_position = 0 + else: + # 未找到匹配的路径,返回已读取的长度 + return readable_count + + return readable_count + def read(self, key: torch.Tensor) -> torch.Tensor: """ 读取token序列对应的KV缓存索引 @@ -242,6 +276,6 @@ def read(self, key: torch.Tensor) -> torch.Tensor: relative_position = 0 else: # 未找到匹配的路径 - return torch.tensor(result_indices, dtype=torch.int64) + return torch.tensor(result_indices) - return torch.tensor(result_indices, dtype=torch.int64) + return torch.tensor(result_indices) diff --git a/lightllm/server/router/dynamic_prompt/hiradix_cache.py b/lightllm/server/router/dynamic_prompt/hiradix_cache.py index 151bcd53a..f23102b87 100644 --- a/lightllm/server/router/dynamic_prompt/hiradix_cache.py +++ b/lightllm/server/router/dynamic_prompt/hiradix_cache.py @@ -10,55 +10,32 @@ def __init__(self, cache_controller: HiCacheController, unique_name, total_token super().__init__(unique_name, total_token_num, rank_in_node, mem_manager) self.cache_controller = cache_controller - def _insert_helper(self, node: TreeNode, key, value): - if node.is_leaf(): - self.evict_tree_set.discard(node) - - try: - first_key_id = key[0].item() - if first_key_id in node.children.keys(): - child: TreeNode = node.children[first_key_id] - prefix_len = match(key, child.token_id_key) - if prefix_len == len(key): - if child.is_leaf(): - self.evict_tree_set.discard(child) - child.update_time() - if child.is_leaf(): - self.evict_tree_set.add(child) - return prefix_len - - elif prefix_len < len(key) and prefix_len < len(child.token_id_key): - if child.is_leaf(): - self.evict_tree_set.discard(child) - - key = key[prefix_len:] - value = value[prefix_len:] - split_parent_node = child.split_node(prefix_len) - new_node = split_parent_node.add_and_return_new_child(key, value) - # update total token num - self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value) - - if split_parent_node.is_leaf(): - self.evict_tree_set.add(split_parent_node) - if new_node.is_leaf(): - self.evict_tree_set.add(new_node) - - if child.is_leaf(): - self.evict_tree_set.add(child) - return prefix_len - elif prefix_len < len(key) and prefix_len == len(child.token_id_key): - return prefix_len + self._insert_helper(child, key[prefix_len:], value[prefix_len:]) - else: - assert False, "can not run to here" - + def insert(self, key, value=None): + if value is None: + value = key + + assert len(key) == len(value) # and len(key) >= 1 + if len(key) == 0: + return 0 + + self.cache_controller.write(key, value) + return self._insert_helper(self.root_node, key, value) + + def match_prefix(self, key, update_refs=False): + assert len(key) != 0 + ans_value_list = [] + available_hi_result = self.cache_controller.readable_length(key) + tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=False) + if tree_node == self.root_node or available_hi_result > len(ans_value_list): + hi_result = self.cache_controller.read(key) + self._insert_helper(tree_node, key, hi_result) + tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs) + if tree_node != self.root_node: + if len(ans_value_list) != 0: + value = torch.concat(ans_value_list) else: - new_node = node.add_and_return_new_child(key, value) - # update total token num - self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value) - if new_node.is_leaf(): - self.evict_tree_set.add(new_node) - return 0 - finally: - node.update_time() - if node.is_leaf(): - self.evict_tree_set.add(node) + assert False, "can not run to here" + return tree_node, len(value), value + else: + self.dec_node_ref_counter(self.root_node) + return None, 0, None From 988da741104dfdd01827543ce6ac3ae0b32ca38f Mon Sep 17 00:00:00 2001 From: wufeiyang Date: Thu, 20 Mar 2025 20:49:23 +0800 Subject: [PATCH 05/22] add thread safety & real mem_manager support --- .../router/dynamic_prompt/cache_controller.py | 236 ++++++++++++------ test/server/test_hicache.py | 86 +------ 2 files changed, 167 insertions(+), 155 deletions(-) diff --git a/lightllm/server/router/dynamic_prompt/cache_controller.py b/lightllm/server/router/dynamic_prompt/cache_controller.py index e0248fa65..82d88ae71 100644 --- a/lightllm/server/router/dynamic_prompt/cache_controller.py +++ b/lightllm/server/router/dynamic_prompt/cache_controller.py @@ -3,7 +3,9 @@ import time import json from typing import Dict, List, Tuple, Optional, Set, Any +from threading import Thread, Event, RLock from queue import Queue +from enum import Enum from lightllm.common.mem_manager import MemoryManager BLOCK_SIZE = 16384 @@ -11,6 +13,11 @@ def get_torch_tensor_size(tensor: torch.Tensor): return tensor.nelement() * tensor.element_size() +class LoadStatus(Enum): + UNLOADED = 0 + LOADING = 1 + LOADED = 2 + class CacheNode: def __init__(self, parent=None, split_token_idx=None): self.parent = parent # 父节点 @@ -19,27 +26,7 @@ def __init__(self, parent=None, split_token_idx=None): self.cache_indices = [] # 存储kv cache在mem_manager中的索引 self.token_ids = [] # 当前节点存储的token ids self.hash = None # 存储在磁盘上的唯一标识 - - def serialize(self): - """将节点数据序列化为JSON""" - data = { - "children": {f"{k[0]}_{k[1]}": [c.hash, p] for k, (c, p) in self.children.items()}, - "cache_indices": self.cache_indices, - "token_ids": self.token_ids, - "split_token_idx": self.split_token_idx - } - return json.dumps(data) - - @classmethod - def deserialize(cls, data_str, parent=None): - """从JSON反序列化节点数据""" - data = json.loads(data_str) - node = cls(parent=parent, split_token_idx=data["split_token_idx"]) - node.cache_indices = data["cache_indices"] - node.token_ids = data["token_ids"] - # 子节点需要单独加载 - return node, {(int(k.split('_')[0]), int(k.split('_')[1])): (v[0], v[1]) for k, v in data["children"].items()} - + self.status = LoadStatus.UNLOADED # 加载状态 class HiCacheController: def __init__(self, mem_manager: MemoryManager): @@ -50,76 +37,72 @@ def __init__(self, mem_manager: MemoryManager): self.root.hash = "root_" + str(time.time()) self.node_cache = {self.root.hash: self.root} # hash -> node - self.read_queue = Queue() - self.write_queue = Queue() self.token_kvcache_size = None # 每个token的kvcache大小 - # 启动后台线程处理读写任务 - self.running = True - self.poll_thread = threading.Thread(target=self._poll_tasks) - self.poll_thread.daemon = True - self.poll_thread.start() + self.node_lock = RLock() + + # 添加写任务队列 + self.writetaskqueue = Queue() + self.write_thread_running = True + + # 启动处理写任务的线程 + self.write_thread = Thread(target=self._process_write_tasks) + self.write_thread.daemon = True + self.write_thread.start() + + def store_mem(self, indices, value): + for idx, value_dim in zip(indices, range(value.shape[0])): + self.mem_manager.load_index_kv_buffer(idx, {"kv_buffer": value[value_dim]}) + return indices + + def get_mem(self, indices): + if len(indices) == 0: + return torch.tensor([]) + return torch.stack([self.mem_manager.get_index_kv_buffer(idx)["kv_buffer"] for idx in indices], dim=0) def reset(self): """重置缓存控制器""" - self.running = False - self.poll_thread.join(timeout=1) + # 停止写任务线程 + self.write_thread_running = False + self.write_thread.join(timeout=1) self.root = CacheNode() self.root.hash = "root_" + str(time.time()) self.node_cache = {self.root.hash: self.root} - self.read_queue = Queue() - self.write_queue = Queue() - - self.running = True - self.poll_thread = threading.Thread(target=self._poll_tasks) - self.poll_thread.daemon = True - self.poll_thread.start() - - def _poll_tasks(self): - """轮询读写任务,检查是否完成""" - while self.running: - # 处理读任务 - pending_reads = [] - while not self.read_queue.empty(): - task = self.read_queue.get() - if not task.ready(): - pending_reads.append(task) - - for task in pending_reads: - self.read_queue.put(task) - - # 处理写任务 - pending_writes = [] - while not self.write_queue.empty(): - task = self.write_queue.get() - if not task.ready(): - pending_writes.append(task) - - for task in pending_writes: - self.write_queue.put(task) - - time.sleep(0.01) # 避免CPU过度使用 + # 重新创建队列和启动线程 + self.writetaskqueue = Queue() + self.write_thread_running = True + self.write_thread = Thread(target=self._process_write_tasks) + self.write_thread.daemon = True + self.write_thread.start() def _ensure_node_loaded(self, node_hash): """确保节点已加载到内存中""" assert node_hash in self.node_cache, f"Node {node_hash} not found in cache" assert node_hash[:4] != "root", "Cannot load root node" - if not self.mem_manager.exist(self.node_cache[node_hash].cache_indices): - task = self.service.create(hashs=[node_hash], mode="r") - self.service.commit(task) - self.read_queue.put(task) - # 需要等待节点加载完成 - while not task.ready(): - time.sleep(0.01) - for node_hash, node_data in zip(task.hashs, task.data): - assert node_hash in self.node_cache - node = self.node_cache[node_hash] - node.cache_indices = self.mem_manager.store(node.cache_indices, node_data) - print(f"Node {node_hash} loaded with {len(node.cache_indices)} cache indices") - print(f"Node {node_hash} loaded to memory") + with self.node_lock: + if self.node_cache[node_hash].status == LoadStatus.LOADED: + return + if self.node_cache[node_hash].status == LoadStatus.LOADING: + while self.node_cache[node_hash].status != LoadStatus.LOADED: + time.sleep(0.01) + return + if self.node_cache[node_hash].status == LoadStatus.UNLOADED: + self.node_cache[node_hash].status = LoadStatus.LOADING + + task = self.service.create(hashs=[node_hash], mode="r") + self.service.commit(task) + # 需要等待节点加载完成 + while not task.ready(): + time.sleep(0.01) + for node_hash, node_data in zip(task.hashs, task.data): + assert node_hash in self.node_cache + node = self.node_cache[node_hash] + node.cache_indices = self.store_mem(node.cache_indices, node_data) + print(f"Node {node_hash} loaded with {len(node.cache_indices)} cache indices") + print(f"Node {node_hash} loaded to memory") def _persist_node(self, node): """将节点持久化到磁盘""" @@ -129,15 +112,37 @@ def _persist_node(self, node): print(f"Persisting node {node.hash} with {len(node.token_ids)} tokens") - # TODO: 将对应的kvcache写入磁盘 - task = self.service.create(hashs=[node.hash], value=self.mem_manager.to_kvcache(node.cache_indices), mode="w") + task = self.service.create(hashs=[node.hash], value=self.get_mem(node.cache_indices), mode="w") self.service.commit(task) - self.write_queue.put(task) self.node_cache[node.hash] = node def write(self, key: torch.Tensor, value: torch.Tensor): """ - 写入token序列及其对应的KV缓存索引 + 将写任务加入队列,由后台线程异步处理 + + key: token_ids序列 + value: 对应的KV缓存索引 + """ + # 将任务加入队列 + self.writetaskqueue.put((key.clone(), value.clone())) + + def _process_write_tasks(self): + """后台线程处理写任务队列""" + while self.write_thread_running: + if not self.writetaskqueue.empty(): + # 从队列获取任务 + key, value = self.writetaskqueue.get() + # 执行实际的写操作 + self._do_write(key, value) + self.writetaskqueue.task_done() + else: + # 队列为空时短暂休眠,避免CPU占用过高 + time.sleep(0.01) + + def _do_write(self, key: torch.Tensor, value: torch.Tensor): + """ + 实际执行写入token序列及其对应的KV缓存索引 + key: token_ids序列 value: 对应的KV缓存索引 """ @@ -146,7 +151,7 @@ def write(self, key: torch.Tensor, value: torch.Tensor): # 首次计算每个token的kvcache大小 if self.token_kvcache_size is None: - kvcache = self.mem_manager.to_kvcache(indices[:1]) # 计算单个token的kvcache + kvcache = self.get_mem(indices[:1]) # 计算单个token的kvcache self.token_kvcache_size = get_torch_tensor_size(kvcache) print(f"Single token KV cache size: {self.token_kvcache_size} bytes, Block size: {BLOCK_SIZE}") @@ -279,3 +284,74 @@ def read(self, key: torch.Tensor) -> torch.Tensor: return torch.tensor(result_indices) return torch.tensor(result_indices) + + +class HiHostTask: + def __init__(self, hashs, mode, value=None): + self.hashs = hashs + self.mode = mode + self._ready = Event() + self.data = value + + def ready(self): + return self._ready.is_set() + + def set_ready(self): + self._ready.set() + +class HiHostService: + def __init__(self): + self.tasks = Queue() + self.added_count = 0 + self.finished_count = 0 + self.running = True + self.hash_data = {} # hash -> (data, device) + self.worker = Thread(target=self.process_tasks) + self.worker.daemon = True + self.worker.start() + + def process_tasks(self): + while self.running: + if not self.tasks.empty(): + start_time = time.time() + task = self.tasks.get() + self.complete(task) + task.set_ready() + print(f"Task for {task.hashs} completed after {time.time() - start_time:.2f}s") + else: + time.sleep(0.01) + + def complete(self, task): + if task.mode == "r": + assert all(hash in self.hash_data for hash in task.hashs) + task.data = torch.stack(list(self.hash_data[hash][0] for hash in task.hashs)) + task.data.to(self.hash_data[task.hashs[0]][1]) + elif task.mode == "w": + device = task.data[0].device + for hash, value in zip(task.hashs, task.data): + self.hash_data[hash] = (value.to("cpu"), device) + self.finished_count += 1 + + def create(self, hashs, mode, value=None): + assert mode in ["r", "w"] + if not isinstance(value, list): + value = [value] + assert len(value) == len(hashs) + task = HiHostTask(hashs, mode, value) + return task + + def all_finished(self): + return self.tasks.empty() and self.added_count == self.finished_count + + def wait_till_all_finished(self): + while not self.all_finished(): + time.sleep(0.01) + + def commit(self, task): + self.tasks.put(task) + self.added_count += 1 + + def shutdown(self): + self.running = False + self.worker.join() + diff --git a/test/server/test_hicache.py b/test/server/test_hicache.py index 82a7a44de..b30c00560 100644 --- a/test/server/test_hicache.py +++ b/test/server/test_hicache.py @@ -4,7 +4,7 @@ import random from threading import Thread, Event from queue import Queue -from lightllm.server.router.dynamic_prompt.cache_controller import HiCacheController, CacheNode, BLOCK_SIZE +from lightllm.server.router.dynamic_prompt.cache_controller import HiCacheController, CacheNode, BLOCK_SIZE, HiHostService, HiHostTask class MockMemoryManager: """模拟内存管理器,仅返回连续的索引值""" @@ -18,6 +18,12 @@ def alloc(self, size): self.store(indices, torch.tensor([[random.randint(0, 0xffff) for __ in range(512)] for _ in range(size)])) return indices + def load_index_kv_buffer(self, index, load_tensor_dict): + self.kvcache_store[index] = load_tensor_dict["kv_buffer"] + + def get_index_kv_buffer(self, index): + return {"kv_buffer": self.kvcache_store[index]} + def to_kvcache(self, indices): assert all([idx in self.kvcache_store for idx in indices]), f"Not all of {indices} are not found in kvcache_store" return torch.tensor([self.kvcache_store[idx].tolist() for idx in indices]) @@ -29,88 +35,15 @@ def store(self, indices, value): print(f"[TEST:MemManager] Stored {value[value_dim].shape} at {idx}") return indices - def exist(self, indices): - return all([idx in self.kvcache_store for idx in indices]) - def free(self, indices): print(f"[TEST:MemManager] Freeing {indices}") for idx in indices: del self.kvcache_store[idx] -class MockTask: - def __init__(self, hashs, mode, value=None): - self.hashs = hashs - self.mode = mode - self._ready = Event() - self.data = value - - def ready(self): - return self._ready.is_set() - - def set_ready(self): - self._ready.set() - -class MockService: - def __init__(self): - self.tasks = Queue() - self.added_count = 0 - self.finished_count = 0 - self.running = True - self.hash_data = {} - self.worker = Thread(target=self.process_tasks) - self.worker.daemon = True - self.worker.start() - - def process_tasks(self): - while self.running: - if not self.tasks.empty(): - # 模拟随机延迟后完成任务 - delay = random.uniform(0.01, 0.1) - time.sleep(delay) - task = self.tasks.get() - self.complete(task) - task.set_ready() - print(f"Task for {task.hashs} completed after {delay:.2f}s") - else: - time.sleep(0.01) - - def complete(self, task): - if task.mode == "r": - assert all(hash in self.hash_data for hash in task.hashs) - task.data = torch.stack(list(self.hash_data[hash] for hash in task.hashs)) - elif task.mode == "w": - for hash, value in zip(task.hashs, task.data): - self.hash_data[hash] = value - self.finished_count += 1 - - def create(self, hashs, mode, value=None): - assert mode in ["r", "w"] - if not isinstance(value, list): - value = [value] - assert len(value) == len(hashs) - task = MockTask(hashs, mode, value) - self.tasks.put(task) - self.added_count += 1 - return task - - def all_finished(self): - return self.tasks.empty() and self.added_count == self.finished_count - - def wait_till_all_finished(self): - time.sleep(1) - while not self.all_finished(): - time.sleep(0.01) - - def commit(self, task): - pass # 在Mock中不需要实现 - - def shutdown(self): - self.running = False - self.worker.join() def setup(): mem_manager = MockMemoryManager() - service = MockService() + service = HiHostService() hicache = HiCacheController(mem_manager) hicache.service = service # 注入模拟服务 @@ -138,6 +71,7 @@ def test_basic_write_read(mem_manager, hicache, token_size): # 写入缓存 hicache.write(torch.tensor(token_ids), torch.tensor(indices)) + time.sleep(2) # 等待任务完成 hicache.service.wait_till_all_finished() @@ -158,6 +92,7 @@ def test_node_splitting(mem_manager, hicache, token_size): kvcache = mem_manager.to_kvcache(indices) hicache.write(torch.tensor(token_ids), torch.tensor(indices)) + time.sleep(2) hicache.service.wait_till_all_finished() # 验证根节点应该有子节点 @@ -176,6 +111,7 @@ def test_partial_read(mem_manager, hicache): indices = mem_manager.alloc(len(token_ids)) kvcache = mem_manager.to_kvcache(indices) hicache.write(torch.tensor(token_ids), torch.tensor(indices)) + time.sleep(2) hicache.service.wait_till_all_finished() # 查询存在的部分前缀 From bf46d17c03657098b57d7d0bdb4157126c15b30b Mon Sep 17 00:00:00 2001 From: wufeiyang Date: Mon, 31 Mar 2025 08:42:34 +0800 Subject: [PATCH 06/22] delete cache controller --- .../router/dynamic_prompt/cache_controller.py | 357 ------------------ 1 file changed, 357 deletions(-) delete mode 100644 lightllm/server/router/dynamic_prompt/cache_controller.py diff --git a/lightllm/server/router/dynamic_prompt/cache_controller.py b/lightllm/server/router/dynamic_prompt/cache_controller.py deleted file mode 100644 index 82d88ae71..000000000 --- a/lightllm/server/router/dynamic_prompt/cache_controller.py +++ /dev/null @@ -1,357 +0,0 @@ -import torch -import threading -import time -import json -from typing import Dict, List, Tuple, Optional, Set, Any -from threading import Thread, Event, RLock -from queue import Queue -from enum import Enum -from lightllm.common.mem_manager import MemoryManager - -BLOCK_SIZE = 16384 - -def get_torch_tensor_size(tensor: torch.Tensor): - return tensor.nelement() * tensor.element_size() - -class LoadStatus(Enum): - UNLOADED = 0 - LOADING = 1 - LOADED = 2 - -class CacheNode: - def __init__(self, parent=None, split_token_idx=None): - self.parent = parent # 父节点 - self.split_token_idx = split_token_idx # 从父节点分裂的位置 - self.children = {} # (token_id, split_position) -> (child_node, split_position) - self.cache_indices = [] # 存储kv cache在mem_manager中的索引 - self.token_ids = [] # 当前节点存储的token ids - self.hash = None # 存储在磁盘上的唯一标识 - self.status = LoadStatus.UNLOADED # 加载状态 - -class HiCacheController: - def __init__(self, mem_manager: MemoryManager): - self.mem_manager = mem_manager - self.service = None # 将由外部代码初始化 - - self.root = CacheNode() - self.root.hash = "root_" + str(time.time()) - - self.node_cache = {self.root.hash: self.root} # hash -> node - - self.token_kvcache_size = None # 每个token的kvcache大小 - - self.node_lock = RLock() - - # 添加写任务队列 - self.writetaskqueue = Queue() - self.write_thread_running = True - - # 启动处理写任务的线程 - self.write_thread = Thread(target=self._process_write_tasks) - self.write_thread.daemon = True - self.write_thread.start() - - def store_mem(self, indices, value): - for idx, value_dim in zip(indices, range(value.shape[0])): - self.mem_manager.load_index_kv_buffer(idx, {"kv_buffer": value[value_dim]}) - return indices - - def get_mem(self, indices): - if len(indices) == 0: - return torch.tensor([]) - return torch.stack([self.mem_manager.get_index_kv_buffer(idx)["kv_buffer"] for idx in indices], dim=0) - - def reset(self): - """重置缓存控制器""" - # 停止写任务线程 - self.write_thread_running = False - self.write_thread.join(timeout=1) - - self.root = CacheNode() - self.root.hash = "root_" + str(time.time()) - self.node_cache = {self.root.hash: self.root} - - # 重新创建队列和启动线程 - self.writetaskqueue = Queue() - self.write_thread_running = True - self.write_thread = Thread(target=self._process_write_tasks) - self.write_thread.daemon = True - self.write_thread.start() - - def _ensure_node_loaded(self, node_hash): - """确保节点已加载到内存中""" - assert node_hash in self.node_cache, f"Node {node_hash} not found in cache" - assert node_hash[:4] != "root", "Cannot load root node" - with self.node_lock: - if self.node_cache[node_hash].status == LoadStatus.LOADED: - return - if self.node_cache[node_hash].status == LoadStatus.LOADING: - while self.node_cache[node_hash].status != LoadStatus.LOADED: - time.sleep(0.01) - return - if self.node_cache[node_hash].status == LoadStatus.UNLOADED: - self.node_cache[node_hash].status = LoadStatus.LOADING - - task = self.service.create(hashs=[node_hash], mode="r") - self.service.commit(task) - # 需要等待节点加载完成 - while not task.ready(): - time.sleep(0.01) - for node_hash, node_data in zip(task.hashs, task.data): - assert node_hash in self.node_cache - node = self.node_cache[node_hash] - node.cache_indices = self.store_mem(node.cache_indices, node_data) - print(f"Node {node_hash} loaded with {len(node.cache_indices)} cache indices") - print(f"Node {node_hash} loaded to memory") - - def _persist_node(self, node): - """将节点持久化到磁盘""" - if not node.hash: - # 为新节点生成hash - node.hash = f"node_{id(node)}_{time.time()}" - - print(f"Persisting node {node.hash} with {len(node.token_ids)} tokens") - - task = self.service.create(hashs=[node.hash], value=self.get_mem(node.cache_indices), mode="w") - self.service.commit(task) - self.node_cache[node.hash] = node - - def write(self, key: torch.Tensor, value: torch.Tensor): - """ - 将写任务加入队列,由后台线程异步处理 - - key: token_ids序列 - value: 对应的KV缓存索引 - """ - # 将任务加入队列 - self.writetaskqueue.put((key.clone(), value.clone())) - - def _process_write_tasks(self): - """后台线程处理写任务队列""" - while self.write_thread_running: - if not self.writetaskqueue.empty(): - # 从队列获取任务 - key, value = self.writetaskqueue.get() - # 执行实际的写操作 - self._do_write(key, value) - self.writetaskqueue.task_done() - else: - # 队列为空时短暂休眠,避免CPU占用过高 - time.sleep(0.01) - - def _do_write(self, key: torch.Tensor, value: torch.Tensor): - """ - 实际执行写入token序列及其对应的KV缓存索引 - - key: token_ids序列 - value: 对应的KV缓存索引 - """ - token_ids = key.cpu().tolist() - indices = value.cpu().tolist() - - # 首次计算每个token的kvcache大小 - if self.token_kvcache_size is None: - kvcache = self.get_mem(indices[:1]) # 计算单个token的kvcache - self.token_kvcache_size = get_torch_tensor_size(kvcache) - print(f"Single token KV cache size: {self.token_kvcache_size} bytes, Block size: {BLOCK_SIZE}") - - current = self.root - position = 0 - relative_position = 0 - - while position < len(token_ids): - token_id = token_ids[position] - print(f"Writing token {token_id} at position {position}, current node has {len(current.token_ids)} tokens") - print(f"relative_position: {relative_position}, node_hash: {current.hash}") - child_key = (token_id, relative_position) - - if child_key in current.children: - print(f"Child key {child_key} found in current.children") - child_info = current.children[child_key] - assert isinstance(child_info[0], CacheNode) - child_hash = child_info[0].hash - self._ensure_node_loaded(child_hash) - current = self.node_cache[child_hash] - position += 1 - relative_position = 0 # next time relative pos is 0 - else: - # 计算当前节点剩余空间 - remaining_space = BLOCK_SIZE - len(current.cache_indices) * self.token_kvcache_size - - # root 不应存储任何内容 - if self.token_kvcache_size <= remaining_space and current != self.root: - # 当前节点有足够空间 - current.token_ids.append(token_ids[position]) - current.cache_indices.append(indices[position]) - position += 1 - relative_position += 1 - self._persist_node(current) - else: - # 当前节点已满,需要创建新节点 - new_node = CacheNode(parent=current, split_token_idx=len(current.token_ids)) - print(f"Creating new node at split position {new_node.split_token_idx}, parent hash: {current.hash}") - - # 将token添加到新节点 - new_node.token_ids.append(token_ids[position]) - new_node.cache_indices.append(indices[position]) - position += 1 - relative_position = 0 # next time relative pos is 0, not affecting child_key - - # 建立父子关系 - current.children[(token_id, new_node.split_token_idx)] = (new_node, new_node.split_token_idx) - - # 持久化 - self._persist_node(new_node) - self._persist_node(current) - - current = new_node - - # 确保最后修改的节点被持久化 - self._persist_node(current) - - def readable_length(self, key: torch.Tensor) -> int: - """ - 计算key对应的KV缓存索引可读取的长度 - """ - token_ids = key.cpu().tolist() - current = self.root - position = 0 - relative_position = 0 - readable_count = 0 - - while position < len(token_ids): - token_id = token_ids[position] - - # 检查当前节点的token - if relative_position < len(current.token_ids) and current.token_ids[relative_position] == token_id: - readable_count += 1 - position += 1 - relative_position += 1 - continue - - # 查找子节点 - child_key = (token_id, relative_position) - if child_key in current.children: - child_info = current.children[child_key] - assert isinstance(child_info[0], CacheNode) - child_hash = child_info[0].hash - current = self.node_cache[child_hash] - relative_position = 0 - else: - # 未找到匹配的路径,返回已读取的长度 - return readable_count - - return readable_count - - def read(self, key: torch.Tensor) -> torch.Tensor: - """ - 读取token序列对应的KV缓存索引 - key: token_ids序列 - 返回: 对应的KV缓存索引 - """ - print(f"Reading key: {key}") - token_ids = key.cpu().tolist() - result_indices = [] - - current = self.root - position = 0 - relative_position = 0 - - while position < len(token_ids): - token_id = token_ids[position] - print(f"Reading token {token_id} at position {position}, node total {len(current.token_ids)} tokens from node hash {current.hash}") - - # 检查当前节点的token - if relative_position < len(current.token_ids) and current.token_ids[relative_position] == token_id: - # TODO: 将读到的东西存到 result_indices 中 - result_indices.append(current.cache_indices[relative_position]) - position += 1 - relative_position += 1 - continue - - # 查找子节点 - child_key = (token_id, relative_position) - print(f"Looking for child {child_key} in node {current.hash}: {current.children}") - if child_key in current.children: - child_info = current.children[child_key] - assert isinstance(child_info[0], CacheNode) - child_hash = child_info[0].hash - self._ensure_node_loaded(child_hash) - current = self.node_cache[child_hash] - relative_position = 0 - else: - # 未找到匹配的路径 - return torch.tensor(result_indices) - - return torch.tensor(result_indices) - - -class HiHostTask: - def __init__(self, hashs, mode, value=None): - self.hashs = hashs - self.mode = mode - self._ready = Event() - self.data = value - - def ready(self): - return self._ready.is_set() - - def set_ready(self): - self._ready.set() - -class HiHostService: - def __init__(self): - self.tasks = Queue() - self.added_count = 0 - self.finished_count = 0 - self.running = True - self.hash_data = {} # hash -> (data, device) - self.worker = Thread(target=self.process_tasks) - self.worker.daemon = True - self.worker.start() - - def process_tasks(self): - while self.running: - if not self.tasks.empty(): - start_time = time.time() - task = self.tasks.get() - self.complete(task) - task.set_ready() - print(f"Task for {task.hashs} completed after {time.time() - start_time:.2f}s") - else: - time.sleep(0.01) - - def complete(self, task): - if task.mode == "r": - assert all(hash in self.hash_data for hash in task.hashs) - task.data = torch.stack(list(self.hash_data[hash][0] for hash in task.hashs)) - task.data.to(self.hash_data[task.hashs[0]][1]) - elif task.mode == "w": - device = task.data[0].device - for hash, value in zip(task.hashs, task.data): - self.hash_data[hash] = (value.to("cpu"), device) - self.finished_count += 1 - - def create(self, hashs, mode, value=None): - assert mode in ["r", "w"] - if not isinstance(value, list): - value = [value] - assert len(value) == len(hashs) - task = HiHostTask(hashs, mode, value) - return task - - def all_finished(self): - return self.tasks.empty() and self.added_count == self.finished_count - - def wait_till_all_finished(self): - while not self.all_finished(): - time.sleep(0.01) - - def commit(self, task): - self.tasks.put(task) - self.added_count += 1 - - def shutdown(self): - self.running = False - self.worker.join() - From 051479c33fd2b10f0c6aa1a223d06082d76d5cd9 Mon Sep 17 00:00:00 2001 From: wufeiyang Date: Mon, 31 Mar 2025 09:23:30 +0800 Subject: [PATCH 07/22] add write support --- .../router/dynamic_prompt/hiradix_cache.py | 50 +++++++++++++++++-- 1 file changed, 46 insertions(+), 4 deletions(-) diff --git a/lightllm/server/router/dynamic_prompt/hiradix_cache.py b/lightllm/server/router/dynamic_prompt/hiradix_cache.py index f23102b87..32e16e875 100644 --- a/lightllm/server/router/dynamic_prompt/hiradix_cache.py +++ b/lightllm/server/router/dynamic_prompt/hiradix_cache.py @@ -1,14 +1,26 @@ import torch -from .cache_controller import HiCacheController from .radix_cache import RadixCache, TreeNode, match from typing import Tuple, Dict, Set, List from lightllm.common.mem_manager import MemoryManager +from lightllm.utils.log_utils import init_logger +from threading import Lock, Thread +logger = init_logger(__name__) class HiRadixCache(RadixCache): - def __init__(self, cache_controller: HiCacheController, unique_name, total_token_num, rank_in_node, mem_manager: MemoryManager = None): + def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager, max_seq_length, py_cache_service): super().__init__(unique_name, total_token_num, rank_in_node, mem_manager) - self.cache_controller = cache_controller + try: + assert py_cache_service is not None + self.py_cache_service = py_cache_service + self.hi_cache_buffer_len = 0 + self.hi_cache_key_buffer = torch.empty(max_seq_length, dtype=torch.int64, device="cpu") + self.hi_cache_kv_buffer = self.mem_manager.alloc(max_seq_length) + self.moving = False + self.moving_lock = Lock() + except Exception as e: + logger.error(f"error alloc hi cache buffer {e}, fallback to normal radix cache") + self.hi_cache_kv_buffer = None def insert(self, key, value=None): if value is None: @@ -18,9 +30,39 @@ def insert(self, key, value=None): if len(key) == 0: return 0 - self.cache_controller.write(key, value) + # current implement is serial, TODO: make it parallel + # if no hi_cache_buffer, work with normal radix cache + if self.hi_cache_kv_buffer is not None: + do_copy = False + # and if is moving, ignore this insert request + with self.moving_lock: + if not self.moving: + self.moving = True + do_copy = True + if do_copy: + # copy the key and value to the hi_cache_buffer + self.hi_cache_key_buffer[:len(key)].copy_(key) + self.hi_cache_buffer_len = len(key) + for buffer_index, index in enumerate(value): + kv_data = self.mem_manager.get_index_kv_buffer(index) + self.mem_manager.load_index_kv_buffer(self.hi_cache_kv_buffer[buffer_index], kv_data) + # create a new thread to store the buffer + thread = Thread(target=self._store_buffer) + thread.start() + return self._insert_helper(self.root_node, key, value) + def _store_buffer(self): + assert self.moving + assert self.hi_cache_buffer_len > 0 + assert self.hi_cache_kv_buffer is not None + key = self.hi_cache_key_buffer[:self.hi_cache_buffer_len].tolist() + write_task = self.py_cache_service.create(tokens=key, kv_page_indexer=self.hi_cache_kv_buffer[:self.hi_cache_buffer_len], mode="w") + while not write_task.ready(): + pass + with self.moving_lock: + self.moving = False + def match_prefix(self, key, update_refs=False): assert len(key) != 0 ans_value_list = [] From 1cc2ca4b0ff253c222bb2482cd4689041d14210f Mon Sep 17 00:00:00 2001 From: wufeiyang Date: Mon, 7 Apr 2025 17:45:09 +0800 Subject: [PATCH 08/22] finish hi cache support first edition --- lightllm/server/api_cli.py | 1 + lightllm/server/api_start.py | 4 ++ lightllm/server/core/objs/start_args_type.py | 1 + .../router/dynamic_prompt/hiradix_cache.py | 42 +++++++++++++++---- lightllm/server/router/manager.py | 1 + .../model_infer/mode_backend/base_backend.py | 12 +++++- 6 files changed, 51 insertions(+), 10 deletions(-) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 601b2a48a..e6cff42be 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -205,6 +205,7 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument("--disable_dynamic_prompt_cache", action="store_true", help="disable dynamic prompt cache") parser.add_argument("--chunked_prefill_size", type=int, default=4096, help="chunked prefill size") + parser.add_argument("--use_hi_dynamic_prompt_cache", action="store_true", help="enable hierachy prompt cache") parser.add_argument("--disable_chunked_prefill", action="store_true", help="whether to disable chunked prefill") parser.add_argument("--diverse_mode", action="store_true", help="diversity generation mode") parser.add_argument("--token_healing_mode", action="store_true", help="code model infer mode") diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index de1e690a2..482cf7522 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -157,6 +157,10 @@ def normal_or_p_d_start(args): assert ( args.batch_max_tokens >= args.chunked_prefill_size ), "chunked prefill mode, batch_max_tokens must >= chunked_prefill_size" + + # if use_hi_dynamic_prompt_cache, then use_dynamic_prompt_cache must be True + if args.use_hi_dynamic_prompt_cache: + assert args.use_dynamic_prompt_cache, "use_hi_dynamic_prompt_cache must be used with use_dynamic_prompt_cache" # help to manage data stored on Ceph if "s3://" in args.model_dir: diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 8a43d983d..ac76d27ed 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -45,6 +45,7 @@ class StartArgs: router_max_wait_tokens: int = field(default=6) disable_aggressive_schedule: bool = field(default=False) disable_dynamic_prompt_cache: bool = field(default=False) + use_hi_dynamic_prompt_cache: bool = field(default=False) chunked_prefill_size: int = field(default=8192) disable_chunked_prefill: bool = field(default=False) diverse_mode: bool = field(default=False) diff --git a/lightllm/server/router/dynamic_prompt/hiradix_cache.py b/lightllm/server/router/dynamic_prompt/hiradix_cache.py index 32e16e875..67c6c162d 100644 --- a/lightllm/server/router/dynamic_prompt/hiradix_cache.py +++ b/lightllm/server/router/dynamic_prompt/hiradix_cache.py @@ -4,15 +4,20 @@ from lightllm.common.mem_manager import MemoryManager from lightllm.utils.log_utils import init_logger from threading import Lock, Thread +from cache.ffi.pywarp import PyLocalCacheService logger = init_logger(__name__) class HiRadixCache(RadixCache): - def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager, max_seq_length, py_cache_service): + def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager, max_seq_length): super().__init__(unique_name, total_token_num, rank_in_node, mem_manager) + print(f"Initializing HiRadixCache") try: - assert py_cache_service is not None - self.py_cache_service = py_cache_service + all_buffers = self.mem_manager.kv_buffer + all_buffers = all_buffers.view(all_buffers.shape[0], all_buffers.shape[1], -1) + self.py_cache_service = PyLocalCacheService( + file="cache/cache_file", storage_size=32 * (1024**3), + num_shard=32, kvcache=all_buffers, num_worker=32) self.hi_cache_buffer_len = 0 self.hi_cache_key_buffer = torch.empty(max_seq_length, dtype=torch.int64, device="cpu") self.hi_cache_kv_buffer = self.mem_manager.alloc(max_seq_length) @@ -53,24 +58,32 @@ def insert(self, key, value=None): return self._insert_helper(self.root_node, key, value) def _store_buffer(self): + logger.info(f"Storing buffer size = {self.hi_cache_buffer_len}") assert self.moving assert self.hi_cache_buffer_len > 0 assert self.hi_cache_kv_buffer is not None key = self.hi_cache_key_buffer[:self.hi_cache_buffer_len].tolist() - write_task = self.py_cache_service.create(tokens=key, kv_page_indexer=self.hi_cache_kv_buffer[:self.hi_cache_buffer_len], mode="w") + write_task = self.py_cache_service.create(tokens=key, kv_page_indexer=self.hi_cache_kv_buffer[:self.hi_cache_buffer_len].type(torch.int64).cuda(), mode="w") while not write_task.ready(): pass + logger.info(f"HiCache: stored one kvcache with len = {self.hi_cache_buffer_len}") with self.moving_lock: self.moving = False def match_prefix(self, key, update_refs=False): assert len(key) != 0 ans_value_list = [] - available_hi_result = self.cache_controller.readable_length(key) - tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=False) - if tree_node == self.root_node or available_hi_result > len(ans_value_list): - hi_result = self.cache_controller.read(key) - self._insert_helper(tree_node, key, hi_result) + tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs) + use_hi_cache = self._query_hi_cache(key, len(ans_value_list)) + if use_hi_cache: + self.free_radix_cache_to_get_enough_token(len(key)) + buffers = self.mem_manager.alloc(len(key)).type(torch.int64).cuda() + read_task = self.py_cache_service.create(tokens=key, kv_page_indexer=buffers, mode="r") + while not read_task.ready(): + pass + logger.info(f"HiCache pulled one cache with len = {len(key)}") + self._insert_helper(self.root_node, key, buffers) + ans_value_list = [] tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs) if tree_node != self.root_node: if len(ans_value_list) != 0: @@ -81,3 +94,14 @@ def match_prefix(self, key, update_refs=False): else: self.dec_node_ref_counter(self.root_node) return None, 0, None + + def _query_hi_cache(self, key, gpu_ans_len) -> bool: + query_result = self.py_cache_service.query(key) + # query_result is a list of bool, find out the max len true continuous from start + max_len = 0 + for result in query_result: + if result: + max_len += 1 + else: + break + return max_len > gpu_ans_len \ No newline at end of file diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 45e82ff3d..d9387a1d9 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -167,6 +167,7 @@ async def wait_to_model_ready(self): "return_all_prompt_logprobs": self.args.return_all_prompt_logprobs, "use_reward_model": self.args.use_reward_model, "disable_dynamic_prompt_cache": self.args.disable_dynamic_prompt_cache, + "use_hi_dynamic_prompt_cache": self.args.use_hi_dynamic_prompt_cache, "data_type": self.args.data_type, "eos_id": self.eos_id, "diverse_mode": self.args.diverse_mode, diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 329dc9f3b..7122a6cc2 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -12,6 +12,7 @@ from lightllm.utils.log_utils import init_logger from lightllm.models import get_model from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache +from lightllm.server.router.dynamic_prompt.hiradix_cache import HiRadixCache from lightllm.server.router.model_infer.infer_batch import InferReq, InferSamplingParams from lightllm.server.router.token_load import TokenLoad from lightllm.common.basemodel.infer_lock import g_infer_state_lock, InferStateLock @@ -53,6 +54,7 @@ def init_model(self, kvargs): self.chunked_prefill_size = self.args.chunked_prefill_size self.return_all_prompt_logprobs = self.args.return_all_prompt_logprobs self.use_dynamic_prompt_cache = not self.args.disable_dynamic_prompt_cache + self.use_hi_dynamic_prompt_cache = self.args.use_hi_dynamic_prompt_cache self.eos_id: List[int] = kvargs.get("eos_id", [2]) self.disable_cudagraph = self.args.disable_cudagraph @@ -118,7 +120,15 @@ def init_model(self, kvargs): self.model: TpPartBaseModel = self.model # for easy typing set_random_seed(2147483647) self.radix_cache = ( - RadixCache( + HiRadixCache( + get_unique_server_name(), + self.model.mem_manager.size, + self.rank_in_node, + mem_manager=self.model.mem_manager, + max_seq_length=kvargs.get("max_seq_length", 1024 * 5), + ) + if self.use_dynamic_prompt_cache and self.use_hi_dynamic_prompt_cache + else RadixCache( get_unique_server_name(), self.model.mem_manager.size, self.rank_in_node, From 570508a632552638b4e62afa0f326560a0140a0f Mon Sep 17 00:00:00 2001 From: wufeiyang Date: Mon, 7 Apr 2025 18:38:31 +0800 Subject: [PATCH 09/22] fix minor bug --- .../server/router/dynamic_prompt/hiradix_cache.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/lightllm/server/router/dynamic_prompt/hiradix_cache.py b/lightllm/server/router/dynamic_prompt/hiradix_cache.py index 67c6c162d..3dc5453ff 100644 --- a/lightllm/server/router/dynamic_prompt/hiradix_cache.py +++ b/lightllm/server/router/dynamic_prompt/hiradix_cache.py @@ -76,15 +76,20 @@ def match_prefix(self, key, update_refs=False): tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs) use_hi_cache = self._query_hi_cache(key, len(ans_value_list)) if use_hi_cache: - self.free_radix_cache_to_get_enough_token(len(key)) + pull_hi_cache = True + try: + self.free_radix_cache_to_get_enough_token(len(key)) + except: + pull_hi_cache = False + if pull_hi_cache: buffers = self.mem_manager.alloc(len(key)).type(torch.int64).cuda() read_task = self.py_cache_service.create(tokens=key, kv_page_indexer=buffers, mode="r") while not read_task.ready(): pass logger.info(f"HiCache pulled one cache with len = {len(key)}") self._insert_helper(self.root_node, key, buffers) - ans_value_list = [] - tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs) + ans_value_list = [] + tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs) if tree_node != self.root_node: if len(ans_value_list) != 0: value = torch.concat(ans_value_list) From f8bcc679f9d76286697869627e0304dccb5ef4dd Mon Sep 17 00:00:00 2001 From: wufeiyang Date: Mon, 7 Apr 2025 19:10:59 +0800 Subject: [PATCH 10/22] fix minor bug --- .../router/dynamic_prompt/hiradix_cache.py | 25 +++++++++++-------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/lightllm/server/router/dynamic_prompt/hiradix_cache.py b/lightllm/server/router/dynamic_prompt/hiradix_cache.py index 3dc5453ff..b85c24c19 100644 --- a/lightllm/server/router/dynamic_prompt/hiradix_cache.py +++ b/lightllm/server/router/dynamic_prompt/hiradix_cache.py @@ -5,13 +5,14 @@ from lightllm.utils.log_utils import init_logger from threading import Lock, Thread from cache.ffi.pywarp import PyLocalCacheService +import time logger = init_logger(__name__) class HiRadixCache(RadixCache): def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager, max_seq_length): super().__init__(unique_name, total_token_num, rank_in_node, mem_manager) - print(f"Initializing HiRadixCache") + logger.info(f"Initializing HiRadixCache") try: all_buffers = self.mem_manager.kv_buffer all_buffers = all_buffers.view(all_buffers.shape[0], all_buffers.shape[1], -1) @@ -65,7 +66,7 @@ def _store_buffer(self): key = self.hi_cache_key_buffer[:self.hi_cache_buffer_len].tolist() write_task = self.py_cache_service.create(tokens=key, kv_page_indexer=self.hi_cache_kv_buffer[:self.hi_cache_buffer_len].type(torch.int64).cuda(), mode="w") while not write_task.ready(): - pass + time.sleep(0.5) logger.info(f"HiCache: stored one kvcache with len = {self.hi_cache_buffer_len}") with self.moving_lock: self.moving = False @@ -74,19 +75,21 @@ def match_prefix(self, key, update_refs=False): assert len(key) != 0 ans_value_list = [] tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs) - use_hi_cache = self._query_hi_cache(key, len(ans_value_list)) - if use_hi_cache: + max_len = self._query_hi_cache(key) + logger.info(f"Matched {len(ans_value_list)} from gpu and {max_len} from disk.") + pull_hi_cache = False + if max_len > len(ans_value_list): pull_hi_cache = True try: - self.free_radix_cache_to_get_enough_token(len(key)) + self.free_radix_cache_to_get_enough_token(max_len) except: pull_hi_cache = False if pull_hi_cache: - buffers = self.mem_manager.alloc(len(key)).type(torch.int64).cuda() - read_task = self.py_cache_service.create(tokens=key, kv_page_indexer=buffers, mode="r") + buffers = self.mem_manager.alloc(max_len).type(torch.int64).cuda() + read_task = self.py_cache_service.create(tokens=key[:max_len], kv_page_indexer=buffers, mode="r") while not read_task.ready(): - pass - logger.info(f"HiCache pulled one cache with len = {len(key)}") + time.sleep(0.5) + logger.info(f"HiCache pulled one cache with len = {max_len}") self._insert_helper(self.root_node, key, buffers) ans_value_list = [] tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs) @@ -100,7 +103,7 @@ def match_prefix(self, key, update_refs=False): self.dec_node_ref_counter(self.root_node) return None, 0, None - def _query_hi_cache(self, key, gpu_ans_len) -> bool: + def _query_hi_cache(self, key) -> bool: query_result = self.py_cache_service.query(key) # query_result is a list of bool, find out the max len true continuous from start max_len = 0 @@ -109,4 +112,4 @@ def _query_hi_cache(self, key, gpu_ans_len) -> bool: max_len += 1 else: break - return max_len > gpu_ans_len \ No newline at end of file + return max_len \ No newline at end of file From a9ba57da3fe6eea0697c41b4d01a0344ec9d4df4 Mon Sep 17 00:00:00 2001 From: wufeiyang Date: Tue, 15 Apr 2025 16:56:52 +0800 Subject: [PATCH 11/22] add little fixes, still buggy --- .../router/dynamic_prompt/hiradix_cache.py | 55 +++++++++++++------ 1 file changed, 38 insertions(+), 17 deletions(-) diff --git a/lightllm/server/router/dynamic_prompt/hiradix_cache.py b/lightllm/server/router/dynamic_prompt/hiradix_cache.py index b85c24c19..a804490fc 100644 --- a/lightllm/server/router/dynamic_prompt/hiradix_cache.py +++ b/lightllm/server/router/dynamic_prompt/hiradix_cache.py @@ -3,7 +3,7 @@ from typing import Tuple, Dict, Set, List from lightllm.common.mem_manager import MemoryManager from lightllm.utils.log_utils import init_logger -from threading import Lock, Thread +from threading import Lock from cache.ffi.pywarp import PyLocalCacheService import time @@ -13,21 +13,27 @@ class HiRadixCache(RadixCache): def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager, max_seq_length): super().__init__(unique_name, total_token_num, rank_in_node, mem_manager) logger.info(f"Initializing HiRadixCache") + self.rank_in_node = rank_in_node + self.write_task = None + self.starting = True try: all_buffers = self.mem_manager.kv_buffer all_buffers = all_buffers.view(all_buffers.shape[0], all_buffers.shape[1], -1) self.py_cache_service = PyLocalCacheService( - file="cache/cache_file", storage_size=32 * (1024**3), + file="cache/cache_file", storage_size=64 * (1024**3), num_shard=32, kvcache=all_buffers, num_worker=32) self.hi_cache_buffer_len = 0 self.hi_cache_key_buffer = torch.empty(max_seq_length, dtype=torch.int64, device="cpu") self.hi_cache_kv_buffer = self.mem_manager.alloc(max_seq_length) - self.moving = False + self.start_store_task = False self.moving_lock = Lock() except Exception as e: logger.error(f"error alloc hi cache buffer {e}, fallback to normal radix cache") self.hi_cache_kv_buffer = None + # write a new function, only insert input(after prefill), call after prefill, + # then when the decode finishes, do syncronize to see whether this can be free + # no buffer, parallel insert inputs def insert(self, key, value=None): if value is None: value = key @@ -42,9 +48,16 @@ def insert(self, key, value=None): do_copy = False # and if is moving, ignore this insert request with self.moving_lock: - if not self.moving: - self.moving = True + if (not self.start_store_task) and self.write_task is not None: + if self.write_task.ready(): + logger.info(f"HiCache of [{self.rank_in_node}]: stored one kvcache with len = {self.hi_cache_buffer_len}") + self.start_store_task = True # ensure ready => start new only one kvcache stores + do_copy = True + elif self.write_task is None and self.starting: + self.starting = False + self.start_store_task = True do_copy = True + if do_copy: # copy the key and value to the hi_cache_buffer self.hi_cache_key_buffer[:len(key)].copy_(key) @@ -53,29 +66,31 @@ def insert(self, key, value=None): kv_data = self.mem_manager.get_index_kv_buffer(index) self.mem_manager.load_index_kv_buffer(self.hi_cache_kv_buffer[buffer_index], kv_data) # create a new thread to store the buffer - thread = Thread(target=self._store_buffer) - thread.start() + self._store_buffer() return self._insert_helper(self.root_node, key, value) def _store_buffer(self): logger.info(f"Storing buffer size = {self.hi_cache_buffer_len}") - assert self.moving assert self.hi_cache_buffer_len > 0 assert self.hi_cache_kv_buffer is not None key = self.hi_cache_key_buffer[:self.hi_cache_buffer_len].tolist() - write_task = self.py_cache_service.create(tokens=key, kv_page_indexer=self.hi_cache_kv_buffer[:self.hi_cache_buffer_len].type(torch.int64).cuda(), mode="w") - while not write_task.ready(): - time.sleep(0.5) - logger.info(f"HiCache: stored one kvcache with len = {self.hi_cache_buffer_len}") + self.write_task = self.py_cache_service.create( + tokens=key, kv_page_indexer=self.hi_cache_kv_buffer[:self.hi_cache_buffer_len].type(torch.int64).cuda(), mode="w") with self.moving_lock: - self.moving = False + self.start_store_task = False def match_prefix(self, key, update_refs=False): + st_time = time.time() assert len(key) != 0 ans_value_list = [] tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs) - max_len = self._query_hi_cache(key) + # add a parameter if get long enough (>50%) + first_query_time = time.time() + logger.info(f"HiCache of [{self.rank_in_node}]: No.1 First GPU query took {first_query_time - st_time}") + max_len = self._query_hi_cache(key) # x64 + hi_cache_query_time = time.time() + logger.info(f"HiCache of [{self.rank_in_node}]: No.2 Disk query took {hi_cache_query_time - first_query_time}") logger.info(f"Matched {len(ans_value_list)} from gpu and {max_len} from disk.") pull_hi_cache = False if max_len > len(ans_value_list): @@ -85,14 +100,20 @@ def match_prefix(self, key, update_refs=False): except: pull_hi_cache = False if pull_hi_cache: - buffers = self.mem_manager.alloc(max_len).type(torch.int64).cuda() + buffers = self.mem_manager.alloc(max_len).type(torch.int64).cuda() # type change & no .cuda() read_task = self.py_cache_service.create(tokens=key[:max_len], kv_page_indexer=buffers, mode="r") while not read_task.ready(): - time.sleep(0.5) + time.sleep(0.01) + hicache_pull_time = time.time() + logger.info(f"HiCache of [{self.rank_in_node}]: No.3 Disk pull took {hicache_pull_time - hi_cache_query_time}") logger.info(f"HiCache pulled one cache with len = {max_len}") - self._insert_helper(self.root_node, key, buffers) + # maybe try: add a function to only insert middle part of kv cache + self._insert_helper(self.root_node, key, buffers.cpu()) # no .cpu() + insert_time = time.time() + logger.info(f"HiCache of [{self.rank_in_node}]: No.4 Reinsert took {insert_time - hicache_pull_time}") ans_value_list = [] tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs) + logger.info(f"HiCache of [{self.rank_in_node}]: No.5 Re match prefix took {time.time() - insert_time}") if tree_node != self.root_node: if len(ans_value_list) != 0: value = torch.concat(ans_value_list) From 3e9154ff41a9f4debaa0e7d56fabbcb13837e3f9 Mon Sep 17 00:00:00 2001 From: wufeiyang Date: Fri, 18 Apr 2025 12:19:59 +0800 Subject: [PATCH 12/22] fix cuda() and type --- .../server/router/dynamic_prompt/hiradix_cache.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/lightllm/server/router/dynamic_prompt/hiradix_cache.py b/lightllm/server/router/dynamic_prompt/hiradix_cache.py index a804490fc..6818f6814 100644 --- a/lightllm/server/router/dynamic_prompt/hiradix_cache.py +++ b/lightllm/server/router/dynamic_prompt/hiradix_cache.py @@ -4,7 +4,7 @@ from lightllm.common.mem_manager import MemoryManager from lightllm.utils.log_utils import init_logger from threading import Lock -from cache.ffi.pywarp import PyLocalCacheService +from kvcache.python.jit import PyLocalCacheService import time logger = init_logger(__name__) @@ -21,7 +21,7 @@ def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager, max_ all_buffers = all_buffers.view(all_buffers.shape[0], all_buffers.shape[1], -1) self.py_cache_service = PyLocalCacheService( file="cache/cache_file", storage_size=64 * (1024**3), - num_shard=32, kvcache=all_buffers, num_worker=32) + num_shard=32, kvcache_tensor=all_buffers, num_worker=32) self.hi_cache_buffer_len = 0 self.hi_cache_key_buffer = torch.empty(max_seq_length, dtype=torch.int64, device="cpu") self.hi_cache_kv_buffer = self.mem_manager.alloc(max_seq_length) @@ -76,7 +76,7 @@ def _store_buffer(self): assert self.hi_cache_kv_buffer is not None key = self.hi_cache_key_buffer[:self.hi_cache_buffer_len].tolist() self.write_task = self.py_cache_service.create( - tokens=key, kv_page_indexer=self.hi_cache_kv_buffer[:self.hi_cache_buffer_len].type(torch.int64).cuda(), mode="w") + tokens=key, kv_page_indexer=self.hi_cache_kv_buffer[:self.hi_cache_buffer_len], mode="w") with self.moving_lock: self.start_store_task = False @@ -100,7 +100,7 @@ def match_prefix(self, key, update_refs=False): except: pull_hi_cache = False if pull_hi_cache: - buffers = self.mem_manager.alloc(max_len).type(torch.int64).cuda() # type change & no .cuda() + buffers = self.mem_manager.alloc(max_len) read_task = self.py_cache_service.create(tokens=key[:max_len], kv_page_indexer=buffers, mode="r") while not read_task.ready(): time.sleep(0.01) @@ -108,7 +108,7 @@ def match_prefix(self, key, update_refs=False): logger.info(f"HiCache of [{self.rank_in_node}]: No.3 Disk pull took {hicache_pull_time - hi_cache_query_time}") logger.info(f"HiCache pulled one cache with len = {max_len}") # maybe try: add a function to only insert middle part of kv cache - self._insert_helper(self.root_node, key, buffers.cpu()) # no .cpu() + self._insert_helper(self.root_node, key, buffers) insert_time = time.time() logger.info(f"HiCache of [{self.rank_in_node}]: No.4 Reinsert took {insert_time - hicache_pull_time}") ans_value_list = [] @@ -133,4 +133,4 @@ def _query_hi_cache(self, key) -> bool: max_len += 1 else: break - return max_len \ No newline at end of file + return max_len * self.py_cache_service.tokens_per_block \ No newline at end of file From b64cf2f2205691ffdfacfd83ce9876f376733b90 Mon Sep 17 00:00:00 2001 From: wufeiyang Date: Fri, 18 Apr 2025 16:20:50 +0800 Subject: [PATCH 13/22] Change to only store prefill --- .../router/dynamic_prompt/hiradix_cache.py | 113 ++++++++++-------- .../router/dynamic_prompt/radix_cache.py | 2 + .../server/router/model_infer/infer_batch.py | 3 + .../model_infer/mode_backend/base_backend.py | 7 ++ .../mode_backend/chunked_prefill/impl.py | 1 + .../mode_backend/continues_batch/impl.py | 1 + 6 files changed, 77 insertions(+), 50 deletions(-) diff --git a/lightllm/server/router/dynamic_prompt/hiradix_cache.py b/lightllm/server/router/dynamic_prompt/hiradix_cache.py index 6818f6814..451b7eb8a 100644 --- a/lightllm/server/router/dynamic_prompt/hiradix_cache.py +++ b/lightllm/server/router/dynamic_prompt/hiradix_cache.py @@ -4,6 +4,7 @@ from lightllm.common.mem_manager import MemoryManager from lightllm.utils.log_utils import init_logger from threading import Lock +from enum import Enum from kvcache.python.jit import PyLocalCacheService import time @@ -14,71 +15,81 @@ def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager, max_ super().__init__(unique_name, total_token_num, rank_in_node, mem_manager) logger.info(f"Initializing HiRadixCache") self.rank_in_node = rank_in_node - self.write_task = None - self.starting = True try: + self.is_hi_radix_cache = True all_buffers = self.mem_manager.kv_buffer all_buffers = all_buffers.view(all_buffers.shape[0], all_buffers.shape[1], -1) self.py_cache_service = PyLocalCacheService( - file="cache/cache_file", storage_size=64 * (1024**3), + file="cache/cache_file", storage_size=128 * (1024**3), num_shard=32, kvcache_tensor=all_buffers, num_worker=32) - self.hi_cache_buffer_len = 0 - self.hi_cache_key_buffer = torch.empty(max_seq_length, dtype=torch.int64, device="cpu") - self.hi_cache_kv_buffer = self.mem_manager.alloc(max_seq_length) - self.start_store_task = False - self.moving_lock = Lock() + self.working_tasks = {} except Exception as e: logger.error(f"error alloc hi cache buffer {e}, fallback to normal radix cache") self.hi_cache_kv_buffer = None + self.is_hi_radix_cache = False # write a new function, only insert input(after prefill), call after prefill, # then when the decode finishes, do syncronize to see whether this can be free # no buffer, parallel insert inputs - def insert(self, key, value=None): - if value is None: - value = key + def insert_disk(self, req_id, key, value): + if req_id in self.working_tasks: + self.wait_till_finish(req_id) + self.working_tasks[req_id] = self.py_cache_service.create(tokens=key, kv_page_indexer=value, mode="w") + logger.info(f"Created store task for req {req_id}.") + + def wait_till_finish(self, req_id): + if req_id not in self.working_tasks: + return + starting_time = time.time() + while not self.working_tasks[req_id].ready(): + time.sleep(0.01) + logger.info(f"Waited {time.time() - starting_time}s for req {req_id}.") + + # def insert(self, key, value=None): + # if value is None: + # value = key - assert len(key) == len(value) # and len(key) >= 1 - if len(key) == 0: - return 0 + # assert len(key) == len(value) # and len(key) >= 1 + # if len(key) == 0: + # return 0 - # current implement is serial, TODO: make it parallel - # if no hi_cache_buffer, work with normal radix cache - if self.hi_cache_kv_buffer is not None: - do_copy = False - # and if is moving, ignore this insert request - with self.moving_lock: - if (not self.start_store_task) and self.write_task is not None: - if self.write_task.ready(): - logger.info(f"HiCache of [{self.rank_in_node}]: stored one kvcache with len = {self.hi_cache_buffer_len}") - self.start_store_task = True # ensure ready => start new only one kvcache stores - do_copy = True - elif self.write_task is None and self.starting: - self.starting = False - self.start_store_task = True - do_copy = True + # # current implement is serial, TODO: make it parallel + # # if no hi_cache_buffer, work with normal radix cache + # if self.hi_cache_kv_buffer is not None: + # do_copy = False + # # and if is moving, ignore this insert request + # with self.moving_lock: + # if (not self.start_store_task) and self.write_task is not None: + # if self.write_task.ready(): + # logger.info(f"HiCache of [{self.rank_in_node}]: stored one kvcache with len = {self.hi_cache_buffer_len}") + # self.start_store_task = True # ensure ready => start new only one kvcache stores + # do_copy = True + # elif self.write_task is None and self.starting: + # self.starting = False + # self.start_store_task = True + # do_copy = True - if do_copy: - # copy the key and value to the hi_cache_buffer - self.hi_cache_key_buffer[:len(key)].copy_(key) - self.hi_cache_buffer_len = len(key) - for buffer_index, index in enumerate(value): - kv_data = self.mem_manager.get_index_kv_buffer(index) - self.mem_manager.load_index_kv_buffer(self.hi_cache_kv_buffer[buffer_index], kv_data) - # create a new thread to store the buffer - self._store_buffer() + # if do_copy: + # # copy the key and value to the hi_cache_buffer + # self.hi_cache_key_buffer[:len(key)].copy_(key) + # self.hi_cache_buffer_len = len(key) + # for buffer_index, index in enumerate(value): + # kv_data = self.mem_manager.get_index_kv_buffer(index) + # self.mem_manager.load_index_kv_buffer(self.hi_cache_kv_buffer[buffer_index], kv_data) + # # create a new thread to store the buffer + # self._store_buffer() - return self._insert_helper(self.root_node, key, value) + # return self._insert_helper(self.root_node, key, value) - def _store_buffer(self): - logger.info(f"Storing buffer size = {self.hi_cache_buffer_len}") - assert self.hi_cache_buffer_len > 0 - assert self.hi_cache_kv_buffer is not None - key = self.hi_cache_key_buffer[:self.hi_cache_buffer_len].tolist() - self.write_task = self.py_cache_service.create( - tokens=key, kv_page_indexer=self.hi_cache_kv_buffer[:self.hi_cache_buffer_len], mode="w") - with self.moving_lock: - self.start_store_task = False + # def _store_buffer(self): + # logger.info(f"Storing buffer size = {self.hi_cache_buffer_len}") + # assert self.hi_cache_buffer_len > 0 + # assert self.hi_cache_kv_buffer is not None + # key = self.hi_cache_key_buffer[:self.hi_cache_buffer_len].tolist() + # self.write_task = self.py_cache_service.create( + # tokens=key, kv_page_indexer=self.hi_cache_kv_buffer[:self.hi_cache_buffer_len], mode="w") + # with self.moving_lock: + # self.start_store_task = False def match_prefix(self, key, update_refs=False): st_time = time.time() @@ -101,11 +112,13 @@ def match_prefix(self, key, update_refs=False): pull_hi_cache = False if pull_hi_cache: buffers = self.mem_manager.alloc(max_len) + before_pull_time = time.time() + logger.info(f"HiCache of [{self.rank_in_node}]: No.2.5 Before disk pull took {before_pull_time - hi_cache_query_time}") read_task = self.py_cache_service.create(tokens=key[:max_len], kv_page_indexer=buffers, mode="r") while not read_task.ready(): - time.sleep(0.01) + time.sleep(0.1) hicache_pull_time = time.time() - logger.info(f"HiCache of [{self.rank_in_node}]: No.3 Disk pull took {hicache_pull_time - hi_cache_query_time}") + logger.info(f"HiCache of [{self.rank_in_node}]: No.3 Disk pull took {hicache_pull_time - before_pull_time}") logger.info(f"HiCache pulled one cache with len = {max_len}") # maybe try: add a function to only insert middle part of kv cache self._insert_helper(self.root_node, key, buffers) diff --git a/lightllm/server/router/dynamic_prompt/radix_cache.py b/lightllm/server/router/dynamic_prompt/radix_cache.py index aeffd3a67..0695d5e77 100644 --- a/lightllm/server/router/dynamic_prompt/radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/radix_cache.py @@ -115,6 +115,8 @@ def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager: Memo f"{unique_name}_tree_total_tokens_num_{rank_in_node}", (1,), dtype=np.int64 ) self.tree_total_tokens_num.arr[0] = 0 + + self.is_hi_radix_cache = False def insert(self, key, value=None): if value is None: diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 10b68245c..497a26cf5 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -109,6 +109,9 @@ def free_a_req_mem(self, free_token_index: List, req: "InferReq", is_group_finis if req.shared_kv_node is not None: self.radix_cache.dec_node_ref_counter(req.shared_kv_node) req.shared_kv_node = None + + if self.radix_cache.is_hi_radix_cache: + self.radix_cache.wait_till_finish(req.req_id) def _save_promptcache_kvbuffer(self): """ diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 7122a6cc2..a724a6a42 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -167,6 +167,13 @@ def prefill(self, reqs: List[Tuple]): def decode(self): """This method can be overridden in subclasses.""" raise NotImplementedError() + + def store_hicache_after_prefill(self, run_reqs): + if self.use_hi_dynamic_prompt_cache and self.radix_cache is not None: + for req in run_reqs: + key = torch.tensor(req.get_input_token_ids()[0 : req.cur_kv_len], dtype=torch.int64, device="cpu") + value = self.model.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].detach().cpu() + self.radix_cache.insert_disk(req.req_id, key, value) def pause_reqs(self, req_ids): if self.dp_size_in_node != 1: diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index 00528fec7..b83b68e97 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -59,6 +59,7 @@ def decode(self): prefill_reqs, is_chuncked_mode=True, is_multimodal=self.is_multimodal ) model_output = self.model.forward(model_input) + self.store_hicache_after_prefill(run_reqs) self._overlap_req_init_and_filter( uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True ) diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/impl.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl.py index b0eb2b58f..38f785501 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl.py @@ -35,6 +35,7 @@ def decode(self): ) model_output = self.model.forward(model_input) logits = model_output.logits + self.store_hicache_after_prefill(run_reqs) self._overlap_req_init_and_filter( uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True From 4f2be9ca3786206f1d5e6319522ad09138b39f74 Mon Sep 17 00:00:00 2001 From: wufeiyang Date: Fri, 18 Apr 2025 17:13:56 +0800 Subject: [PATCH 14/22] format --- lightllm/server/api_start.py | 2 +- .../router/dynamic_prompt/hiradix_cache.py | 27 ++++++++++++------- .../router/dynamic_prompt/radix_cache.py | 2 +- .../server/router/model_infer/infer_batch.py | 2 +- .../model_infer/mode_backend/base_backend.py | 2 +- 5 files changed, 21 insertions(+), 14 deletions(-) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 482cf7522..bca4f23b4 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -157,7 +157,7 @@ def normal_or_p_d_start(args): assert ( args.batch_max_tokens >= args.chunked_prefill_size ), "chunked prefill mode, batch_max_tokens must >= chunked_prefill_size" - + # if use_hi_dynamic_prompt_cache, then use_dynamic_prompt_cache must be True if args.use_hi_dynamic_prompt_cache: assert args.use_dynamic_prompt_cache, "use_hi_dynamic_prompt_cache must be used with use_dynamic_prompt_cache" diff --git a/lightllm/server/router/dynamic_prompt/hiradix_cache.py b/lightllm/server/router/dynamic_prompt/hiradix_cache.py index 451b7eb8a..fc466ce0c 100644 --- a/lightllm/server/router/dynamic_prompt/hiradix_cache.py +++ b/lightllm/server/router/dynamic_prompt/hiradix_cache.py @@ -10,6 +10,7 @@ logger = init_logger(__name__) + class HiRadixCache(RadixCache): def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager, max_seq_length): super().__init__(unique_name, total_token_num, rank_in_node, mem_manager) @@ -20,14 +21,18 @@ def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager, max_ all_buffers = self.mem_manager.kv_buffer all_buffers = all_buffers.view(all_buffers.shape[0], all_buffers.shape[1], -1) self.py_cache_service = PyLocalCacheService( - file="cache/cache_file", storage_size=128 * (1024**3), - num_shard=32, kvcache_tensor=all_buffers, num_worker=32) + file="cache/cache_file", + storage_size=128 * (1024 ** 3), + num_shard=32, + kvcache_tensor=all_buffers, + num_worker=32, + ) self.working_tasks = {} except Exception as e: logger.error(f"error alloc hi cache buffer {e}, fallback to normal radix cache") self.hi_cache_kv_buffer = None self.is_hi_radix_cache = False - + # write a new function, only insert input(after prefill), call after prefill, # then when the decode finishes, do syncronize to see whether this can be free # no buffer, parallel insert inputs @@ -36,7 +41,7 @@ def insert_disk(self, req_id, key, value): self.wait_till_finish(req_id) self.working_tasks[req_id] = self.py_cache_service.create(tokens=key, kv_page_indexer=value, mode="w") logger.info(f"Created store task for req {req_id}.") - + def wait_till_finish(self, req_id): if req_id not in self.working_tasks: return @@ -52,7 +57,7 @@ def wait_till_finish(self, req_id): # assert len(key) == len(value) # and len(key) >= 1 # if len(key) == 0: # return 0 - + # # current implement is serial, TODO: make it parallel # # if no hi_cache_buffer, work with normal radix cache # if self.hi_cache_kv_buffer is not None: @@ -80,7 +85,7 @@ def wait_till_finish(self, req_id): # self._store_buffer() # return self._insert_helper(self.root_node, key, value) - + # def _store_buffer(self): # logger.info(f"Storing buffer size = {self.hi_cache_buffer_len}") # assert self.hi_cache_buffer_len > 0 @@ -99,7 +104,7 @@ def match_prefix(self, key, update_refs=False): # add a parameter if get long enough (>50%) first_query_time = time.time() logger.info(f"HiCache of [{self.rank_in_node}]: No.1 First GPU query took {first_query_time - st_time}") - max_len = self._query_hi_cache(key) # x64 + max_len = self._query_hi_cache(key) # x64 hi_cache_query_time = time.time() logger.info(f"HiCache of [{self.rank_in_node}]: No.2 Disk query took {hi_cache_query_time - first_query_time}") logger.info(f"Matched {len(ans_value_list)} from gpu and {max_len} from disk.") @@ -113,7 +118,9 @@ def match_prefix(self, key, update_refs=False): if pull_hi_cache: buffers = self.mem_manager.alloc(max_len) before_pull_time = time.time() - logger.info(f"HiCache of [{self.rank_in_node}]: No.2.5 Before disk pull took {before_pull_time - hi_cache_query_time}") + logger.info( + f"HiCache of [{self.rank_in_node}]: No.2.5 Before disk pull took {before_pull_time - hi_cache_query_time}" + ) read_task = self.py_cache_service.create(tokens=key[:max_len], kv_page_indexer=buffers, mode="r") while not read_task.ready(): time.sleep(0.1) @@ -136,7 +143,7 @@ def match_prefix(self, key, update_refs=False): else: self.dec_node_ref_counter(self.root_node) return None, 0, None - + def _query_hi_cache(self, key) -> bool: query_result = self.py_cache_service.query(key) # query_result is a list of bool, find out the max len true continuous from start @@ -146,4 +153,4 @@ def _query_hi_cache(self, key) -> bool: max_len += 1 else: break - return max_len * self.py_cache_service.tokens_per_block \ No newline at end of file + return max_len * self.py_cache_service.tokens_per_block diff --git a/lightllm/server/router/dynamic_prompt/radix_cache.py b/lightllm/server/router/dynamic_prompt/radix_cache.py index 0695d5e77..45dd50099 100644 --- a/lightllm/server/router/dynamic_prompt/radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/radix_cache.py @@ -115,7 +115,7 @@ def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager: Memo f"{unique_name}_tree_total_tokens_num_{rank_in_node}", (1,), dtype=np.int64 ) self.tree_total_tokens_num.arr[0] = 0 - + self.is_hi_radix_cache = False def insert(self, key, value=None): diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 497a26cf5..18e524629 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -109,7 +109,7 @@ def free_a_req_mem(self, free_token_index: List, req: "InferReq", is_group_finis if req.shared_kv_node is not None: self.radix_cache.dec_node_ref_counter(req.shared_kv_node) req.shared_kv_node = None - + if self.radix_cache.is_hi_radix_cache: self.radix_cache.wait_till_finish(req.req_id) diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index a724a6a42..d0da54a1d 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -167,7 +167,7 @@ def prefill(self, reqs: List[Tuple]): def decode(self): """This method can be overridden in subclasses.""" raise NotImplementedError() - + def store_hicache_after_prefill(self, run_reqs): if self.use_hi_dynamic_prompt_cache and self.radix_cache is not None: for req in run_reqs: From 67da4cb722b119337a7cee685fa86e1122761fa3 Mon Sep 17 00:00:00 2001 From: wufeiyang Date: Fri, 18 Apr 2025 17:14:50 +0800 Subject: [PATCH 15/22] format2 --- lightllm/server/router/dynamic_prompt/hiradix_cache.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lightllm/server/router/dynamic_prompt/hiradix_cache.py b/lightllm/server/router/dynamic_prompt/hiradix_cache.py index fc466ce0c..2a62b28de 100644 --- a/lightllm/server/router/dynamic_prompt/hiradix_cache.py +++ b/lightllm/server/router/dynamic_prompt/hiradix_cache.py @@ -14,7 +14,7 @@ class HiRadixCache(RadixCache): def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager, max_seq_length): super().__init__(unique_name, total_token_num, rank_in_node, mem_manager) - logger.info(f"Initializing HiRadixCache") + logger.info("Initializing HiRadixCache") self.rank_in_node = rank_in_node try: self.is_hi_radix_cache = True @@ -66,7 +66,7 @@ def wait_till_finish(self, req_id): # with self.moving_lock: # if (not self.start_store_task) and self.write_task is not None: # if self.write_task.ready(): - # logger.info(f"HiCache of [{self.rank_in_node}]: stored one kvcache with len = {self.hi_cache_buffer_len}") + # logger.info(f"HiCache of [{self.rank_in_node}]: stored len = {self.hi_cache_buffer_len}") # self.start_store_task = True # ensure ready => start new only one kvcache stores # do_copy = True # elif self.write_task is None and self.starting: @@ -119,7 +119,7 @@ def match_prefix(self, key, update_refs=False): buffers = self.mem_manager.alloc(max_len) before_pull_time = time.time() logger.info( - f"HiCache of [{self.rank_in_node}]: No.2.5 Before disk pull took {before_pull_time - hi_cache_query_time}" + f"HiCache of [{self.rank_in_node}]: No.2.5 Before pull took {before_pull_time - hi_cache_query_time}" ) read_task = self.py_cache_service.create(tokens=key[:max_len], kv_page_indexer=buffers, mode="r") while not read_task.ready(): From 8f79f8149317b42c58f4f3738b30bc9756e6a38a Mon Sep 17 00:00:00 2001 From: Feiyang Wu Date: Fri, 18 Apr 2025 17:29:57 +0800 Subject: [PATCH 16/22] format3 --- test/server/test_hicache.py | 63 +++++++++++++++++++++++-------------- 1 file changed, 39 insertions(+), 24 deletions(-) diff --git a/test/server/test_hicache.py b/test/server/test_hicache.py index b30c00560..bb82457c4 100644 --- a/test/server/test_hicache.py +++ b/test/server/test_hicache.py @@ -4,10 +4,18 @@ import random from threading import Thread, Event from queue import Queue -from lightllm.server.router.dynamic_prompt.cache_controller import HiCacheController, CacheNode, BLOCK_SIZE, HiHostService, HiHostTask +from lightllm.server.router.dynamic_prompt.cache_controller import ( + HiCacheController, + CacheNode, + BLOCK_SIZE, + HiHostService, + HiHostTask, +) + class MockMemoryManager: """模拟内存管理器,仅返回连续的索引值""" + def __init__(self): self.current_idx = 0 self.kvcache_store = {} @@ -15,26 +23,28 @@ def __init__(self): def alloc(self, size): indices = list(range(self.current_idx, self.current_idx + size)) self.current_idx += size - self.store(indices, torch.tensor([[random.randint(0, 0xffff) for __ in range(512)] for _ in range(size)])) + self.store(indices, torch.tensor([[random.randint(0, 0xFFFF) for __ in range(512)] for _ in range(size)])) return indices - + def load_index_kv_buffer(self, index, load_tensor_dict): self.kvcache_store[index] = load_tensor_dict["kv_buffer"] - + def get_index_kv_buffer(self, index): return {"kv_buffer": self.kvcache_store[index]} - + def to_kvcache(self, indices): - assert all([idx in self.kvcache_store for idx in indices]), f"Not all of {indices} are not found in kvcache_store" + assert all( + [idx in self.kvcache_store for idx in indices] + ), f"Not all of {indices} are not found in kvcache_store" return torch.tensor([self.kvcache_store[idx].tolist() for idx in indices]) - + def store(self, indices, value): print(f"[TEST:MemManager] Storing {value.shape} at {indices}") for idx, value_dim in zip(indices, range(value.shape[0])): self.kvcache_store[idx] = value[value_dim] print(f"[TEST:MemManager] Stored {value[value_dim].shape} at {idx}") return indices - + def free(self, indices): print(f"[TEST:MemManager] Freeing {indices}") for idx in indices: @@ -46,43 +56,45 @@ def setup(): service = HiHostService() hicache = HiCacheController(mem_manager) hicache.service = service # 注入模拟服务 - + indices = mem_manager.alloc(5) print(mem_manager.to_kvcache(indices)) - + # 预先计算单token大小 dummy_indices = mem_manager.alloc(1) kvcache = mem_manager.to_kvcache(dummy_indices[:1]) token_size = kvcache.nelement() * kvcache.element_size() print(f"[TEST] Single token KV cache size: {token_size} bytes, Block size: {BLOCK_SIZE}") - + return mem_manager, service, hicache, token_size + def test_basic_write_read(mem_manager, hicache, token_size): # 计算每个块可容纳的token数量 tokens_per_block = BLOCK_SIZE // token_size print(f"[TEST] Each block can hold {tokens_per_block} tokens") - + # 生成测试数据:刚好占满一个块 token_ids = list(range(tokens_per_block)) indices = mem_manager.alloc(len(token_ids)) kvcache = mem_manager.to_kvcache(indices) print(f"[TEST] Generated KV cache with shape: {kvcache.shape}, type: {kvcache.dtype}") - + # 写入缓存 hicache.write(torch.tensor(token_ids), torch.tensor(indices)) time.sleep(2) - + # 等待任务完成 hicache.service.wait_till_all_finished() - + mem_manager.free(indices) - + # 读取验证 result = hicache.read(torch.tensor(token_ids)) result = mem_manager.to_kvcache(result.tolist()) assert result.eq(kvcache).all(), f"Retrieved kvcache: {result}, Expected kvcache: {kvcache}" - print(f"[TEST] Basic test passed. Retrieved kvcache\n\n") + print("[TEST] Basic test passed. Retrieved kvcache\n\n") + def test_node_splitting(mem_manager, hicache, token_size): tokens_per_block = BLOCK_SIZE // token_size @@ -90,22 +102,23 @@ def test_node_splitting(mem_manager, hicache, token_size): token_ids = list(range(12, 12 + tokens_per_block * 3 + 1)) indices = mem_manager.alloc(len(token_ids)) kvcache = mem_manager.to_kvcache(indices) - + hicache.write(torch.tensor(token_ids), torch.tensor(indices)) time.sleep(2) hicache.service.wait_till_all_finished() - + # 验证根节点应该有子节点 root = hicache.root assert len(root.children) > 0 print(f"\nRoot node has {len(root.children)} children") - + # 读取完整序列 result = hicache.read(torch.tensor(token_ids)) result = mem_manager.to_kvcache(result.tolist()) assert result.eq(kvcache).all(), f"Retrieved kvcache: {result}, Expected kvcache: {kvcache}" print(f"[TEST] Node splitting test passed. Retrieved kvcache: {result.shape}\n\n") + def test_partial_read(mem_manager, hicache): token_ids = [97, 98, 99, 100, 101, 102] indices = mem_manager.alloc(len(token_ids)) @@ -113,13 +126,13 @@ def test_partial_read(mem_manager, hicache): hicache.write(torch.tensor(token_ids), torch.tensor(indices)) time.sleep(2) hicache.service.wait_till_all_finished() - + # 查询存在的部分前缀 result = hicache.read(torch.tensor([97, 98, 99])) result = mem_manager.to_kvcache(result.tolist()) assert result.eq(kvcache[:3]).all() - print(f"[TEST] Partial read passed") - + print("[TEST] Partial read passed") + # 查询不存在的前缀 result = hicache.read(torch.tensor([97, 98, 100])) assert len(result) == 2 @@ -127,6 +140,7 @@ def test_partial_read(mem_manager, hicache): assert result.eq(kvcache[:2]).all() print(f"[TEST] Non-existent prefix returned: {result.tolist()}") + def main(): mem_manager, service, hicache, token_size = setup() try: @@ -136,5 +150,6 @@ def main(): finally: service.shutdown() + if __name__ == "__main__": - main() \ No newline at end of file + main() From 50d0faab91e352df82cc08cd69fe15e7e1389ecd Mon Sep 17 00:00:00 2001 From: wufeiyang Date: Fri, 9 May 2025 03:28:51 +0800 Subject: [PATCH 17/22] Support only one GPU store partially with bug --- .../router/dynamic_prompt/hiradix_cache.py | 148 +++++++++++------- .../server/router/model_infer/infer_batch.py | 2 +- .../model_infer/mode_backend/base_backend.py | 24 ++- .../mode_backend/chunked_prefill/impl.py | 1 + 4 files changed, 110 insertions(+), 65 deletions(-) diff --git a/lightllm/server/router/dynamic_prompt/hiradix_cache.py b/lightllm/server/router/dynamic_prompt/hiradix_cache.py index 2a62b28de..7f5b29661 100644 --- a/lightllm/server/router/dynamic_prompt/hiradix_cache.py +++ b/lightllm/server/router/dynamic_prompt/hiradix_cache.py @@ -17,6 +17,9 @@ def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager, max_ logger.info("Initializing HiRadixCache") self.rank_in_node = rank_in_node try: + # TODO: determine by model type && dp, tp + store_once = True # Deepseek -> True, Llama -> False + self.do_store = store_once and self.rank_in_node == 0 self.is_hi_radix_cache = True all_buffers = self.mem_manager.kv_buffer all_buffers = all_buffers.view(all_buffers.shape[0], all_buffers.shape[1], -1) @@ -37,83 +40,111 @@ def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager, max_ # then when the decode finishes, do syncronize to see whether this can be free # no buffer, parallel insert inputs def insert_disk(self, req_id, key, value): + if not self.do_store: + return if req_id in self.working_tasks: - self.wait_till_finish(req_id) + self.abort_req_store_task(req_id) self.working_tasks[req_id] = self.py_cache_service.create(tokens=key, kv_page_indexer=value, mode="w") logger.info(f"Created store task for req {req_id}.") - def wait_till_finish(self, req_id): - if req_id not in self.working_tasks: + def abort_req_store_task(self, req_id): + if not self.do_store: + return + if self.working_tasks[req_id].ready(): + logger.info(f"Calling abort for req {req_id}, but is finished.") return - starting_time = time.time() - while not self.working_tasks[req_id].ready(): - time.sleep(0.01) - logger.info(f"Waited {time.time() - starting_time}s for req {req_id}.") - - # def insert(self, key, value=None): - # if value is None: - # value = key - - # assert len(key) == len(value) # and len(key) >= 1 - # if len(key) == 0: - # return 0 - - # # current implement is serial, TODO: make it parallel - # # if no hi_cache_buffer, work with normal radix cache - # if self.hi_cache_kv_buffer is not None: - # do_copy = False - # # and if is moving, ignore this insert request - # with self.moving_lock: - # if (not self.start_store_task) and self.write_task is not None: - # if self.write_task.ready(): - # logger.info(f"HiCache of [{self.rank_in_node}]: stored len = {self.hi_cache_buffer_len}") - # self.start_store_task = True # ensure ready => start new only one kvcache stores - # do_copy = True - # elif self.write_task is None and self.starting: - # self.starting = False - # self.start_store_task = True - # do_copy = True - - # if do_copy: - # # copy the key and value to the hi_cache_buffer - # self.hi_cache_key_buffer[:len(key)].copy_(key) - # self.hi_cache_buffer_len = len(key) - # for buffer_index, index in enumerate(value): - # kv_data = self.mem_manager.get_index_kv_buffer(index) - # self.mem_manager.load_index_kv_buffer(self.hi_cache_kv_buffer[buffer_index], kv_data) - # # create a new thread to store the buffer - # self._store_buffer() - - # return self._insert_helper(self.root_node, key, value) - - # def _store_buffer(self): - # logger.info(f"Storing buffer size = {self.hi_cache_buffer_len}") - # assert self.hi_cache_buffer_len > 0 - # assert self.hi_cache_kv_buffer is not None - # key = self.hi_cache_key_buffer[:self.hi_cache_buffer_len].tolist() - # self.write_task = self.py_cache_service.create( - # tokens=key, kv_page_indexer=self.hi_cache_kv_buffer[:self.hi_cache_buffer_len], mode="w") - # with self.moving_lock: - # self.start_store_task = False + logger.info(f"Aborting req {req_id} unfinished.") + self.py_cache_service.az5(self.working_tasks[req_id]) + + # TODO: finish this function to only update new ones + def _reinsert_helper(self, node: TreeNode, key, value, ans_value_list: list, update_refs=False): + if node.is_leaf(): + self.evict_tree_set.discard(node) + + if update_refs: + node.ref_counter += 1 + # from 0 to 1 need update refs token num + if node.ref_counter == 1: + self.refed_tokens_num.arr[0] += len(node.token_mem_index_value) + + try: + if len(key) == 0: + return node + + first_key_id = key[0].item() + if first_key_id in node.children.keys(): + child: TreeNode = node.children[first_key_id] + prefix_len = match(key, child.token_id_key) + if prefix_len == len(key): + if child.is_leaf(): + self.evict_tree_set.discard(child) + child.update_time() + ans_value_list.append(child.token_mem_index_value) + if child.is_leaf(): + self.evict_tree_set.add(child) + return prefix_len + + elif prefix_len < len(key) and prefix_len < len(child.token_id_key): + if child.is_leaf(): + self.evict_tree_set.discard(child) + + key = key[prefix_len:] + value = value[prefix_len:] + split_parent_node = child.split_node(prefix_len) + new_node = split_parent_node.add_and_return_new_child(key, value) + # update total token num + self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value) + + if split_parent_node.is_leaf(): + self.evict_tree_set.add(split_parent_node) + if new_node.is_leaf(): + self.evict_tree_set.add(new_node) + + if child.is_leaf(): + self.evict_tree_set.add(child) + return prefix_len + elif prefix_len < len(key) and prefix_len == len(child.token_id_key): + return prefix_len + self._insert_helper(child, key[prefix_len:], value[prefix_len:]) + else: + assert False, "can not run to here" + + else: + new_node = node.add_and_return_new_child(key, value) + # update total token num + self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value) + ans_value_list.append(new_node.token_mem_index_value) + if update_refs: + new_node.ref_counter += 1 + if new_node.ref_counter == 1: + self.refed_tokens_num.arr[0] += len(new_node.token_mem_index_value) + if new_node.is_leaf(): + self.evict_tree_set.add(new_node) + return new_node + finally: + node.update_time() + if node.is_leaf(): + self.evict_tree_set.add(node) def match_prefix(self, key, update_refs=False): st_time = time.time() assert len(key) != 0 ans_value_list = [] - tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs) + tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=False) # add a parameter if get long enough (>50%) first_query_time = time.time() logger.info(f"HiCache of [{self.rank_in_node}]: No.1 First GPU query took {first_query_time - st_time}") max_len = self._query_hi_cache(key) # x64 hi_cache_query_time = time.time() logger.info(f"HiCache of [{self.rank_in_node}]: No.2 Disk query took {hi_cache_query_time - first_query_time}") - logger.info(f"Matched {len(ans_value_list)} from gpu and {max_len} from disk.") + logger.info(f"Matched {sum(len(s) for s in ans_value_list)} from gpu and {max_len} from disk.") pull_hi_cache = False - if max_len > len(ans_value_list): + if max_len > sum(len(s) for s in ans_value_list): pull_hi_cache = True try: self.free_radix_cache_to_get_enough_token(max_len) except: + if update_refs: + tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs) pull_hi_cache = False if pull_hi_cache: buffers = self.mem_manager.alloc(max_len) @@ -133,7 +164,10 @@ def match_prefix(self, key, update_refs=False): logger.info(f"HiCache of [{self.rank_in_node}]: No.4 Reinsert took {insert_time - hicache_pull_time}") ans_value_list = [] tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs) - logger.info(f"HiCache of [{self.rank_in_node}]: No.5 Re match prefix took {time.time() - insert_time}") + logger.info( + f"HiCache of [{self.rank_in_node}]: No.5 Re match prefix took {time.time() - insert_time}" + + f" matched {sum(len(s) for s in ans_value_list)} tokens" + ) if tree_node != self.root_node: if len(ans_value_list) != 0: value = torch.concat(ans_value_list) diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 18e524629..d866cc1b7 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -111,7 +111,7 @@ def free_a_req_mem(self, free_token_index: List, req: "InferReq", is_group_finis req.shared_kv_node = None if self.radix_cache.is_hi_radix_cache: - self.radix_cache.wait_till_finish(req.req_id) + self.radix_cache.abort_req_store_task(req.req_id) def _save_promptcache_kvbuffer(self): """ diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index d0da54a1d..793d246e1 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -168,13 +168,6 @@ def decode(self): """This method can be overridden in subclasses.""" raise NotImplementedError() - def store_hicache_after_prefill(self, run_reqs): - if self.use_hi_dynamic_prompt_cache and self.radix_cache is not None: - for req in run_reqs: - key = torch.tensor(req.get_input_token_ids()[0 : req.cur_kv_len], dtype=torch.int64, device="cpu") - value = self.model.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].detach().cpu() - self.radix_cache.insert_disk(req.req_id, key, value) - def pause_reqs(self, req_ids): if self.dp_size_in_node != 1: req_ids = [req_id for req_id in req_ids if req_id in g_infer_context.requests_mapping] @@ -371,6 +364,23 @@ def _overlap_req_init_and_filter( return + def _overlap_store_prefill_reqs(self, run_reqs: List[InferReq]): + if run_reqs: + with torch.cuda.stream(g_infer_context.get_overlap_stream()): + if self.use_hi_dynamic_prompt_cache and self.radix_cache is not None: + for req in run_reqs: + if req.cur_output_len > 1: + continue + key = torch.tensor( + req.get_input_token_ids()[0 : req.cur_kv_len], dtype=torch.int64, device="cpu" + ) + value = self.model.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].detach().cpu() + self.radix_cache.insert_disk(req.req_id, key, value) + + torch.cuda.current_stream().wait_stream(g_infer_context.get_overlap_stream()) + + return + # 一些可以复用的通用功能函数 def _post_init_reqs(self, uninit_reqs: List[InferReq]): """ diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index b83b68e97..0edb3a173 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -43,6 +43,7 @@ def decode(self): self._overlap_req_init_and_filter( uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True ) + self._overlap_store_prefill_reqs(run_reqs=run_reqs) next_token_ids, next_token_probs = sample(model_output.logits, run_reqs, self.eos_id) next_token_ids = next_token_ids.detach().cpu().numpy() next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() From 9e6274c6e08366261596c41f0fd80b0e1cf398a2 Mon Sep 17 00:00:00 2001 From: wufeiyang Date: Fri, 9 May 2025 19:00:59 +0800 Subject: [PATCH 18/22] fixed problems --- .../router/dynamic_prompt/hiradix_cache.py | 88 ++++++++++++------- 1 file changed, 54 insertions(+), 34 deletions(-) diff --git a/lightllm/server/router/dynamic_prompt/hiradix_cache.py b/lightllm/server/router/dynamic_prompt/hiradix_cache.py index 7f5b29661..be2ce1608 100644 --- a/lightllm/server/router/dynamic_prompt/hiradix_cache.py +++ b/lightllm/server/router/dynamic_prompt/hiradix_cache.py @@ -1,4 +1,5 @@ import torch +import torch.distributed as dist from .radix_cache import RadixCache, TreeNode, match from typing import Tuple, Dict, Set, List from lightllm.common.mem_manager import MemoryManager @@ -23,12 +24,16 @@ def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager, max_ self.is_hi_radix_cache = True all_buffers = self.mem_manager.kv_buffer all_buffers = all_buffers.view(all_buffers.shape[0], all_buffers.shape[1], -1) - self.py_cache_service = PyLocalCacheService( - file="cache/cache_file", - storage_size=128 * (1024 ** 3), - num_shard=32, - kvcache_tensor=all_buffers, - num_worker=32, + self.py_cache_service = ( + PyLocalCacheService( + file="cache/cache_file", + storage_size=128 * (1024 ** 3), + num_shard=32, + kvcache_tensor=all_buffers, + num_worker=32, + ) + if self.do_store + else None ) self.working_tasks = {} except Exception as e: @@ -48,7 +53,7 @@ def insert_disk(self, req_id, key, value): logger.info(f"Created store task for req {req_id}.") def abort_req_store_task(self, req_id): - if not self.do_store: + if not self.do_store or req_id not in self.working_tasks: return if self.working_tasks[req_id].ready(): logger.info(f"Calling abort for req {req_id}, but is finished.") @@ -126,48 +131,63 @@ def _reinsert_helper(self, node: TreeNode, key, value, ans_value_list: list, upd self.evict_tree_set.add(node) def match_prefix(self, key, update_refs=False): - st_time = time.time() assert len(key) != 0 ans_value_list = [] - tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=False) - # add a parameter if get long enough (>50%) - first_query_time = time.time() - logger.info(f"HiCache of [{self.rank_in_node}]: No.1 First GPU query took {first_query_time - st_time}") - max_len = self._query_hi_cache(key) # x64 - hi_cache_query_time = time.time() - logger.info(f"HiCache of [{self.rank_in_node}]: No.2 Disk query took {hi_cache_query_time - first_query_time}") - logger.info(f"Matched {sum(len(s) for s in ans_value_list)} from gpu and {max_len} from disk.") + pull_hi_cache_tensor = torch.tensor([0], dtype=torch.int64).cuda(self.rank_in_node) + if self.do_store: + # st_time = time.time() + tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=False) + # add a parameter if get long enough (>50%) + # first_query_time = time.time() + # logger.info(f"HiCache of [{self.rank_in_node}]: No.1 First GPU query took {first_query_time - st_time}s") + max_len = self._query_hi_cache(key) # x64 + # hi_cache_q_time = time.time() + # logger.info(f"HiCache of [{self.rank_in_node}]: No.2 Disk query {hi_cache_q_time - first_query_time}s") + logger.info(f"Matched {sum(len(s) for s in ans_value_list)} from gpu and {max_len} from disk.") + pull_hi_cache_tensor[0] = max_len if (max_len > sum(len(s) for s in ans_value_list)) else 0 + # hi_cache_q_time = time.time() + dist.broadcast(pull_hi_cache_tensor, src=0) + # logger.info(f"After broadcast on rank {self.rank_in_node}, tensor={pull_hi_cache_tensor}") pull_hi_cache = False - if max_len > sum(len(s) for s in ans_value_list): + # logger.info(f"Rank {self.rank_in_node}, {pull_hi_cache=} {pull_hi_cache_tensor=}") + + if pull_hi_cache_tensor[0] == 0 and not self.do_store: + tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=False) + elif pull_hi_cache_tensor[0] > 0: pull_hi_cache = True + max_len = pull_hi_cache_tensor[0] try: self.free_radix_cache_to_get_enough_token(max_len) except: - if update_refs: - tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs) + logger.info(f"Unable to free on rank {self.rank_in_node}") + pull_hi_cache_tensor[0] = 0 pull_hi_cache = False + ans_value_list = [] + tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs) if pull_hi_cache: buffers = self.mem_manager.alloc(max_len) - before_pull_time = time.time() - logger.info( - f"HiCache of [{self.rank_in_node}]: No.2.5 Before pull took {before_pull_time - hi_cache_query_time}" - ) - read_task = self.py_cache_service.create(tokens=key[:max_len], kv_page_indexer=buffers, mode="r") - while not read_task.ready(): - time.sleep(0.1) - hicache_pull_time = time.time() - logger.info(f"HiCache of [{self.rank_in_node}]: No.3 Disk pull took {hicache_pull_time - before_pull_time}") + # before_pull_time = time.time() + # logger.info( + # f"HiCache of [{self.rank_in_node}]: No.2.5 Before pull took {before_pull_time - hi_cache_q_time}" + # ) + if self.do_store: + read_task = self.py_cache_service.create(tokens=key[:max_len], kv_page_indexer=buffers, mode="r") + while not read_task.ready(): + time.sleep(0.05) + dist.broadcast(self.mem_manager.get_index_kv_buffer(buffers)["kv_buffer"], src=0) + # hicache_pull_time = time.time() + # logger.info(f"HiCache of [{self.rank_in_node}]: No.3 Disk pull {hicache_pull_time - before_pull_time}s") logger.info(f"HiCache pulled one cache with len = {max_len}") # maybe try: add a function to only insert middle part of kv cache self._insert_helper(self.root_node, key, buffers) - insert_time = time.time() - logger.info(f"HiCache of [{self.rank_in_node}]: No.4 Reinsert took {insert_time - hicache_pull_time}") + # insert_time = time.time() + # logger.info(f"HiCache of [{self.rank_in_node}]: No.4 Reinsert took {insert_time - hicache_pull_time}") ans_value_list = [] tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs) - logger.info( - f"HiCache of [{self.rank_in_node}]: No.5 Re match prefix took {time.time() - insert_time}" - + f" matched {sum(len(s) for s in ans_value_list)} tokens" - ) + # logger.info( + # f"HiCache of [{self.rank_in_node}]: No.5 Re match prefix took {time.time() - insert_time}" + # + f" matched {sum(len(s) for s in ans_value_list)} tokens" + # ) if tree_node != self.root_node: if len(ans_value_list) != 0: value = torch.concat(ans_value_list) From add815722a07fb284eda7166bbb973ca68fd9276 Mon Sep 17 00:00:00 2001 From: wufeiyang Date: Wed, 21 May 2025 17:53:42 +0800 Subject: [PATCH 19/22] delete useless codes --- .../router/dynamic_prompt/hiradix_cache.py | 94 ------------------- 1 file changed, 94 deletions(-) diff --git a/lightllm/server/router/dynamic_prompt/hiradix_cache.py b/lightllm/server/router/dynamic_prompt/hiradix_cache.py index be2ce1608..d716d9634 100644 --- a/lightllm/server/router/dynamic_prompt/hiradix_cache.py +++ b/lightllm/server/router/dynamic_prompt/hiradix_cache.py @@ -41,9 +41,6 @@ def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager, max_ self.hi_cache_kv_buffer = None self.is_hi_radix_cache = False - # write a new function, only insert input(after prefill), call after prefill, - # then when the decode finishes, do syncronize to see whether this can be free - # no buffer, parallel insert inputs def insert_disk(self, req_id, key, value): if not self.do_store: return @@ -61,95 +58,17 @@ def abort_req_store_task(self, req_id): logger.info(f"Aborting req {req_id} unfinished.") self.py_cache_service.az5(self.working_tasks[req_id]) - # TODO: finish this function to only update new ones - def _reinsert_helper(self, node: TreeNode, key, value, ans_value_list: list, update_refs=False): - if node.is_leaf(): - self.evict_tree_set.discard(node) - - if update_refs: - node.ref_counter += 1 - # from 0 to 1 need update refs token num - if node.ref_counter == 1: - self.refed_tokens_num.arr[0] += len(node.token_mem_index_value) - - try: - if len(key) == 0: - return node - - first_key_id = key[0].item() - if first_key_id in node.children.keys(): - child: TreeNode = node.children[first_key_id] - prefix_len = match(key, child.token_id_key) - if prefix_len == len(key): - if child.is_leaf(): - self.evict_tree_set.discard(child) - child.update_time() - ans_value_list.append(child.token_mem_index_value) - if child.is_leaf(): - self.evict_tree_set.add(child) - return prefix_len - - elif prefix_len < len(key) and prefix_len < len(child.token_id_key): - if child.is_leaf(): - self.evict_tree_set.discard(child) - - key = key[prefix_len:] - value = value[prefix_len:] - split_parent_node = child.split_node(prefix_len) - new_node = split_parent_node.add_and_return_new_child(key, value) - # update total token num - self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value) - - if split_parent_node.is_leaf(): - self.evict_tree_set.add(split_parent_node) - if new_node.is_leaf(): - self.evict_tree_set.add(new_node) - - if child.is_leaf(): - self.evict_tree_set.add(child) - return prefix_len - elif prefix_len < len(key) and prefix_len == len(child.token_id_key): - return prefix_len + self._insert_helper(child, key[prefix_len:], value[prefix_len:]) - else: - assert False, "can not run to here" - - else: - new_node = node.add_and_return_new_child(key, value) - # update total token num - self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value) - ans_value_list.append(new_node.token_mem_index_value) - if update_refs: - new_node.ref_counter += 1 - if new_node.ref_counter == 1: - self.refed_tokens_num.arr[0] += len(new_node.token_mem_index_value) - if new_node.is_leaf(): - self.evict_tree_set.add(new_node) - return new_node - finally: - node.update_time() - if node.is_leaf(): - self.evict_tree_set.add(node) - def match_prefix(self, key, update_refs=False): assert len(key) != 0 ans_value_list = [] pull_hi_cache_tensor = torch.tensor([0], dtype=torch.int64).cuda(self.rank_in_node) if self.do_store: - # st_time = time.time() tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=False) - # add a parameter if get long enough (>50%) - # first_query_time = time.time() - # logger.info(f"HiCache of [{self.rank_in_node}]: No.1 First GPU query took {first_query_time - st_time}s") max_len = self._query_hi_cache(key) # x64 - # hi_cache_q_time = time.time() - # logger.info(f"HiCache of [{self.rank_in_node}]: No.2 Disk query {hi_cache_q_time - first_query_time}s") logger.info(f"Matched {sum(len(s) for s in ans_value_list)} from gpu and {max_len} from disk.") pull_hi_cache_tensor[0] = max_len if (max_len > sum(len(s) for s in ans_value_list)) else 0 - # hi_cache_q_time = time.time() dist.broadcast(pull_hi_cache_tensor, src=0) - # logger.info(f"After broadcast on rank {self.rank_in_node}, tensor={pull_hi_cache_tensor}") pull_hi_cache = False - # logger.info(f"Rank {self.rank_in_node}, {pull_hi_cache=} {pull_hi_cache_tensor=}") if pull_hi_cache_tensor[0] == 0 and not self.do_store: tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=False) @@ -166,28 +85,15 @@ def match_prefix(self, key, update_refs=False): tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs) if pull_hi_cache: buffers = self.mem_manager.alloc(max_len) - # before_pull_time = time.time() - # logger.info( - # f"HiCache of [{self.rank_in_node}]: No.2.5 Before pull took {before_pull_time - hi_cache_q_time}" - # ) if self.do_store: read_task = self.py_cache_service.create(tokens=key[:max_len], kv_page_indexer=buffers, mode="r") while not read_task.ready(): time.sleep(0.05) dist.broadcast(self.mem_manager.get_index_kv_buffer(buffers)["kv_buffer"], src=0) - # hicache_pull_time = time.time() - # logger.info(f"HiCache of [{self.rank_in_node}]: No.3 Disk pull {hicache_pull_time - before_pull_time}s") logger.info(f"HiCache pulled one cache with len = {max_len}") - # maybe try: add a function to only insert middle part of kv cache self._insert_helper(self.root_node, key, buffers) - # insert_time = time.time() - # logger.info(f"HiCache of [{self.rank_in_node}]: No.4 Reinsert took {insert_time - hicache_pull_time}") ans_value_list = [] tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs) - # logger.info( - # f"HiCache of [{self.rank_in_node}]: No.5 Re match prefix took {time.time() - insert_time}" - # + f" matched {sum(len(s) for s in ans_value_list)} tokens" - # ) if tree_node != self.root_node: if len(ans_value_list) != 0: value = torch.concat(ans_value_list) From 8fd842c8363b46d7ffe659903037bb8cf0dbe242 Mon Sep 17 00:00:00 2001 From: wufeiyang Date: Thu, 22 May 2025 10:15:08 +0800 Subject: [PATCH 20/22] little bug fix --- lightllm/server/router/dynamic_prompt/hiradix_cache.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lightllm/server/router/dynamic_prompt/hiradix_cache.py b/lightllm/server/router/dynamic_prompt/hiradix_cache.py index d716d9634..ea2a9deeb 100644 --- a/lightllm/server/router/dynamic_prompt/hiradix_cache.py +++ b/lightllm/server/router/dynamic_prompt/hiradix_cache.py @@ -70,8 +70,9 @@ def match_prefix(self, key, update_refs=False): dist.broadcast(pull_hi_cache_tensor, src=0) pull_hi_cache = False - if pull_hi_cache_tensor[0] == 0 and not self.do_store: - tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=False) + if pull_hi_cache_tensor[0] == 0: + ans_value_list = [] + tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs) elif pull_hi_cache_tensor[0] > 0: pull_hi_cache = True max_len = pull_hi_cache_tensor[0] From 39ca1281f3a42c4f16521bd764e98f94e92dc07b Mon Sep 17 00:00:00 2001 From: Feiyang Wu <42094053+jayfeather9@users.noreply.github.com> Date: Fri, 23 May 2025 18:51:11 +0800 Subject: [PATCH 21/22] add cli param fix --- lightllm/server/api_cli.py | 2 +- lightllm/server/api_start.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index e6cff42be..3e37142b1 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -312,7 +312,7 @@ def make_argument_parser() -> argparse.ArgumentParser: "--visual_infer_batch_size", type=int, default=1, help="number of images to process in each inference batch" ) parser.add_argument( - "--visual_gpu_ids", nargs="+", type=int, default=None, help="List of GPU IDs to use, e.g., 0 1 2" + "--visual_gpu_ids", nargs="+", type=int, default=[0, 1, 2, 3, 4, 5, 6, 7], help="List of GPU IDs to use, e.g., 0 1 2" ) parser.add_argument("--visual_tp", type=int, default=1, help="number of tensort parallel instances for ViT") parser.add_argument("--visual_dp", type=int, default=1, help="number of data parallel instances for ViT") diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index bca4f23b4..3a4093482 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -160,7 +160,7 @@ def normal_or_p_d_start(args): # if use_hi_dynamic_prompt_cache, then use_dynamic_prompt_cache must be True if args.use_hi_dynamic_prompt_cache: - assert args.use_dynamic_prompt_cache, "use_hi_dynamic_prompt_cache must be used with use_dynamic_prompt_cache" + assert not args.disable_dynamic_prompt_cache, "use_hi_dynamic_prompt_cache must be used with use_dynamic_prompt_cache" # help to manage data stored on Ceph if "s3://" in args.model_dir: From 22205442dd1b0189edeb448a4b9f55dbda045b98 Mon Sep 17 00:00:00 2001 From: yujinbiao Date: Fri, 6 Jun 2025 20:22:02 +0800 Subject: [PATCH 22/22] refactor hicache --- .../router/dynamic_prompt/hiradix_cache.py | 197 +++++++++--------- .../server/router/model_infer/infer_batch.py | 2 - .../model_infer/mode_backend/base_backend.py | 20 +- .../mode_backend/chunked_prefill/impl.py | 1 - 4 files changed, 105 insertions(+), 115 deletions(-) diff --git a/lightllm/server/router/dynamic_prompt/hiradix_cache.py b/lightllm/server/router/dynamic_prompt/hiradix_cache.py index ea2a9deeb..31a306f67 100644 --- a/lightllm/server/router/dynamic_prompt/hiradix_cache.py +++ b/lightllm/server/router/dynamic_prompt/hiradix_cache.py @@ -1,117 +1,128 @@ import torch +import time +import tempfile +import numpy as np import torch.distributed as dist +from os.path import join from .radix_cache import RadixCache, TreeNode, match from typing import Tuple, Dict, Set, List from lightllm.common.mem_manager import MemoryManager from lightllm.utils.log_utils import init_logger from threading import Lock from enum import Enum +from .shared_arr import SharedArray from kvcache.python.jit import PyLocalCacheService -import time logger = init_logger(__name__) +def wait_until_ready(task, timeout=10.0, check_interval=0.01): + start_time = time.time() + while not task.ready(): + time.sleep(check_interval) + if time.time() - start_time > timeout: + logger.error("Current kv cache task not ready in time") + return False + return True -class HiRadixCache(RadixCache): - def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager, max_seq_length): - super().__init__(unique_name, total_token_num, rank_in_node, mem_manager) - logger.info("Initializing HiRadixCache") - self.rank_in_node = rank_in_node - try: - # TODO: determine by model type && dp, tp - store_once = True # Deepseek -> True, Llama -> False - self.do_store = store_once and self.rank_in_node == 0 - self.is_hi_radix_cache = True - all_buffers = self.mem_manager.kv_buffer - all_buffers = all_buffers.view(all_buffers.shape[0], all_buffers.shape[1], -1) - self.py_cache_service = ( - PyLocalCacheService( - file="cache/cache_file", - storage_size=128 * (1024 ** 3), - num_shard=32, - kvcache_tensor=all_buffers, - num_worker=32, - ) - if self.do_store - else None - ) - self.working_tasks = {} - except Exception as e: - logger.error(f"error alloc hi cache buffer {e}, fallback to normal radix cache") - self.hi_cache_kv_buffer = None - self.is_hi_radix_cache = False +class LocalCacheManager: - def insert_disk(self, req_id, key, value): - if not self.do_store: - return - if req_id in self.working_tasks: - self.abort_req_store_task(req_id) - self.working_tasks[req_id] = self.py_cache_service.create(tokens=key, kv_page_indexer=value, mode="w") - logger.info(f"Created store task for req {req_id}.") + def __init__(self, unique_name: str, rank_in_node: int, mem_manager): + tmp_dir = tempfile.mkdtemp(prefix=f"cache_{unique_name}_{rank_in_node}") + self.cache_file = join(tmp_dir, "cache_file") + all_buffers = mem_manager.kv_buffer + all_buffers = all_buffers.view(all_buffers.shape[0], all_buffers.shape[1], -1) - def abort_req_store_task(self, req_id): - if not self.do_store or req_id not in self.working_tasks: - return - if self.working_tasks[req_id].ready(): - logger.info(f"Calling abort for req {req_id}, but is finished.") - return - logger.info(f"Aborting req {req_id} unfinished.") - self.py_cache_service.az5(self.working_tasks[req_id]) + self.py_cache_service = PyLocalCacheService( + file=self.cache_file, + storage_size=128 * (1024 ** 3), # 128GB + num_shard=32, + kvcache_tensor=all_buffers, + num_worker=8 + ) - def match_prefix(self, key, update_refs=False): - assert len(key) != 0 - ans_value_list = [] - pull_hi_cache_tensor = torch.tensor([0], dtype=torch.int64).cuda(self.rank_in_node) - if self.do_store: - tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=False) - max_len = self._query_hi_cache(key) # x64 - logger.info(f"Matched {sum(len(s) for s in ans_value_list)} from gpu and {max_len} from disk.") - pull_hi_cache_tensor[0] = max_len if (max_len > sum(len(s) for s in ans_value_list)) else 0 - dist.broadcast(pull_hi_cache_tensor, src=0) - pull_hi_cache = False + def insert(self, tokens, kv_page_indexer, start_pos=0): + t = self.py_cache_service.create( + tokens=tokens, + kv_page_indexer=kv_page_indexer, + mode="w", + start_pos=start_pos) + res = wait_until_ready(t) + if not res: + self.py_cache_service.az5(t) - if pull_hi_cache_tensor[0] == 0: - ans_value_list = [] - tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs) - elif pull_hi_cache_tensor[0] > 0: - pull_hi_cache = True - max_len = pull_hi_cache_tensor[0] - try: - self.free_radix_cache_to_get_enough_token(max_len) - except: - logger.info(f"Unable to free on rank {self.rank_in_node}") - pull_hi_cache_tensor[0] = 0 - pull_hi_cache = False - ans_value_list = [] - tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs) - if pull_hi_cache: - buffers = self.mem_manager.alloc(max_len) - if self.do_store: - read_task = self.py_cache_service.create(tokens=key[:max_len], kv_page_indexer=buffers, mode="r") - while not read_task.ready(): - time.sleep(0.05) - dist.broadcast(self.mem_manager.get_index_kv_buffer(buffers)["kv_buffer"], src=0) - logger.info(f"HiCache pulled one cache with len = {max_len}") - self._insert_helper(self.root_node, key, buffers) - ans_value_list = [] - tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs) - if tree_node != self.root_node: - if len(ans_value_list) != 0: - value = torch.concat(ans_value_list) - else: - assert False, "can not run to here" - return tree_node, len(value), value - else: - self.dec_node_ref_counter(self.root_node) - return None, 0, None + def read(self, tokens, kv_page_indexer, start_pos=0): + t = self.py_cache_service.create( + tokens=tokens, + kv_page_indexer=kv_page_indexer, + mode="r", + start_pos=start_pos) + res = wait_until_ready(t) + return res - def _query_hi_cache(self, key) -> bool: - query_result = self.py_cache_service.query(key) - # query_result is a list of bool, find out the max len true continuous from start + def query(self, tokens): + query_result = self.py_cache_service.query(tokens) max_len = 0 for result in query_result: if result: max_len += 1 else: break - return max_len * self.py_cache_service.tokens_per_block + return max_len * self.block_size + + @property + def block_size(self,): + return self.py_cache_service.tokens_per_block + +class HiRadixCache(RadixCache): + def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager): + super().__init__(unique_name, total_token_num, rank_in_node, mem_manager) + self.rank_in_node = rank_in_node + self.local_cache_manager = LocalCacheManager( + unique_name, + rank_in_node, + mem_manager, + ) + self.is_hi_radix_cache = True + self.disk_cache_match_count = SharedArray(f"{unique_name}_disk_cache_match_count_{rank_in_node}", (1,), dtype=np.int64) + self.disk_cache_match_count.arr[0] = 0 + self.total_match_count = SharedArray(f"{unique_name}_total_match_count_{rank_in_node}", (1,), dtype=np.int64) + self.total_match_count.arr[0] = 0 + self.disk_cache_match_ratio = SharedArray(f"{unique_name}_disk_cache_match_ratio_{rank_in_node}", (1,), dtype=np.float32) + self.disk_cache_match_ratio.arr[0] = 0.0 + logger.info(f"Initializing HiRadixCache {rank_in_node}") + + def insert(self, key, value=None): + share_len = super().insert(key, value) + if share_len == 0: + return 0 + self.local_cache_manager.insert(key, value) + return share_len + + def match_prefix(self, key, update_refs=False): + assert len(key) != 0 + self.total_match_count.arr[0] += 1 + ans_value_list = [] + ans_value = None + tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=False) + if tree_node.node_prefix_total_len != 0: + ans_value = torch.concat(ans_value_list) + max_len = 0 + if tree_node.node_prefix_total_len < len(key): + max_len = self.local_cache_manager.query(key) + if max_len > tree_node.node_prefix_total_len: + pull_len = max_len - tree_node.node_prefix_total_len + self.disk_cache_match_count.arr[0] += 1 + self.disk_cache_match_ratio.arr[0] = self.disk_cache_match_count.arr[0] / self.total_match_count.arr[0] + self.free_radix_cache_to_get_enough_token(pull_len) + buffers = self.mem_manager.alloc(pull_len) + start_pos = 0 + if ans_value is not None: + buffers = torch.concat([ans_value, buffers]) + start_pos = (tree_node.node_prefix_total_len - 1) // self.local_cache_manager.block_size * self.local_cache_manager.block_size + logger.debug(f"HiCache current match ratio {self.disk_cache_match_ratio.arr[0]}, pulled cache len {pull_len} from disk") + res = self.local_cache_manager.read(tokens=key[:max_len], kv_page_indexer=buffers, start_pos=start_pos) + if res: + super().insert(key[:max_len], buffers) + else: + self.mem_manager.free(buffers[tree_node.node_prefix_total_len:]) + return super().match_prefix(key, update_refs=update_refs) diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index d866cc1b7..0774244b6 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -110,8 +110,6 @@ def free_a_req_mem(self, free_token_index: List, req: "InferReq", is_group_finis self.radix_cache.dec_node_ref_counter(req.shared_kv_node) req.shared_kv_node = None - if self.radix_cache.is_hi_radix_cache: - self.radix_cache.abort_req_store_task(req.req_id) def _save_promptcache_kvbuffer(self): """ diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 793d246e1..50b4aa547 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -124,8 +124,7 @@ def init_model(self, kvargs): get_unique_server_name(), self.model.mem_manager.size, self.rank_in_node, - mem_manager=self.model.mem_manager, - max_seq_length=kvargs.get("max_seq_length", 1024 * 5), + mem_manager=self.model.mem_manager ) if self.use_dynamic_prompt_cache and self.use_hi_dynamic_prompt_cache else RadixCache( @@ -364,23 +363,6 @@ def _overlap_req_init_and_filter( return - def _overlap_store_prefill_reqs(self, run_reqs: List[InferReq]): - if run_reqs: - with torch.cuda.stream(g_infer_context.get_overlap_stream()): - if self.use_hi_dynamic_prompt_cache and self.radix_cache is not None: - for req in run_reqs: - if req.cur_output_len > 1: - continue - key = torch.tensor( - req.get_input_token_ids()[0 : req.cur_kv_len], dtype=torch.int64, device="cpu" - ) - value = self.model.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].detach().cpu() - self.radix_cache.insert_disk(req.req_id, key, value) - - torch.cuda.current_stream().wait_stream(g_infer_context.get_overlap_stream()) - - return - # 一些可以复用的通用功能函数 def _post_init_reqs(self, uninit_reqs: List[InferReq]): """ diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index 0edb3a173..b83b68e97 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -43,7 +43,6 @@ def decode(self): self._overlap_req_init_and_filter( uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True ) - self._overlap_store_prefill_reqs(run_reqs=run_reqs) next_token_ids, next_token_probs = sample(model_output.logits, run_reqs, self.eos_id) next_token_ids = next_token_ids.detach().cpu().numpy() next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy()