diff --git a/lightllm/common/basemodel/infer_struct.py b/lightllm/common/basemodel/infer_struct.py index 021de6843..e45cc11c7 100755 --- a/lightllm/common/basemodel/infer_struct.py +++ b/lightllm/common/basemodel/infer_struct.py @@ -5,6 +5,7 @@ from typing import Tuple, Any, Optional from .triton_kernel.gen_prefill_params import gen_prefill_params from .triton_kernel.gen_decode_params import gen_decode_params +from .triton_kernel.multimodal_emb import mark_multimodal_obj class InferStateInfo: @@ -98,3 +99,24 @@ def copy_for_cuda_graph(self, new_infer_state: "InferStateInfo"): if attr_ is not None and attr_.data_ptr() != attr_value.data_ptr(): attr_.copy_(attr_value, non_blocking=True) return + + def mark_multimodal_objs_for_prefill(self, input_ids: torch.Tensor): + """ + 功能函数,用于标记在chuncked prefill的过程中,到底哪些多模态对象对应的token是需要参与计算的。 + 因为分chunck的原因,并不是所有的多模态对象对应的token都需要参与计算。 + """ + multi_objs = [] + for _, p in enumerate(self.multimodal_params): + for obj in p["images"] + p["audios"]: + multi_objs.append(obj) + + if multi_objs: + obj_start_ids = torch.tensor([e["token_id"] for e in multi_objs], dtype=torch.int64, device="cuda") + obj_token_lens = torch.tensor([e["token_num"] for e in multi_objs], dtype=torch.int64, device="cuda") + marks = mark_multimodal_obj( + obj_start_token_ids=obj_start_ids, obj_token_lens=obj_token_lens, input_ids=input_ids + ) + marks_array = marks.detach().cpu().numpy() + for mark, obj in zip(marks_array, multi_objs): + obj["_prefill_"] = mark > 0 + return diff --git a/lightllm/common/basemodel/triton_kernel/multimodal_emb.py b/lightllm/common/basemodel/triton_kernel/multimodal_emb.py index 8b66827a5..64d45e0dc 100644 --- a/lightllm/common/basemodel/triton_kernel/multimodal_emb.py +++ b/lightllm/common/basemodel/triton_kernel/multimodal_emb.py @@ -5,48 +5,78 @@ @triton.jit def _fwd_kernel( - Prompt_ids, + Prompt_ids, Text_weight_embs, Img_embs, Out, Img_token_lens, Img_start_token_ids, Img_start_locs, - stride_text_emb_s, stride_text_emb_d, # text_stride - stride_img_emb_s, stride_img_emb_d, # img_stride - stride_out_s, stride_out_d, + stride_text_emb_s, + stride_text_emb_d, # text_stride + stride_img_emb_s, + stride_img_emb_d, # img_stride + stride_out_s, + stride_out_d, tp_text_start_token_id, tp_text_end_token_id, hidden_size, - BLOCK_HIDDEN_DIM: tl.constexpr - ): + BLOCK_HIDDEN_DIM: tl.constexpr, +): seq_index = tl.program_id(0).to(tl.int64) img_handle_id = tl.program_id(1) token_id = tl.load(Prompt_ids + seq_index) off_d = tl.arange(0, BLOCK_HIDDEN_DIM) - + # load store text emb - for _ in range(0, tl.where((img_handle_id == 0) & (token_id < tp_text_end_token_id) & (token_id >= tp_text_start_token_id), 1, 0), 1): - load_emb = tl.load(Text_weight_embs + stride_text_emb_s * (token_id - tp_text_start_token_id) + off_d * stride_text_emb_d, mask=off_d < hidden_size, other=0) + for _ in range( + 0, + tl.where((img_handle_id == 0) & (token_id < tp_text_end_token_id) & (token_id >= tp_text_start_token_id), 1, 0), + 1, + ): + load_emb = tl.load( + Text_weight_embs + stride_text_emb_s * (token_id - tp_text_start_token_id) + off_d * stride_text_emb_d, + mask=off_d < hidden_size, + other=0, + ) tl.store(Out + stride_out_s * seq_index + stride_out_d * off_d, load_emb, mask=off_d < hidden_size) - + img_start_token_id = tl.load(Img_start_token_ids + img_handle_id - 1, mask=img_handle_id >= 1, other=0) img_start_loc = tl.load(Img_start_locs + img_handle_id - 1, mask=img_handle_id >= 1, other=0) img_token_len = tl.load(Img_token_lens + img_handle_id - 1, mask=img_handle_id >= 1, other=0) # load store img emb - for _ in range(0, tl.where((img_handle_id != 0) & (token_id >= img_start_token_id) & (token_id < img_start_token_id + img_token_len), 1, 0), 1): - load_emb = tl.load(Img_embs + stride_img_emb_s * (img_start_loc + token_id - img_start_token_id) + off_d * stride_img_emb_d, mask=off_d < hidden_size, other=0) + for _ in range( + 0, + tl.where( + (img_handle_id != 0) & (token_id >= img_start_token_id) & (token_id < img_start_token_id + img_token_len), + 1, + 0, + ), + 1, + ): + load_emb = tl.load( + Img_embs + stride_img_emb_s * (img_start_loc + token_id - img_start_token_id) + off_d * stride_img_emb_d, + mask=off_d < hidden_size, + other=0, + ) tl.store(Out + stride_out_s * seq_index + stride_out_d * off_d, load_emb, mask=off_d < hidden_size) return @torch.no_grad() -def multimodal_emb(out: torch.Tensor, prompt_ids: torch.Tensor, text_weight_embs: torch.Tensor, img_embs: torch.Tensor, - img_token_lens: torch.Tensor, img_start_token_ids: torch.Tensor, img_start_locs: torch.Tensor, - tp_text_start_token_id, - tp_text_end_token_id): +def multimodal_emb( + out: torch.Tensor, + prompt_ids: torch.Tensor, + text_weight_embs: torch.Tensor, + img_embs: torch.Tensor, + img_token_lens: torch.Tensor, + img_start_token_ids: torch.Tensor, + img_start_locs: torch.Tensor, + tp_text_start_token_id, + tp_text_end_token_id, +): total_len = prompt_ids.shape[0] BLOCK = triton.next_power_of_2(out.shape[1]) # print(len(img_token_lens)) @@ -60,9 +90,12 @@ def multimodal_emb(out: torch.Tensor, prompt_ids: torch.Tensor, text_weight_embs img_token_lens, img_start_token_ids, img_start_locs, - text_weight_embs.stride(0), text_weight_embs.stride(1), - img_embs.stride(0), img_embs.stride(1), - out.stride(0), out.stride(1), + text_weight_embs.stride(0), + text_weight_embs.stride(1), + img_embs.stride(0), + img_embs.stride(1), + out.stride(0), + out.stride(1), tp_text_start_token_id, tp_text_end_token_id, hidden_size=out.shape[1], @@ -73,6 +106,48 @@ def multimodal_emb(out: torch.Tensor, prompt_ids: torch.Tensor, text_weight_embs return +@triton.jit +def _mark_multimodal_obj_need_kernel( + obj_start_token_ids_ptr, + obj_token_lens_ptr, + obj_marks_ptr, + input_ids_ptr, + input_size, + BLOCK_SIZE: tl.constexpr, +): + + obj_index = tl.program_id(0) + start_id = tl.load(obj_start_token_ids_ptr + obj_index) + token_len = tl.load(obj_token_lens_ptr + obj_index) + + for block_start in range(0, input_size, BLOCK_SIZE): + block_range = block_start + tl.arange(0, BLOCK_SIZE) + cur_input_ids = tl.load(input_ids_ptr + block_range, mask=block_range < input_size, other=0) + mark = tl.where((cur_input_ids >= start_id) & (cur_input_ids < start_id + token_len), 1, 0) + mark = tl.sum(mark) + tl.store(obj_marks_ptr + obj_index, 1, mask=mark > 0) + return + + +@torch.no_grad() +def mark_multimodal_obj(obj_start_token_ids: torch.Tensor, obj_token_lens: torch.Tensor, input_ids: torch.Tensor): + out_mark = torch.empty_like(obj_start_token_ids) + out_mark.fill_(0) + assert obj_start_token_ids.shape == obj_token_lens.shape + BLOCK = 512 + grid = (obj_start_token_ids.shape[0],) + _mark_multimodal_obj_need_kernel[grid]( + obj_start_token_ids_ptr=obj_start_token_ids, + obj_token_lens_ptr=obj_token_lens, + obj_marks_ptr=out_mark, + input_ids_ptr=input_ids, + input_size=input_ids.shape[0], + BLOCK_SIZE=BLOCK, + num_warps=1, + num_stages=1, + ) + return out_mark + def test(): S, D = 1024 * 1000, 128 * 64 @@ -80,27 +155,35 @@ def test(): image_size = 10 image_token_size = 512 - text_weight = torch.randn((vob_size, D), device='cuda', dtype=torch.float16) - img_weight = torch.randn((image_size * image_token_size, D), device='cuda', dtype=torch.float16) - img_token_lens = torch.full((image_size,), image_token_size, device='cuda', dtype=torch.long) - img_start_token_ids = (torch.arange(0, image_size * image_token_size, image_token_size) + vob_size * 10).cuda().long() + text_weight = torch.randn((vob_size, D), device="cuda", dtype=torch.float16) + img_weight = torch.randn((image_size * image_token_size, D), device="cuda", dtype=torch.float16) + img_token_lens = torch.full((image_size,), image_token_size, device="cuda", dtype=torch.long) + img_start_token_ids = ( + (torch.arange(0, image_size * image_token_size, image_token_size) + vob_size * 10).cuda().long() + ) img_start_locs = torch.arange(0, image_size * image_token_size, image_token_size).cuda().long() prompt_ids = torch.arange(0, S, 1).cuda().long() - prompt_ids[0: image_size * image_token_size] = (vob_size * 10 + torch.arange(0, image_size * image_token_size, 1)).cuda().long() + prompt_ids[0 : image_size * image_token_size] = ( + (vob_size * 10 + torch.arange(0, image_size * image_token_size, 1)).cuda().long() + ) out = torch.zeros((S, D), dtype=torch.float16, device="cuda") print(out.shape) import time - - triton_output = multimodal_emb(out, prompt_ids, text_weight, img_weight, img_token_lens, img_start_token_ids, img_start_locs, 0, vob_size) + + multimodal_emb( + out, prompt_ids, text_weight, img_weight, img_token_lens, img_start_token_ids, img_start_locs, 0, vob_size + ) torch.cuda.synchronize() iters = 20 t1 = time.time() for _ in range(iters): - triton_output = multimodal_emb(out, prompt_ids, text_weight, img_weight, img_token_lens, img_start_token_ids, img_start_locs, 0, vob_size) + multimodal_emb( + out, prompt_ids, text_weight, img_weight, img_token_lens, img_start_token_ids, img_start_locs, 0, vob_size + ) torch.cuda.synchronize() t2 = time.time() print("Triton time cost", (t2 - t1) / iters) @@ -109,4 +192,3 @@ def test(): # if __name__ == "__main__": # test() - diff --git a/lightllm/common/image_cache_manager.py b/lightllm/common/image_cache_manager.py new file mode 100644 index 000000000..fb04e4b59 --- /dev/null +++ b/lightllm/common/image_cache_manager.py @@ -0,0 +1,76 @@ +from collections import OrderedDict +from lightllm.utils.dist_utils import get_current_device_id + + +class ImageCacheManager: + def __init__(self): + """ + Initialize the image cache manager with a simple GPU cache and an LRU CPU cache. + """ + self._gpu_cache = dict() + self._cpu_cache = OrderedDict() + + def set_max_size(self, max_size: int): + """ + Set the maximum number of items to keep in the CPU cache. + :param max_size: Maximum number of items to keep in the CPU cache. + """ + if max_size <= 0: + raise ValueError("max_size must be greater than 0") + self._max_size = max_size + + def set_embed(self, uuid, embed): + """ + Store the embedding for the given uuid in the GPU cache. + :param uuid: Unique identifier for the image + :param embed: Embedding vector for the image (on GPU) + """ + self._gpu_cache[uuid] = embed + + def get_embed(self, uuid): + """ + Retrieve the embedding for the given uuid. Prefer GPU cache, + otherwise return CPU cache and move to GPU (simulate .cuda()). + :param uuid: Unique identifier for the image + :return: Embedding vector (on GPU if possible, else move from CPU to GPU) + """ + if uuid in self._gpu_cache: + return self._gpu_cache[uuid] + elif uuid in self._cpu_cache: + self._cpu_cache.move_to_end(uuid) + embed = self._cpu_cache[uuid].cuda(get_current_device_id()) + return embed + return None + + def query_embed(self, uuid): + """ + Query if the embedding for the given uuid is in the cache. + :param uuid: Unique identifier for the image + :return: True if the embedding is in the cache, False otherwise + """ + return uuid in self._gpu_cache or uuid in self._cpu_cache + + def filter(self, uuid_list): + """ + Given a list of uuids, move their embeddings from GPU cache to CPU cache if present, + and return a dict of those found in the cache and their embeddings (on CPU). + :param uuid_list: List of uuids + """ + for uuid in uuid_list: + if uuid in self._gpu_cache: + embed_cpu = self._gpu_cache[uuid].cpu() + # Move to CPU cache and remove from GPU cache + self._gpu_cache.pop(uuid) + if uuid in self._cpu_cache: + self._cpu_cache.move_to_end(uuid) + self._cpu_cache[uuid] = embed_cpu + if len(self._cpu_cache) > self._max_size: + self._cpu_cache.popitem(last=False) + elif uuid in self._cpu_cache: + self._cpu_cache.move_to_end(uuid) + print(self._gpu_cache.keys()) + print(self._cpu_cache.keys()) + return + + +image_cache_manager = ImageCacheManager() diff --git a/lightllm/models/gemma3/layer_infer/pre_layer_infer.py b/lightllm/models/gemma3/layer_infer/pre_layer_infer.py index 89c8e0d8d..46b782879 100644 --- a/lightllm/models/gemma3/layer_infer/pre_layer_infer.py +++ b/lightllm/models/gemma3/layer_infer/pre_layer_infer.py @@ -35,10 +35,12 @@ def context_forward(self, input_ids, infer_state, layer_weight): else: weight_mask[idx] = scale + infer_state.mark_multimodal_objs_for_prefill(input_ids=input_ids) + for batch_id, p in enumerate(infer_state.multimodal_params): for img in p["images"]: # skip the same image - if img["token_id"] in img_start_token_ids: + if img["token_id"] in img_start_token_ids or img["_prefill_"] is False: continue # pull the img_embeds by uid from shm data = read_shm(get_shm_name_embed(img["uuid"])) diff --git a/lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py index f19563932..486319495 100644 --- a/lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py @@ -3,6 +3,9 @@ from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight from lightllm.models.internlm2.layer_weights.pre_and_post_layer_weight import Internlm2PreAndPostLayerWeight +from lightllm.models.vit.model import VisionTransformer +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.common.image_cache_manager import image_cache_manager # add key: language_model.xxx -> xxx @@ -15,9 +18,45 @@ def rename_weight_keys(weights): weights[k[len(prefix) :]] = weights[k] +class InternVLPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): + def __init__(self, data_type, network_config, mode): + super().__init__(data_type, network_config, mode) + # if we don't assign an extra process for visual model, we need initialize the image cache manager here + if get_env_start_args().disable_extra_process_for_multimodal: + kvargs = { + "weight_dir": get_env_start_args().model_dir, + "data_type": self.data_type_, + "quant_type": get_env_start_args().vit_quant_type, + "quant_cfg": get_env_start_args().vit_quant_cfg, + "max_batch_size": get_env_start_args().visual_infer_batch_size, + } + self.visual_model = VisionTransformer( + kvargs=kvargs, + ) + image_cache_manager.set_max_size(get_env_start_args().cache_capacity * 2) + return + + def load_hf_weights(self, weights): + rename_weight_keys(weights) + super().load_hf_weights(weights) + + class InternVLPhi3PreAndPostLayerWeight(LlamaPreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) + # if we don't assign an extra process for visual model, we need initialize the image cache manager here + if get_env_start_args().disable_extra_process_for_multimodal: + kvargs = { + "weight_dir": get_env_start_args().model_dir, + "data_type": self.data_type_, + "quant_type": get_env_start_args().vit_quant_type, + "quant_cfg": get_env_start_args().vit_quant_cfg, + "max_batch_size": get_env_start_args().visual_infer_batch_size, + } + self.visual_model = VisionTransformer( + kvargs=kvargs, + ) + image_cache_manager.set_max_size(get_env_start_args().cache_capacity * 2) return def load_hf_weights(self, weights): @@ -29,6 +68,19 @@ def load_hf_weights(self, weights): class InternVLInternlm2PreAndPostLayerWeight(Internlm2PreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) + # if we don't assign an extra process for visual model, we need initialize the image cache manager here + if get_env_start_args().disable_extra_process_for_multimodal: + kvargs = { + "weight_dir": get_env_start_args().model_dir, + "data_type": self.data_type_, + "quant_type": get_env_start_args().vit_quant_type, + "quant_cfg": get_env_start_args().vit_quant_cfg, + "max_batch_size": get_env_start_args().visual_infer_batch_size, + } + self.visual_model = VisionTransformer( + kvargs=kvargs, + ) + image_cache_manager.set_max_size(get_env_start_args().cache_capacity * 2) return def load_hf_weights(self, weights): @@ -40,6 +92,19 @@ def load_hf_weights(self, weights): class InternVLLlamaPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) + # if we don't assign an extra process for visual model, we need initialize the image cache manager here + if get_env_start_args().disable_extra_process_for_multimodal: + kvargs = { + "weight_dir": get_env_start_args().model_dir, + "data_type": self.data_type_, + "quant_type": get_env_start_args().vit_quant_type, + "quant_cfg": get_env_start_args().vit_quant_cfg, + "max_batch_size": get_env_start_args().visual_infer_batch_size, + } + self.visual_model = VisionTransformer( + kvargs=kvargs, + ) + image_cache_manager.set_max_size(get_env_start_args().cache_capacity * 2) return def load_hf_weights(self, weights): diff --git a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py index 60c9e0564..f4ac4c326 100644 --- a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py @@ -6,7 +6,9 @@ from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer from lightllm.utils.infer_utils import mark_cost_time +from lightllm.utils.envs_utils import get_env_start_args from lightllm.server.embed_cache.utils import bytes2tensor, read_shm, get_shm_name_embed +from lightllm.common.image_cache_manager import image_cache_manager from lightllm.common.basemodel.triton_kernel.multimodal_emb import multimodal_emb from lightllm.distributed.communication_op import all_reduce @@ -29,8 +31,24 @@ class LlamaMultimodalPreLayerInfer(LlamaPreLayerInfer): def __init__(self, network_config, mode): super().__init__(network_config, mode) + self.disable_extra_process_for_multimodal = get_env_start_args().disable_extra_process_for_multimodal return + def _infer_image_embeds(self, infer_state, layer_weight): + if not self.disable_extra_process_for_multimodal: + return + infer_images = [] + for _, p in enumerate(infer_state.multimodal_params): + for img in p["images"] + p["audios"]: + if (img["_prefill_"] is True) and (not image_cache_manager.query_embed(img["uuid"])): + infer_images.append(img) + if len(infer_images) > 0: + infer_batch_size = get_env_start_args().visual_infer_batch_size + for i in range(0, len(infer_images), infer_batch_size): + img_embeds, uuids, valid_ids = layer_weight.visual_model.encode(infer_images[i : i + infer_batch_size]) + for uuid, valid_id in zip(uuids, valid_ids): + image_cache_manager.set_embed(uuid, img_embeds[valid_id[0] : valid_id[1]]) + def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): img_weight = [] @@ -42,14 +60,22 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei device = layer_weight.wte_weight_.device dtype = layer_weight.wte_weight_.dtype hidden_size = layer_weight.wte_weight_.shape[1] + + infer_state.mark_multimodal_objs_for_prefill(input_ids=input_ids) + + self._infer_image_embeds(infer_state, layer_weight) for batch_id, p in enumerate(infer_state.multimodal_params): for img in p["images"] + p["audios"]: # skip the same image - if img["token_id"] in img_start_token_ids: + if img["token_id"] in img_start_token_ids or img["_prefill_"] is False: continue # pull the img_embeds by uid from shm - data = read_shm(get_shm_name_embed(img["uuid"])) - img_weight.append(bytes2tensor(data).cuda().reshape(img["token_num"], -1)) + if self.disable_extra_process_for_multimodal: + img_embed = image_cache_manager.get_embed(img["uuid"]) + img_weight.append(img_embed.reshape(img["token_num"], -1)) + else: + data = read_shm(get_shm_name_embed(img["uuid"])) + img_weight.append(bytes2tensor(data).cuda().reshape(img["token_num"], -1)) img_start_token_ids.append(img["token_id"]) img_token_lens.append(img["token_num"]) img_start_locs.append(img_start_loc) diff --git a/lightllm/models/vit/layer_weights/transformer_layer_weight.py b/lightllm/models/vit/layer_weights/transformer_layer_weight.py index 55d73fa73..3c42f712e 100644 --- a/lightllm/models/vit/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/vit/layer_weights/transformer_layer_weight.py @@ -11,7 +11,9 @@ MultiROWMMWeight, TpNormWeight, ) -from lightllm.utils.dist_utils import get_current_device_id +from lightllm.utils.dist_utils import ( + get_current_device_id, +) class ViTTransformerLayerWeight(TransformerLayerWeight): diff --git a/lightllm/models/vit/model.py b/lightllm/models/vit/model.py index 01bb69bdf..a8b475889 100644 --- a/lightllm/models/vit/model.py +++ b/lightllm/models/vit/model.py @@ -18,7 +18,8 @@ from io import BytesIO from rpyc.utils.classic import obtain from lightllm.common.quantization import Quantcfg -from lightllm.utils.dist_utils import get_dp_world_size +from lightllm.utils.dist_utils import get_dp_world_size, get_global_world_size +from lightllm.utils.envs_utils import get_env_start_args from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager @@ -47,6 +48,7 @@ def __init__(self, kvargs): self.quant_cfg_path = kvargs.get("quant_cfg", None) self.load_image_func = get_load_image_func(self.weight_dir_) self.max_batch_size = kvargs.get("max_batch_size", 1) + self.enable_tensor_cache = not get_env_start_args().disable_extra_process_for_multimodal self._init_datatype() self._init_config() @@ -63,6 +65,7 @@ def _check_max_len_infer(self): disable_check_max_len_infer = os.getenv("DISABLE_CHECK_MAX_LEN_INFER", None) is not None if disable_check_max_len_infer: return + self.enable_tensor_cache = True try: dummy_images = torch.randn( @@ -70,6 +73,7 @@ def _check_max_len_infer(self): ).cuda() all_img_embeds = self.forward(dummy_images) del all_img_embeds + del dummy_images logger.info(f"vit check max_len {self.max_batch_size} infer ok") except (RuntimeError, torch.OutOfMemoryError) as e: logger.exception(str(e)) @@ -78,6 +82,7 @@ def _check_max_len_infer(self): ) logger.error(exception_str) raise Exception(exception_str) + self.enable_tensor_cache = not get_env_start_args().disable_extra_process_for_multimodal return def _init_config(self): @@ -150,6 +155,8 @@ def _init_infer_layer(self): return def _init_datatype(self): + if isinstance(self.data_type, torch.dtype): + return if self.data_type in ["fp16", "float16"]: self.data_type = torch.float16 elif self.data_type in ["bf16", "bfloat16"]: @@ -161,12 +168,14 @@ def _init_datatype(self): @torch.no_grad() def forward(self, pixel_values): - g_cache_manager.cache_env_in() + if self.enable_tensor_cache: + g_cache_manager.cache_env_in() input_embs = self.pre_infer.forward(pixel_values, self.pre_post_weight) for i in range(self.layers_num + self.select_layer + 1): input_embs = self.layers_infer[i].forward(input_embs, self.trans_layers_weight[i]) input_embs = self.post_infer.forward(input_embs[:, 1:, :], self.pre_post_weight) - g_cache_manager.cache_env_out() + if self.enable_tensor_cache: + g_cache_manager.cache_env_out() return input_embs @torch.no_grad() @@ -182,6 +191,12 @@ def encode(self, images: List[ImageItem]): image_data = Image.open(BytesIO(image_data)) t = self.load_image_func(image_data, max_num=img.extra_params["image_patch_max_num"]) img_tensors.append(t) + elif isinstance(img, dict): + uuids.append(img["uuid"]) + image_data = read_shm(get_shm_name_data(img["uuid"])) + image_data = Image.open(BytesIO(image_data)) + t = self.load_image_func(image_data, max_num=img["extra_params"]["image_patch_max_num"]) + img_tensors.append(t) else: raise Exception("Unsupport input types: {} for {}".format(type(img), img)) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 3f3eaf96f..3cb62d842 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -233,6 +233,11 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--enable_multimodal", action="store_true", help="Whether or not to allow to load additional visual models." ) + parser.add_argument( + "--disable_extra_process_for_multimodal", + action="store_true", + help="Whether or not to disable extra process for multimodal.", + ) parser.add_argument( "--enable_multimodal_audio", action="store_true", diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index de1e690a2..1f9107fa7 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -243,7 +243,7 @@ def normal_or_p_d_start(args): ], start_args=[(cache_port, args)], ) - if args.enable_multimodal_audio: + if args.enable_multimodal_audio and not args.disable_extra_process_for_multimodal: from .audioserver.manager import start_audio_process process_manager.start_submodule_processes( @@ -263,7 +263,7 @@ def normal_or_p_d_start(args): ], ) - else: + elif not args.disable_extra_process_for_multimodal: process_manager.start_submodule_processes( start_funcs=[ start_visual_process, diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index fa455c225..967c716dd 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -81,10 +81,12 @@ def __init__( ) self.enable_multimodal = enable_multimodal + self.disable_extra_process_for_multimodal = args.disable_extra_process_for_multimodal if self.enable_multimodal: self.cache_client = rpyc.connect("localhost", cache_port) - self.send_to_visual = context.socket(zmq.PUSH) - self.send_to_visual.connect(f"{args.zmq_mode}127.0.0.1:{visual_port}") + if not self.disable_extra_process_for_multimodal: + self.send_to_visual = context.socket(zmq.PUSH) + self.send_to_visual.connect(f"{args.zmq_mode}127.0.0.1:{visual_port}") self.shm_req_manager = ShmReqManager() @@ -449,7 +451,7 @@ async def transfer_to_next_module( ): if self.pd_mode == NodeRole.P: - if self.enable_multimodal: + if self.enable_multimodal and not self.disable_extra_process_for_multimodal: self.send_to_visual.send_pyobj( group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL, @@ -470,7 +472,7 @@ async def transfer_to_next_module( return if self.pd_mode == NodeRole.NORMAL: - if self.enable_multimodal: + if self.enable_multimodal and not self.disable_extra_process_for_multimodal: self.send_to_visual.send_pyobj( group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL, diff --git a/lightllm/server/multimodal_params.py b/lightllm/server/multimodal_params.py index bf320e199..97d456355 100644 --- a/lightllm/server/multimodal_params.py +++ b/lightllm/server/multimodal_params.py @@ -119,6 +119,7 @@ def to_dict(self): ret["uuid"] = self.uuid ret["token_id"] = self.token_id ret["token_num"] = self.token_num + ret["extra_params"] = self.extra_params return ret def to_origin_dict(self): diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 10b68245c..551a70a22 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -9,6 +9,7 @@ from dataclasses import dataclass, field from typing import List, Dict, Tuple, Optional, Union, Any from lightllm.common.req_manager import ReqManager +from lightllm.common.image_cache_manager import image_cache_manager from lightllm.utils.infer_utils import mark_start, mark_end from lightllm.server.core.objs import Req, SamplingParams, FinishStatus, ShmReqManager from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache, TreeNode @@ -131,6 +132,7 @@ def filter(self, finished_request_ids: List[int]): free_req_index = [] free_token_index = [] + image_uuid_list = [] for request_id in finished_request_ids: req: InferReq = self.requests_mapping.pop(request_id) group_req_id = convert_sub_id_to_group_id(req.shm_req.request_id) @@ -145,6 +147,10 @@ def filter(self, finished_request_ids: List[int]): # logger.info(f"infer release req id {req.shm_req.request_id}") req.shm_req.shm_infer_released = True self.shm_req_manager.put_back_req_obj(req.shm_req) + if req.multimodal_params is not None and get_env_start_args().disable_extra_process_for_multimodal: + for img in req.multimodal_params["images"]: + image_uuid_list.append(img["uuid"]) + image_cache_manager.filter(image_uuid_list) free_token_index = custom_cat(free_token_index) self.req_manager.free(free_req_index, free_token_index) diff --git a/unit_tests/common/basemodel/triton_kernel/test_multimodal_emb.py b/unit_tests/common/basemodel/triton_kernel/test_multimodal_emb.py new file mode 100755 index 000000000..49fcc2b60 --- /dev/null +++ b/unit_tests/common/basemodel/triton_kernel/test_multimodal_emb.py @@ -0,0 +1,22 @@ +import torch +import pytest +from lightllm.common.basemodel.triton_kernel.multimodal_emb import mark_multimodal_obj +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +def test_mark_mubltimodal_obj(): + obj_start_ids = torch.tensor([1, 4, 100], device="cuda", dtype=torch.int64) + obj_token_lens = torch.tensor([1, 3, 2], device="cuda", dtype=torch.int64) + input_ids = torch.tensor([1, 7, 9, 333], device="cuda", dtype=torch.int64) + + mark_obj = mark_multimodal_obj( + obj_start_token_ids=obj_start_ids, obj_token_lens=obj_token_lens, input_ids=input_ids + ) + + assert torch.equal(mark_obj, torch.tensor([1, 0, 0], device="cuda")) + + +if __name__ == "__main__": + pytest.main()