Skip to content

Multimodal improve #951

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions lightllm/common/basemodel/infer_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Tuple, Any, Optional
from .triton_kernel.gen_prefill_params import gen_prefill_params
from .triton_kernel.gen_decode_params import gen_decode_params
from .triton_kernel.multimodal_emb import mark_multimodal_obj


class InferStateInfo:
Expand Down Expand Up @@ -98,3 +99,24 @@ def copy_for_cuda_graph(self, new_infer_state: "InferStateInfo"):
if attr_ is not None and attr_.data_ptr() != attr_value.data_ptr():
attr_.copy_(attr_value, non_blocking=True)
return

def mark_multimodal_objs_for_prefill(self, input_ids: torch.Tensor):
"""
功能函数,用于标记在chuncked prefill的过程中,到底哪些多模态对象对应的token是需要参与计算的。
因为分chunck的原因,并不是所有的多模态对象对应的token都需要参与计算。
"""
multi_objs = []
for _, p in enumerate(self.multimodal_params):
for obj in p["images"] + p["audios"]:
multi_objs.append(obj)

if multi_objs:
obj_start_ids = torch.tensor([e["token_id"] for e in multi_objs], dtype=torch.int64, device="cuda")
obj_token_lens = torch.tensor([e["token_num"] for e in multi_objs], dtype=torch.int64, device="cuda")
marks = mark_multimodal_obj(
obj_start_token_ids=obj_start_ids, obj_token_lens=obj_token_lens, input_ids=input_ids
)
marks_array = marks.detach().cpu().numpy()
for mark, obj in zip(marks_array, multi_objs):
obj["_prefill_"] = mark > 0
return
138 changes: 110 additions & 28 deletions lightllm/common/basemodel/triton_kernel/multimodal_emb.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,48 +5,78 @@

@triton.jit
def _fwd_kernel(
Prompt_ids,
Prompt_ids,
Text_weight_embs,
Img_embs,
Out,
Img_token_lens,
Img_start_token_ids,
Img_start_locs,
stride_text_emb_s, stride_text_emb_d, # text_stride
stride_img_emb_s, stride_img_emb_d, # img_stride
stride_out_s, stride_out_d,
stride_text_emb_s,
stride_text_emb_d, # text_stride
stride_img_emb_s,
stride_img_emb_d, # img_stride
stride_out_s,
stride_out_d,
tp_text_start_token_id,
tp_text_end_token_id,
hidden_size,
BLOCK_HIDDEN_DIM: tl.constexpr
):
BLOCK_HIDDEN_DIM: tl.constexpr,
):

seq_index = tl.program_id(0).to(tl.int64)
img_handle_id = tl.program_id(1)

token_id = tl.load(Prompt_ids + seq_index)
off_d = tl.arange(0, BLOCK_HIDDEN_DIM)

# load store text emb
for _ in range(0, tl.where((img_handle_id == 0) & (token_id < tp_text_end_token_id) & (token_id >= tp_text_start_token_id), 1, 0), 1):
load_emb = tl.load(Text_weight_embs + stride_text_emb_s * (token_id - tp_text_start_token_id) + off_d * stride_text_emb_d, mask=off_d < hidden_size, other=0)
for _ in range(
0,
tl.where((img_handle_id == 0) & (token_id < tp_text_end_token_id) & (token_id >= tp_text_start_token_id), 1, 0),
1,
):
load_emb = tl.load(
Text_weight_embs + stride_text_emb_s * (token_id - tp_text_start_token_id) + off_d * stride_text_emb_d,
mask=off_d < hidden_size,
other=0,
)
tl.store(Out + stride_out_s * seq_index + stride_out_d * off_d, load_emb, mask=off_d < hidden_size)

img_start_token_id = tl.load(Img_start_token_ids + img_handle_id - 1, mask=img_handle_id >= 1, other=0)
img_start_loc = tl.load(Img_start_locs + img_handle_id - 1, mask=img_handle_id >= 1, other=0)
img_token_len = tl.load(Img_token_lens + img_handle_id - 1, mask=img_handle_id >= 1, other=0)
# load store img emb
for _ in range(0, tl.where((img_handle_id != 0) & (token_id >= img_start_token_id) & (token_id < img_start_token_id + img_token_len), 1, 0), 1):
load_emb = tl.load(Img_embs + stride_img_emb_s * (img_start_loc + token_id - img_start_token_id) + off_d * stride_img_emb_d, mask=off_d < hidden_size, other=0)
for _ in range(
0,
tl.where(
(img_handle_id != 0) & (token_id >= img_start_token_id) & (token_id < img_start_token_id + img_token_len),
1,
0,
),
1,
):
load_emb = tl.load(
Img_embs + stride_img_emb_s * (img_start_loc + token_id - img_start_token_id) + off_d * stride_img_emb_d,
mask=off_d < hidden_size,
other=0,
)
tl.store(Out + stride_out_s * seq_index + stride_out_d * off_d, load_emb, mask=off_d < hidden_size)
return


@torch.no_grad()
def multimodal_emb(out: torch.Tensor, prompt_ids: torch.Tensor, text_weight_embs: torch.Tensor, img_embs: torch.Tensor,
img_token_lens: torch.Tensor, img_start_token_ids: torch.Tensor, img_start_locs: torch.Tensor,
tp_text_start_token_id,
tp_text_end_token_id):
def multimodal_emb(
out: torch.Tensor,
prompt_ids: torch.Tensor,
text_weight_embs: torch.Tensor,
img_embs: torch.Tensor,
img_token_lens: torch.Tensor,
img_start_token_ids: torch.Tensor,
img_start_locs: torch.Tensor,
tp_text_start_token_id,
tp_text_end_token_id,
):
total_len = prompt_ids.shape[0]
BLOCK = triton.next_power_of_2(out.shape[1])
# print(len(img_token_lens))
Expand All @@ -60,9 +90,12 @@ def multimodal_emb(out: torch.Tensor, prompt_ids: torch.Tensor, text_weight_embs
img_token_lens,
img_start_token_ids,
img_start_locs,
text_weight_embs.stride(0), text_weight_embs.stride(1),
img_embs.stride(0), img_embs.stride(1),
out.stride(0), out.stride(1),
text_weight_embs.stride(0),
text_weight_embs.stride(1),
img_embs.stride(0),
img_embs.stride(1),
out.stride(0),
out.stride(1),
tp_text_start_token_id,
tp_text_end_token_id,
hidden_size=out.shape[1],
Expand All @@ -73,34 +106,84 @@ def multimodal_emb(out: torch.Tensor, prompt_ids: torch.Tensor, text_weight_embs
return


@triton.jit
def _mark_multimodal_obj_need_kernel(
obj_start_token_ids_ptr,
obj_token_lens_ptr,
obj_marks_ptr,
input_ids_ptr,
input_size,
BLOCK_SIZE: tl.constexpr,
):

obj_index = tl.program_id(0)
start_id = tl.load(obj_start_token_ids_ptr + obj_index)
token_len = tl.load(obj_token_lens_ptr + obj_index)

for block_start in range(0, input_size, BLOCK_SIZE):
block_range = block_start + tl.arange(0, BLOCK_SIZE)
cur_input_ids = tl.load(input_ids_ptr + block_range, mask=block_range < input_size, other=0)
mark = tl.where((cur_input_ids >= start_id) & (cur_input_ids < start_id + token_len), 1, 0)
mark = tl.sum(mark)
tl.store(obj_marks_ptr + obj_index, 1, mask=mark > 0)
return


@torch.no_grad()
def mark_multimodal_obj(obj_start_token_ids: torch.Tensor, obj_token_lens: torch.Tensor, input_ids: torch.Tensor):
out_mark = torch.empty_like(obj_start_token_ids)
out_mark.fill_(0)
assert obj_start_token_ids.shape == obj_token_lens.shape
BLOCK = 512
grid = (obj_start_token_ids.shape[0],)
_mark_multimodal_obj_need_kernel[grid](
obj_start_token_ids_ptr=obj_start_token_ids,
obj_token_lens_ptr=obj_token_lens,
obj_marks_ptr=out_mark,
input_ids_ptr=input_ids,
input_size=input_ids.shape[0],
BLOCK_SIZE=BLOCK,
num_warps=1,
num_stages=1,
)
return out_mark


def test():
S, D = 1024 * 1000, 128 * 64
vob_size = 320000
image_size = 10
image_token_size = 512

text_weight = torch.randn((vob_size, D), device='cuda', dtype=torch.float16)
img_weight = torch.randn((image_size * image_token_size, D), device='cuda', dtype=torch.float16)
img_token_lens = torch.full((image_size,), image_token_size, device='cuda', dtype=torch.long)
img_start_token_ids = (torch.arange(0, image_size * image_token_size, image_token_size) + vob_size * 10).cuda().long()
text_weight = torch.randn((vob_size, D), device="cuda", dtype=torch.float16)
img_weight = torch.randn((image_size * image_token_size, D), device="cuda", dtype=torch.float16)
img_token_lens = torch.full((image_size,), image_token_size, device="cuda", dtype=torch.long)
img_start_token_ids = (
(torch.arange(0, image_size * image_token_size, image_token_size) + vob_size * 10).cuda().long()
)
img_start_locs = torch.arange(0, image_size * image_token_size, image_token_size).cuda().long()

prompt_ids = torch.arange(0, S, 1).cuda().long()
prompt_ids[0: image_size * image_token_size] = (vob_size * 10 + torch.arange(0, image_size * image_token_size, 1)).cuda().long()
prompt_ids[0 : image_size * image_token_size] = (
(vob_size * 10 + torch.arange(0, image_size * image_token_size, 1)).cuda().long()
)

out = torch.zeros((S, D), dtype=torch.float16, device="cuda")
print(out.shape)

import time

triton_output = multimodal_emb(out, prompt_ids, text_weight, img_weight, img_token_lens, img_start_token_ids, img_start_locs, 0, vob_size)

multimodal_emb(
out, prompt_ids, text_weight, img_weight, img_token_lens, img_start_token_ids, img_start_locs, 0, vob_size
)

torch.cuda.synchronize()
iters = 20
t1 = time.time()
for _ in range(iters):
triton_output = multimodal_emb(out, prompt_ids, text_weight, img_weight, img_token_lens, img_start_token_ids, img_start_locs, 0, vob_size)
multimodal_emb(
out, prompt_ids, text_weight, img_weight, img_token_lens, img_start_token_ids, img_start_locs, 0, vob_size
)
torch.cuda.synchronize()
t2 = time.time()
print("Triton time cost", (t2 - t1) / iters)
Expand All @@ -109,4 +192,3 @@ def test():

# if __name__ == "__main__":
# test()

76 changes: 76 additions & 0 deletions lightllm/common/image_cache_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from collections import OrderedDict
from lightllm.utils.dist_utils import get_current_device_id


class ImageCacheManager:
def __init__(self):
"""
Initialize the image cache manager with a simple GPU cache and an LRU CPU cache.
"""
self._gpu_cache = dict()
self._cpu_cache = OrderedDict()
Comment on lines +10 to +11

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The _max_size attribute is used in the filter method but is not initialized in the __init__ method, which can lead to an AttributeError if filter() is called before set_max_size().

Suggested change
self._gpu_cache = dict()
self._cpu_cache = OrderedDict()
self._gpu_cache = dict()
self._cpu_cache = OrderedDict()
self._max_size = 0


def set_max_size(self, max_size: int):
"""
Set the maximum number of items to keep in the CPU cache.
:param max_size: Maximum number of items to keep in the CPU cache.
"""
if max_size <= 0:
raise ValueError("max_size must be greater than 0")
self._max_size = max_size

def set_embed(self, uuid, embed):
"""
Store the embedding for the given uuid in the GPU cache.
:param uuid: Unique identifier for the image
:param embed: Embedding vector for the image (on GPU)
"""
self._gpu_cache[uuid] = embed

def get_embed(self, uuid):
"""
Retrieve the embedding for the given uuid. Prefer GPU cache,
otherwise return CPU cache and move to GPU (simulate .cuda()).
:param uuid: Unique identifier for the image
:return: Embedding vector (on GPU if possible, else move from CPU to GPU)
"""
if uuid in self._gpu_cache:
return self._gpu_cache[uuid]
elif uuid in self._cpu_cache:
self._cpu_cache.move_to_end(uuid)
embed = self._cpu_cache[uuid].cuda(get_current_device_id())
return embed
return None

def query_embed(self, uuid):
"""
Query if the embedding for the given uuid is in the cache.
:param uuid: Unique identifier for the image
:return: True if the embedding is in the cache, False otherwise
"""
return uuid in self._gpu_cache or uuid in self._cpu_cache

def filter(self, uuid_list):
"""
Given a list of uuids, move their embeddings from GPU cache to CPU cache if present,
and return a dict of those found in the cache and their embeddings (on CPU).
:param uuid_list: List of uuids
"""
for uuid in uuid_list:
if uuid in self._gpu_cache:
embed_cpu = self._gpu_cache[uuid].cpu()
# Move to CPU cache and remove from GPU cache
self._gpu_cache.pop(uuid)
if uuid in self._cpu_cache:
self._cpu_cache.move_to_end(uuid)
self._cpu_cache[uuid] = embed_cpu
if len(self._cpu_cache) > self._max_size:
self._cpu_cache.popitem(last=False)
elif uuid in self._cpu_cache:
self._cpu_cache.move_to_end(uuid)
print(self._gpu_cache.keys())
print(self._cpu_cache.keys())
Comment on lines +71 to +72

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These print statements are likely for debugging and should be removed or replaced with proper logging.

return


image_cache_manager = ImageCacheManager()
4 changes: 3 additions & 1 deletion lightllm/models/gemma3/layer_infer/pre_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@ def context_forward(self, input_ids, infer_state, layer_weight):
else:
weight_mask[idx] = scale

infer_state.mark_multimodal_objs_for_prefill(input_ids=input_ids)

for batch_id, p in enumerate(infer_state.multimodal_params):
for img in p["images"]:
# skip the same image
if img["token_id"] in img_start_token_ids:
if img["token_id"] in img_start_token_ids or img["_prefill_"] is False:
continue
# pull the img_embeds by uid from shm
data = read_shm(get_shm_name_embed(img["uuid"]))
Expand Down
Loading