From 038f6bf5679024942543c239ab7fef22df9343a3 Mon Sep 17 00:00:00 2001 From: grimoire Date: Tue, 16 Jul 2024 19:32:11 +0800 Subject: [PATCH 01/89] attn layer --- lmdeploy/pytorch/attention/__init__.py | 6 ++ lmdeploy/pytorch/attention/ascend.py | 55 ++++++++++ lmdeploy/pytorch/attention/base.py | 103 ++++++++++++++++++ lmdeploy/pytorch/attention/layer.py | 55 ++++++++++ lmdeploy/pytorch/attention/selector.py | 22 ++++ lmdeploy/pytorch/attention/triton.py | 134 ++++++++++++++++++++++++ lmdeploy/pytorch/engine/cache_engine.py | 44 ++++++-- lmdeploy/pytorch/engine/devices/cuda.py | 15 +++ lmdeploy/pytorch/models/llama.py | 44 ++++---- 9 files changed, 443 insertions(+), 35 deletions(-) create mode 100644 lmdeploy/pytorch/attention/__init__.py create mode 100644 lmdeploy/pytorch/attention/ascend.py create mode 100644 lmdeploy/pytorch/attention/base.py create mode 100644 lmdeploy/pytorch/attention/layer.py create mode 100644 lmdeploy/pytorch/attention/selector.py create mode 100644 lmdeploy/pytorch/attention/triton.py diff --git a/lmdeploy/pytorch/attention/__init__.py b/lmdeploy/pytorch/attention/__init__.py new file mode 100644 index 000000000..dbe52bb1c --- /dev/null +++ b/lmdeploy/pytorch/attention/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# attention module is modified from: +# https://github.com/vllm-project/vllm/blob/main/vllm/attention/ +from .selector import get_attn_backend + +__all__ = ['get_attn_backend'] diff --git a/lmdeploy/pytorch/attention/ascend.py b/lmdeploy/pytorch/attention/ascend.py new file mode 100644 index 000000000..572a698eb --- /dev/null +++ b/lmdeploy/pytorch/attention/ascend.py @@ -0,0 +1,55 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple, Type + +import torch + +from .base import AttentionBackend, AttentionImpl, AttentionMetadata + + +class AscendAttentionBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + raise 'ascend' + + @staticmethod + def get_impl_cls() -> Type['AttentionImpl']: + return AscendAttentionImpl + + @staticmethod + def get_metadata_cls() -> Type['AttentionMetadata']: + return AscendAttentionMetadata + + @staticmethod + def get_k_block_shape( + block_size: int, + num_heads: int, + head_size: int, + dtype: torch.dtype, + ) -> Tuple[int, ...]: + return ( + block_size, + num_heads, + head_size, + ) + + @staticmethod + def get_v_block_shape( + block_size: int, + num_heads: int, + head_size: int, + dtype: torch.dtype, + ) -> Tuple[int, ...]: + return ( + block_size, + num_heads, + head_size, + ) + + +class AscendAttentionMetadata(AttentionMetadata): + pass + + +class AscendAttentionImpl(AttentionImpl[AscendAttentionMetadata]): + pass diff --git a/lmdeploy/pytorch/attention/base.py b/lmdeploy/pytorch/attention/base.py new file mode 100644 index 000000000..359397167 --- /dev/null +++ b/lmdeploy/pytorch/attention/base.py @@ -0,0 +1,103 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# modify from: +# https://github.com/vllm-project/vllm/blob/main/vllm/attention/backends/abstract.py +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Generic, Tuple, Type, TypeVar + +import torch + + +class AttentionBackend(ABC): + + @staticmethod + @abstractmethod + def get_name() -> str: + raise NotImplementedError + + @staticmethod + @abstractmethod + def get_impl_cls() -> Type['AttentionImpl']: + raise NotImplementedError + + @staticmethod + @abstractmethod + def get_metadata_cls() -> Type['AttentionMetadata']: + raise NotImplementedError + + @staticmethod + @abstractmethod + def get_k_block_shape( + block_size: int, + num_heads: int, + head_size: int, + dtype: torch.dtype, + ) -> Tuple[int, ...]: + raise NotImplementedError + + @staticmethod + @abstractmethod + def get_v_block_shape( + block_size: int, + num_heads: int, + head_size: int, + dtype: torch.dtype, + ) -> Tuple[int, ...]: + raise NotImplementedError + + +@dataclass +class AttentionMetadata: + is_decoding: bool + block_offsets: torch.Tensor + q_start_loc: torch.Tensor = None + q_seqlens: torch.Tensor = None + kv_seqlens: torch.Tensor = None + max_q_seqlen: int = 0 + max_kv_seqlen: int = 0 + + +T = TypeVar('T', bound=AttentionMetadata) + + +class AttentionImpl(ABC, Generic[T]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float = None, + num_kv_heads: int = None, + v_head_size: int = None, + alibi_scale: float = None, + sliding_window: int = None, + **kwargs, + ) -> None: + if scale is None: + scale = 1.0 / (head_size**0.5) + + if num_kv_heads is None: + num_kv_heads = num_heads + + if v_head_size is None: + v_head_size = head_size + + self.num_heads = num_heads + self.head_size = head_size + self.scale = scale + self.num_kv_heads = num_kv_heads + self.v_head_size = v_head_size + self.alibi_scale = alibi_scale + self.sliding_window = sliding_window + + @abstractmethod + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + attn_metadata: T, + ) -> torch.Tensor: + raise NotImplementedError diff --git a/lmdeploy/pytorch/attention/layer.py b/lmdeploy/pytorch/attention/layer.py new file mode 100644 index 000000000..c0b18648c --- /dev/null +++ b/lmdeploy/pytorch/attention/layer.py @@ -0,0 +1,55 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch import nn + +from .base import AttentionMetadata + + +class Attention(nn.Module): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float = None, + num_kv_heads: int = None, + v_head_size: int = None, + alibi_scale: float = None, + sliding_window: int = None, + **kwargs, + ): + from .selector import get_attn_backend + super().__init__() + attn_backend = get_attn_backend() + impl_cls = attn_backend.get_impl_cls() + + self.impl = impl_cls( + num_heads, + head_size, + scale, + num_kv_heads, + v_head_size, + alibi_scale, + sliding_window, + **kwargs, + ) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + inplace: bool = True, + ) -> torch.Tensor: + return self.impl.forward( + query, + key, + value, + k_cache, + v_cache, + attn_metadata=attn_metadata, + inplace=inplace, + ) diff --git a/lmdeploy/pytorch/attention/selector.py b/lmdeploy/pytorch/attention/selector.py new file mode 100644 index 000000000..5f4045a18 --- /dev/null +++ b/lmdeploy/pytorch/attention/selector.py @@ -0,0 +1,22 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from lmdeploy.pytorch.devices import get_device_manager +from lmdeploy.utils import get_logger + + +def get_attn_backend(): + """get attention backend.""" + device_mgr = get_device_manager() + device_ctx = device_mgr.current_context() + + device_type = device_ctx.device_type + + if device_type == 'cuda': + from .triton import TritonAttentionBackend + return TritonAttentionBackend + if device_type == 'ascend': + from .ascend import AscendAttentionBackend + return AscendAttentionBackend + else: + logger = get_logger('lmdeploy') + logger.warning(f'Unsupported device type: {device_type}') + return None diff --git a/lmdeploy/pytorch/attention/triton.py b/lmdeploy/pytorch/attention/triton.py new file mode 100644 index 000000000..b193dcc64 --- /dev/null +++ b/lmdeploy/pytorch/attention/triton.py @@ -0,0 +1,134 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple, Type + +import torch + +from .base import AttentionBackend, AttentionImpl, AttentionMetadata + + +class TritonAttentionBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + raise 'triton' + + @staticmethod + def get_impl_cls() -> Type['AttentionImpl']: + return TritonAttentionImpl + + @staticmethod + def get_metadata_cls() -> Type['AttentionMetadata']: + return TritonAttentionMetadata + + @staticmethod + def get_k_block_shape( + block_size: int, + num_heads: int, + head_size: int, + dtype: torch.dtype, + ) -> Tuple[int, ...]: + return ( + block_size, + num_heads, + head_size, + ) + + @staticmethod + def get_v_block_shape( + block_size: int, + num_heads: int, + head_size: int, + dtype: torch.dtype, + ) -> Tuple[int, ...]: + return ( + block_size, + num_heads, + head_size, + ) + + +class TritonAttentionMetadata(AttentionMetadata): + pass + + +class TritonAttentionImpl(AttentionImpl[TritonAttentionMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float = None, + num_kv_heads: int = None, + v_head_size: int = None, + alibi_scale: float = None, + sliding_window: int = None, + **kwargs, + ): + super().__init__( + num_heads, + head_size, + scale, + num_kv_heads, + v_head_size, + alibi_scale, + sliding_window, + **kwargs, + ) + + from lmdeploy.pytorch.kernels.cuda import (fill_kv_cache, + paged_attention_fwd) + self.fill_kv_cache = fill_kv_cache + self.paged_attention_fwd = paged_attention_fwd + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + attn_metadata: TritonAttentionMetadata, + inplace: bool = True, + ) -> torch.Tensor: + + block_offsets = attn_metadata.block_offsets + q_start_loc = attn_metadata.q_start_loc + q_seqlens = attn_metadata.q_seqlens + kv_seqlens = attn_metadata.kv_seqlens + max_q_seqlen = attn_metadata.max_q_seqlen + + # fill kv cache + self.fill_kv_cache( + key, + value, + k_cache, + v_cache, + q_start_loc, + q_seqlens, + kv_seq_length=kv_seqlens, + max_q_seq_length=max_q_seqlen, + block_offsets=block_offsets, + ) + + if inplace: + attn_output = query[..., :self.v_head_size] + else: + q_shape = query.shape + o_shape = q_shape[:-1] + (self.v_head_size, ) + attn_output = query.new_empty(o_shape) + + self.paged_attention_fwd( + query, + k_cache, + v_cache, + attn_output, + block_offsets, + q_start_loc=q_start_loc, + q_seqlens=q_seqlens, + kv_seqlens=kv_seqlens, + max_seqlen=max_q_seqlen, + window_size=self.sliding_window, + sm_scale=self.scale, + ) + + return attn_output diff --git a/lmdeploy/pytorch/engine/cache_engine.py b/lmdeploy/pytorch/engine/cache_engine.py index a8d133a08..82a4d082a 100644 --- a/lmdeploy/pytorch/engine/cache_engine.py +++ b/lmdeploy/pytorch/engine/cache_engine.py @@ -4,6 +4,7 @@ import torch +from lmdeploy.pytorch.attention import get_attn_backend from lmdeploy.utils import get_logger from ..config import CacheConfig, ModelConfig @@ -78,26 +79,47 @@ def num_cpu_blocks(self): return self.cache_config.num_cpu_blocks @classmethod - def _get_block_shape_impl(cls, - model_config: ModelConfig, - block_size: int, - head_size: int, - world_size: int = 1, - local: bool = True): + def _get_key_block_shape_impl(cls, + model_config: ModelConfig, + block_size: int, + head_size: int, + world_size: int = 1, + local: bool = True): """get single block shape.""" + attn_backend = get_attn_backend() + dtype = model_config.dtype + num_heads = model_config.num_key_value_heads + if local and not model_config.multi_query_attention: + assert num_heads % world_size == 0, \ + f'num_heads: {num_heads}, world_size: {world_size}' + num_heads = num_heads // world_size + return attn_backend.get_k_block_shape(block_size, num_heads, head_size, + dtype) + + @classmethod + def _get_value_block_shape_impl(cls, + model_config: ModelConfig, + block_size: int, + head_size: int, + world_size: int = 1, + local: bool = True): + """get single block shape.""" + attn_backend = get_attn_backend() + dtype = model_config.dtype num_heads = model_config.num_key_value_heads if local and not model_config.multi_query_attention: assert num_heads % world_size == 0, \ f'num_heads: {num_heads}, world_size: {world_size}' num_heads = num_heads // world_size - return (block_size, num_heads, head_size) + return attn_backend.get_v_block_shape(block_size, num_heads, head_size, + dtype) def get_key_block_shape(self, local: bool = False) -> Tuple[int, int, int]: """get shape of key block.""" head_size = self.model_config.k_head_dim if head_size is None: head_size = self.model_config.head_dim - return self._get_block_shape_impl( + return self._get_key_block_shape_impl( self.model_config, block_size=self.block_size, head_size=head_size, @@ -111,7 +133,7 @@ def get_value_block_shape(self, head_size = self.model_config.v_head_dim if head_size is None: head_size = self.model_config.head_dim - return self._get_block_shape_impl( + return self._get_value_block_shape_impl( self.model_config, block_size=self.block_size, head_size=head_size, @@ -222,14 +244,14 @@ def get_cache_block_size(cls, key_head_size = model_config.head_dim if value_head_size is None: value_head_size = model_config.head_dim - key_shape = cls._get_block_shape_impl( + key_shape = cls._get_key_block_shape_impl( model_config, block_size=block_size, head_size=key_head_size, world_size=world_size, local=True, ) - value_shape = cls._get_block_shape_impl( + value_shape = cls._get_value_block_shape_impl( model_config, block_size=block_size, head_size=value_head_size, diff --git a/lmdeploy/pytorch/engine/devices/cuda.py b/lmdeploy/pytorch/engine/devices/cuda.py index 95c78df97..12d06fd6b 100644 --- a/lmdeploy/pytorch/engine/devices/cuda.py +++ b/lmdeploy/pytorch/engine/devices/cuda.py @@ -1,4 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +from lmdeploy.pytorch.attention import get_attn_backend + from .base_device_utils import BaseDeviceUtils @@ -9,4 +11,17 @@ class CUDADeviceUtils(BaseDeviceUtils): @classmethod def update_step_context(cls, step_context): """update step context.""" + attn_backend = get_attn_backend() + attn_meta_cls = attn_backend.get_metadata_cls() + attn_meta = attn_meta_cls( + step_context.is_decoding, + step_context.block_offsets, + q_start_loc=step_context.q_start_loc, + q_seqlens=step_context.q_seq_length, + kv_seqlens=step_context.kv_seq_length, + max_q_seqlen=step_context.max_q_seq_length, + max_kv_seqlen=step_context.max_kv_seq_length, + ) + + step_context.attn_meta = attn_meta return step_context diff --git a/lmdeploy/pytorch/models/llama.py b/lmdeploy/pytorch/models/llama.py index 386becea7..ed398cc95 100644 --- a/lmdeploy/pytorch/models/llama.py +++ b/lmdeploy/pytorch/models/llama.py @@ -64,6 +64,22 @@ def _load_weights(self, loader, rank: int, world_size: int, world_size=world_size, prefix='o_proj') + def _update_model_fn(self): + """update model.""" + from lmdeploy.pytorch.attention.layer import Attention + world_size = 1 + if dist.is_initialized(): + world_size = dist.get_world_size() + num_heads = self.num_heads // world_size + num_kv_heads = self.num_key_value_heads // world_size + head_size = self.head_dim + self.attn_fwd = Attention( + num_heads, + head_size, + num_kv_heads=num_kv_heads, + v_head_size=head_size, + ) + @classmethod def _distribute_output_fn(cls, outputs, **kwargs): """Distribution output hook.""" @@ -82,11 +98,6 @@ def _contiguous_batching_forward_default_impl( Optional[Tuple[torch.Tensor]]]: """default rewrite.""" context = self.context.context - kv_seq_length = context.kv_seq_length - q_seq_length = context.q_seq_length - q_start_loc = context.q_start_loc - block_offsets = context.block_offsets - max_q_seq_length = context.max_q_seq_length max_kv_seq_length = context.max_kv_seq_length num_heads = self.num_heads // world_size @@ -174,29 +185,14 @@ def __rotary_emb_fn(query_states, key_states, value_states): query_states, key_states, value_states = __rotary_emb_fn( query_states, key_states, value_states) - fill_kv_cache( + attn_output = self.attn_fwd( + query_states, key_states, value_states, past_key_value[0], past_key_value[1], - q_start_loc, - q_seq_length, - kv_seq_length=kv_seq_length, - max_q_seq_length=max_q_seq_length, - block_offsets=block_offsets, - ) - - attn_output = query_states - paged_attention_fwd( - query_states, - past_key_value[0], - past_key_value[1], - attn_output, - block_offsets, - q_start_loc=q_start_loc, - q_seqlens=q_seq_length, - kv_seqlens=kv_seq_length, - max_seqlen=max_q_seq_length, + context.attn_meta, + inplace=True, ) attn_output = attn_output.reshape(*hidden_states.shape[:-1], hidden_size) From 68936c9edfeb3bc9cbe60065d611872d27f15264 Mon Sep 17 00:00:00 2001 From: grimoire Date: Wed, 17 Jul 2024 15:16:38 +0800 Subject: [PATCH 02/89] move to backend --- lmdeploy/pytorch/attention/backends/__init__.py | 1 + lmdeploy/pytorch/attention/{ => backends}/ascend.py | 0 lmdeploy/pytorch/attention/{ => backends}/base.py | 0 lmdeploy/pytorch/attention/{ => backends}/triton.py | 0 lmdeploy/pytorch/attention/layer.py | 2 +- lmdeploy/pytorch/attention/selector.py | 4 ++-- 6 files changed, 4 insertions(+), 3 deletions(-) create mode 100644 lmdeploy/pytorch/attention/backends/__init__.py rename lmdeploy/pytorch/attention/{ => backends}/ascend.py (100%) rename lmdeploy/pytorch/attention/{ => backends}/base.py (100%) rename lmdeploy/pytorch/attention/{ => backends}/triton.py (100%) diff --git a/lmdeploy/pytorch/attention/backends/__init__.py b/lmdeploy/pytorch/attention/backends/__init__.py new file mode 100644 index 000000000..ef101fec6 --- /dev/null +++ b/lmdeploy/pytorch/attention/backends/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/lmdeploy/pytorch/attention/ascend.py b/lmdeploy/pytorch/attention/backends/ascend.py similarity index 100% rename from lmdeploy/pytorch/attention/ascend.py rename to lmdeploy/pytorch/attention/backends/ascend.py diff --git a/lmdeploy/pytorch/attention/base.py b/lmdeploy/pytorch/attention/backends/base.py similarity index 100% rename from lmdeploy/pytorch/attention/base.py rename to lmdeploy/pytorch/attention/backends/base.py diff --git a/lmdeploy/pytorch/attention/triton.py b/lmdeploy/pytorch/attention/backends/triton.py similarity index 100% rename from lmdeploy/pytorch/attention/triton.py rename to lmdeploy/pytorch/attention/backends/triton.py diff --git a/lmdeploy/pytorch/attention/layer.py b/lmdeploy/pytorch/attention/layer.py index c0b18648c..439c18749 100644 --- a/lmdeploy/pytorch/attention/layer.py +++ b/lmdeploy/pytorch/attention/layer.py @@ -2,7 +2,7 @@ import torch from torch import nn -from .base import AttentionMetadata +from .backends.base import AttentionMetadata class Attention(nn.Module): diff --git a/lmdeploy/pytorch/attention/selector.py b/lmdeploy/pytorch/attention/selector.py index 5f4045a18..7fa80430d 100644 --- a/lmdeploy/pytorch/attention/selector.py +++ b/lmdeploy/pytorch/attention/selector.py @@ -11,10 +11,10 @@ def get_attn_backend(): device_type = device_ctx.device_type if device_type == 'cuda': - from .triton import TritonAttentionBackend + from .backends.triton import TritonAttentionBackend return TritonAttentionBackend if device_type == 'ascend': - from .ascend import AscendAttentionBackend + from .backends.ascend import AscendAttentionBackend return AscendAttentionBackend else: logger = get_logger('lmdeploy') From ccdb3ea9241604821e9f4499e639b4bb1a565c67 Mon Sep 17 00:00:00 2001 From: grimoire Date: Fri, 19 Jul 2024 19:59:49 +0800 Subject: [PATCH 03/89] add base layer --- lmdeploy/pytorch/attention/__init__.py | 6 - .../pytorch/attention/backends/__init__.py | 1 - lmdeploy/pytorch/attention/backends/ascend.py | 55 --- lmdeploy/pytorch/backends/__init__.py | 3 + lmdeploy/pytorch/backends/activation.py | 17 + lmdeploy/pytorch/backends/apply_rotary_emb.py | 17 + lmdeploy/pytorch/backends/ascend/__init__.py | 2 + lmdeploy/pytorch/backends/ascend/attention.py | 10 + .../ascend/layer_backend.py} | 50 ++- .../base.py => backends/attention.py} | 59 +-- lmdeploy/pytorch/backends/base.py | 59 +++ lmdeploy/pytorch/backends/cuda/__init__.py | 2 + .../pytorch/backends/cuda/apply_rotary_emb.py | 25 ++ .../triton.py => backends/cuda/attention.py} | 68 +-- .../pytorch/backends/cuda/layer_backend.py | 76 ++++ lmdeploy/pytorch/backends/cuda/norm.py | 31 ++ lmdeploy/pytorch/backends/default/__init__.py | 2 + .../pytorch/backends/default/activation.py | 23 + .../backends/default/apply_rotary_emb.py | 43 ++ .../pytorch/backends/default/layer_backend.py | 59 +++ lmdeploy/pytorch/backends/default/linear.py | 26 ++ lmdeploy/pytorch/backends/default/norm.py | 33 ++ .../backends/default/rotary_embedding.py | 92 ++++ lmdeploy/pytorch/backends/linear.py | 19 + lmdeploy/pytorch/backends/norm.py | 19 + lmdeploy/pytorch/backends/rotary_embedding.py | 30 ++ .../{attention => backends}/selector.py | 10 +- lmdeploy/pytorch/engine/cache_engine.py | 6 +- lmdeploy/pytorch/engine/devices/__init__.py | 43 -- .../engine/devices/base_device_utils.py | 24 - lmdeploy/pytorch/engine/devices/cuda.py | 27 -- lmdeploy/pytorch/engine/devices/dipu.py | 13 - lmdeploy/pytorch/engine/model_agent.py | 4 +- .../kernels/cuda/apply_rotary_pos_emb.py | 35 +- lmdeploy/pytorch/layers/__init__.py | 10 + lmdeploy/pytorch/layers/activation.py | 16 + lmdeploy/pytorch/layers/apply_rotary_emb.py | 16 + .../layer.py => layers/attention.py} | 11 +- lmdeploy/pytorch/layers/linear.py | 312 +++++++++++++ lmdeploy/pytorch/layers/norm.py | 17 + lmdeploy/pytorch/layers/rotary_embedding.py | 30 ++ lmdeploy/pytorch/models/llama.py | 417 +++++------------- lmdeploy/pytorch/models/module_map.py | 4 +- 43 files changed, 1226 insertions(+), 596 deletions(-) delete mode 100644 lmdeploy/pytorch/attention/__init__.py delete mode 100644 lmdeploy/pytorch/attention/backends/__init__.py delete mode 100644 lmdeploy/pytorch/attention/backends/ascend.py create mode 100644 lmdeploy/pytorch/backends/__init__.py create mode 100644 lmdeploy/pytorch/backends/activation.py create mode 100644 lmdeploy/pytorch/backends/apply_rotary_emb.py create mode 100644 lmdeploy/pytorch/backends/ascend/__init__.py create mode 100644 lmdeploy/pytorch/backends/ascend/attention.py rename lmdeploy/pytorch/{engine/devices/ascend.py => backends/ascend/layer_backend.py} (58%) rename lmdeploy/pytorch/{attention/backends/base.py => backends/attention.py} (65%) create mode 100644 lmdeploy/pytorch/backends/base.py create mode 100644 lmdeploy/pytorch/backends/cuda/__init__.py create mode 100644 lmdeploy/pytorch/backends/cuda/apply_rotary_emb.py rename lmdeploy/pytorch/{attention/backends/triton.py => backends/cuda/attention.py} (73%) create mode 100644 lmdeploy/pytorch/backends/cuda/layer_backend.py create mode 100644 lmdeploy/pytorch/backends/cuda/norm.py create mode 100644 lmdeploy/pytorch/backends/default/__init__.py create mode 100644 lmdeploy/pytorch/backends/default/activation.py create mode 100644 lmdeploy/pytorch/backends/default/apply_rotary_emb.py create mode 100644 lmdeploy/pytorch/backends/default/layer_backend.py create mode 100644 lmdeploy/pytorch/backends/default/linear.py create mode 100644 lmdeploy/pytorch/backends/default/norm.py create mode 100644 lmdeploy/pytorch/backends/default/rotary_embedding.py create mode 100644 lmdeploy/pytorch/backends/linear.py create mode 100644 lmdeploy/pytorch/backends/norm.py create mode 100644 lmdeploy/pytorch/backends/rotary_embedding.py rename lmdeploy/pytorch/{attention => backends}/selector.py (69%) delete mode 100644 lmdeploy/pytorch/engine/devices/__init__.py delete mode 100644 lmdeploy/pytorch/engine/devices/base_device_utils.py delete mode 100644 lmdeploy/pytorch/engine/devices/cuda.py delete mode 100644 lmdeploy/pytorch/engine/devices/dipu.py create mode 100644 lmdeploy/pytorch/layers/__init__.py create mode 100644 lmdeploy/pytorch/layers/activation.py create mode 100644 lmdeploy/pytorch/layers/apply_rotary_emb.py rename lmdeploy/pytorch/{attention/layer.py => layers/attention.py} (80%) create mode 100644 lmdeploy/pytorch/layers/linear.py create mode 100644 lmdeploy/pytorch/layers/norm.py create mode 100644 lmdeploy/pytorch/layers/rotary_embedding.py diff --git a/lmdeploy/pytorch/attention/__init__.py b/lmdeploy/pytorch/attention/__init__.py deleted file mode 100644 index dbe52bb1c..000000000 --- a/lmdeploy/pytorch/attention/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -# attention module is modified from: -# https://github.com/vllm-project/vllm/blob/main/vllm/attention/ -from .selector import get_attn_backend - -__all__ = ['get_attn_backend'] diff --git a/lmdeploy/pytorch/attention/backends/__init__.py b/lmdeploy/pytorch/attention/backends/__init__.py deleted file mode 100644 index ef101fec6..000000000 --- a/lmdeploy/pytorch/attention/backends/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. diff --git a/lmdeploy/pytorch/attention/backends/ascend.py b/lmdeploy/pytorch/attention/backends/ascend.py deleted file mode 100644 index 572a698eb..000000000 --- a/lmdeploy/pytorch/attention/backends/ascend.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from typing import Tuple, Type - -import torch - -from .base import AttentionBackend, AttentionImpl, AttentionMetadata - - -class AscendAttentionBackend(AttentionBackend): - - @staticmethod - def get_name() -> str: - raise 'ascend' - - @staticmethod - def get_impl_cls() -> Type['AttentionImpl']: - return AscendAttentionImpl - - @staticmethod - def get_metadata_cls() -> Type['AttentionMetadata']: - return AscendAttentionMetadata - - @staticmethod - def get_k_block_shape( - block_size: int, - num_heads: int, - head_size: int, - dtype: torch.dtype, - ) -> Tuple[int, ...]: - return ( - block_size, - num_heads, - head_size, - ) - - @staticmethod - def get_v_block_shape( - block_size: int, - num_heads: int, - head_size: int, - dtype: torch.dtype, - ) -> Tuple[int, ...]: - return ( - block_size, - num_heads, - head_size, - ) - - -class AscendAttentionMetadata(AttentionMetadata): - pass - - -class AscendAttentionImpl(AttentionImpl[AscendAttentionMetadata]): - pass diff --git a/lmdeploy/pytorch/backends/__init__.py b/lmdeploy/pytorch/backends/__init__.py new file mode 100644 index 000000000..dc8538228 --- /dev/null +++ b/lmdeploy/pytorch/backends/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base import LayerType # noqa: F401 +from .selector import get_backend # noqa: F401 diff --git a/lmdeploy/pytorch/backends/activation.py b/lmdeploy/pytorch/backends/activation.py new file mode 100644 index 000000000..b6478705f --- /dev/null +++ b/lmdeploy/pytorch/backends/activation.py @@ -0,0 +1,17 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABC, abstractmethod + + +class SiluAndMulImpl(ABC): + + @abstractmethod + def forward(self, x): + raise NotImplementedError + + +class SiluAndMulBuilder(ABC): + + @staticmethod + @abstractmethod + def build(inplace: bool = False): + raise NotImplementedError diff --git a/lmdeploy/pytorch/backends/apply_rotary_emb.py b/lmdeploy/pytorch/backends/apply_rotary_emb.py new file mode 100644 index 000000000..9e6955bfd --- /dev/null +++ b/lmdeploy/pytorch/backends/apply_rotary_emb.py @@ -0,0 +1,17 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABC, abstractmethod + + +class ApplyRotaryEmbImpl(ABC): + + @abstractmethod + def forward(self, query, key, cos, sin, inplace: bool = True): + raise NotImplementedError + + +class ApplyRotaryEmbBuilder(ABC): + + @staticmethod + @abstractmethod + def build(): + raise NotImplementedError diff --git a/lmdeploy/pytorch/backends/ascend/__init__.py b/lmdeploy/pytorch/backends/ascend/__init__.py new file mode 100644 index 000000000..72e28d247 --- /dev/null +++ b/lmdeploy/pytorch/backends/ascend/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .layer_backend import AscendLayerBackend # noqa: F401 diff --git a/lmdeploy/pytorch/backends/ascend/attention.py b/lmdeploy/pytorch/backends/ascend/attention.py new file mode 100644 index 000000000..ee83af2c8 --- /dev/null +++ b/lmdeploy/pytorch/backends/ascend/attention.py @@ -0,0 +1,10 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from ..base import AttentionImpl, AttentionMetadata + + +class AscendAttentionMetadata(AttentionMetadata): + pass + + +class AscendAttentionImpl(AttentionImpl[AscendAttentionMetadata]): + pass diff --git a/lmdeploy/pytorch/engine/devices/ascend.py b/lmdeploy/pytorch/backends/ascend/layer_backend.py similarity index 58% rename from lmdeploy/pytorch/engine/devices/ascend.py rename to lmdeploy/pytorch/backends/ascend/layer_backend.py index a09fa5f65..adbf8b379 100644 --- a/lmdeploy/pytorch/engine/devices/ascend.py +++ b/lmdeploy/pytorch/backends/ascend/layer_backend.py @@ -1,12 +1,56 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple + import torch -from .dipu import DIPUDeviceUtils +from ..base import LayerType +from ..default import DefaultLayersBackend + + +class AscendLayersBackend(DefaultLayersBackend): + + @staticmethod + def get_name() -> str: + raise 'ascend' + + @classmethod + def get_layer_impl_builder(cls, layer_type: LayerType): + if layer_type == LayerType.Attention: + from .attention import AscendAttentionImpl + return AscendAttentionImpl + else: + return super().get_layer_impl_builder(layer_type) + @staticmethod + def get_attention_metadata_cls(): + from .attention import AscendAttentionMetadata + return AscendAttentionMetadata -class ASCENDDeviceUtils(DIPUDeviceUtils): + @staticmethod + def get_k_block_shape( + block_size: int, + num_heads: int, + head_size: int, + dtype: torch.dtype, + ) -> Tuple[int, ...]: + return ( + block_size, + num_heads, + head_size, + ) - device = 'ascend' + @staticmethod + def get_v_block_shape( + block_size: int, + num_heads: int, + head_size: int, + dtype: torch.dtype, + ) -> Tuple[int, ...]: + return ( + block_size, + num_heads, + head_size, + ) @classmethod def update_step_context(cls, step_context): diff --git a/lmdeploy/pytorch/attention/backends/base.py b/lmdeploy/pytorch/backends/attention.py similarity index 65% rename from lmdeploy/pytorch/attention/backends/base.py rename to lmdeploy/pytorch/backends/attention.py index 359397167..2745f4fde 100644 --- a/lmdeploy/pytorch/attention/backends/base.py +++ b/lmdeploy/pytorch/backends/attention.py @@ -1,51 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. -# modify from: -# https://github.com/vllm-project/vllm/blob/main/vllm/attention/backends/abstract.py from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Generic, Tuple, Type, TypeVar +from typing import Generic, TypeVar import torch -class AttentionBackend(ABC): - - @staticmethod - @abstractmethod - def get_name() -> str: - raise NotImplementedError - - @staticmethod - @abstractmethod - def get_impl_cls() -> Type['AttentionImpl']: - raise NotImplementedError - - @staticmethod - @abstractmethod - def get_metadata_cls() -> Type['AttentionMetadata']: - raise NotImplementedError - - @staticmethod - @abstractmethod - def get_k_block_shape( - block_size: int, - num_heads: int, - head_size: int, - dtype: torch.dtype, - ) -> Tuple[int, ...]: - raise NotImplementedError - - @staticmethod - @abstractmethod - def get_v_block_shape( - block_size: int, - num_heads: int, - head_size: int, - dtype: torch.dtype, - ) -> Tuple[int, ...]: - raise NotImplementedError - - @dataclass class AttentionMetadata: is_decoding: bool @@ -101,3 +61,20 @@ def forward( attn_metadata: T, ) -> torch.Tensor: raise NotImplementedError + + +class AttentionBuilder(ABC, Generic[T]): + + @staticmethod + @abstractmethod + def build( + num_heads: int, + head_size: int, + scale: float = None, + num_kv_heads: int = None, + v_head_size: int = None, + alibi_scale: float = None, + sliding_window: int = None, + **kwargs, + ) -> AttentionImpl[T]: + raise NotImplementedError diff --git a/lmdeploy/pytorch/backends/base.py b/lmdeploy/pytorch/backends/base.py new file mode 100644 index 000000000..a46dd6fc1 --- /dev/null +++ b/lmdeploy/pytorch/backends/base.py @@ -0,0 +1,59 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# modify from: +# https://github.com/vllm-project/vllm/blob/main/vllm/attention/backends/abstract.py +from abc import ABC, abstractmethod +from enum import Enum, auto +from typing import Tuple + +import torch + + +class LayerType(Enum): + Attention = auto() + Linear = auto() + RotaryEmbedding = auto() + ApplyRotaryEmb = auto() + SiluAndMul = auto() + RMSNorm = auto() + + +class LayersBackend(ABC): + + @staticmethod + @abstractmethod + def get_name() -> str: + raise NotImplementedError + + @classmethod + @abstractmethod + def get_layer_impl_builder(cls, layer_type: LayerType): + raise NotImplementedError + + @staticmethod + @abstractmethod + def get_attention_metadata_cls(): + raise NotImplementedError + + @staticmethod + @abstractmethod + def get_k_block_shape( + block_size: int, + num_heads: int, + head_size: int, + dtype: torch.dtype, + ) -> Tuple[int, ...]: + raise NotImplementedError + + @staticmethod + @abstractmethod + def get_v_block_shape( + block_size: int, + num_heads: int, + head_size: int, + dtype: torch.dtype, + ) -> Tuple[int, ...]: + raise NotImplementedError + + @classmethod + def update_step_context(cls, step_context): + return step_context diff --git a/lmdeploy/pytorch/backends/cuda/__init__.py b/lmdeploy/pytorch/backends/cuda/__init__.py new file mode 100644 index 000000000..4f598c098 --- /dev/null +++ b/lmdeploy/pytorch/backends/cuda/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .layer_backend import CudaLayersBackend # noqa: F401 diff --git a/lmdeploy/pytorch/backends/cuda/apply_rotary_emb.py b/lmdeploy/pytorch/backends/cuda/apply_rotary_emb.py new file mode 100644 index 000000000..fd90ba82b --- /dev/null +++ b/lmdeploy/pytorch/backends/cuda/apply_rotary_emb.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from lmdeploy.pytorch.kernels.cuda import apply_rotary_pos_emb + +from ..apply_rotary_emb import ApplyRotaryEmbBuilder, ApplyRotaryEmbImpl + + +class TritonApplyRotaryEmbImpl(ApplyRotaryEmbImpl): + + def forward(self, query, key, cos, sin, inplace: bool = True): + if inplace: + q_embed = query + k_embed = key + else: + q_embed = torch.empty_like(query) + k_embed = torch.empty_like(key) + return apply_rotary_pos_emb(query, key, cos, sin, q_embed, k_embed) + + +class TritonApplyRotaryEmbBuilder(ApplyRotaryEmbBuilder): + + @staticmethod + def build(): + return TritonApplyRotaryEmbImpl() diff --git a/lmdeploy/pytorch/attention/backends/triton.py b/lmdeploy/pytorch/backends/cuda/attention.py similarity index 73% rename from lmdeploy/pytorch/attention/backends/triton.py rename to lmdeploy/pytorch/backends/cuda/attention.py index b193dcc64..92d06c761 100644 --- a/lmdeploy/pytorch/attention/backends/triton.py +++ b/lmdeploy/pytorch/backends/cuda/attention.py @@ -1,50 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Tuple, Type - import torch -from .base import AttentionBackend, AttentionImpl, AttentionMetadata - - -class TritonAttentionBackend(AttentionBackend): - - @staticmethod - def get_name() -> str: - raise 'triton' - - @staticmethod - def get_impl_cls() -> Type['AttentionImpl']: - return TritonAttentionImpl - - @staticmethod - def get_metadata_cls() -> Type['AttentionMetadata']: - return TritonAttentionMetadata - - @staticmethod - def get_k_block_shape( - block_size: int, - num_heads: int, - head_size: int, - dtype: torch.dtype, - ) -> Tuple[int, ...]: - return ( - block_size, - num_heads, - head_size, - ) - - @staticmethod - def get_v_block_shape( - block_size: int, - num_heads: int, - head_size: int, - dtype: torch.dtype, - ) -> Tuple[int, ...]: - return ( - block_size, - num_heads, - head_size, - ) +from ..attention import AttentionBuilder, AttentionImpl, AttentionMetadata class TritonAttentionMetadata(AttentionMetadata): @@ -132,3 +89,26 @@ def forward( ) return attn_output + + +class TritonAttentionBuilder(AttentionBuilder[TritonAttentionMetadata]): + + @staticmethod + def build( + num_heads: int, + head_size: int, + scale: float = None, + num_kv_heads: int = None, + v_head_size: int = None, + alibi_scale: float = None, + sliding_window: int = None, + **kwargs, + ) -> TritonAttentionImpl: + return TritonAttentionImpl(num_heads, + head_size, + scale=scale, + num_kv_heads=num_kv_heads, + v_head_size=v_head_size, + alibi_scale=alibi_scale, + sliding_window=sliding_window, + **kwargs) diff --git a/lmdeploy/pytorch/backends/cuda/layer_backend.py b/lmdeploy/pytorch/backends/cuda/layer_backend.py new file mode 100644 index 000000000..6c69726e8 --- /dev/null +++ b/lmdeploy/pytorch/backends/cuda/layer_backend.py @@ -0,0 +1,76 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple + +import torch + +from ..base import LayerType +from ..default import DefaultLayersBackend + + +class CudaLayersBackend(DefaultLayersBackend): + + @staticmethod + def get_name() -> str: + raise 'cuda' + + @classmethod + def get_layer_impl_builder(cls, layer_type: LayerType): + if layer_type == LayerType.Attention: + from .attention import TritonAttentionBuilder + return TritonAttentionBuilder + elif layer_type == LayerType.ApplyRotaryEmb: + from .apply_rotary_emb import TritonApplyRotaryEmbBuilder + return TritonApplyRotaryEmbBuilder + elif layer_type == LayerType.RMSNorm: + from .norm import TritonRMSNormBuilder + return TritonRMSNormBuilder + else: + return super().get_layer_impl_builder(layer_type) + + @staticmethod + def get_attention_metadata_cls(): + from .attention import TritonAttentionMetadata + return TritonAttentionMetadata + + @staticmethod + def get_k_block_shape( + block_size: int, + num_heads: int, + head_size: int, + dtype: torch.dtype, + ) -> Tuple[int, ...]: + return ( + block_size, + num_heads, + head_size, + ) + + @staticmethod + def get_v_block_shape( + block_size: int, + num_heads: int, + head_size: int, + dtype: torch.dtype, + ) -> Tuple[int, ...]: + return ( + block_size, + num_heads, + head_size, + ) + + @classmethod + def update_step_context(cls, step_context): + """update step context.""" + attn_meta_cls = cls.get_attention_metadata_cls() + attn_meta = attn_meta_cls( + step_context.is_decoding, + step_context.block_offsets, + q_start_loc=step_context.q_start_loc, + q_seqlens=step_context.q_seq_length, + kv_seqlens=step_context.kv_seq_length, + max_q_seqlen=step_context.max_q_seq_length, + max_kv_seqlen=step_context.max_kv_seq_length, + ) + + step_context.attn_meta = attn_meta + return step_context diff --git a/lmdeploy/pytorch/backends/cuda/norm.py b/lmdeploy/pytorch/backends/cuda/norm.py new file mode 100644 index 000000000..83edcd748 --- /dev/null +++ b/lmdeploy/pytorch/backends/cuda/norm.py @@ -0,0 +1,31 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch import nn + +from lmdeploy.pytorch.kernels.cuda import rms_norm + +from ..norm import RMSNormBuilder, RMSNormImpl + + +class TritonRMSNormImpl(RMSNormImpl, nn.Module): + + def __init__(self, weight: torch.Tensor, eps: float = 1e-6): + super().__init__() + self.weight = nn.Parameter(weight.clone()) + self.eps = eps + + def forward(self, x: torch.Tensor, residual: torch.Tensor = None): + if residual is not None: + x = x + residual + residual = x + x = rms_norm(x, self.weight, self.eps) + if residual is None: + return x + return x, residual + + +class TritonRMSNormBuilder(RMSNormBuilder): + + @staticmethod + def build(weight: torch.Tensor, eps: float = 1e-6): + return TritonRMSNormImpl(weight, eps) diff --git a/lmdeploy/pytorch/backends/default/__init__.py b/lmdeploy/pytorch/backends/default/__init__.py new file mode 100644 index 000000000..00aa5e8c3 --- /dev/null +++ b/lmdeploy/pytorch/backends/default/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .layer_backend import DefaultLayersBackend # noqa: F401 diff --git a/lmdeploy/pytorch/backends/default/activation.py b/lmdeploy/pytorch/backends/default/activation.py new file mode 100644 index 000000000..6d3c49af1 --- /dev/null +++ b/lmdeploy/pytorch/backends/default/activation.py @@ -0,0 +1,23 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from torch import nn + +from ..activation import SiluAndMulBuilder, SiluAndMulImpl + + +class DefaultSiluAndMulImpl(SiluAndMulImpl, nn.Module): + + def __init__(self, inplace: bool): + super().__init__() + self.inplace = inplace + self.silu = nn.SiLU(inplace) + + def forward(self, x): + gate, up = x.chunk(2, -1) + return self.silu(gate) * up + + +class DefaultSiluAndMulBuilder(SiluAndMulBuilder): + + @staticmethod + def build(inplace: bool = False): + return DefaultSiluAndMulImpl(inplace) diff --git a/lmdeploy/pytorch/backends/default/apply_rotary_emb.py b/lmdeploy/pytorch/backends/default/apply_rotary_emb.py new file mode 100644 index 000000000..fa41123f6 --- /dev/null +++ b/lmdeploy/pytorch/backends/default/apply_rotary_emb.py @@ -0,0 +1,43 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from ..apply_rotary_emb import ApplyRotaryEmbBuilder, ApplyRotaryEmbImpl + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + half_size = x.shape[-1] // 2 + x1 = x[..., :half_size] + x2 = x[..., half_size:] + out = torch.empty_like(x) + out[..., :half_size] = -x2 + out[..., half_size:] = x1 + return out + + +class DefaultApplyRotaryEmbImpl(ApplyRotaryEmbImpl): + + def forward(self, query, key, cos, sin, inplace: bool = True): + unsqueeze_dim = -2 + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + if inplace: + q_embed = query + k_embed = key + q_sin = rotate_half(query) * sin + q_embed.mul_(cos) + q_embed.add_(q_sin) + k_sin = rotate_half(key) * sin + k_embed.mul_(cos) + k_embed.add_(k_sin) + else: + q_embed = (query * cos) + (rotate_half(query) * sin) + k_embed = (key * cos) + (rotate_half(key) * sin) + return q_embed, k_embed + + +class DefaultApplyRotaryEmbBuilder(ApplyRotaryEmbBuilder): + + @staticmethod + def build(): + return DefaultApplyRotaryEmbImpl() diff --git a/lmdeploy/pytorch/backends/default/layer_backend.py b/lmdeploy/pytorch/backends/default/layer_backend.py new file mode 100644 index 000000000..01f3235d8 --- /dev/null +++ b/lmdeploy/pytorch/backends/default/layer_backend.py @@ -0,0 +1,59 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple + +import torch + +from ..base import LayersBackend, LayerType + + +class DefaultLayersBackend(LayersBackend): + + @staticmethod + def get_name() -> str: + raise 'default' + + @classmethod + def get_layer_impl_builder(cls, layer_type: LayerType): + if layer_type == LayerType.Linear: + from .linear import DefaultLinearBuilder + return DefaultLinearBuilder + elif layer_type == LayerType.RotaryEmbedding: + from .rotary_embedding import DefaultRotaryEmbeddingBuilder + return DefaultRotaryEmbeddingBuilder + elif layer_type == LayerType.ApplyRotaryEmb: + from .apply_rotary_emb import DefaultApplyRotaryEmbBuilder + return DefaultApplyRotaryEmbBuilder + elif layer_type == LayerType.SiluAndMul: + from .activation import DefaultSiluAndMulBuilder + return DefaultSiluAndMulBuilder + elif layer_type == LayerType.RMSNorm: + from .norm import DefaultRMSNormBuilder + return DefaultRMSNormBuilder + else: + raise RuntimeError(f'{layer_type} not supported.') + + @staticmethod + def get_k_block_shape( + block_size: int, + num_heads: int, + head_size: int, + dtype: torch.dtype, + ) -> Tuple[int, ...]: + return ( + block_size, + num_heads, + head_size, + ) + + @staticmethod + def get_v_block_shape( + block_size: int, + num_heads: int, + head_size: int, + dtype: torch.dtype, + ) -> Tuple[int, ...]: + return ( + block_size, + num_heads, + head_size, + ) diff --git a/lmdeploy/pytorch/backends/default/linear.py b/lmdeploy/pytorch/backends/default/linear.py new file mode 100644 index 000000000..ddd6547c6 --- /dev/null +++ b/lmdeploy/pytorch/backends/default/linear.py @@ -0,0 +1,26 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from torch import distributed as dist +from torch import nn + +from ..linear import LinearBuilder, LinearImpl + + +class DefaultLinearImpl(LinearImpl): + + def __init__(self, mod: nn.Module, all_reduce: bool = False): + super().__init__() + self.mod = mod + self.all_reduce = all_reduce + + def forward(self, x): + out = self.mod(x) + if self.all_reduce: + dist.all_reduce(out) + return out + + +class DefaultLinearBuilder(LinearBuilder): + + @staticmethod + def build(mod: nn.Module, all_reduce: bool = False): + return DefaultLinearImpl(mod, all_reduce) diff --git a/lmdeploy/pytorch/backends/default/norm.py b/lmdeploy/pytorch/backends/default/norm.py new file mode 100644 index 000000000..cb7f3b36f --- /dev/null +++ b/lmdeploy/pytorch/backends/default/norm.py @@ -0,0 +1,33 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch import nn + +from ..norm import RMSNormBuilder, RMSNormImpl + + +class DefaultRMSNormImpl(RMSNormImpl, nn.Module): + + def __init__(self, weight: torch.Tensor, eps: float = 1e-6): + super().__init__() + self.weight = nn.Parameter(weight.clone()) + self.eps = eps + + def forward(self, x: torch.Tensor, residual: torch.Tensor = None): + input_dtype = x.dtype + if residual is not None: + x = x + residual + residual = x + x = x.to(torch.float32) + variance = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + self.eps) + x = self.weight * x.to(input_dtype) + if residual is None: + return x + return x, residual + + +class DefaultRMSNormBuilder(RMSNormBuilder): + + @staticmethod + def build(weight: torch.Tensor, eps: float = 1e-6): + return DefaultRMSNormImpl(weight, eps) diff --git a/lmdeploy/pytorch/backends/default/rotary_embedding.py b/lmdeploy/pytorch/backends/default/rotary_embedding.py new file mode 100644 index 000000000..734c2f600 --- /dev/null +++ b/lmdeploy/pytorch/backends/default/rotary_embedding.py @@ -0,0 +1,92 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch import nn + +from ..rotary_embedding import (EmbeddingType, RotaryEmbeddingBuilder, + RotaryEmbeddingImpl) + + +class RotaryEmbeddingImpl(RotaryEmbeddingImpl, nn.Module): + + def __init__(self, + dim: int, + base: int = 10000, + scaling_factor: float = 1.0): + super().__init__() + self.scaling_factor = scaling_factor + self.dim = dim + self.base = base + inv_freq = 1.0 / (self.base**(torch.arange( + 0, self.dim, 2, dtype=torch.int64).float() / self.dim)) + self.register_buffer('inv_freq', inv_freq, persistent=False) + + def forward(self, x, position_ids): + # x: [bs, num_attention_heads, seq_len, head_size] + if self.inv_freq.device != x.device: + self.inv_freq = self.inv_freq.to(x.device) + position_ids = position_ids.float() / self.scaling_factor + inv_freq_expanded = self.inv_freq[None, :, None].float().expand( + position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :] + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance( + device_type, str) and device_type != 'mps' else 'cpu' + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() + @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class LlamaDynamicNTKScalingRotaryEmbedding(RotaryEmbeddingImpl): + """LlamaRotaryEmbedding extended with Dynamic NTK scaling. + + Credits to the Reddit users /u/bloc97 and /u/emozilla + """ + + def __init__(self, + dim: int, + base: int = 10000, + scaling_factor: float = 1.0, + max_position_embeddings: int = 2048): + super().__init__(dim, base, scaling_factor) + self.max_position_embeddings = max_position_embeddings + + def forward(self, x, position_ids): + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_position_embeddings: + base = self.base * ((self.scaling_factor * seq_len / + self.max_position_embeddings) - + (self.scaling_factor - 1))**(self.dim / + (self.dim - 2)) + inv_freq = 1.0 / (base**(torch.arange( + 0, self.dim, 2, dtype=torch.int64).float().to(x.device) / + self.dim)) + self.register_buffer('inv_freq', inv_freq, persistent=False) + + cos, sin = super().forward(x, position_ids) + return cos, sin + + +class DefaultRotaryEmbeddingBuilder(RotaryEmbeddingBuilder): + + @staticmethod + def build( + dim: int, + max_position_embeddings: int = 2048, + base: int = 10000, + scaling_factor: float = 1.0, + emb_type: EmbeddingType = EmbeddingType.Default, + ): + if emb_type in (EmbeddingType.Default, EmbeddingType.LinearScaling): + return RotaryEmbeddingImpl(dim, base, scaling_factor) + elif emb_type == EmbeddingType.DynamicNTKScaling: + return LlamaDynamicNTKScalingRotaryEmbedding( + dim, base, scaling_factor, max_position_embeddings) + else: + raise NotImplementedError( + f'Unsupported embedding type: {emb_type}') diff --git a/lmdeploy/pytorch/backends/linear.py b/lmdeploy/pytorch/backends/linear.py new file mode 100644 index 000000000..97843a0ec --- /dev/null +++ b/lmdeploy/pytorch/backends/linear.py @@ -0,0 +1,19 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABC, abstractmethod + +from torch import nn + + +class LinearImpl(ABC, nn.Module): + + @abstractmethod + def forward(self, x): + raise NotImplementedError + + +class LinearBuilder(ABC): + + @staticmethod + @abstractmethod + def build(mod: nn.Module, all_reduce: bool = False): + raise NotImplementedError diff --git a/lmdeploy/pytorch/backends/norm.py b/lmdeploy/pytorch/backends/norm.py new file mode 100644 index 000000000..a0008fb85 --- /dev/null +++ b/lmdeploy/pytorch/backends/norm.py @@ -0,0 +1,19 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABC, abstractmethod + +import torch + + +class RMSNormImpl(ABC): + + @abstractmethod + def forward(self, x: torch.Tensor, residual: torch.Tensor = None): + raise NotImplementedError + + +class RMSNormBuilder(ABC): + + @staticmethod + @abstractmethod + def build(weight: torch.Tensor, eps: float = 1e-6): + raise NotImplementedError diff --git a/lmdeploy/pytorch/backends/rotary_embedding.py b/lmdeploy/pytorch/backends/rotary_embedding.py new file mode 100644 index 000000000..5bf5cada4 --- /dev/null +++ b/lmdeploy/pytorch/backends/rotary_embedding.py @@ -0,0 +1,30 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABC, abstractmethod +from enum import Enum, auto + + +class EmbeddingType(Enum): + Default = auto() + LinearScaling = auto() + DynamicNTKScaling = auto() + + +class RotaryEmbeddingImpl(ABC): + + @abstractmethod + def forward(self, x, position_ids): + raise NotImplementedError + + +class RotaryEmbeddingBuilder(ABC): + + @staticmethod + @abstractmethod + def build( + dim: int, + max_position_embeddings: int = 2048, + base: int = 10000, + scaling_factor: float = 1.0, + emb_type: EmbeddingType = EmbeddingType.Default, + ): + raise NotImplementedError diff --git a/lmdeploy/pytorch/attention/selector.py b/lmdeploy/pytorch/backends/selector.py similarity index 69% rename from lmdeploy/pytorch/attention/selector.py rename to lmdeploy/pytorch/backends/selector.py index 7fa80430d..e7680048d 100644 --- a/lmdeploy/pytorch/attention/selector.py +++ b/lmdeploy/pytorch/backends/selector.py @@ -3,7 +3,7 @@ from lmdeploy.utils import get_logger -def get_attn_backend(): +def get_backend(): """get attention backend.""" device_mgr = get_device_manager() device_ctx = device_mgr.current_context() @@ -11,11 +11,11 @@ def get_attn_backend(): device_type = device_ctx.device_type if device_type == 'cuda': - from .backends.triton import TritonAttentionBackend - return TritonAttentionBackend + from .cuda import CudaLayersBackend + return CudaLayersBackend if device_type == 'ascend': - from .backends.ascend import AscendAttentionBackend - return AscendAttentionBackend + from .ascend import AscendLayersBackend + return AscendLayersBackend else: logger = get_logger('lmdeploy') logger.warning(f'Unsupported device type: {device_type}') diff --git a/lmdeploy/pytorch/engine/cache_engine.py b/lmdeploy/pytorch/engine/cache_engine.py index 82a4d082a..a0689dab1 100644 --- a/lmdeploy/pytorch/engine/cache_engine.py +++ b/lmdeploy/pytorch/engine/cache_engine.py @@ -4,7 +4,7 @@ import torch -from lmdeploy.pytorch.attention import get_attn_backend +from lmdeploy.pytorch.backends import get_backend from lmdeploy.utils import get_logger from ..config import CacheConfig, ModelConfig @@ -86,7 +86,7 @@ def _get_key_block_shape_impl(cls, world_size: int = 1, local: bool = True): """get single block shape.""" - attn_backend = get_attn_backend() + attn_backend = get_backend() dtype = model_config.dtype num_heads = model_config.num_key_value_heads if local and not model_config.multi_query_attention: @@ -104,7 +104,7 @@ def _get_value_block_shape_impl(cls, world_size: int = 1, local: bool = True): """get single block shape.""" - attn_backend = get_attn_backend() + attn_backend = get_backend() dtype = model_config.dtype num_heads = model_config.num_key_value_heads if local and not model_config.multi_query_attention: diff --git a/lmdeploy/pytorch/engine/devices/__init__.py b/lmdeploy/pytorch/engine/devices/__init__.py deleted file mode 100644 index 754119d1f..000000000 --- a/lmdeploy/pytorch/engine/devices/__init__.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import importlib - -from lmdeploy.pytorch.devices import get_device_manager -from lmdeploy.utils import get_logger - -from .base_device_utils import BaseDeviceUtils - -logger = get_logger('lmdeploy') - -CURRENT_DEVICE_UTILS = None - - -def _device_utils_callback(*args, **kwargs): - """callback.""" - global CURRENT_DEVICE_UTILS - CURRENT_DEVICE_UTILS = None - - -get_device_manager().register_context_callback(_device_utils_callback) - - -def get_current_device_utils() -> BaseDeviceUtils: - """get device utils.""" - global CURRENT_DEVICE_UTILS - if CURRENT_DEVICE_UTILS is not None: - return CURRENT_DEVICE_UTILS - - current_context = get_device_manager().current_context() - device_type = current_context.device_type - loaded_utils = BaseDeviceUtils._sub_classes - if device_type not in loaded_utils: - try: - importlib.import_module(f'{__name__}.{device_type}') - assert device_type in loaded_utils - except ImportError: - logger.debug('Failed to import device utils for ' - f'device: {device_type}. ') - importlib.import_module(f'{__name__}.cuda') - loaded_utils[device_type] = loaded_utils['cuda'] - - CURRENT_DEVICE_UTILS = loaded_utils[device_type] - return CURRENT_DEVICE_UTILS diff --git a/lmdeploy/pytorch/engine/devices/base_device_utils.py b/lmdeploy/pytorch/engine/devices/base_device_utils.py deleted file mode 100644 index 84c7e4ee0..000000000 --- a/lmdeploy/pytorch/engine/devices/base_device_utils.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from abc import ABC - - -class BaseDeviceUtils(ABC): - - _sub_classes = dict() - device = None - - def __init_subclass__(cls) -> None: - super().__init_subclass__() - BaseDeviceUtils.register_builder(cls) - - @classmethod - def register_builder(cls, sub_cls): - """register builder.""" - if sub_cls not in BaseDeviceUtils._sub_classes: - BaseDeviceUtils._sub_classes[sub_cls.device] = sub_cls - - @classmethod - def update_step_context(cls, step_context): - """update step context.""" - raise NotImplementedError('`update_step_context` of ' - f'<{cls}> not implemented.') diff --git a/lmdeploy/pytorch/engine/devices/cuda.py b/lmdeploy/pytorch/engine/devices/cuda.py deleted file mode 100644 index 12d06fd6b..000000000 --- a/lmdeploy/pytorch/engine/devices/cuda.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from lmdeploy.pytorch.attention import get_attn_backend - -from .base_device_utils import BaseDeviceUtils - - -class CUDADeviceUtils(BaseDeviceUtils): - - device = 'cuda' - - @classmethod - def update_step_context(cls, step_context): - """update step context.""" - attn_backend = get_attn_backend() - attn_meta_cls = attn_backend.get_metadata_cls() - attn_meta = attn_meta_cls( - step_context.is_decoding, - step_context.block_offsets, - q_start_loc=step_context.q_start_loc, - q_seqlens=step_context.q_seq_length, - kv_seqlens=step_context.kv_seq_length, - max_q_seqlen=step_context.max_q_seq_length, - max_kv_seqlen=step_context.max_kv_seq_length, - ) - - step_context.attn_meta = attn_meta - return step_context diff --git a/lmdeploy/pytorch/engine/devices/dipu.py b/lmdeploy/pytorch/engine/devices/dipu.py deleted file mode 100644 index d2cc9c424..000000000 --- a/lmdeploy/pytorch/engine/devices/dipu.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .base_device_utils import BaseDeviceUtils - - -class DIPUDeviceUtils(BaseDeviceUtils): - - device = 'dipu' - - @classmethod - def update_step_context(cls, step_context): - """update step context.""" - raise NotImplementedError('`update_step_context` of ' - f'<{cls}> not implemented.') diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 8b301bc8e..8e5d5f5fe 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -11,6 +11,7 @@ from torch import multiprocessing as mp from lmdeploy.pytorch.accel import LoadNoInit +from lmdeploy.pytorch.backends import get_backend from lmdeploy.utils import get_logger from ..adapter.adapter import (AdapterWeightMap, SchedulerAdapter, @@ -22,7 +23,6 @@ from ..utils import get_gpu_memory from ..weight_loader.model_weight_loader import load_model_weights from .cache_engine import CacheEngine -from .devices import get_current_device_utils logger = get_logger('lmdeploy') @@ -430,7 +430,7 @@ def new( local_adapter_ids=inputs.local_adapter_ids, adapter_params=adapter_params) - ret = get_current_device_utils().update_step_context(ret) + ret = get_backend().update_step_context(ret) return ret @classmethod diff --git a/lmdeploy/pytorch/kernels/cuda/apply_rotary_pos_emb.py b/lmdeploy/pytorch/kernels/cuda/apply_rotary_pos_emb.py index a327f682a..0330aa94f 100644 --- a/lmdeploy/pytorch/kernels/cuda/apply_rotary_pos_emb.py +++ b/lmdeploy/pytorch/kernels/cuda/apply_rotary_pos_emb.py @@ -40,7 +40,6 @@ def apply_rotary_pos_emb_qk_kernel( K, COS, SIN, - POS, Q_EMB, K_EMB, seq_len, @@ -66,7 +65,8 @@ def apply_rotary_pos_emb_qk_kernel( seq_block_id = tl.program_id(0) pos_offset = seq_block_id * BLOCK + tl.arange(0, BLOCK) - pos_ids = tl.load(POS + pos_offset, pos_offset < seq_len, other=-1) + # pos_ids = tl.load(POS + pos_offset, pos_offset < seq_len, other=-1) + pos_ids = pos_offset feat_size = half_size * 2 feat_offset_l = tl.arange(0, BLOCK_N) @@ -75,11 +75,13 @@ def apply_rotary_pos_emb_qk_kernel( half_size)[None, :] cs_offset_l = pos_ids[:, None] * feat_size + feat_offset_l[None, :] cs_offset_h = pos_ids[:, None] * feat_size + feat_offset_h[None, :] - pos_ids_mask = pos_ids[:, None] >= 0 - cos_l = tl.load(COS + cs_offset_l, mask=pos_ids_mask) - cos_h = tl.load(COS + cs_offset_h, mask=pos_ids_mask) - sin_l = tl.load(SIN + cs_offset_l, mask=pos_ids_mask) - sin_h = tl.load(SIN + cs_offset_h, mask=pos_ids_mask) + # pos_ids_mask = pos_ids[:, None] >= 0 + pos_ids_mask = pos_ids[:, None] < seq_len + q_elem_type = Q.dtype.element_ty + cos_l = tl.load(COS + cs_offset_l, mask=pos_ids_mask).to(q_elem_type) + cos_h = tl.load(COS + cs_offset_h, mask=pos_ids_mask).to(q_elem_type) + sin_l = tl.load(SIN + cs_offset_l, mask=pos_ids_mask).to(q_elem_type) + sin_h = tl.load(SIN + cs_offset_h, mask=pos_ids_mask).to(q_elem_type) q_ptr = Q + pos_offset * stride_qs qe_ptr = Q_EMB + pos_offset * stride_qes @@ -124,8 +126,6 @@ def apply_rotary_pos_emb(q: Tensor, k: Tensor, cos: Tensor, sin: Tensor, - position_ids: Tensor = None, - position_ids_1d: Tensor = None, q_embed: Tensor = None, k_embed: Tensor = None): """Apply rotary positional embedding on query and key. @@ -135,29 +135,23 @@ def apply_rotary_pos_emb(q: Tensor, k (Tensor): Key state. cos (Tensor): cosine matrix (seq_len, dim). sin (Tensor): sine matrix (seq_len, dim). - position_ids (Tensor): Position ids of q and k. - position_ids_1d (Tensor): 1d Position ids. q_embed (Tensor): output q, can be same as q k_embed (Tensor): output k, can be same as k Returns: Tuple[Tensor, Tensor]: Embedded query and key. """ - if cos.device != q.device or cos.dtype != q.dtype: - cos = cos.to(device=q.device, dtype=q.dtype) - if sin.device != q.device or sin.dtype != q.dtype: - sin = sin.to(device=q.device, dtype=q.dtype) - if position_ids_1d is None: - seq_length = position_ids[..., -1] + 1 - position_ids_1d = [ids[:l] for ids, l in zip(position_ids, seq_length)] - position_ids_1d = torch.cat(position_ids_1d) + if cos.device != q.device: + cos = cos.to(device=q.device) + if sin.device != q.device: + sin = sin.to(device=q.device) if q_embed is None: q_embed = torch.empty_like(q) if k_embed is None: k_embed = torch.empty_like(k) - seq_len = position_ids_1d.size(-1) + seq_len = cos.numel() // cos.size(-1) BLOCK = 32 half_size = q.size(-1) // 2 BLOCK_N = triton.next_power_of_2(half_size) @@ -172,7 +166,6 @@ def apply_rotary_pos_emb(q: Tensor, k, cos, sin, - position_ids_1d, q_embed, k_embed, seq_len=seq_len, diff --git a/lmdeploy/pytorch/layers/__init__.py b/lmdeploy/pytorch/layers/__init__.py new file mode 100644 index 000000000..229dcd613 --- /dev/null +++ b/lmdeploy/pytorch/layers/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# attention module is modified from: +# https://github.com/vllm-project/vllm/blob/main/vllm/attention/ +from .activation import SiluAndMul # noqa: F401 +from .attention import Attention # noqa: F401 +from .linear import build_linear, build_merged_linear # noqa: F401 +from .norm import RMSNorm # noqa: F401 +from .rotary_embedding import ApplyRotaryEmb # noqa: F401 +from .rotary_embedding import EmbeddingType # noqa: F401 +from .rotary_embedding import build_rotary_embedding # noqa: F401 diff --git a/lmdeploy/pytorch/layers/activation.py b/lmdeploy/pytorch/layers/activation.py new file mode 100644 index 000000000..9f7e67098 --- /dev/null +++ b/lmdeploy/pytorch/layers/activation.py @@ -0,0 +1,16 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from torch import nn + +from ..backends import LayerType, get_backend + + +class SiluAndMul(nn.Module): + + def __init__(self, inplace: bool = True): + super().__init__() + backend = get_backend() + builder = backend.get_layer_impl_builder(LayerType.SiluAndMul) + self.impl = builder.build(inplace) + + def forward(self, x): + return self.impl.forward(x) diff --git a/lmdeploy/pytorch/layers/apply_rotary_emb.py b/lmdeploy/pytorch/layers/apply_rotary_emb.py new file mode 100644 index 000000000..2e600ff96 --- /dev/null +++ b/lmdeploy/pytorch/layers/apply_rotary_emb.py @@ -0,0 +1,16 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from torch import nn + +from ..backends import LayerType, get_backend + + +class ApplyRotaryEmb(nn.Module): + + def __init__(self): + super().__init__() + backend = get_backend() + builder = backend.get_layer_impl_builder(LayerType.ApplyRotaryEmb) + self.impl = builder.build() + + def forward(self, query, key, cos, sin, inplace: bool = True): + return self.impl.forward(query, key, cos, sin, inplace) diff --git a/lmdeploy/pytorch/attention/layer.py b/lmdeploy/pytorch/layers/attention.py similarity index 80% rename from lmdeploy/pytorch/attention/layer.py rename to lmdeploy/pytorch/layers/attention.py index 439c18749..7f82fbfd2 100644 --- a/lmdeploy/pytorch/attention/layer.py +++ b/lmdeploy/pytorch/layers/attention.py @@ -2,7 +2,8 @@ import torch from torch import nn -from .backends.base import AttentionMetadata +from ..backends import LayerType, get_backend +from ..backends.attention import AttentionMetadata class Attention(nn.Module): @@ -18,12 +19,12 @@ def __init__( sliding_window: int = None, **kwargs, ): - from .selector import get_attn_backend super().__init__() - attn_backend = get_attn_backend() - impl_cls = attn_backend.get_impl_cls() + layer_backend = get_backend() + impl_builder = layer_backend.get_layer_impl_builder( + LayerType.Attention) - self.impl = impl_cls( + self.impl = impl_builder.build( num_heads, head_size, scale, diff --git a/lmdeploy/pytorch/layers/linear.py b/lmdeploy/pytorch/layers/linear.py new file mode 100644 index 000000000..9bea95799 --- /dev/null +++ b/lmdeploy/pytorch/layers/linear.py @@ -0,0 +1,312 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from dataclasses import dataclass +from typing import List + +import torch +import torch.distributed as dist +from torch import nn + +from lmdeploy.pytorch.models.q_modules import QLinear +from lmdeploy.utils import get_logger + +from ..backends import LayerType, get_backend + +logger = get_logger('lmdeploy') + +try: + from peft.tuners.lora import Linear as LoRALinear +except ImportError: + logger.debug('load peft.tuners.lora.Linear failed.') + + class LoRALinear: + pass + + +try: + from peft.tuners.lora.awq import AwqLoraLinear +except ImportError: + logger.debug('load peft.tuners.lora.awq.AwqLoraLinear failed.') + + class AwqLoraLinear: + pass + + +try: + from awq.modules.linear.gemm import WQLinear_GEMM +except ImportError: + logger.debug('load awq.modules.linear.gemm.WQLinearGEMM failed.') + + class WQLinear_GEMM: + pass + + +def _get_world_rank(): + """get current world size and rank.""" + world_size = 1 + rank = 0 + + if dist.is_initialized(): + world_size = dist.get_world_size() + rank = dist.get_rank() + + return world_size, rank + + +@dataclass +class AdapterInfo: + r: dict + lora_A: nn.ModuleDict + lora_B: nn.ModuleDict + scaling: dict + base_slice: slice + + @staticmethod + def from_lora_linear(mod: nn.Module, base_slice: slice = None): + return AdapterInfo( + r=mod.r, + lora_A=mod.lora_A, + lora_B=mod.lora_B, + scaling=mod.scaling, + base_slice=base_slice, + ) + + +class SLoRA(nn.Module): + + def __init__(self, adapter_infos: List[AdapterInfo]): + super().__init__() + + def forward(self, x): + raise NotImplementedError + + +class AwqLinear(nn.Module): + + def __init__(self, + mod: nn.Module, + adapter_infos: List[AdapterInfo] = None): + super().__init__() + + +class W8A8Linear(nn.Module): + + def __init__(self, + mod: nn.Module, + adapter_infos: List[AdapterInfo] = None): + super().__init__() + + +class BaseLinear(nn.Module): + + def __init__(self, + mod: nn.Module, + adapter_infos: List[AdapterInfo] = None, + all_reduce: bool = False): + super().__init__() + layer_backend = get_backend() + + if isinstance(mod, nn.Linear): + impl_builder = layer_backend.get_layer_impl_builder( + LayerType.Linear) + else: + raise NotImplementedError(f'Unsupported linear type: {type(mod)}') + self.impl = impl_builder.build(mod, all_reduce) + + adapter_infos = adapter_infos if adapter_infos is not None else [] + self.adapter = None + if len(adapter_infos) > 0: + self.adapter = SLoRA(adapter_infos) + + def forward(self, x): + out = self.impl.forward(x) + + if self.adapter is not None: + out = self.adapter(out) + return out + + +def _merge_base_linear(*linears: List[nn.Module]): + """merge naive linear.""" + weights = [mod.weight for mod in linears] + bias = [mod.bias for mod in linears] + + in_features = weights[0].size(1) + dtype = weights[0].dtype + device = weights[0].device + for w in weights: + assert w.size(1) == in_features + assert w.dtype == dtype + assert w.device == device + out_features = sum(w.size(0) for w in weights) + + new_weight = torch.cat(weights, dim=0) + new_bias = None + if bias[0] is not None: + assert all(b is not None for b in bias) + new_bias = torch.cat(bias) + has_bias = new_bias is not None + merged_linear = nn.Linear(in_features, + out_features, + bias=has_bias, + dtype=dtype, + device=device) + state_dict = dict(weight=new_weight) + if has_bias: + state_dict['bias'] = new_bias + merged_linear.load_state_dict(state_dict) + return merged_linear + + +def _merge_qlinear(*linears: List[nn.Module]): + """merge qlinear.""" + weights = [mod.weight for mod in linears] + scalings = [mod.scaling for mod in linears] + bias = [mod.bias for mod in linears] + + in_features = weights[0].size(1) + dtype = weights[0].dtype + device = weights[0].device + for w in weights: + assert w.size(1) == in_features + assert w.dtype == dtype + assert w.device == device + out_features = sum(w.size(0) for w in weights) + + new_weight = torch.cat(weights, dim=0) + new_scaling = torch.cat(scalings, dim=0) + new_bias = None + if bias[0] is not None: + assert all(b is not None for b in bias) + new_bias = torch.cat(bias) + has_bias = new_bias is not None + merged_linear = QLinear(in_features, + out_features, + bias=has_bias, + dtype=dtype, + device=device) + state_dict = dict( + weight=new_weight, + scaling=new_scaling, + ) + if has_bias: + state_dict['bias'] = new_bias + merged_linear.load_state_dict(state_dict) + return merged_linear + + +def _merge_awqlinear(*linears: List[nn.Module]): + """merge awqlinear.""" + qweights = [mod.qweight for mod in linears] + scales = [mod.scales for mod in linears] + qzeros = [mod.qzeros for mod in linears] + bias = [mod.bias for mod in linears] + w_bits = [mod.w_bit for mod in linears] + group_sizes = [mod.group_size for mod in linears] + + w_bit = w_bits[0] + group_size = group_sizes[0] + assert all(wb == w_bit for wb in w_bits) + assert all(gs == group_size for gs in group_sizes) + in_features = qweights[0].size(1) + device = qweights[0].device + for w in qweights: + assert w.size(1) == in_features + assert w.device == device + out_features = sum(w.size(0) for w in qweights) + + new_qweight = torch.cat(qweights, dim=0) + new_scales = torch.cat(scales, dim=0) + new_qzeros = torch.cat(qzeros, dim=0) + new_bias = None + if bias[0] is not None: + assert all(b is not None for b in bias) + new_bias = torch.cat(bias) + has_bias = new_bias is not None + merged_linear = WQLinear_GEMM( + w_bit, + group_size, + in_features, + out_features, + bias=has_bias, + dev=device, + ) + state_dict = dict( + qweight=new_qweight, + scales=new_scales, + qzeros=new_qzeros, + ) + if has_bias: + state_dict['bias'] = new_bias + merged_linear.load_state_dict(state_dict) + return merged_linear + + +def build_merged_linear(*linears: List[nn.Module], free_origin=False): + """merge linear.""" + base_layers = [] + out_features = [] + adapter_infos = [] + cum_out_feature = 0 + for mod in linears: + # get base layers + base_layer = getattr(mod, 'base_layer', mod) + base_layers.append(base_layer) + + # get out_feature + if hasattr(base_layer, 'weight'): + weight = base_layer.weight + out_feature = weight.size(0) + else: + scales = base_layer.scales + out_feature = scales.size(1) + slice_start = cum_out_feature + cum_out_feature += out_feature + + # get adapter info + adapter_info = None + if isinstance(mod, (LoRALinear, AwqLoraLinear)): + adapter_slice = slice(slice_start, cum_out_feature) + adapter_info = AdapterInfo.from_lora_linear(mod, adapter_slice) + out_features.append(out_feature) + if adapter_info is not None: + adapter_infos.append(adapter_info) + + # check base layer type + base_type = type(base_layers[0]) + assert all(isinstance(layer, base_type) for layer in base_layers) + + # merge base layer + if base_type == nn.Linear: + base_layer = _merge_base_linear(*base_layers) + elif base_type == WQLinear_GEMM: + base_layer = _merge_awqlinear(*base_layers) + elif base_type == QLinear: + base_layer = _merge_qlinear(*base_layers) + else: + raise NotImplementedError(f'Unknown linear type: {type(mod)}') + ret = build_linear(base_layer, adapter_infos) + if free_origin: + for mod in linears: + mod.to('meta') + return ret + + +def build_linear(mod: nn.Module, + adapter_infos: List[AdapterInfo] = None) -> nn.Module: + """build linear.""" + if isinstance(mod, nn.Linear): + return BaseLinear(mod, adapter_infos) + elif isinstance(mod, WQLinear_GEMM): + return AwqLinear(mod, adapter_infos) + elif isinstance(mod, QLinear): + return W8A8Linear(mod, adapter_infos) + elif isinstance(mod, LoRALinear): + base_layer = mod.base_layer + adapter_info = AdapterInfo.from_lora_linear(mod) + return build_linear(base_layer, [adapter_info]) + elif isinstance(mod, AwqLoraLinear): + base_layer = mod.base_layer + adapter_info = AdapterInfo.from_lora_linear(mod) + return build_linear(base_layer, [adapter_info]) + else: + raise NotImplementedError(f'Unknown linear type: {type(mod)}') diff --git a/lmdeploy/pytorch/layers/norm.py b/lmdeploy/pytorch/layers/norm.py new file mode 100644 index 000000000..c4ce81740 --- /dev/null +++ b/lmdeploy/pytorch/layers/norm.py @@ -0,0 +1,17 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch import nn + +from ..backends import LayerType, get_backend + + +class RMSNorm(nn.Module): + + def __init__(self, weight: torch.Tensor, eps: float = 1e-6): + super().__init__() + backend = get_backend() + builder = backend.get_layer_impl_builder(LayerType.RMSNorm) + self.impl = builder.build(weight, eps) + + def forward(self, x: torch.Tensor, residual: torch.Tensor = None): + return self.impl.forward(x, residual) diff --git a/lmdeploy/pytorch/layers/rotary_embedding.py b/lmdeploy/pytorch/layers/rotary_embedding.py new file mode 100644 index 000000000..1db92cde3 --- /dev/null +++ b/lmdeploy/pytorch/layers/rotary_embedding.py @@ -0,0 +1,30 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from torch import nn + +from ..backends import LayerType, get_backend +from ..backends.rotary_embedding import EmbeddingType + + +def build_rotary_embedding( + dim: int, + max_position_embeddings: int = 2048, + base: int = 10000, + scaling_factor: float = 1.0, + emb_type: EmbeddingType = EmbeddingType.Default) -> nn.Module: + backend = get_backend() + + builder = backend.get_layer_impl_builder(LayerType.RotaryEmbedding) + return builder.build(dim, max_position_embeddings, base, scaling_factor, + emb_type) + + +class ApplyRotaryEmb(nn.Module): + + def __init__(self): + super().__init__() + backend = get_backend() + builder = backend.get_layer_impl_builder(LayerType.ApplyRotaryEmb) + self.impl = builder.build() + + def forward(self, query, key, cos, sin, inplace: bool = True): + return self.impl.forward(query, key, cos, sin, inplace) diff --git a/lmdeploy/pytorch/models/llama.py b/lmdeploy/pytorch/models/llama.py index ed398cc95..210757722 100644 --- a/lmdeploy/pytorch/models/llama.py +++ b/lmdeploy/pytorch/models/llama.py @@ -3,48 +3,16 @@ import torch import torch.distributed as dist -import transformers -from packaging import version from torch import nn from transformers.modeling_outputs import BaseModelOutputWithPast -from ..kernels import apply_rotary_pos_emb as apply_rotary_pos_emb_old -from ..kernels import (fill_kv_cache, fused_rotary_emb, paged_attention_fwd, - rms_norm) +from lmdeploy.pytorch.layers import (ApplyRotaryEmb, Attention, EmbeddingType, + RMSNorm, SiluAndMul, build_merged_linear, + build_rotary_embedding) + from ..weight_loader.dist_utils import (colwise_parallelize_linear, rowwise_parallelize_linear) -TRANSFORMERS_VERSION = version.parse(transformers.__version__) -VERSION_4_38_0 = version.parse('4.38.0') - - -class LlamaRMSNorm(nn.Module): - """Rewrite RMSNorm.""" - - def forward(self, hidden_states): - """forward.""" - # torch.nn.functional.normalize based implementation might leads - # to wrong output - ret = rms_norm(hidden_states, self.weight, self.variance_epsilon) - - return ret - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., :x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2:] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors.""" - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - class LlamaAttention(nn.Module): """Rewrite module of LlamaAttention.""" @@ -66,7 +34,32 @@ def _load_weights(self, loader, rank: int, world_size: int, def _update_model_fn(self): """update model.""" - from lmdeploy.pytorch.attention.layer import Attention + + # qkv + self.qkv_proj = build_merged_linear(self.q_proj, self.k_proj, + self.v_proj) + del self.q_proj, self.k_proj, self.v_proj + + # rotary embedding + old_emb = self.rotary_emb + rotary_name = type(old_emb).__name__ + if rotary_name in [ + 'LlamaRotaryEmbedding', 'LlamaLinearScalingRotaryEmbedding' + ]: + emb_type = EmbeddingType.LinearScaling + elif rotary_name == 'LlamaDynamicNTKScalingRotaryEmbedding': + emb_type = EmbeddingType.DynamicNTKScaling + self.rotary_emb = build_rotary_embedding( + old_emb.dim, + old_emb.max_position_embeddings, + old_emb.base, + old_emb.scaling_factor, + emb_type, + ) + + self.apply_rotary_pos_emb = ApplyRotaryEmb() + + # attention world_size = 1 if dist.is_initialized(): world_size = dist.get_world_size() @@ -98,7 +91,6 @@ def _contiguous_batching_forward_default_impl( Optional[Tuple[torch.Tensor]]]: """default rewrite.""" context = self.context.context - max_kv_seq_length = context.max_kv_seq_length num_heads = self.num_heads // world_size num_kv_heads = self.num_key_value_heads // world_size @@ -107,84 +99,37 @@ def _contiguous_batching_forward_default_impl( def __qkv_proj(hidden_states): """qkv proj.""" - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - return query_states, key_states, value_states - - def __rotary_emb_fn_old(query_states, key_states, value_states): - """rotary embedding old.""" - if max_kv_seq_length >= self.rotary_emb.max_seq_len_cached: - # create larger cache - cos, sin = self.rotary_emb(value_states, - seq_len=max_kv_seq_length + 128) - cos = self.rotary_emb.cos_cached - sin = self.rotary_emb.sin_cached - query_states, key_states = apply_rotary_pos_emb_old( - query_states, - key_states, - cos, - sin, - position_ids, - context.position_ids_1d, - q_embed=query_states, - k_embed=key_states, - context=context) - return query_states, key_states, value_states + qkv_states = self.qkv_proj(hidden_states) + # (-1, heads, head_dim) + qkv_states = qkv_states.flatten(0, + -2).unflatten(-1, (-1, head_dim)) + query_states, key_states, value_states = qkv_states.split( + (num_heads, num_kv_heads, num_kv_heads), dim=1) - def __rotary_emb_fn_438_naive(query_states, key_states, value_states): - """rotary embedding transformers>4.38.""" - cos, sin = self.rotary_emb(value_states, - context.position_ids_1d[None]) - cos = cos[0] - sin = sin[0] - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin) return query_states, key_states, value_states - def __rotary_emb_fn_438_fused(query_states, key_states, value_states): - scaling_factor = getattr(self.rotary_emb, 'scaling_factor', 1.0) - inv_freq = self.rotary_emb.inv_freq - query_states, key_states = fused_rotary_emb( - query_states[None], - key_states[None], - context.position_ids_1d[None], - inv_freq=inv_freq, - scaling_factor=scaling_factor, - out_q=query_states[None], - out_k=key_states[None]) - return query_states[0], key_states[0], value_states - - def __rotary_emb_fn_438(query_states, key_states, value_states): - rotary_name = type(self.rotary_emb).__name__ - if rotary_name in [ - 'LlamaRotaryEmbedding', 'LlamaLinearScalingRotaryEmbedding' - ]: - return __rotary_emb_fn_438_fused(query_states, key_states, - value_states) - else: - return __rotary_emb_fn_438_naive(query_states, key_states, - value_states) - - def __rotary_emb_fn(query_states, key_states, value_states): + def __rotary_emb_fn(query_states, key_states): """rotary embedding.""" - if TRANSFORMERS_VERSION >= VERSION_4_38_0: - return __rotary_emb_fn_438(query_states, key_states, - value_states) + if not hasattr(context, '_cos'): + cos, sin = self.rotary_emb(query_states, + context.position_ids_1d[None]) + cos = cos[0] + sin = sin[0] + context._cos = cos + context._sin = sin else: - return __rotary_emb_fn_old(query_states, key_states, - value_states) - - query_states, key_states, value_states = __qkv_proj(hidden_states) + cos = context._cos + sin = context._sin - query_states = query_states.view(-1, num_heads, head_dim) - key_states = key_states.view(-1, num_kv_heads, head_dim) - value_states = value_states.view(-1, num_kv_heads, head_dim) - - query_states, key_states, value_states = __rotary_emb_fn( - query_states, key_states, value_states) + query_states, key_states = self.apply_rotary_pos_emb(query_states, + key_states, + cos, + sin, + inplace=True) + return query_states, key_states + query_states, key_states, value_states = __qkv_proj(hidden_states) + query_states, key_states = __rotary_emb_fn(query_states, key_states) attn_output = self.attn_fwd( query_states, key_states, @@ -226,225 +171,99 @@ def forward( ) -class LlamaAttentionAscend(nn.Module): - """Rewrite module of LlamaAttention.""" +class LlamaMLP(nn.Module): def _load_weights(self, loader, rank: int, world_size: int, device: torch.device): """load weights.""" - for mod_name in ['q_proj', 'k_proj', 'v_proj']: - colwise_parallelize_linear( - getattr(self, mod_name), - loader, - rank=rank, - world_size=world_size, - prefix=mod_name, - ) - rowwise_parallelize_linear(self.o_proj, + for mod_name in ['gate_proj', 'up_proj']: + colwise_parallelize_linear(getattr(self, mod_name), + loader, + rank=rank, + world_size=world_size, + prefix=mod_name) + rowwise_parallelize_linear(self.down_proj, loader, rank=rank, world_size=world_size, - prefix='o_proj') + prefix='down_proj') @classmethod def _distribute_output_fn(cls, outputs, **kwargs): """Distribution output hook.""" - dist.all_reduce(outputs[0]) + dist.all_reduce(outputs) return outputs - def _contiguous_batching_forward_default_impl( - self, - hidden_states: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - attention_mask: Optional[torch.Tensor] = None, - world_size: int = 1, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - """default rewrite.""" - context = self.context.context - kv_seq_length = context.kv_seq_length - q_seq_length = context.q_seq_length - q_start_loc = context.q_start_loc - block_offsets = context.block_offsets - max_q_seq_length = context.max_q_seq_length - max_kv_seq_length = context.max_kv_seq_length - - num_heads = self.num_heads // world_size - num_kv_heads = self.num_key_value_heads // world_size - head_dim = self.head_dim - hidden_size = num_heads * head_dim - - def __qkv_proj(hidden_states): - """qkv proj.""" - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - return query_states, key_states, value_states - - def __rotary_emb_fn_old(query_states, key_states, value_states): - """rotary embedding old.""" - if max_kv_seq_length >= self.rotary_emb.max_seq_len_cached: - # create larger cache - cos, sin = self.rotary_emb(value_states, - seq_len=max_kv_seq_length + 128) - cos = self.rotary_emb.cos_cached - sin = self.rotary_emb.sin_cached - query_states, key_states = apply_rotary_pos_emb_old( - query_states, - key_states, - cos, - sin, - position_ids, - context.position_ids_1d, - q_embed=query_states, - k_embed=key_states, - ) - return query_states, key_states, value_states - - def __rotary_emb_fn_438_naive(query_states, key_states, value_states): - """rotary embedding transformers>4.38.""" - cos, sin = self.rotary_emb(value_states, - context.position_ids_1d[None]) - cos = cos[0] - sin = sin[0] - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin) - return query_states, key_states, value_states - - def __rotary_emb_fn_438_fused(query_states, key_states, value_states): - scaling_factor = getattr(self.rotary_emb, 'scaling_factor', 1.0) - inv_freq = self.rotary_emb.inv_freq - query_states, key_states = fused_rotary_emb( - query_states[None], - key_states[None], - context.position_ids_1d[None], - inv_freq=inv_freq, - scaling_factor=scaling_factor, - out_q=query_states[None], - out_k=key_states[None], - context=context, - ) - return query_states[0], key_states[0], value_states - - def __rotary_emb_fn_438(query_states, key_states, value_states): - rotary_name = type(self.rotary_emb).__name__ - if rotary_name in [ - 'LlamaRotaryEmbedding', - 'LlamaLinearScalingRotaryEmbedding', - ]: - return __rotary_emb_fn_438_fused(query_states, key_states, - value_states) - else: - return __rotary_emb_fn_438_naive(query_states, key_states, - value_states) - - def __rotary_emb_fn(query_states, key_states, value_states): - """rotary embedding.""" - if TRANSFORMERS_VERSION >= VERSION_4_38_0: - return __rotary_emb_fn_438(query_states, key_states, - value_states) - else: - return __rotary_emb_fn_old(query_states, key_states, - value_states) - - query_states, key_states, value_states = __qkv_proj(hidden_states) + def _update_model_fn(self): + """update model.""" - query_states = query_states.view(-1, num_heads, head_dim) - key_states = key_states.view(-1, num_kv_heads, head_dim) - value_states = value_states.view(-1, num_kv_heads, head_dim) + # gate up + self.gate_up_proj = build_merged_linear(self.gate_proj, self.up_proj) + del self.gate_proj, self.up_proj - query_states, key_states, value_states = __rotary_emb_fn( - query_states, key_states, value_states) + # silu and mul + self.act_fn = SiluAndMul(inplace=True) - fill_kv_cache( - key_states, - value_states, - past_key_value[0], - past_key_value[1], - q_start_loc, - q_seq_length, - kv_seq_length=kv_seq_length, - max_q_seq_length=max_q_seq_length, - block_offsets=block_offsets, - context=context, - ) + def forward(self, x): + gate_up = self.gate_up_proj(x) + act = self.act_fn(gate_up) + return self.down_proj(act) - attn_output = query_states - paged_attention_fwd( - query_states, - key_states, - value_states, - past_key_value[0], - past_key_value[1], - attn_output, - block_offsets, - q_start_loc=q_start_loc, - q_seqlens=q_seq_length, - kv_seqlens=kv_seq_length, - max_seqlen=max_q_seq_length, - context=context, - ) - attn_output = attn_output.reshape(*hidden_states.shape[:-1], - hidden_size) - attn_output = self.o_proj(attn_output) +class PatchedLlamaDecoderLayer(nn.Module): - return attn_output, None, past_key_value + def _update_model_fn(self): + """update model.""" + input_layernorm = self.input_layernorm + self.input_layernorm = RMSNorm(input_layernorm.weight, + input_layernorm.variance_epsilon) + post_attention_layernorm = self.post_attention_layernorm + self.post_attention_layernorm = RMSNorm( + post_attention_layernorm.weight, + post_attention_layernorm.variance_epsilon) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, + past_key_value: Optional[List[torch.FloatTensor]] = None, + residual: Optional[torch.Tensor] = None, **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - """Rewrite of LlamaAttention.forward.""" - world_size = 1 - if dist.is_initialized(): - world_size = dist.get_world_size() - return self._contiguous_batching_forward_default_impl( - hidden_states, - position_ids, - past_key_value, - output_attentions, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, + torch.FloatTensor]]]: + + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, attention_mask=attention_mask, - world_size=world_size, + position_ids=position_ids, + past_key_value=past_key_value, ) + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) -class LlamaMLP(nn.Module): - - def _load_weights(self, loader, rank: int, world_size: int, - device: torch.device): - """load weights.""" - for mod_name in ['gate_proj', 'up_proj']: - colwise_parallelize_linear(getattr(self, mod_name), - loader, - rank=rank, - world_size=world_size, - prefix=mod_name) - rowwise_parallelize_linear(self.down_proj, - loader, - rank=rank, - world_size=world_size, - prefix='down_proj') - - @classmethod - def _distribute_output_fn(cls, outputs, **kwargs): - """Distribution output hook.""" - dist.all_reduce(outputs) + outputs = (hidden_states, residual) return outputs class LlamaModel(nn.Module): + def _update_model_fn(self): + """update model.""" + norm = self.norm + self.norm = RMSNorm(norm.weight, norm.variance_epsilon) + def _continuous_batching_forward( self, input_ids: torch.LongTensor = None, @@ -453,8 +272,6 @@ def _continuous_batching_forward( inputs_embeds: Optional[torch.FloatTensor] = None ) -> Union[Tuple, BaseModelOutputWithPast]: """Rewrite implementation of LlamaModel.forward.""" - output_attentions = False - use_cache = True context = self.context.context # get inputs from context vision_embeddings = context.input_embeddings @@ -471,21 +288,19 @@ def _continuous_batching_forward( attention_mask = None hidden_states = inputs_embeds - + residual = None for idx, decoder_layer in enumerate(self.layers): past_key_value = past_key_values[idx] - layer_outputs = decoder_layer( + hidden_states, residual = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, + residual=residual, ) - hidden_states = layer_outputs[0] - hidden_states = self.norm(hidden_states) + hidden_states, _ = self.norm(hidden_states, residual) return BaseModelOutputWithPast( last_hidden_state=hidden_states, diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index 54d581369..13806a8d5 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -14,8 +14,8 @@ f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaModel', 'transformers.models.llama.modeling_llama.LlamaMLP': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaMLP', - 'transformers.models.llama.modeling_llama.LlamaRMSNorm': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaRMSNorm', + 'transformers.models.llama.modeling_llama.LlamaDecoderLayer': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.PatchedLlamaDecoderLayer', # support modeling rewritten in lmdeploy 'modeling_llama.LlamaAttention': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaAttention', From b123e4def4831664017d254a232f4bd59ed3647d Mon Sep 17 00:00:00 2001 From: grimoire Date: Sat, 20 Jul 2024 18:18:50 +0800 Subject: [PATCH 04/89] finish llama base --- lmdeploy/pytorch/backends/default/linear.py | 6 +- lmdeploy/pytorch/backends/linear.py | 6 +- lmdeploy/pytorch/engine/engine.py | 4 +- lmdeploy/pytorch/engine/model_agent.py | 393 +----------------- lmdeploy/pytorch/layers/linear.py | 25 +- lmdeploy/pytorch/model_inputs.py | 376 +++++++++++++++++ lmdeploy/pytorch/models/llama.py | 348 +++++++--------- lmdeploy/pytorch/models/module_map.py | 7 +- lmdeploy/pytorch/models/patch.py | 140 ++----- .../weight_loader/model_weight_loader.py | 2 +- 10 files changed, 620 insertions(+), 687 deletions(-) create mode 100644 lmdeploy/pytorch/model_inputs.py diff --git a/lmdeploy/pytorch/backends/default/linear.py b/lmdeploy/pytorch/backends/default/linear.py index ddd6547c6..7f0917814 100644 --- a/lmdeploy/pytorch/backends/default/linear.py +++ b/lmdeploy/pytorch/backends/default/linear.py @@ -2,6 +2,8 @@ from torch import distributed as dist from torch import nn +from lmdeploy.pytorch.model_inputs import StepContextManager + from ..linear import LinearBuilder, LinearImpl @@ -22,5 +24,7 @@ def forward(self, x): class DefaultLinearBuilder(LinearBuilder): @staticmethod - def build(mod: nn.Module, all_reduce: bool = False): + def build(mod: nn.Module, + ctx_mgr: StepContextManager = None, + all_reduce: bool = False): return DefaultLinearImpl(mod, all_reduce) diff --git a/lmdeploy/pytorch/backends/linear.py b/lmdeploy/pytorch/backends/linear.py index 97843a0ec..af9c88689 100644 --- a/lmdeploy/pytorch/backends/linear.py +++ b/lmdeploy/pytorch/backends/linear.py @@ -3,6 +3,8 @@ from torch import nn +from lmdeploy.pytorch.model_inputs import StepContextManager + class LinearImpl(ABC, nn.Module): @@ -15,5 +17,7 @@ class LinearBuilder(ABC): @staticmethod @abstractmethod - def build(mod: nn.Module, all_reduce: bool = False): + def build(mod: nn.Module, + ctx_mgr: StepContextManager = None, + all_reduce: bool = False): raise NotImplementedError diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 8bda7b477..18f4325ad 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -17,10 +17,10 @@ from ..devices import DeviceContext, get_device_manager from ..messages import (InputEmbeddingRangeType, InputEmbeddingType, MessageStatus, SchedulerSequence) +from ..model_inputs import AdapterInfo, ModelInputs, VisionModelInputs from ..paging import Scheduler from .logits_process import FusedLogitsProcessor, SamplingInputs -from .model_agent import (AdapterInfo, AutoModelAgent, ModelInputs, - VisionModelInputs) +from .model_agent import AutoModelAgent from .request import Request, RequestManager, RequestType, Response logger = get_logger('lmdeploy') diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 8e5d5f5fe..482d45d49 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -2,7 +2,6 @@ import asyncio import os import warnings -from dataclasses import dataclass, field, fields from datetime import timedelta from typing import Any, Callable, Dict, List @@ -11,14 +10,13 @@ from torch import multiprocessing as mp from lmdeploy.pytorch.accel import LoadNoInit -from lmdeploy.pytorch.backends import get_backend from lmdeploy.utils import get_logger -from ..adapter.adapter import (AdapterWeightMap, SchedulerAdapter, - get_indexed_lora_linears, get_loralinear_info, - update_lora_linears) +from ..adapter.adapter import (AdapterWeightMap, get_indexed_lora_linears, + get_loralinear_info, update_lora_linears) from ..config import CacheConfig, ModelConfig from ..devices import DeviceContext, get_device_manager +from ..model_inputs import ModelInputs, StepContext from ..models.patch import patch, update_model from ..utils import get_gpu_memory from ..weight_loader.model_weight_loader import load_model_weights @@ -112,353 +110,6 @@ def __adjust_block_size(): logger.debug('block num: {}'.format(cache_config.num_gpu_blocks)) -@dataclass -class AdapterInfo: - ranks: torch.LongTensor - scalings: torch.Tensor - rank_offsets: torch.LongTensor - target_modules: List[str] - max_rank_per_target: List[int] - max_rank: int - - @classmethod - def from_adapters(cls, adapters: List[SchedulerAdapter]): - """from adapters.""" - if len(adapters) == 0: - return None - target_modules = adapters[0].target_modules - max_rank = adapters[0].max_rank - ranks = [ada.rank for ada in adapters] - scalings = [ada.scaling for ada in adapters] - rank_offsets = [torch.from_numpy(ada.rank_offset) for ada in adapters] - ranks = torch.tensor(ranks) - scalings = torch.tensor(scalings) - rank_offsets = torch.stack(rank_offsets) - max_rank_per_target = ranks.max(0)[0].tolist() - - return cls( - ranks=ranks, - scalings=scalings, - rank_offsets=rank_offsets, - target_modules=target_modules, - max_rank=max_rank, - max_rank_per_target=max_rank_per_target, - ) - - def split_by_targets(self): - """split by targets.""" - ret = dict() - max_rank = self.max_rank - for idx, target in enumerate(self.target_modules): - r = self.ranks[:, idx] - scaling = self.scalings[:, idx] - r_off_start = idx * max_rank - r_off_end = r_off_start + max_rank - rank_offset = self.rank_offsets[:, r_off_start:r_off_end] - max_rank_per_target = [self.max_rank_per_target[idx]] - ret[target] = AdapterInfo( - r, - scaling, - rank_offset, - target_modules=[target], - max_rank=max_rank_per_target[0], - max_rank_per_target=max_rank_per_target, - ) - return ret - - def to_device(self, device: str): - """to device.""" - out_dict = dict() - for f in fields(self): - k = f.name - v = getattr(self, k) - if isinstance(v, torch.Tensor): - v = v.to(device) - out_dict[k] = v - - return AdapterInfo(**out_dict) - - -@dataclass -class VisionModelInputs: - """Vision model inputs.""" - history_lengths: torch.LongTensor = None - history_image_nums: torch.LongTensor = None - history_image_token_lengths: torch.LongTensor = None - input_embeddings: List[List[torch.Tensor]] = None - input_embedding_ranges: List[torch.LongTensor] = None - input_embedding_indexing: torch.BoolTensor = None - - def to_device(self, device: str): - """to device.""" - out_dict = dict() - for f in fields(self): - k = f.name - v = getattr(self, k) - if isinstance(v, torch.Tensor): - v = v.to(device) - elif k == 'input_embedding_ranges' and v is not None: - v = [e.to(device) for e in v] - elif k == 'input_embeddings' and v is not None: - v = [[e.to(device) for e in li] for li in v] - out_dict[k] = v - - return VisionModelInputs(**out_dict) - - def get_inputs(self, history_lengths: torch.Tensor, - seq_lengths: torch.Tensor): - """get vision embedding inputs.""" - input_embeddings = None - input_embedding_indexing = None - if self.input_embeddings is not None and len( - self.input_embeddings) > 0: - input_embedding_li = [] - for (his_len, seq_len, embeddings, - emb_ranges) in zip(history_lengths, seq_lengths, - self.input_embeddings, - self.input_embedding_ranges): - for emb, (emb_start, emb_end) in zip(embeddings, emb_ranges): - start = max(emb_start, his_len) - emb_start - end = min(emb_end, his_len + seq_len) - emb_start - if 0 <= start < end: - input_embedding_li.append(emb[start:end]) - # has embeddings - if len(input_embedding_li) > 0: - input_embeddings = torch.cat(input_embedding_li, dim=0) - device = input_embeddings.device - starts = history_lengths - self.history_lengths - ends = starts + seq_lengths - input_embedding_indexing = torch.cat([ - indexing[s:e] for indexing, s, e in zip( - self.input_embedding_indexing, starts, ends) - ], - dim=0) - index_ranges = torch.arange(input_embedding_indexing.numel(), - device=device) - input_embedding_indexing = index_ranges[ - input_embedding_indexing] - return input_embeddings, input_embedding_indexing - - -@dataclass -class ModelInputs: - """Input of the model.""" - input_ids: torch.LongTensor - seq_length: torch.LongTensor - history_lengths: torch.LongTensor - block_offsets: torch.LongTensor - max_q_seq_length: int - max_history_length: int - is_decoding: bool - num_ignored_history: torch.LongTensor - local_adapter_ids: torch.LongTensor = None - adapter_info: AdapterInfo = None - meta: Any = None - vision_inputs: VisionModelInputs = None - - def update(self, input_ids: torch.LongTensor): - """update input ids.""" - assert self.is_decoding - self.history_lengths = self.history_lengths + 1 - self.max_history_length = self.max_history_length + 1 - if input_ids.dim() == 1: - input_ids = input_ids[None, :] - self.input_ids = input_ids - return self - - def split(self, split_size: int, block_size: int): - """split inputs.""" - assert len( - self.seq_length) == 1, ('Can not perform split on batched input.') - assert split_size % block_size == 0, ( - 'split_size should be multi of block_size.') - - input_ids = self.input_ids - if input_ids.numel() < split_size: - return self - - num_blocks = split_size // block_size - overlap = (self.history_lengths[0] % block_size != 0) - max_seq_len = self.seq_length[0].item() - ret = [] - block_start = 0 - for i in range(0, max_seq_len, split_size): - start = i - end = min(max_seq_len, i + split_size) - block_end = block_start + num_blocks - if overlap: - block_end += 1 - - block_offsets = self.block_offsets - inp = ModelInputs( - input_ids=self.input_ids[:, start:end], - seq_length=input_ids.new_tensor([end - start]), - block_offsets=block_offsets, - history_lengths=self.history_lengths + start, - max_q_seq_length=end - start, - max_history_length=self.max_history_length + start, - is_decoding=self.is_decoding, - num_ignored_history=self.num_ignored_history, - local_adapter_ids=self.local_adapter_ids, - adapter_info=self.adapter_info, - meta=self.meta, - vision_inputs=self.vision_inputs, - ) - ret.append(inp) - block_start += num_blocks - - return ret - - def to_device(self, device: str): - """to device.""" - out_dict = dict() - for f in fields(self): - k = f.name - v = getattr(self, k) - if isinstance(v, torch.Tensor): - v = v.to(device) - elif isinstance(v, VisionModelInputs): - v = v.to_device(device) - elif isinstance(v, AdapterInfo): - v = v.to_device(device) - out_dict[k] = v - - return ModelInputs(**out_dict) - - -@dataclass -class StepContext: - """context of Model. - - patched model might need extra information to perform inference. This - dataclass provide these infos and tools. - """ - inputs: ModelInputs - block_offsets: torch.LongTensor - position_ids: torch.LongTensor - position_ids_1d: torch.LongTensor - q_start_loc: torch.LongTensor - attention_mask: torch.LongTensor - history_lengths: torch.LongTensor - q_seq_length: torch.LongTensor - kv_seq_length: torch.LongTensor - max_q_seq_length: int - max_kv_seq_length: int - kv_caches: List - is_decoding: bool - world_size: int = 1 - local_adapter_ids: torch.LongTensor = None - adapter_params: Dict[str, AdapterInfo] = None - input_embeddings: torch.Tensor = None - input_embedding_indexing: torch.Tensor = None - - _outputs: Dict = field(default_factory=dict) - - @classmethod - def new( - cls, - inputs: ModelInputs, - world_size: int = 1, - device: str = 'cuda', - kv_caches: List = None, - cache_config: CacheConfig = None, - ): - """build step context. - - Args: - inputs (ModelInputs): packaged model inputs. - world_size (int): The distribution world size. - device (str): The device of the tensors. - """ - q_seq_length = inputs.seq_length - max_q_seq_length = inputs.max_q_seq_length - history_lengths = inputs.history_lengths - - # for vlm - input_embeddings, input_embedding_indexing = None, None - if (inputs.vision_inputs is not None - and inputs.vision_inputs.input_embeddings is not None): - input_embeddings, input_embedding_indexing = \ - inputs.vision_inputs.get_inputs(history_lengths, q_seq_length) - - batch_size = len(q_seq_length) - device = q_seq_length.device - - # q_start_loc and kv_seq_length - if inputs.is_decoding: - q_start_loc = torch.arange(0, batch_size, device=device) - attention_mask = torch.ones_like(q_seq_length)[:, None] - position_ids = history_lengths.unsqueeze(-1) - else: - q_start_loc = q_seq_length.cumsum(0) - q_seq_length - mask_range = torch.arange(max_q_seq_length, device=device)[None, :] - attention_mask = (mask_range < q_seq_length[:, None]).long() - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids += history_lengths.unsqueeze(-1) - - # position ids 1d - position_ids_1d = cls.get_position_ids_1d(position_ids, q_seq_length, - device) - # seq_len + history_length - kv_seq_length = q_seq_length + history_lengths - max_kv_seq_length = max_q_seq_length + inputs.max_history_length - - window_size = getattr(cache_config, 'window_size', 0) - if window_size > 0: - kv_seq_length -= inputs.num_ignored_history - - adapter_params = None - if inputs.adapter_info is not None: - adapter_params = inputs.adapter_info.split_by_targets() - - ret = StepContext(inputs=inputs, - block_offsets=inputs.block_offsets, - position_ids=position_ids, - position_ids_1d=position_ids_1d, - input_embeddings=input_embeddings, - input_embedding_indexing=input_embedding_indexing, - attention_mask=attention_mask, - q_start_loc=q_start_loc, - history_lengths=inputs.history_lengths, - q_seq_length=inputs.seq_length, - kv_seq_length=kv_seq_length, - max_q_seq_length=max_q_seq_length, - max_kv_seq_length=max_kv_seq_length, - kv_caches=kv_caches, - is_decoding=inputs.is_decoding, - world_size=world_size, - local_adapter_ids=inputs.local_adapter_ids, - adapter_params=adapter_params) - - ret = get_backend().update_step_context(ret) - return ret - - @classmethod - def get_position_ids_1d(cls, - position_ids: torch.LongTensor, - seq_length: torch.LongTensor, - device: str = 'cuda'): - """get 1d position_ids.""" - if position_ids.size(0) == 1 or position_ids.size(1) == 1: - position_ids_1d = position_ids.flatten() - else: - position_ids_1d = [ - ids[:l] for ids, l in zip(position_ids.cpu(), seq_length.cpu()) - ] - position_ids_1d = torch.cat(position_ids_1d).to(device) - return position_ids_1d - - def set_output(self, key, value): - """set output.""" - self._outputs[key] = value - - def get_output(self, key): - """get output.""" - if key in self._outputs: - return self._outputs[key] - return None - - def cache_swapping(cache_engine: CacheEngine, swap_in_map: dict, swap_out_map: dict): """perform cache swapping.""" @@ -476,8 +127,9 @@ def cache_swapping(cache_engine: CacheEngine, swap_in_map: dict, event.wait() +@torch.inference_mode() def model_forward( - patched_model: torch.nn.Module, + model: torch.nn.Module, inputs: ModelInputs, cache_engine: CacheEngine, world_size: int = 1, @@ -485,7 +137,7 @@ def model_forward( ): """perform model forward.""" stream = stream or torch.cuda.current_stream() - with torch.inference_mode(), torch.cuda.stream(stream): + with torch.cuda.stream(stream): # forward inputs = inputs.to_device('cuda') context = StepContext.new( @@ -494,18 +146,15 @@ def model_forward( kv_caches=cache_engine.gpu_cache, cache_config=cache_engine.cache_config, ) - output = patched_model.patched_forward( - input_ids=inputs.input_ids, - position_ids=context.position_ids, - attention_mask=context.attention_mask, - past_key_values=cache_engine.gpu_cache, - return_dict=True, - output_attentions=False, - output_hidden_states=False, - use_origin=False, - context=context, - ) - return dict(logits=output['logits'], custom_outputs=context._outputs) + ctx_mgr = model.ctx_mgr + with ctx_mgr.context(context): + output = model( + input_ids=inputs.input_ids, + position_ids=context.position_ids, + past_key_values=cache_engine.gpu_cache, + context=context, + ) + return dict(logits=output) def _load_adapters(hf_model: torch.nn.Module, @@ -685,8 +334,7 @@ def _build_model(self, if adapters: _load_adapters(hf_model, adapters) - patched_model = patch(hf_model, _PATCH_ARG_NAMES) - update_model(patched_model) + patched_model = update_model(hf_model) if adapters: _unparam_lora_weight(patched_model) @@ -852,12 +500,7 @@ def _broadcast_config(cache_config): model.eval() model.config.use_cache = True - patched_model = patch( - model, - extra_args=_PATCH_ARG_NAMES, - rank=rank, - world_size=world_size, - ) + patched_model = patch(model) load_model_weights(patched_model, model_path, adapters, @@ -866,7 +509,7 @@ def _broadcast_config(cache_config): device='cuda') if rank == 0: logger.debug('Updating model.') - update_model(patched_model) + patched_model = update_model(patched_model) _update_cache_config(model_config, cache_config, diff --git a/lmdeploy/pytorch/layers/linear.py b/lmdeploy/pytorch/layers/linear.py index 9bea95799..e3fc20b8d 100644 --- a/lmdeploy/pytorch/layers/linear.py +++ b/lmdeploy/pytorch/layers/linear.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from dataclasses import dataclass -from typing import List +from typing import Any, List import torch import torch.distributed as dist @@ -101,6 +101,7 @@ class BaseLinear(nn.Module): def __init__(self, mod: nn.Module, adapter_infos: List[AdapterInfo] = None, + ctx_mgr: Any = None, all_reduce: bool = False): super().__init__() layer_backend = get_backend() @@ -110,7 +111,7 @@ def __init__(self, LayerType.Linear) else: raise NotImplementedError(f'Unsupported linear type: {type(mod)}') - self.impl = impl_builder.build(mod, all_reduce) + self.impl = impl_builder.build(mod, ctx_mgr, all_reduce) adapter_infos = adapter_infos if adapter_infos is not None else [] self.adapter = None @@ -241,7 +242,10 @@ def _merge_awqlinear(*linears: List[nn.Module]): return merged_linear -def build_merged_linear(*linears: List[nn.Module], free_origin=False): +def build_merged_linear(*linears: List[nn.Module], + ctx_mgr: Any = None, + all_reduce: bool = False, + free_origin=False): """merge linear.""" base_layers = [] out_features = [] @@ -284,7 +288,10 @@ def build_merged_linear(*linears: List[nn.Module], free_origin=False): base_layer = _merge_qlinear(*base_layers) else: raise NotImplementedError(f'Unknown linear type: {type(mod)}') - ret = build_linear(base_layer, adapter_infos) + ret = build_linear(base_layer, + adapter_infos, + ctx_mgr=ctx_mgr, + all_reduce=all_reduce) if free_origin: for mod in linears: mod.to('meta') @@ -292,10 +299,16 @@ def build_merged_linear(*linears: List[nn.Module], free_origin=False): def build_linear(mod: nn.Module, - adapter_infos: List[AdapterInfo] = None) -> nn.Module: + adapter_infos: List[AdapterInfo] = None, + ctx_mgr: Any = None, + all_reduce: bool = False) -> nn.Module: """build linear.""" + if all_reduce: + world_size, rank = _get_world_rank() + all_reduce = world_size > 1 + if isinstance(mod, nn.Linear): - return BaseLinear(mod, adapter_infos) + return BaseLinear(mod, adapter_infos, ctx_mgr, all_reduce) elif isinstance(mod, WQLinear_GEMM): return AwqLinear(mod, adapter_infos) elif isinstance(mod, QLinear): diff --git a/lmdeploy/pytorch/model_inputs.py b/lmdeploy/pytorch/model_inputs.py new file mode 100644 index 000000000..eb16993f9 --- /dev/null +++ b/lmdeploy/pytorch/model_inputs.py @@ -0,0 +1,376 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from contextlib import contextmanager +from dataclasses import dataclass, field, fields +from typing import Any, Dict, List + +import torch + +from lmdeploy.pytorch.backends import get_backend + +from .adapter.adapter import SchedulerAdapter +from .config import CacheConfig + + +@dataclass +class AdapterInfo: + ranks: torch.LongTensor + scalings: torch.Tensor + rank_offsets: torch.LongTensor + target_modules: List[str] + max_rank_per_target: List[int] + max_rank: int + + @classmethod + def from_adapters(cls, adapters: List[SchedulerAdapter]): + """from adapters.""" + if len(adapters) == 0: + return None + target_modules = adapters[0].target_modules + max_rank = adapters[0].max_rank + ranks = [ada.rank for ada in adapters] + scalings = [ada.scaling for ada in adapters] + rank_offsets = [torch.from_numpy(ada.rank_offset) for ada in adapters] + ranks = torch.tensor(ranks) + scalings = torch.tensor(scalings) + rank_offsets = torch.stack(rank_offsets) + max_rank_per_target = ranks.max(0)[0].tolist() + + return cls( + ranks=ranks, + scalings=scalings, + rank_offsets=rank_offsets, + target_modules=target_modules, + max_rank=max_rank, + max_rank_per_target=max_rank_per_target, + ) + + def split_by_targets(self): + """split by targets.""" + ret = dict() + max_rank = self.max_rank + for idx, target in enumerate(self.target_modules): + r = self.ranks[:, idx] + scaling = self.scalings[:, idx] + r_off_start = idx * max_rank + r_off_end = r_off_start + max_rank + rank_offset = self.rank_offsets[:, r_off_start:r_off_end] + max_rank_per_target = [self.max_rank_per_target[idx]] + ret[target] = AdapterInfo( + r, + scaling, + rank_offset, + target_modules=[target], + max_rank=max_rank_per_target[0], + max_rank_per_target=max_rank_per_target, + ) + return ret + + def to_device(self, device: str): + """to device.""" + out_dict = dict() + for f in fields(self): + k = f.name + v = getattr(self, k) + if isinstance(v, torch.Tensor): + v = v.to(device) + out_dict[k] = v + + return AdapterInfo(**out_dict) + + +@dataclass +class VisionModelInputs: + """Vision model inputs.""" + history_lengths: torch.LongTensor = None + history_image_nums: torch.LongTensor = None + history_image_token_lengths: torch.LongTensor = None + input_embeddings: List[List[torch.Tensor]] = None + input_embedding_ranges: List[torch.LongTensor] = None + input_embedding_indexing: torch.BoolTensor = None + + def to_device(self, device: str): + """to device.""" + out_dict = dict() + for f in fields(self): + k = f.name + v = getattr(self, k) + if isinstance(v, torch.Tensor): + v = v.to(device) + elif k == 'input_embedding_ranges' and v is not None: + v = [e.to(device) for e in v] + elif k == 'input_embeddings' and v is not None: + v = [[e.to(device) for e in li] for li in v] + out_dict[k] = v + + return VisionModelInputs(**out_dict) + + def get_inputs(self, history_lengths: torch.Tensor, + seq_lengths: torch.Tensor): + """get vision embedding inputs.""" + input_embeddings = None + input_embedding_indexing = None + if self.input_embeddings is not None and len( + self.input_embeddings) > 0: + input_embedding_li = [] + for (his_len, seq_len, embeddings, + emb_ranges) in zip(history_lengths, seq_lengths, + self.input_embeddings, + self.input_embedding_ranges): + for emb, (emb_start, emb_end) in zip(embeddings, emb_ranges): + start = max(emb_start, his_len) - emb_start + end = min(emb_end, his_len + seq_len) - emb_start + if 0 <= start < end: + input_embedding_li.append(emb[start:end]) + # has embeddings + if len(input_embedding_li) > 0: + input_embeddings = torch.cat(input_embedding_li, dim=0) + device = input_embeddings.device + starts = history_lengths - self.history_lengths + ends = starts + seq_lengths + input_embedding_indexing = torch.cat([ + indexing[s:e] for indexing, s, e in zip( + self.input_embedding_indexing, starts, ends) + ], + dim=0) + index_ranges = torch.arange(input_embedding_indexing.numel(), + device=device) + input_embedding_indexing = index_ranges[ + input_embedding_indexing] + return input_embeddings, input_embedding_indexing + + +@dataclass +class ModelInputs: + """Input of the model.""" + input_ids: torch.LongTensor + seq_length: torch.LongTensor + history_lengths: torch.LongTensor + block_offsets: torch.LongTensor + max_q_seq_length: int + max_history_length: int + is_decoding: bool + num_ignored_history: torch.LongTensor + local_adapter_ids: torch.LongTensor = None + adapter_info: AdapterInfo = None + meta: Any = None + vision_inputs: VisionModelInputs = None + + def update(self, input_ids: torch.LongTensor): + """update input ids.""" + assert self.is_decoding + self.history_lengths = self.history_lengths + 1 + self.max_history_length = self.max_history_length + 1 + if input_ids.dim() == 1: + input_ids = input_ids[None, :] + self.input_ids = input_ids + return self + + def split(self, split_size: int, block_size: int): + """split inputs.""" + assert len( + self.seq_length) == 1, ('Can not perform split on batched input.') + assert split_size % block_size == 0, ( + 'split_size should be multi of block_size.') + + input_ids = self.input_ids + if input_ids.numel() < split_size: + return self + + num_blocks = split_size // block_size + overlap = (self.history_lengths[0] % block_size != 0) + max_seq_len = self.seq_length[0].item() + ret = [] + block_start = 0 + for i in range(0, max_seq_len, split_size): + start = i + end = min(max_seq_len, i + split_size) + block_end = block_start + num_blocks + if overlap: + block_end += 1 + + block_offsets = self.block_offsets + inp = ModelInputs( + input_ids=self.input_ids[:, start:end], + seq_length=input_ids.new_tensor([end - start]), + block_offsets=block_offsets, + history_lengths=self.history_lengths + start, + max_q_seq_length=end - start, + max_history_length=self.max_history_length + start, + is_decoding=self.is_decoding, + num_ignored_history=self.num_ignored_history, + local_adapter_ids=self.local_adapter_ids, + adapter_info=self.adapter_info, + meta=self.meta, + vision_inputs=self.vision_inputs, + ) + ret.append(inp) + block_start += num_blocks + + return ret + + def to_device(self, device: str): + """to device.""" + out_dict = dict() + for f in fields(self): + k = f.name + v = getattr(self, k) + if isinstance(v, torch.Tensor): + v = v.to(device) + elif isinstance(v, VisionModelInputs): + v = v.to_device(device) + elif isinstance(v, AdapterInfo): + v = v.to_device(device) + out_dict[k] = v + + return ModelInputs(**out_dict) + + +@dataclass +class StepContext: + """context of Model. + + patched model might need extra information to perform inference. This + dataclass provide these infos and tools. + """ + inputs: ModelInputs + block_offsets: torch.LongTensor + position_ids: torch.LongTensor + q_start_loc: torch.LongTensor + attention_mask: torch.LongTensor + history_lengths: torch.LongTensor + q_seq_length: torch.LongTensor + kv_seq_length: torch.LongTensor + max_q_seq_length: int + max_kv_seq_length: int + kv_caches: List + is_decoding: bool + world_size: int = 1 + local_adapter_ids: torch.LongTensor = None + adapter_params: Dict[str, AdapterInfo] = None + input_embeddings: torch.Tensor = None + input_embedding_indexing: torch.Tensor = None + + _outputs: Dict = field(default_factory=dict) + + @classmethod + def new( + cls, + inputs: ModelInputs, + world_size: int = 1, + kv_caches: List = None, + cache_config: CacheConfig = None, + ): + """build step context. + + Args: + inputs (ModelInputs): packaged model inputs. + world_size (int): The distribution world size. + device (str): The device of the tensors. + """ + q_seq_length = inputs.seq_length + max_q_seq_length = inputs.max_q_seq_length + history_lengths = inputs.history_lengths + + # for vlm + input_embeddings, input_embedding_indexing = None, None + if (inputs.vision_inputs is not None + and inputs.vision_inputs.input_embeddings is not None): + input_embeddings, input_embedding_indexing = \ + inputs.vision_inputs.get_inputs(history_lengths, q_seq_length) + + batch_size = len(q_seq_length) + device = q_seq_length.device + + # q_start_loc and kv_seq_length + if inputs.is_decoding: + q_start_loc = torch.arange(0, batch_size, device=device) + attention_mask = torch.ones_like(q_seq_length)[:, None] + position_ids = history_lengths.unsqueeze(-1) + else: + q_start_loc = q_seq_length.cumsum(0) - q_seq_length + mask_range = torch.arange(max_q_seq_length, device=device)[None, :] + attention_mask = (mask_range < q_seq_length[:, None]).long() + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids += history_lengths.unsqueeze(-1) + + # position ids 1d + position_ids = cls.get_position_ids_1d(position_ids, + q_seq_length)[None] + # seq_len + history_length + kv_seq_length = q_seq_length + history_lengths + max_kv_seq_length = max_q_seq_length + inputs.max_history_length + + window_size = getattr(cache_config, 'window_size', 0) + if window_size > 0: + kv_seq_length -= inputs.num_ignored_history + + adapter_params = None + if inputs.adapter_info is not None: + adapter_params = inputs.adapter_info.split_by_targets() + + ret = StepContext(inputs=inputs, + block_offsets=inputs.block_offsets, + position_ids=position_ids, + input_embeddings=input_embeddings, + input_embedding_indexing=input_embedding_indexing, + attention_mask=attention_mask, + q_start_loc=q_start_loc, + history_lengths=inputs.history_lengths, + q_seq_length=inputs.seq_length, + kv_seq_length=kv_seq_length, + max_q_seq_length=max_q_seq_length, + max_kv_seq_length=max_kv_seq_length, + kv_caches=kv_caches, + is_decoding=inputs.is_decoding, + world_size=world_size, + local_adapter_ids=inputs.local_adapter_ids, + adapter_params=adapter_params) + + ret = get_backend().update_step_context(ret) + return ret + + @classmethod + def get_position_ids_1d(cls, position_ids: torch.LongTensor, + seq_length: torch.LongTensor): + """get 1d position_ids.""" + if position_ids.size(0) == 1 or position_ids.size(1) == 1: + position_ids_1d = position_ids.flatten() + else: + device = position_ids.device + position_ids_1d = [ + ids[:l] for ids, l in zip(position_ids.cpu(), seq_length.cpu()) + ] + position_ids_1d = torch.cat(position_ids_1d).to(device) + return position_ids_1d + + +class StepContextManager: + + def __init__(self): + self._current_ctx = None + + @staticmethod + def build_context( + inputs: ModelInputs, + world_size: int = 1, + kv_caches: List = None, + cache_config: CacheConfig = None, + ): + """build context.""" + return StepContext.new( + inputs, + world_size, + kv_caches, + cache_config, + ) + + @contextmanager + def context(self, ctx: StepContext): + """context context.""" + self._current_ctx = ctx + yield ctx + self._current_ctx = None + + def current_context(self): + """get current_context.""" + return self._current_ctx diff --git a/lmdeploy/pytorch/models/llama.py b/lmdeploy/pytorch/models/llama.py index 210757722..8ed675816 100644 --- a/lmdeploy/pytorch/models/llama.py +++ b/lmdeploy/pytorch/models/llama.py @@ -7,8 +7,10 @@ from transformers.modeling_outputs import BaseModelOutputWithPast from lmdeploy.pytorch.layers import (ApplyRotaryEmb, Attention, EmbeddingType, - RMSNorm, SiluAndMul, build_merged_linear, + RMSNorm, SiluAndMul, build_linear, + build_merged_linear, build_rotary_embedding) +from lmdeploy.pytorch.model_inputs import StepContextManager from ..weight_loader.dist_utils import (colwise_parallelize_linear, rowwise_parallelize_linear) @@ -17,119 +19,82 @@ class LlamaAttention(nn.Module): """Rewrite module of LlamaAttention.""" - def _load_weights(self, loader, rank: int, world_size: int, + def __init__(self, origin: nn.Module, ctx_mgr: StepContextManager): + super().__init__() + world_size = 1 + if dist.is_initialized(): + world_size = dist.get_world_size() + self.ctx_mgr = ctx_mgr + self.num_heads = origin.num_heads // world_size + self.num_kv_heads = origin.num_key_value_heads // world_size + self.head_dim = origin.head_dim + + # qkv + self.qkv_proj = build_merged_linear(origin.q_proj, + origin.k_proj, + origin.v_proj, + ctx_mgr=ctx_mgr, + free_origin=True) + + self.apply_rotary_pos_emb = ApplyRotaryEmb() + + # attention + self.attn_fwd = Attention( + self.num_heads, + self.head_dim, + num_kv_heads=self.num_kv_heads, + v_head_size=self.head_dim, + ) + + self.o_proj = build_linear(origin.o_proj, + ctx_mgr=ctx_mgr, + all_reduce=world_size > 1) + + @staticmethod + def _load_weights(mod, loader, rank: int, world_size: int, device: torch.device): """load weights.""" for mod_name in ['q_proj', 'k_proj', 'v_proj']: - colwise_parallelize_linear(getattr(self, mod_name), + colwise_parallelize_linear(getattr(mod, mod_name), loader, rank=rank, world_size=world_size, prefix=mod_name) - rowwise_parallelize_linear(self.o_proj, + rowwise_parallelize_linear(mod.o_proj, loader, rank=rank, world_size=world_size, prefix='o_proj') - def _update_model_fn(self): - """update model.""" - - # qkv - self.qkv_proj = build_merged_linear(self.q_proj, self.k_proj, - self.v_proj) - del self.q_proj, self.k_proj, self.v_proj - - # rotary embedding - old_emb = self.rotary_emb - rotary_name = type(old_emb).__name__ - if rotary_name in [ - 'LlamaRotaryEmbedding', 'LlamaLinearScalingRotaryEmbedding' - ]: - emb_type = EmbeddingType.LinearScaling - elif rotary_name == 'LlamaDynamicNTKScalingRotaryEmbedding': - emb_type = EmbeddingType.DynamicNTKScaling - self.rotary_emb = build_rotary_embedding( - old_emb.dim, - old_emb.max_position_embeddings, - old_emb.base, - old_emb.scaling_factor, - emb_type, - ) - - self.apply_rotary_pos_emb = ApplyRotaryEmb() - - # attention - world_size = 1 - if dist.is_initialized(): - world_size = dist.get_world_size() - num_heads = self.num_heads // world_size - num_kv_heads = self.num_key_value_heads // world_size - head_size = self.head_dim - self.attn_fwd = Attention( - num_heads, - head_size, - num_kv_heads=num_kv_heads, - v_head_size=head_size, - ) - - @classmethod - def _distribute_output_fn(cls, outputs, **kwargs): - """Distribution output hook.""" - dist.all_reduce(outputs[0]) - return outputs - - def _contiguous_batching_forward_default_impl( + def forward( self, hidden_states: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - attention_mask: Optional[torch.Tensor] = None, - world_size: int = 1, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """default rewrite.""" - context = self.context.context - - num_heads = self.num_heads // world_size - num_kv_heads = self.num_key_value_heads // world_size - head_dim = self.head_dim - hidden_size = num_heads * head_dim - - def __qkv_proj(hidden_states): - """qkv proj.""" - qkv_states = self.qkv_proj(hidden_states) - # (-1, heads, head_dim) - qkv_states = qkv_states.flatten(0, - -2).unflatten(-1, (-1, head_dim)) - query_states, key_states, value_states = qkv_states.split( - (num_heads, num_kv_heads, num_kv_heads), dim=1) - - return query_states, key_states, value_states - - def __rotary_emb_fn(query_states, key_states): - """rotary embedding.""" - if not hasattr(context, '_cos'): - cos, sin = self.rotary_emb(query_states, - context.position_ids_1d[None]) - cos = cos[0] - sin = sin[0] - context._cos = cos - context._sin = sin - else: - cos = context._cos - sin = context._sin - - query_states, key_states = self.apply_rotary_pos_emb(query_states, - key_states, - cos, - sin, - inplace=True) - return query_states, key_states - - query_states, key_states, value_states = __qkv_proj(hidden_states) - query_states, key_states = __rotary_emb_fn(query_states, key_states) + """Rewrite of LlamaAttention.forward.""" + context = self.ctx_mgr.current_context() + + qkv_states = self.qkv_proj(hidden_states) + # (-1, heads, head_dim) + qkv_states = qkv_states.flatten(0, -2) + qkv_states = qkv_states.unflatten(-1, (-1, self.head_dim)) + query_states, key_states, value_states = qkv_states.split(( + self.num_heads, + self.num_kv_heads, + self.num_kv_heads, + ), + dim=1) + + cos, sin = rotary_pos_emb + query_states, key_states = self.apply_rotary_pos_emb( + query_states, + key_states, + cos, + sin, + inplace=True, + ) attn_output = self.attn_fwd( query_states, key_states, @@ -139,85 +104,69 @@ def __rotary_emb_fn(query_states, key_states): context.attn_meta, inplace=True, ) - attn_output = attn_output.reshape(*hidden_states.shape[:-1], - hidden_size) + attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1) attn_output = self.o_proj(attn_output) - return attn_output, None, past_key_value + return attn_output - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - """Rewrite of LlamaAttention.forward.""" - world_size = 1 - if dist.is_initialized(): - world_size = dist.get_world_size() - return self._contiguous_batching_forward_default_impl( - hidden_states, - position_ids, - past_key_value, - output_attentions, - attention_mask=attention_mask, - world_size=world_size, + +class LlamaMLP(nn.Module): + + def __init__(self, origin: nn.Module, ctx_mgr: StepContextManager): + super().__init__() + # gate up + self.gate_up_proj = build_merged_linear( + origin.gate_proj, + origin.up_proj, + ctx_mgr=ctx_mgr, + free_origin=True, ) + # silu and mul + self.act_fn = SiluAndMul(inplace=True) -class LlamaMLP(nn.Module): + # down + self.down_proj = build_linear(origin.down_proj, + ctx_mgr=ctx_mgr, + all_reduce=True) - def _load_weights(self, loader, rank: int, world_size: int, + @staticmethod + def _load_weights(mod: nn.Module, loader, rank: int, world_size: int, device: torch.device): """load weights.""" for mod_name in ['gate_proj', 'up_proj']: - colwise_parallelize_linear(getattr(self, mod_name), + colwise_parallelize_linear(getattr(mod, mod_name), loader, rank=rank, world_size=world_size, prefix=mod_name) - rowwise_parallelize_linear(self.down_proj, + rowwise_parallelize_linear(mod.down_proj, loader, rank=rank, world_size=world_size, prefix='down_proj') - @classmethod - def _distribute_output_fn(cls, outputs, **kwargs): - """Distribution output hook.""" - dist.all_reduce(outputs) - return outputs - - def _update_model_fn(self): - """update model.""" - - # gate up - self.gate_up_proj = build_merged_linear(self.gate_proj, self.up_proj) - del self.gate_proj, self.up_proj - - # silu and mul - self.act_fn = SiluAndMul(inplace=True) - def forward(self, x): gate_up = self.gate_up_proj(x) act = self.act_fn(gate_up) return self.down_proj(act) -class PatchedLlamaDecoderLayer(nn.Module): +class LlamaDecoderLayer(nn.Module): + + def __init__(self, origin: nn.Module, layer_idx: int, + ctx_mgr: StepContextManager): + super().__init__() + self.layer_idx = layer_idx + self.self_attn = LlamaAttention(origin.self_attn, ctx_mgr) + self.mlp = LlamaMLP(origin.mlp, ctx_mgr) - def _update_model_fn(self): - """update model.""" - input_layernorm = self.input_layernorm + # norm + input_layernorm = origin.input_layernorm self.input_layernorm = RMSNorm(input_layernorm.weight, input_layernorm.variance_epsilon) - post_attention_layernorm = self.post_attention_layernorm + post_attention_layernorm = origin.post_attention_layernorm self.post_attention_layernorm = RMSNorm( post_attention_layernorm.weight, post_attention_layernorm.variance_epsilon) @@ -225,13 +174,10 @@ def _update_model_fn(self): def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[List[torch.FloatTensor]] = None, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], + past_key_value: Optional[List[torch.FloatTensor]], residual: Optional[torch.Tensor] = None, - **kwargs, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, - torch.FloatTensor]]]: + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: if residual is None: residual = hidden_states @@ -241,10 +187,9 @@ def forward( hidden_states, residual) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states = self.self_attn( hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, + rotary_pos_emb=rotary_pos_emb, past_key_value=past_key_value, ) @@ -259,20 +204,42 @@ def forward( class LlamaModel(nn.Module): - def _update_model_fn(self): - """update model.""" - norm = self.norm + def __init__(self, origin: nn.Module, ctx_mgr: StepContextManager): + super().__init__() + self.ctx_mgr = ctx_mgr + self.embed_tokens = origin.embed_tokens + self.layers = nn.ModuleList([ + LlamaDecoderLayer(layer, idx, ctx_mgr) + for idx, layer in enumerate(origin.layers) + ]) + norm = origin.norm self.norm = RMSNorm(norm.weight, norm.variance_epsilon) - def _continuous_batching_forward( + rotary_emb = origin.layers[0].self_attn.rotary_emb + rotary_name = type(rotary_emb).__name__ + if rotary_name in [ + 'LlamaRotaryEmbedding', 'LlamaLinearScalingRotaryEmbedding' + ]: + emb_type = EmbeddingType.LinearScaling + elif rotary_name == 'LlamaDynamicNTKScalingRotaryEmbedding': + emb_type = EmbeddingType.DynamicNTKScaling + self.rotary_emb = build_rotary_embedding( + rotary_emb.dim, + rotary_emb.max_position_embeddings, + rotary_emb.base, + rotary_emb.scaling_factor, + emb_type, + ) + + def forward( self, input_ids: torch.LongTensor = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None + inputs_embeds: Optional[torch.FloatTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: - """Rewrite implementation of LlamaModel.forward.""" - context = self.context.context + """Rewrite of LlamaModel.forward.""" + context = self.ctx_mgr.current_context() # get inputs from context vision_embeddings = context.input_embeddings vision_embedding_indexing = context.input_embedding_indexing @@ -284,48 +251,51 @@ def _continuous_batching_forward( inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to( inputs_embeds) - # Attention mask is not necessary in continuous batching - attention_mask = None hidden_states = inputs_embeds residual = None + cos, sin = self.rotary_emb(hidden_states, position_ids) + cos, sin = cos[0], sin[0] + rotary_pos_emb = (cos, sin) for idx, decoder_layer in enumerate(self.layers): past_key_value = past_key_values[idx] hidden_states, residual = decoder_layer( hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, + rotary_pos_emb=rotary_pos_emb, past_key_value=past_key_value, residual=residual, ) hidden_states, _ = self.norm(hidden_states, residual) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=past_key_values, - hidden_states=None, - attentions=None, - ) + return hidden_states + + +class LlamaForCausalLM(nn.Module): + + def __init__(self, origin: nn.Module, ctx_mgr: StepContextManager): + super().__init__() + self.ctx_mgr = ctx_mgr + self.model = LlamaModel(origin.model, ctx_mgr) + self.lm_head = build_linear(origin.lm_head) def forward( self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: torch.Tensor = None, **kwargs, - ) -> Union[Tuple, BaseModelOutputWithPast]: - """Rewrite of LlamaModel.forward.""" - return self._continuous_batching_forward( - input_ids, - position_ids, - past_key_values, - inputs_embeds, + ): + hidden_states = self.model( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, ) + + logits = self.lm_head(hidden_states) + logits = logits.float() + + return logits diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index 13806a8d5..da09d589a 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -15,13 +15,16 @@ 'transformers.models.llama.modeling_llama.LlamaMLP': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaMLP', 'transformers.models.llama.modeling_llama.LlamaDecoderLayer': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.PatchedLlamaDecoderLayer', + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaDecoderLayer', # support modeling rewritten in lmdeploy 'modeling_llama.LlamaAttention': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaAttention', 'modeling_llama.LlamaModel': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaModel', - 'modeling_llama.LlamaMLP': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaMLP', + 'modeling_llama.LlamaMLP': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaMLP', + 'modeling_llama.LlamaForCausalLM': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaForCausalLM', } # Falcon Models in transformer / on hub diff --git a/lmdeploy/pytorch/models/patch.py b/lmdeploy/pytorch/models/patch.py index b0e6bf411..79c22c6eb 100644 --- a/lmdeploy/pytorch/models/patch.py +++ b/lmdeploy/pytorch/models/patch.py @@ -3,10 +3,9 @@ import inspect import re from copy import copy -from typing import Any, Dict, Sequence +from typing import Any, Dict import torch -from addict import Addict from lmdeploy.utils import get_logger @@ -112,115 +111,46 @@ def get_origin_mod(self): return model -def _patch(model: torch.nn.Module, - context: Addict, - module_map: Dict[str, str] = None) -> torch.nn.Module: +def _patch(model: torch.nn.Module, module_map: Dict[str, + str]) -> torch.nn.Module: """patch the model with rewrite module. Args: model (Module): model to be patched. - context (Addict): The environment info to patched in model Returns: Module: The patched model """ - if module_map is None: - module_map = MODULE_MAP - - def _recursive_children(context, named_children): + def _recursive_children(named_children): """recursive children.""" for name, child in named_children: - patched_child = _patch(child, context, module_map=module_map) - if patched_child != child: - model.register_module(name, patched_child) + _patch(child, module_map=module_map) - _recursive_children(context, model.named_children()) + _recursive_children(model.named_children()) rewrite_qualname = _find_rewrite_module_qualname(model, module_map=module_map) if rewrite_qualname is not None: cls_type = _class_from_qualname(rewrite_qualname) - model = _update_module_type(model, cls_type, dict(context=context)) + if hasattr(cls_type, '_load_weights'): + setattr(model, '_load_weights', cls_type._load_weights) return model -def _update_model(model: torch.nn.Module): - """Update model after patch and load. - - Args: - model (Module): The model to be updated. - """ - # recursive over children - for _, child in model.named_children(): - _update_model(child) - - if hasattr(model, '_update_model_fn'): - model._update_model_fn() - - -def update_model(model: torch.nn.Module): - """update model.""" - return _update_model(model) - - -def _dist_model(model: torch.nn.Module, rank: int = 0): - """distribute model parameters.""" - - def _register_hooks(): - """register hooks.""" - if hasattr(model, '_distribute_input_fn'): - input_fn = model._distribute_input_fn - model.register_forward_pre_hook( - lambda _, inputs, inputs_dict: input_fn(inputs, inputs_dict), - with_kwargs=True, - ) - - if hasattr(model, '_distribute_output_fn'): - output_fn = model._distribute_output_fn - model.register_forward_hook( - lambda mod, inputs, outputs: output_fn(outputs)) - - for name, child in model.named_children(): - if rank == 0: - logger.debug(f'Distribute module: <{name}>') - new_child = _dist_model(child, rank) - if new_child != child: - model.register_module(name, child) - - _register_hooks() - - return model - - -class PatchedForward: - """patched forward.""" - - def __init__(self, model, context, extra_args): - self._model = model - self._patch_context: Dict = context - self._extra_args: list = extra_args - - def __call__(self, *args, **kwargs): - for arg_name in self._extra_args: - extra_arg = kwargs.pop(arg_name, None) - self._patch_context[arg_name] = extra_arg - - output = self._model(*args, **kwargs) - - self._patch_context.clear() - - return output +def _get_module_map(): + """get module map.""" + module_map = MODULE_MAP.copy() + device_type = get_device_manager().current_context().device_type + if device_type != 'cuda': + device_map = DEVICE_SPECIAL_MODULE_MAP.get(device_type, dict()) + module_map.update(device_map) + return module_map @torch.inference_mode() -def patch( - model: torch.nn.Module, - extra_args: Sequence[str] = None, - rank: int = 0, - world_size: int = 1, -): +def patch(model: torch.nn.Module, ): """Patch the model with rewrite modules. Extra arguments will be patched in forward of model, weights on each rank @@ -228,35 +158,25 @@ def patch( Args: model (Module): Model to be patched. - extra_args (Sequence[str]): Extra arguments of model forward. - rank (int): Distribution rank. - world_size (int): Distribution world size. Returns: Module: The patched model. """ - if rank == 0: - logger.info('Patching model.') - - if extra_args is None: - extra_args = [] - - _patch_context = Addict() + module_map = _get_module_map() + model = _patch(model, module_map=module_map) + return model - module_map = MODULE_MAP.copy() - device_type = get_device_manager().current_context().device_type - if device_type != 'cuda': - device_map = DEVICE_SPECIAL_MODULE_MAP.get(device_type, dict()) - module_map.update(device_map) - model = _patch(model, _patch_context, module_map=module_map) +def update_model(model: torch.nn.Module): + """build model.""" + from lmdeploy.pytorch.model_inputs import StepContextManager + ctx_mgr = StepContextManager() + module_map = _get_module_map() - if world_size > 1: - model = _dist_model(model, rank) + rewrite_qualname = _find_rewrite_module_qualname(model, + module_map=module_map) - patched_forward = PatchedForward(model, - _patch_context, - extra_args=extra_args) - model.patched_forward = patched_forward + if rewrite_qualname is not None: + model_cls = _class_from_qualname(rewrite_qualname) - return model + return model_cls(model, ctx_mgr) diff --git a/lmdeploy/pytorch/weight_loader/model_weight_loader.py b/lmdeploy/pytorch/weight_loader/model_weight_loader.py index 9ecc726ca..9ddbb1169 100644 --- a/lmdeploy/pytorch/weight_loader/model_weight_loader.py +++ b/lmdeploy/pytorch/weight_loader/model_weight_loader.py @@ -192,7 +192,7 @@ def __load_no_recursive(mod: torch.nn.Module): mod.register_buffer(name, param) if not load_only and hasattr(model, '_load_weights'): - model._load_weights(loader, rank, world_size, device=device) + model._load_weights(model, loader, rank, world_size, device=device) else: __load_no_recursive(model) for name, child in model.named_children(): From 5a09d9f61311d263dc133bb6f23f886726f95f6d Mon Sep 17 00:00:00 2001 From: grimoire Date: Mon, 22 Jul 2024 20:57:31 +0800 Subject: [PATCH 05/89] add lora and w8a8 --- lmdeploy/pytorch/adapter/adapter.py | 33 ++- lmdeploy/pytorch/backends/base.py | 4 + .../pytorch/backends/cuda/layer_backend.py | 12 + .../backends/cuda/multinomial_sampling.py | 10 + lmdeploy/pytorch/backends/cuda/qmodules.py | 73 ++++++ lmdeploy/pytorch/backends/cuda/slora.py | 234 ++++++++++++++++++ .../pytorch/backends/default/layer_backend.py | 3 + lmdeploy/pytorch/backends/default/linear.py | 13 +- .../backends/default/multinomial_sampling.py | 24 ++ lmdeploy/pytorch/backends/linear.py | 6 +- .../pytorch/backends/multinomial_sampling.py | 22 ++ lmdeploy/pytorch/backends/qmodules.py | 37 +++ lmdeploy/pytorch/backends/slora.py | 55 ++++ lmdeploy/pytorch/engine/logits_process.py | 5 +- lmdeploy/pytorch/layers/__init__.py | 1 - lmdeploy/pytorch/layers/linear.py | 195 +++++++++------ .../pytorch/layers/multinomial_sampling.py | 13 + lmdeploy/pytorch/layers/norm.py | 10 +- lmdeploy/pytorch/models/llama.py | 81 +++--- lmdeploy/pytorch/models/module_map.py | 6 - lmdeploy/pytorch/models/q_modules.py | 3 + 21 files changed, 691 insertions(+), 149 deletions(-) create mode 100644 lmdeploy/pytorch/backends/cuda/multinomial_sampling.py create mode 100644 lmdeploy/pytorch/backends/cuda/qmodules.py create mode 100644 lmdeploy/pytorch/backends/cuda/slora.py create mode 100644 lmdeploy/pytorch/backends/default/multinomial_sampling.py create mode 100644 lmdeploy/pytorch/backends/multinomial_sampling.py create mode 100644 lmdeploy/pytorch/backends/qmodules.py create mode 100644 lmdeploy/pytorch/backends/slora.py create mode 100644 lmdeploy/pytorch/layers/multinomial_sampling.py diff --git a/lmdeploy/pytorch/adapter/adapter.py b/lmdeploy/pytorch/adapter/adapter.py index e50071383..97749be9a 100644 --- a/lmdeploy/pytorch/adapter/adapter.py +++ b/lmdeploy/pytorch/adapter/adapter.py @@ -32,11 +32,12 @@ def _cache_weight(cache: Tensor, weight: Tensor, rank_offset: Tensor): def _get_named_loralinears(model: torch.nn.Module): """get all named loralinear.""" - from peft.tuners.lora import Linear as LoRALinear named_loralinear: Dict[str, torch.nn.Module] = dict() for name, module in model.named_modules(): - if isinstance(module, LoRALinear): - named_loralinear[name] = module + if getattr(module, 'lora_adapters', None) is None: + continue + for idx, adapter in enumerate(module.lora_adapters): + named_loralinear[f'{name}_{idx}'] = adapter return named_loralinear @@ -83,10 +84,8 @@ def __update_linear(linear, idx, target_name, adapter_names): """update linear.""" linear.layer_idx = idx linear.target_name = target_name - for name in adapter_names: - if name in linear.lora_A: - linear.lora_A.pop(name) - linear.lora_B.pop(name) + linear.adapter_info.lora_A = None + linear.adapter_info.lora_B = None adapter_names = [weight_map.adapter_name for weight_map in weight_maps] @@ -115,13 +114,11 @@ def __post_init__(self): @classmethod def from_loralinear(cls, linear: torch.nn.Module): """create from lora linear.""" - from peft.tuners.lora import Linear as LoRALinear - assert isinstance(linear, LoRALinear) - - ranks = linear.r - scalings = linear.scaling - base_weight = linear.base_layer.weight - out_features, in_features = base_weight.shape + adapter_info = linear.adapter_info + ranks = adapter_info.r + scalings = adapter_info.scaling + out_features = adapter_info.out_features + in_features = adapter_info.in_features return cls( ranks=ranks, scalings=scalings, @@ -209,10 +206,12 @@ def cache_lora_linear(self, lora_linear: Dict[str, torch.nn.Module], rank_offset = self.rank_offset.reshape(-1, self.max_rank) for tidx, target in enumerate(target_modules): linear = lora_linear[target] - if not (name in linear.lora_A and name in linear.lora_B): + adapter_info = linear.adapter_info + if not (name in adapter_info.lora_A + and name in adapter_info.lora_B): continue - linear_a = linear.lora_A[name] - linear_b = linear.lora_B[name] + linear_a = adapter_info.lora_A[name] + linear_b = adapter_info.lora_B[name] weight_a = linear_a.weight weight_b = linear_b.weight assert weight_a is not None diff --git a/lmdeploy/pytorch/backends/base.py b/lmdeploy/pytorch/backends/base.py index a46dd6fc1..d71435eb3 100644 --- a/lmdeploy/pytorch/backends/base.py +++ b/lmdeploy/pytorch/backends/base.py @@ -15,6 +15,10 @@ class LayerType(Enum): ApplyRotaryEmb = auto() SiluAndMul = auto() RMSNorm = auto() + SLoRA = auto() + LinearW8A8 = auto() + RMSNormW8A8 = auto() + MultinomialSampling = auto() class LayersBackend(ABC): diff --git a/lmdeploy/pytorch/backends/cuda/layer_backend.py b/lmdeploy/pytorch/backends/cuda/layer_backend.py index 6c69726e8..7c8e26661 100644 --- a/lmdeploy/pytorch/backends/cuda/layer_backend.py +++ b/lmdeploy/pytorch/backends/cuda/layer_backend.py @@ -24,6 +24,18 @@ def get_layer_impl_builder(cls, layer_type: LayerType): elif layer_type == LayerType.RMSNorm: from .norm import TritonRMSNormBuilder return TritonRMSNormBuilder + elif layer_type == LayerType.SLoRA: + from .slora import TritonSLoRABuilder + return TritonSLoRABuilder + elif layer_type == LayerType.LinearW8A8: + from .qmodules import TritonLinearW8A8Builder + return TritonLinearW8A8Builder + elif layer_type == LayerType.RMSNormW8A8: + from .qmodules import TritonRMSNormBuilder + return TritonRMSNormBuilder + elif layer_type == LayerType.MultinomialSampling: + from .multinomial_sampling import TritonMultinomialSamplingBuilder + return TritonMultinomialSamplingBuilder else: return super().get_layer_impl_builder(layer_type) diff --git a/lmdeploy/pytorch/backends/cuda/multinomial_sampling.py b/lmdeploy/pytorch/backends/cuda/multinomial_sampling.py new file mode 100644 index 000000000..b91637ce7 --- /dev/null +++ b/lmdeploy/pytorch/backends/cuda/multinomial_sampling.py @@ -0,0 +1,10 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from lmdeploy.pytorch.kernels.cuda import multinomial_sampling + +from ..multinomial_sampling import MultinomialSamplingBuilder + + +class TritonMultinomialSamplingBuilder(MultinomialSamplingBuilder): + + def build(): + return multinomial_sampling diff --git a/lmdeploy/pytorch/backends/cuda/qmodules.py b/lmdeploy/pytorch/backends/cuda/qmodules.py new file mode 100644 index 000000000..c41fdf6e2 --- /dev/null +++ b/lmdeploy/pytorch/backends/cuda/qmodules.py @@ -0,0 +1,73 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.distributed as dist +from torch import nn + +from lmdeploy.pytorch.kernels.cuda.w8a8_triton_kernels import ( + matmul_kernel_dynamic_quant, per_token_quant_int8, rms_norm_dynamic_quant) +from lmdeploy.pytorch.model_inputs import StepContextManager +from lmdeploy.pytorch.models.q_modules import QTensor + +from ..qmodules import (LinearW8A8Builder, LinearW8A8Impl, RMSNormW8A8Builder, + RMSNormW8A8Impl) + + +class TritonRMSNormW8A8Impl(RMSNormW8A8Impl, nn.Module): + + def __init__(self, weight: torch.Tensor, eps: float = 1e-6): + super().__init__() + self.weight = nn.Parameter(weight.clone()) + self.eps = eps + + def forward(self, x: torch.Tensor, residual: torch.Tensor = None): + if residual is not None: + x = x + residual + residual = x + hidden_states_quant, rms_scale = rms_norm_dynamic_quant( + x, self.weight, self.eps) + x = QTensor(hidden_states_quant, rms_scale) + if residual is None: + return x + return x, residual + + +class TritonRMSNormBuilder(RMSNormW8A8Builder): + + @staticmethod + def build(weight: torch.Tensor, eps: float = 1e-6): + return TritonRMSNormW8A8Impl(weight, eps) + + +class TritonLinearW8A8Impl(LinearW8A8Impl, nn.Module): + + def __init__(self, mod: nn.Module): + super().__init__() + self.weight = mod.weight + self.scale = mod.scale + self.bias = mod.bias + + def forward(self, x, all_reduce: bool = False): + if isinstance(x, torch.Tensor): + x = x.contiguous() + input_quant, input_scale = per_token_quant_int8(x, 1e-7) + else: + assert isinstance(x, QTensor) + input_quant, input_scale = x.tensor, x.scale + + out = matmul_kernel_dynamic_quant(input_quant, + self.weight, + input_scale, + self.scale, + output_dtype=torch.float16, + bias=self.bias) + + if all_reduce: + dist.all_reduce(out) + return out + + +class TritonLinearW8A8Builder(LinearW8A8Builder): + + @staticmethod + def build(mod: nn.Module, ctx_mgr: StepContextManager = None): + return TritonLinearW8A8Impl(mod) diff --git a/lmdeploy/pytorch/backends/cuda/slora.py b/lmdeploy/pytorch/backends/cuda/slora.py new file mode 100644 index 000000000..8c397fe49 --- /dev/null +++ b/lmdeploy/pytorch/backends/cuda/slora.py @@ -0,0 +1,234 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from dataclasses import dataclass + +import torch +import torch.distributed as dist + +from lmdeploy.pytorch.kernels.cuda.mbgmm import mbgmm_a, mbgmm_b +from lmdeploy.pytorch.kernels.cuda.mbgmv import mbgmv_a, mbgmv_b +from lmdeploy.pytorch.kernels.rearange_all_gather import rearange_all_gather +from lmdeploy.pytorch.model_inputs import StepContextManager + +from ..slora import AdapterInfo, SLoRABuilder, SLoRAImpl + + +@dataclass +class PackedLoRAInput: + x: torch.Tensor + a_cache: torch.Tensor + b_cache: torch.Tensor + q_start_loc: torch.Tensor + q_seqlens: torch.Tensor + adapter_ids: torch.Tensor + scaling: torch.Tensor + rank_offset: torch.Tensor + ranks: torch.Tensor + max_seq_len: int + max_rank: int + is_decoding: bool + + +class TritonSLoRAImpl(SLoRAImpl): + + def __init__(self, + adapter_info: AdapterInfo, + ctx_mgr: StepContextManager, + colwise: bool = True): + self.base_slice = adapter_info.base_slice + self.ctx_mgr = ctx_mgr + self.colwise = colwise + + def _make_packed_lora_input(self, x, target_name: str, layer_idx: int): + context = self.ctx_mgr.current_context() + adapter_param = context.adapter_params[target_name] + + # adapter cache + ranks = adapter_param.ranks + scaling = adapter_param.scalings + rank_offset = adapter_param.rank_offsets + max_rank = adapter_param.max_rank + k_cache, v_cache = context.kv_caches[layer_idx] + cache_len = k_cache.size(0) + a_cache = k_cache.view(cache_len, -1) + b_cache = v_cache.view(cache_len, -1) + + return PackedLoRAInput(x=x.flatten(0, -2).contiguous(), + a_cache=a_cache, + b_cache=b_cache, + q_start_loc=context.q_start_loc, + q_seqlens=context.q_seq_length, + adapter_ids=context.local_adapter_ids, + scaling=scaling, + rank_offset=rank_offset, + ranks=ranks, + max_seq_len=context.max_q_seq_length, + max_rank=max_rank, + is_decoding=context.is_decoding) + + def _forward_rowwise(self, + lora_input: PackedLoRAInput, + base_output: torch.Tensor, + is_tp: bool = True): + """forward_rowwise.""" + sliced_base = base_output[..., self.base_slice] + out_size = sliced_base.size(-1) + if is_tp: + rank = dist.get_rank() + world_size = dist.get_world_size() + out_size //= world_size + if not lora_input.is_decoding: + xa = mbgmm_a(lora_input.x, + lora_input.a_cache, + q_start_loc=lora_input.q_start_loc, + q_seqlens=lora_input.q_seqlens, + adapter_ids=lora_input.adapter_ids, + rank_offset=lora_input.rank_offset, + ranks=lora_input.ranks, + max_seq_len=lora_input.max_seq_len, + max_rank=lora_input.max_rank) + lora_out = mbgmm_b(xa, + lora_input.b_cache, + q_start_loc=lora_input.q_start_loc, + q_seqlens=lora_input.q_seqlens, + adapter_ids=lora_input.adapter_ids, + scaling=lora_input.scaling, + rank_offset=lora_input.rank_offset, + ranks=lora_input.ranks, + max_seq_len=lora_input.max_seq_len, + max_rank=lora_input.max_rank, + out_size=out_size) + else: + xa = mbgmv_a(lora_input.x, + lora_input.a_cache, + adapter_ids=lora_input.adapter_ids, + rank_offset=lora_input.rank_offset, + ranks=lora_input.ranks, + max_rank=lora_input.max_rank) + lora_out = mbgmv_b(xa, + lora_input.b_cache, + adapter_ids=lora_input.adapter_ids, + scaling=lora_input.scaling, + rank_offset=lora_input.rank_offset, + ranks=lora_input.ranks, + max_rank=lora_input.max_rank, + out_size=out_size) + + if is_tp: + out_shape = base_output.shape + out = base_output.flatten(0, -2) + slice_off = self.base_slice.start + slice_off = 0 if slice_off is None else slice_off + slice_start = slice_off + rank * out_size + slice_end = slice_start + out_size + out[:, slice_start:slice_end] += lora_out + out = out.reshape(out_shape) + else: + lora_out = lora_out.reshape(sliced_base.shape) + sliced_base.add_(lora_out) + out = base_output + + return out + + def _forward_colwise(self, lora_input: PackedLoRAInput, + base_output: torch.Tensor): + """forward_colwise.""" + + def __gather_xa(xa): + """gather xa.""" + gathered_xa = xa.new_empty(world_size, xa.size(0), xa.size(1)) + dist.all_gather_into_tensor(gathered_xa, xa) + # TODO: gather would failed when adapters have different ranks. + gathered_xa = gathered_xa.permute(1, 0, 2).flatten(-2, -1) + return gathered_xa + + sliced_base = base_output[..., self.base_slice] + out_size = sliced_base.size(-1) + world_size = dist.get_world_size() + + if not lora_input.is_decoding: + xa = mbgmm_a(lora_input.x, + lora_input.a_cache, + q_start_loc=lora_input.q_start_loc, + q_seqlens=lora_input.q_seqlens, + adapter_ids=lora_input.adapter_ids, + rank_offset=lora_input.rank_offset, + ranks=lora_input.ranks, + max_seq_len=lora_input.max_seq_len, + max_rank=lora_input.max_rank, + rank_step=world_size) + gathered_xa = __gather_xa(xa) + if len(lora_input.ranks) > 1: + gathered_xa = rearange_all_gather( + gathered_xa, + b_start_loc=lora_input.q_start_loc, + b_seq_lens=lora_input.q_seqlens, + adapter_ids=lora_input.adapter_ids, + ranks=lora_input.ranks, + world_size=world_size, + max_seq_len=lora_input.max_seq_len, + output=gathered_xa) + lora_out = mbgmm_b(gathered_xa, + lora_input.b_cache, + q_start_loc=lora_input.q_start_loc, + q_seqlens=lora_input.q_seqlens, + adapter_ids=lora_input.adapter_ids, + scaling=lora_input.scaling, + rank_offset=lora_input.rank_offset, + ranks=lora_input.ranks, + max_seq_len=lora_input.max_seq_len, + max_rank=lora_input.max_rank, + out_size=out_size) + else: + xa = mbgmv_a(lora_input.x, + lora_input.a_cache, + adapter_ids=lora_input.adapter_ids, + rank_offset=lora_input.rank_offset, + ranks=lora_input.ranks, + max_rank=lora_input.max_rank, + rank_step=world_size) + gathered_xa = __gather_xa(xa) + if len(lora_input.ranks) > 1: + gathered_xa = rearange_all_gather( + gathered_xa, + b_start_loc=lora_input.q_start_loc, + b_seq_lens=lora_input.q_seqlens, + adapter_ids=lora_input.adapter_ids, + ranks=lora_input.ranks, + world_size=world_size, + max_seq_len=lora_input.max_seq_len, + output=gathered_xa) + lora_out = mbgmv_b(gathered_xa, + lora_input.b_cache, + adapter_ids=lora_input.adapter_ids, + scaling=lora_input.scaling, + rank_offset=lora_input.rank_offset, + ranks=lora_input.ranks, + max_rank=lora_input.max_rank, + out_size=out_size) + + lora_out = lora_out.reshape(sliced_base.shape) + sliced_base.add_(lora_out) + output = base_output + + return output + + def forward(self, + x: torch.Tensor, + base_output: torch.Tensor, + target_name: str, + layer_idx: int, + is_tp: bool = True): + lora_input = self._make_packed_lora_input(x, target_name, layer_idx) + if self.colwise and is_tp: + return self._forward_colwise(lora_input, base_output) + else: + return self._forward_rowwise(lora_input, base_output, is_tp) + + +class TritonSLoRABuilder(SLoRABuilder): + + @staticmethod + def build(adapter_info: AdapterInfo, + ctx_mgr: StepContextManager, + colwise: bool = True): + return TritonSLoRAImpl(adapter_info, ctx_mgr, colwise) diff --git a/lmdeploy/pytorch/backends/default/layer_backend.py b/lmdeploy/pytorch/backends/default/layer_backend.py index 01f3235d8..b80730f43 100644 --- a/lmdeploy/pytorch/backends/default/layer_backend.py +++ b/lmdeploy/pytorch/backends/default/layer_backend.py @@ -29,6 +29,9 @@ def get_layer_impl_builder(cls, layer_type: LayerType): elif layer_type == LayerType.RMSNorm: from .norm import DefaultRMSNormBuilder return DefaultRMSNormBuilder + elif layer_type == LayerType.MultinomialSampling: + from .multinomial_sampling import DefaultMultinomialSamplingBuilder + return DefaultMultinomialSamplingBuilder else: raise RuntimeError(f'{layer_type} not supported.') diff --git a/lmdeploy/pytorch/backends/default/linear.py b/lmdeploy/pytorch/backends/default/linear.py index 7f0917814..b3a8b84e1 100644 --- a/lmdeploy/pytorch/backends/default/linear.py +++ b/lmdeploy/pytorch/backends/default/linear.py @@ -9,14 +9,13 @@ class DefaultLinearImpl(LinearImpl): - def __init__(self, mod: nn.Module, all_reduce: bool = False): + def __init__(self, mod: nn.Module): super().__init__() self.mod = mod - self.all_reduce = all_reduce - def forward(self, x): + def forward(self, x, all_reduce: bool = False): out = self.mod(x) - if self.all_reduce: + if all_reduce: dist.all_reduce(out) return out @@ -24,7 +23,5 @@ def forward(self, x): class DefaultLinearBuilder(LinearBuilder): @staticmethod - def build(mod: nn.Module, - ctx_mgr: StepContextManager = None, - all_reduce: bool = False): - return DefaultLinearImpl(mod, all_reduce) + def build(mod: nn.Module, ctx_mgr: StepContextManager = None): + return DefaultLinearImpl(mod) diff --git a/lmdeploy/pytorch/backends/default/multinomial_sampling.py b/lmdeploy/pytorch/backends/default/multinomial_sampling.py new file mode 100644 index 000000000..5cffe6c63 --- /dev/null +++ b/lmdeploy/pytorch/backends/default/multinomial_sampling.py @@ -0,0 +1,24 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from ..multinomial_sampling import (MultinomialSamplingBuilder, + MultinomialSamplingImpl) + + +class DefaultMultinomialSamplingImpl(MultinomialSamplingImpl): + + def forward(scores: torch.Tensor, + seeds: torch.LongTensor, + offsets: torch.LongTensor, + indices: torch.Tensor = None): + sampled_index = torch.multinomial(scores, + num_samples=1, + replacement=True) + outputs = torch.gather(indices, dim=1, index=sampled_index) + return outputs.view(-1) + + +class DefaultMultinomialSamplingBuilder(MultinomialSamplingBuilder): + + def build(): + return DefaultMultinomialSamplingImpl() diff --git a/lmdeploy/pytorch/backends/linear.py b/lmdeploy/pytorch/backends/linear.py index af9c88689..5aa1710a7 100644 --- a/lmdeploy/pytorch/backends/linear.py +++ b/lmdeploy/pytorch/backends/linear.py @@ -9,7 +9,7 @@ class LinearImpl(ABC, nn.Module): @abstractmethod - def forward(self, x): + def forward(self, x, all_reduce: bool = False): raise NotImplementedError @@ -17,7 +17,5 @@ class LinearBuilder(ABC): @staticmethod @abstractmethod - def build(mod: nn.Module, - ctx_mgr: StepContextManager = None, - all_reduce: bool = False): + def build(mod: nn.Module, ctx_mgr: StepContextManager = None): raise NotImplementedError diff --git a/lmdeploy/pytorch/backends/multinomial_sampling.py b/lmdeploy/pytorch/backends/multinomial_sampling.py new file mode 100644 index 000000000..35dcf091a --- /dev/null +++ b/lmdeploy/pytorch/backends/multinomial_sampling.py @@ -0,0 +1,22 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABC, abstractmethod + +import torch + + +class MultinomialSamplingImpl(ABC): + + @abstractmethod + def forward(scores: torch.Tensor, + seeds: torch.LongTensor, + offsets: torch.LongTensor, + indices: torch.Tensor = None): + raise NotImplementedError + + +class MultinomialSamplingBuilder(ABC): + + @staticmethod + @abstractmethod + def build(): + raise NotImplementedError diff --git a/lmdeploy/pytorch/backends/qmodules.py b/lmdeploy/pytorch/backends/qmodules.py new file mode 100644 index 000000000..19bfc4e29 --- /dev/null +++ b/lmdeploy/pytorch/backends/qmodules.py @@ -0,0 +1,37 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABC, abstractmethod + +import torch +from torch import nn + +from lmdeploy.pytorch.model_inputs import StepContextManager + + +class RMSNormW8A8Impl(ABC, nn.Module): + + @abstractmethod + def forward(self, x: torch.Tensor, residual: torch.Tensor = None): + raise NotImplementedError + + +class RMSNormW8A8Builder(ABC): + + @staticmethod + @abstractmethod + def build(weight: torch.Tensor, eps: float = 1e-6): + raise NotImplementedError + + +class LinearW8A8Impl(ABC, nn.Module): + + @abstractmethod + def forward(self, x, all_reduce: bool = False): + raise NotImplementedError + + +class LinearW8A8Builder(ABC): + + @staticmethod + @abstractmethod + def build(mod: nn.Module, ctx_mgr: StepContextManager = None): + raise NotImplementedError diff --git a/lmdeploy/pytorch/backends/slora.py b/lmdeploy/pytorch/backends/slora.py new file mode 100644 index 000000000..6efd3f960 --- /dev/null +++ b/lmdeploy/pytorch/backends/slora.py @@ -0,0 +1,55 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABC, abstractmethod +from dataclasses import dataclass + +import torch +from torch import nn + +from lmdeploy.pytorch.model_inputs import StepContextManager + + +@dataclass +class AdapterInfo: + r: dict + lora_A: nn.ModuleDict + lora_B: nn.ModuleDict + scaling: dict + base_slice: slice + in_features: int + out_features: int + + @staticmethod + def from_lora_linear(mod: nn.Module, base_slice: slice = None): + if base_slice is None: + base_slice = slice(None) + return AdapterInfo( + r=mod.r, + lora_A=mod.lora_A, + lora_B=mod.lora_B, + scaling=mod.scaling, + base_slice=base_slice, + in_features=mod.in_features, + out_features=mod.out_features, + ) + + +class SLoRAImpl(ABC): + + @abstractmethod + def forward(self, + x: torch.Tensor, + base_output: torch.Tensor, + target_name: str, + layer_idx: int, + is_tp: bool = True): + raise NotImplementedError + + +class SLoRABuilder(ABC): + + @staticmethod + @abstractmethod + def build(adapter_info: AdapterInfo, + ctx_mgr: StepContextManager, + colwise: bool = True): + raise NotImplementedError diff --git a/lmdeploy/pytorch/engine/logits_process.py b/lmdeploy/pytorch/engine/logits_process.py index 4f8e720f3..fb948bd8a 100644 --- a/lmdeploy/pytorch/engine/logits_process.py +++ b/lmdeploy/pytorch/engine/logits_process.py @@ -89,7 +89,8 @@ def _multinomial_sampling(scores: torch.Tensor, offsets: torch.LongTensor, indices: torch.LongTensor = None): """sampling.""" - from lmdeploy.pytorch.kernels import multinomial_sampling + from lmdeploy.pytorch.layers.multinomial_sampling import \ + multinomial_sampling return multinomial_sampling(scores, seeds, offsets, indices) @@ -258,7 +259,7 @@ def __call__(self, input_ids: torch.LongTensor, stop_words = sampling_inputs.stop_words if stop_words is not None: - stop_words = stop_words * self.ignore_eos[:, None] + stop_words = torch.where(self.ignore_eos[:, None], stop_words, -1) scores = _process_bad_words(scores, stop_words) return scores diff --git a/lmdeploy/pytorch/layers/__init__.py b/lmdeploy/pytorch/layers/__init__.py index 229dcd613..4e04b0627 100644 --- a/lmdeploy/pytorch/layers/__init__.py +++ b/lmdeploy/pytorch/layers/__init__.py @@ -3,7 +3,6 @@ # https://github.com/vllm-project/vllm/blob/main/vllm/attention/ from .activation import SiluAndMul # noqa: F401 from .attention import Attention # noqa: F401 -from .linear import build_linear, build_merged_linear # noqa: F401 from .norm import RMSNorm # noqa: F401 from .rotary_embedding import ApplyRotaryEmb # noqa: F401 from .rotary_embedding import EmbeddingType # noqa: F401 diff --git a/lmdeploy/pytorch/layers/linear.py b/lmdeploy/pytorch/layers/linear.py index e3fc20b8d..ea13b5acc 100644 --- a/lmdeploy/pytorch/layers/linear.py +++ b/lmdeploy/pytorch/layers/linear.py @@ -1,5 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -from dataclasses import dataclass from typing import Any, List import torch @@ -10,6 +9,7 @@ from lmdeploy.utils import get_logger from ..backends import LayerType, get_backend +from ..backends.slora import AdapterInfo logger = get_logger('lmdeploy') @@ -52,32 +52,24 @@ def _get_world_rank(): return world_size, rank -@dataclass -class AdapterInfo: - r: dict - lora_A: nn.ModuleDict - lora_B: nn.ModuleDict - scaling: dict - base_slice: slice - - @staticmethod - def from_lora_linear(mod: nn.Module, base_slice: slice = None): - return AdapterInfo( - r=mod.r, - lora_A=mod.lora_A, - lora_B=mod.lora_B, - scaling=mod.scaling, - base_slice=base_slice, - ) - - class SLoRA(nn.Module): - def __init__(self, adapter_infos: List[AdapterInfo]): + def __init__(self, + adapter_info: AdapterInfo, + ctx_mgr: Any = None, + colwise: bool = True, + is_tp: bool = True): super().__init__() + self.adapter_info = adapter_info + impl_builder = get_backend().get_layer_impl_builder(LayerType.SLoRA) + self.impl = impl_builder.build(adapter_info, ctx_mgr, colwise=colwise) + self.target_name = None + self.layer_idx = None + self.is_tp = is_tp - def forward(self, x): - raise NotImplementedError + def forward(self, x, base_output=None): + return self.impl.forward(x, base_output, self.target_name, + self.layer_idx, self.is_tp) class AwqLinear(nn.Module): @@ -92,8 +84,19 @@ class W8A8Linear(nn.Module): def __init__(self, mod: nn.Module, - adapter_infos: List[AdapterInfo] = None): + ctx_mgr: Any = None, + colwise: bool = True, + is_tp: bool = False): super().__init__() + impl_builder = get_backend().get_layer_impl_builder( + LayerType.LinearW8A8) + self.impl = impl_builder.build(mod, ctx_mgr) + self.is_tp = is_tp + self.colwise = colwise + + def forward(self, x): + is_tp = False if self.colwise else self.is_tp + return self.impl.forward(x, is_tp) class BaseLinear(nn.Module): @@ -102,27 +105,32 @@ def __init__(self, mod: nn.Module, adapter_infos: List[AdapterInfo] = None, ctx_mgr: Any = None, - all_reduce: bool = False): + colwise: bool = True, + is_tp: bool = False): super().__init__() - layer_backend = get_backend() - - if isinstance(mod, nn.Linear): - impl_builder = layer_backend.get_layer_impl_builder( - LayerType.Linear) - else: - raise NotImplementedError(f'Unsupported linear type: {type(mod)}') - self.impl = impl_builder.build(mod, ctx_mgr, all_reduce) + impl_builder = get_backend().get_layer_impl_builder(LayerType.Linear) + self.impl = impl_builder.build(mod, ctx_mgr) adapter_infos = adapter_infos if adapter_infos is not None else [] - self.adapter = None + self.lora_adapters = None if len(adapter_infos) > 0: - self.adapter = SLoRA(adapter_infos) + self.lora_adapters = nn.ModuleList( + SLoRA(info, ctx_mgr, colwise, is_tp) for info in adapter_infos) - def forward(self, x): - out = self.impl.forward(x) + self.is_tp = is_tp + self.colwise = colwise - if self.adapter is not None: - out = self.adapter(out) + def forward(self, x): + if self.lora_adapters is None: + is_tp = False if self.colwise else self.is_tp + return self.impl.forward(x, is_tp) + + out = self.impl.forward(x, False) + if self.lora_adapters is not None: + for lora_adapter in self.lora_adapters: + out = lora_adapter(x, out) + if self.is_tp: + dist.all_reduce(out) return out @@ -161,7 +169,7 @@ def _merge_base_linear(*linears: List[nn.Module]): def _merge_qlinear(*linears: List[nn.Module]): """merge qlinear.""" weights = [mod.weight for mod in linears] - scalings = [mod.scaling for mod in linears] + scales = [mod.scale for mod in linears] bias = [mod.bias for mod in linears] in_features = weights[0].size(1) @@ -174,7 +182,7 @@ def _merge_qlinear(*linears: List[nn.Module]): out_features = sum(w.size(0) for w in weights) new_weight = torch.cat(weights, dim=0) - new_scaling = torch.cat(scalings, dim=0) + new_scale = torch.cat(scales, dim=0) new_bias = None if bias[0] is not None: assert all(b is not None for b in bias) @@ -187,7 +195,7 @@ def _merge_qlinear(*linears: List[nn.Module]): device=device) state_dict = dict( weight=new_weight, - scaling=new_scaling, + scale=new_scale, ) if has_bias: state_dict['bias'] = new_bias @@ -242,10 +250,67 @@ def _merge_awqlinear(*linears: List[nn.Module]): return merged_linear -def build_merged_linear(*linears: List[nn.Module], - ctx_mgr: Any = None, - all_reduce: bool = False, - free_origin=False): +def build_linear(mod: nn.Module, + adapter_infos: List[AdapterInfo] = None, + ctx_mgr: Any = None, + colwise: bool = True, + is_tp: bool = False) -> nn.Module: + """build linear.""" + if is_tp: + world_size, rank = _get_world_rank() + is_tp = world_size > 1 + + if isinstance(mod, nn.Linear): + return BaseLinear(mod, + adapter_infos, + ctx_mgr, + colwise=colwise, + is_tp=is_tp) + elif isinstance(mod, WQLinear_GEMM): + return AwqLinear(mod, adapter_infos) + elif isinstance(mod, QLinear): + return W8A8Linear(mod, ctx_mgr, colwise, is_tp) + elif isinstance(mod, LoRALinear): + base_layer = mod.base_layer + adapter_info = AdapterInfo.from_lora_linear(mod) + return build_linear(base_layer, [adapter_info], + ctx_mgr=ctx_mgr, + colwise=colwise, + is_tp=is_tp) + elif isinstance(mod, AwqLoraLinear): + base_layer = mod.base_layer + adapter_info = AdapterInfo.from_lora_linear(mod) + return build_linear(base_layer, [adapter_info], + ctx_mgr=ctx_mgr, + colwise=colwise, + is_tp=is_tp) + else: + raise NotImplementedError(f'Unknown linear type: {type(mod)}') + + +def build_colwise_linear(mod: nn.Module, + adapter_infos: List[AdapterInfo] = None, + ctx_mgr: Any = None, + is_tp: bool = False) -> nn.Module: + return build_linear(mod, adapter_infos, ctx_mgr, colwise=True, is_tp=is_tp) + + +def build_rowwise_linear(mod: nn.Module, + adapter_infos: List[AdapterInfo] = None, + ctx_mgr: Any = None, + is_tp: bool = False) -> nn.Module: + return build_linear(mod, + adapter_infos, + ctx_mgr, + colwise=False, + is_tp=is_tp) + + +def build_merged_colwise_linear( + *linears: List[nn.Module], + ctx_mgr: Any = None, + is_tp: bool = False, +): """merge linear.""" base_layers = [] out_features = [] @@ -288,38 +353,8 @@ def build_merged_linear(*linears: List[nn.Module], base_layer = _merge_qlinear(*base_layers) else: raise NotImplementedError(f'Unknown linear type: {type(mod)}') - ret = build_linear(base_layer, - adapter_infos, - ctx_mgr=ctx_mgr, - all_reduce=all_reduce) - if free_origin: - for mod in linears: - mod.to('meta') + ret = build_colwise_linear(base_layer, + adapter_infos, + ctx_mgr=ctx_mgr, + is_tp=is_tp) return ret - - -def build_linear(mod: nn.Module, - adapter_infos: List[AdapterInfo] = None, - ctx_mgr: Any = None, - all_reduce: bool = False) -> nn.Module: - """build linear.""" - if all_reduce: - world_size, rank = _get_world_rank() - all_reduce = world_size > 1 - - if isinstance(mod, nn.Linear): - return BaseLinear(mod, adapter_infos, ctx_mgr, all_reduce) - elif isinstance(mod, WQLinear_GEMM): - return AwqLinear(mod, adapter_infos) - elif isinstance(mod, QLinear): - return W8A8Linear(mod, adapter_infos) - elif isinstance(mod, LoRALinear): - base_layer = mod.base_layer - adapter_info = AdapterInfo.from_lora_linear(mod) - return build_linear(base_layer, [adapter_info]) - elif isinstance(mod, AwqLoraLinear): - base_layer = mod.base_layer - adapter_info = AdapterInfo.from_lora_linear(mod) - return build_linear(base_layer, [adapter_info]) - else: - raise NotImplementedError(f'Unknown linear type: {type(mod)}') diff --git a/lmdeploy/pytorch/layers/multinomial_sampling.py b/lmdeploy/pytorch/layers/multinomial_sampling.py new file mode 100644 index 000000000..dd00e870c --- /dev/null +++ b/lmdeploy/pytorch/layers/multinomial_sampling.py @@ -0,0 +1,13 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from ..backends import LayerType, get_backend + + +def multinomial_sampling(scores: torch.Tensor, + seeds: torch.LongTensor, + offsets: torch.LongTensor, + indices: torch.Tensor = None): + impl_builder = get_backend().get_layer_impl_builder( + LayerType.MultinomialSampling) + return impl_builder.build()(scores, seeds, offsets, indices) diff --git a/lmdeploy/pytorch/layers/norm.py b/lmdeploy/pytorch/layers/norm.py index c4ce81740..1767794f1 100644 --- a/lmdeploy/pytorch/layers/norm.py +++ b/lmdeploy/pytorch/layers/norm.py @@ -7,10 +7,16 @@ class RMSNorm(nn.Module): - def __init__(self, weight: torch.Tensor, eps: float = 1e-6): + def __init__(self, + weight: torch.Tensor, + eps: float = 1e-6, + is_w8a8: bool = False): super().__init__() backend = get_backend() - builder = backend.get_layer_impl_builder(LayerType.RMSNorm) + if is_w8a8: + builder = backend.get_layer_impl_builder(LayerType.RMSNormW8A8) + else: + builder = backend.get_layer_impl_builder(LayerType.RMSNorm) self.impl = builder.build(weight, eps) def forward(self, x: torch.Tensor, residual: torch.Tensor = None): diff --git a/lmdeploy/pytorch/models/llama.py b/lmdeploy/pytorch/models/llama.py index 8ed675816..a02acc77d 100644 --- a/lmdeploy/pytorch/models/llama.py +++ b/lmdeploy/pytorch/models/llama.py @@ -7,9 +7,10 @@ from transformers.modeling_outputs import BaseModelOutputWithPast from lmdeploy.pytorch.layers import (ApplyRotaryEmb, Attention, EmbeddingType, - RMSNorm, SiluAndMul, build_linear, - build_merged_linear, + RMSNorm, SiluAndMul, build_rotary_embedding) +from lmdeploy.pytorch.layers.linear import (build_merged_colwise_linear, + build_rowwise_linear) from lmdeploy.pytorch.model_inputs import StepContextManager from ..weight_loader.dist_utils import (colwise_parallelize_linear, @@ -24,17 +25,21 @@ def __init__(self, origin: nn.Module, ctx_mgr: StepContextManager): world_size = 1 if dist.is_initialized(): world_size = dist.get_world_size() + is_tp = world_size > 1 self.ctx_mgr = ctx_mgr self.num_heads = origin.num_heads // world_size self.num_kv_heads = origin.num_key_value_heads // world_size self.head_dim = origin.head_dim # qkv - self.qkv_proj = build_merged_linear(origin.q_proj, - origin.k_proj, - origin.v_proj, - ctx_mgr=ctx_mgr, - free_origin=True) + self.qkv_proj = build_merged_colwise_linear( + origin.q_proj, + origin.k_proj, + origin.v_proj, + ctx_mgr=ctx_mgr, + is_tp=is_tp, + ) + del origin.q_proj, origin.k_proj, origin.v_proj self.apply_rotary_pos_emb = ApplyRotaryEmb() @@ -46,9 +51,11 @@ def __init__(self, origin: nn.Module, ctx_mgr: StepContextManager): v_head_size=self.head_dim, ) - self.o_proj = build_linear(origin.o_proj, - ctx_mgr=ctx_mgr, - all_reduce=world_size > 1) + self.o_proj = build_rowwise_linear( + origin.o_proj, + ctx_mgr=ctx_mgr, + is_tp=is_tp, + ) @staticmethod def _load_weights(mod, loader, rank: int, world_size: int, @@ -75,17 +82,18 @@ def forward( Optional[Tuple[torch.Tensor]]]: """Rewrite of LlamaAttention.forward.""" context = self.ctx_mgr.current_context() - qkv_states = self.qkv_proj(hidden_states) # (-1, heads, head_dim) qkv_states = qkv_states.flatten(0, -2) qkv_states = qkv_states.unflatten(-1, (-1, self.head_dim)) - query_states, key_states, value_states = qkv_states.split(( - self.num_heads, - self.num_kv_heads, - self.num_kv_heads, - ), - dim=1) + query_states, key_states, value_states = qkv_states.split( + ( + self.num_heads, + self.num_kv_heads, + self.num_kv_heads, + ), + dim=1, + ) cos, sin = rotary_pos_emb query_states, key_states = self.apply_rotary_pos_emb( @@ -115,21 +123,26 @@ class LlamaMLP(nn.Module): def __init__(self, origin: nn.Module, ctx_mgr: StepContextManager): super().__init__() + world_size = 1 + if dist.is_initialized(): + world_size = dist.get_world_size() + is_tp = world_size > 1 # gate up - self.gate_up_proj = build_merged_linear( + self.gate_up_proj = build_merged_colwise_linear( origin.gate_proj, origin.up_proj, ctx_mgr=ctx_mgr, - free_origin=True, + is_tp=is_tp, ) + del origin.gate_proj, origin.up_proj # silu and mul self.act_fn = SiluAndMul(inplace=True) # down - self.down_proj = build_linear(origin.down_proj, - ctx_mgr=ctx_mgr, - all_reduce=True) + self.down_proj = build_rowwise_linear(origin.down_proj, + ctx_mgr=ctx_mgr, + is_tp=is_tp) @staticmethod def _load_weights(mod: nn.Module, loader, rank: int, world_size: int, @@ -164,12 +177,19 @@ def __init__(self, origin: nn.Module, layer_idx: int, # norm input_layernorm = origin.input_layernorm - self.input_layernorm = RMSNorm(input_layernorm.weight, - input_layernorm.variance_epsilon) + is_w8a8 = hasattr(input_layernorm, 'from_float') + self.input_layernorm = RMSNorm( + input_layernorm.weight, + input_layernorm.variance_epsilon, + is_w8a8=is_w8a8, + ) post_attention_layernorm = origin.post_attention_layernorm + is_w8a8 = hasattr(post_attention_layernorm, 'from_float') self.post_attention_layernorm = RMSNorm( post_attention_layernorm.weight, - post_attention_layernorm.variance_epsilon) + post_attention_layernorm.variance_epsilon, + is_w8a8=is_w8a8, + ) def forward( self, @@ -213,7 +233,10 @@ def __init__(self, origin: nn.Module, ctx_mgr: StepContextManager): for idx, layer in enumerate(origin.layers) ]) norm = origin.norm - self.norm = RMSNorm(norm.weight, norm.variance_epsilon) + is_w8a8 = hasattr(norm, 'from_float') + self.norm = RMSNorm(norm.weight, + norm.variance_epsilon, + is_w8a8=is_w8a8) rotary_emb = origin.layers[0].self_attn.rotary_emb rotary_name = type(rotary_emb).__name__ @@ -223,11 +246,12 @@ def __init__(self, origin: nn.Module, ctx_mgr: StepContextManager): emb_type = EmbeddingType.LinearScaling elif rotary_name == 'LlamaDynamicNTKScalingRotaryEmbedding': emb_type = EmbeddingType.DynamicNTKScaling + scaling_factor = getattr(rotary_emb, 'scaling_factor', 1.0) self.rotary_emb = build_rotary_embedding( rotary_emb.dim, rotary_emb.max_position_embeddings, rotary_emb.base, - rotary_emb.scaling_factor, + scaling_factor, emb_type, ) @@ -278,7 +302,7 @@ def __init__(self, origin: nn.Module, ctx_mgr: StepContextManager): super().__init__() self.ctx_mgr = ctx_mgr self.model = LlamaModel(origin.model, ctx_mgr) - self.lm_head = build_linear(origin.lm_head) + self.lm_head = build_rowwise_linear(origin.lm_head) def forward( self, @@ -297,5 +321,4 @@ def forward( logits = self.lm_head(hidden_states) logits = logits.float() - return logits diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index da09d589a..bc9d32541 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -10,17 +10,11 @@ f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaAttention', 'transformers.models.llama.modeling_llama.LlamaAttention': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaAttention', - 'transformers.models.llama.modeling_llama.LlamaModel': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaModel', 'transformers.models.llama.modeling_llama.LlamaMLP': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaMLP', - 'transformers.models.llama.modeling_llama.LlamaDecoderLayer': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaDecoderLayer', # support modeling rewritten in lmdeploy 'modeling_llama.LlamaAttention': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaAttention', - 'modeling_llama.LlamaModel': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaModel', 'modeling_llama.LlamaMLP': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaMLP', 'modeling_llama.LlamaForCausalLM': diff --git a/lmdeploy/pytorch/models/q_modules.py b/lmdeploy/pytorch/models/q_modules.py index 001fab7a6..ea4bd1eba 100644 --- a/lmdeploy/pytorch/models/q_modules.py +++ b/lmdeploy/pytorch/models/q_modules.py @@ -138,6 +138,9 @@ def forward(self, input): if isinstance(input, torch.Tensor): input_quant, input_scale = per_token_quant_int8(input, 1e-7) + print(input_quant) + print(input_scale) + exit() else: assert isinstance(input, QTensor) input_quant, input_scale = input.tensor, input.scale From 4755b1e1775ebeff5526b2a4343839fbe27f6089 Mon Sep 17 00:00:00 2001 From: grimoire Date: Tue, 23 Jul 2024 14:27:53 +0800 Subject: [PATCH 06/89] support awq --- lmdeploy/pytorch/backends/awq_modules.py | 21 +++++ lmdeploy/pytorch/backends/base.py | 1 + lmdeploy/pytorch/backends/cuda/awq_modules.py | 88 +++++++++++++++++++ .../pytorch/backends/cuda/layer_backend.py | 23 ++++- .../pytorch/backends/default/awq_modules.py | 48 ++++++++++ .../pytorch/backends/default/layer_backend.py | 3 + lmdeploy/pytorch/layers/linear.py | 48 ++++++++-- lmdeploy/pytorch/model_inputs.py | 53 ++++++----- lmdeploy/pytorch/tools/make_inputs.py | 1 - tests/pytorch/kernel/test_apply_rotary.py | 19 ++-- 10 files changed, 256 insertions(+), 49 deletions(-) create mode 100644 lmdeploy/pytorch/backends/awq_modules.py create mode 100644 lmdeploy/pytorch/backends/cuda/awq_modules.py create mode 100644 lmdeploy/pytorch/backends/default/awq_modules.py diff --git a/lmdeploy/pytorch/backends/awq_modules.py b/lmdeploy/pytorch/backends/awq_modules.py new file mode 100644 index 000000000..b82af3bd7 --- /dev/null +++ b/lmdeploy/pytorch/backends/awq_modules.py @@ -0,0 +1,21 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABC, abstractmethod + +from torch import nn + +from lmdeploy.pytorch.model_inputs import StepContextManager + + +class LinearW4A16Impl(ABC, nn.Module): + + @abstractmethod + def forward(self, x, all_reduce: bool = False): + raise NotImplementedError + + +class LinearW4A16Builder(ABC): + + @staticmethod + @abstractmethod + def build(mod: nn.Module, ctx_mgr: StepContextManager = None): + raise NotImplementedError diff --git a/lmdeploy/pytorch/backends/base.py b/lmdeploy/pytorch/backends/base.py index d71435eb3..0bb8b94b9 100644 --- a/lmdeploy/pytorch/backends/base.py +++ b/lmdeploy/pytorch/backends/base.py @@ -19,6 +19,7 @@ class LayerType(Enum): LinearW8A8 = auto() RMSNormW8A8 = auto() MultinomialSampling = auto() + LinearW4A16 = auto() class LayersBackend(ABC): diff --git a/lmdeploy/pytorch/backends/cuda/awq_modules.py b/lmdeploy/pytorch/backends/cuda/awq_modules.py new file mode 100644 index 000000000..25c4023a7 --- /dev/null +++ b/lmdeploy/pytorch/backends/cuda/awq_modules.py @@ -0,0 +1,88 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch import distributed as dist +from torch import nn + +from lmdeploy.pytorch.model_inputs import StepContextManager + +from ..awq_modules import LinearW4A16Builder, LinearW4A16Impl + + +def wq_gemm_forward( + x, + qweight, + qzeros, + scales, + w_bit=4, + group_size=128, + bias=None, + out_features=0, +): + """wq gemm forward.""" + from awq.modules.linear.gemm import awq_ext + out_shape = x.shape[:-1] + (out_features, ) + input_dtype = x.dtype + if input_dtype != torch.float16: + x = x.half() + + FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0] * x.shape[1] >= 1024 + + if FP16_MATMUL_HEURISTIC_CONDITION: + # TODO: remove event wait if awq kernel set stream + default_stream = torch.cuda.default_stream() + event_def = torch.cuda.Event() + event_def.record() + event_def.wait(default_stream) + out = awq_ext.dequantize_weights_cuda(qweight, scales, qzeros, 0, + 0, 0, False) + event_def = torch.cuda.Event() + event_def.record(default_stream) + event_def.wait() + out = torch.matmul(x, out) + else: + x = x.flatten(0, -2) + if not x.is_contiguous(): + x = x.contiguous() + out = awq_ext.gemm_forward_cuda(x, qweight, scales, qzeros, 8) + + out = out + bias if bias is not None else out + out = out.reshape(out_shape) + + # always want 3D tensor if tensor is 2D + if len(out.shape) == 2: + out = out.unsqueeze(0) + + if input_dtype != torch.float16: + out = out.to(dtype=input_dtype) + return out + +class AwqLinearW4A16Impl(LinearW4A16Impl): + + def __init__(self, mod: nn.Module): + super().__init__() + from awq.modules.linear.gemm import AWQ_INSTALLED + assert AWQ_INSTALLED + self.qweight = mod.qweight + self.qzeros = mod.qzeros + self.scales = mod.scales + self.w_bit = mod.w_bit + self.group_size = mod.group_size + self.bias = mod.bias + self.in_features = mod.in_features + self.out_features = mod.out_features + + def forward(self, x, all_reduce: bool = False): + out_features = self.scales.size(1) + out = wq_gemm_forward(x, self.qweight, self.qzeros, self.scales, + self.w_bit, self.group_size, self.bias, + out_features) + if all_reduce: + dist.all_reduce(out) + return out + + +class AwqLinearW4A16Builder(LinearW4A16Builder): + + @staticmethod + def build(mod: nn.Module, ctx_mgr: StepContextManager = None): + return AwqLinearW4A16Impl(mod) diff --git a/lmdeploy/pytorch/backends/cuda/layer_backend.py b/lmdeploy/pytorch/backends/cuda/layer_backend.py index 7c8e26661..f7b561629 100644 --- a/lmdeploy/pytorch/backends/cuda/layer_backend.py +++ b/lmdeploy/pytorch/backends/cuda/layer_backend.py @@ -3,9 +3,13 @@ import torch +from lmdeploy.utils import get_logger + from ..base import LayerType from ..default import DefaultLayersBackend +logger = get_logger('lmdeploy') + class CudaLayersBackend(DefaultLayersBackend): @@ -36,7 +40,18 @@ def get_layer_impl_builder(cls, layer_type: LayerType): elif layer_type == LayerType.MultinomialSampling: from .multinomial_sampling import TritonMultinomialSamplingBuilder return TritonMultinomialSamplingBuilder + elif layer_type == LayerType.LinearW4A16: + from awq.modules.linear.gemm import AWQ_INSTALLED + if AWQ_INSTALLED: + from .awq_modules import AwqLinearW4A16Builder + return AwqLinearW4A16Builder + else: + logger.debug( + f'Op {layer_type} fallback to default implementation.') + return super().get_layer_impl_builder(layer_type) else: + logger.debug( + f'Op {layer_type} fallback to default implementation.') return super().get_layer_impl_builder(layer_type) @staticmethod @@ -78,10 +93,10 @@ def update_step_context(cls, step_context): step_context.is_decoding, step_context.block_offsets, q_start_loc=step_context.q_start_loc, - q_seqlens=step_context.q_seq_length, - kv_seqlens=step_context.kv_seq_length, - max_q_seqlen=step_context.max_q_seq_length, - max_kv_seqlen=step_context.max_kv_seq_length, + q_seqlens=step_context.q_seqlens, + kv_seqlens=step_context.kv_seqlens, + max_q_seqlen=step_context.max_q_seqlen, + max_kv_seqlen=step_context.max_kv_seqlen, ) step_context.attn_meta = attn_meta diff --git a/lmdeploy/pytorch/backends/default/awq_modules.py b/lmdeploy/pytorch/backends/default/awq_modules.py new file mode 100644 index 000000000..320d4eaa4 --- /dev/null +++ b/lmdeploy/pytorch/backends/default/awq_modules.py @@ -0,0 +1,48 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch import distributed as dist +from torch import nn + +from lmdeploy.pytorch.model_inputs import StepContextManager + +from ..awq_modules import LinearW4A16Builder, LinearW4A16Impl + + +class DefaultLinearW4A16Impl(LinearW4A16Impl): + + def __init__(self, mod: nn.Module): + super().__init__() + self.qweight = mod.qweight + self.qzeros = mod.qzeros + self.scales = mod.scales + self.w_bit = mod.w_bit + self.group_size = mod.group_size + self.bias = mod.bias + self.in_features = mod.in_features + self.out_features = mod.out_features + + def forward(self, x, all_reduce: bool = False): + from awq.utils.packing_utils import dequantize_gemm + out_shape = x.shape[:-1] + (self.out_features, ) + input_dtype = x.dtype + if input_dtype != torch.float16: + x = x.half() + + out = dequantize_gemm(self.qweight, self.qzeros, self.scales, self.w_bit, self.group_size) + out = torch.matmul(x, out) + + out = out + self.bias if self.bias is not None else out + out = out.reshape(out_shape) + + if input_dtype != torch.float16: + out = out.to(dtype=input_dtype) + if all_reduce: + dist.all_reduce(out) + return out + + +class DefaultLinearW4A16Builder(LinearW4A16Builder): + + @staticmethod + def build(mod: nn.Module, ctx_mgr: StepContextManager = None): + return DefaultLinearW4A16Impl(mod) diff --git a/lmdeploy/pytorch/backends/default/layer_backend.py b/lmdeploy/pytorch/backends/default/layer_backend.py index b80730f43..dc8619979 100644 --- a/lmdeploy/pytorch/backends/default/layer_backend.py +++ b/lmdeploy/pytorch/backends/default/layer_backend.py @@ -32,6 +32,9 @@ def get_layer_impl_builder(cls, layer_type: LayerType): elif layer_type == LayerType.MultinomialSampling: from .multinomial_sampling import DefaultMultinomialSamplingBuilder return DefaultMultinomialSamplingBuilder + elif layer_type == LayerType.LinearW4A16: + from .awq_modules import DefaultLinearW4A16Builder + return DefaultLinearW4A16Builder else: raise RuntimeError(f'{layer_type} not supported.') diff --git a/lmdeploy/pytorch/layers/linear.py b/lmdeploy/pytorch/layers/linear.py index ea13b5acc..7edc06f1b 100644 --- a/lmdeploy/pytorch/layers/linear.py +++ b/lmdeploy/pytorch/layers/linear.py @@ -76,8 +76,36 @@ class AwqLinear(nn.Module): def __init__(self, mod: nn.Module, - adapter_infos: List[AdapterInfo] = None): + adapter_infos: List[AdapterInfo] = None, + ctx_mgr: Any = None, + colwise: bool = True, + is_tp: bool = False): super().__init__() + impl_builder = get_backend().get_layer_impl_builder( + LayerType.LinearW4A16) + self.impl = impl_builder.build(mod, ctx_mgr) + + adapter_infos = adapter_infos if adapter_infos is not None else [] + self.lora_adapters = None + if len(adapter_infos) > 0: + self.lora_adapters = nn.ModuleList( + SLoRA(info, ctx_mgr, colwise, is_tp) for info in adapter_infos) + + self.is_tp = is_tp + self.colwise = colwise + + def forward(self, x): + if self.lora_adapters is None: + is_tp = False if self.colwise else self.is_tp + return self.impl.forward(x, is_tp) + + out = self.impl.forward(x, False) + if self.lora_adapters is not None: + for lora_adapter in self.lora_adapters: + out = lora_adapter(x, out) + if self.is_tp: + dist.all_reduce(out) + return out class W8A8Linear(nn.Module): @@ -216,16 +244,16 @@ def _merge_awqlinear(*linears: List[nn.Module]): group_size = group_sizes[0] assert all(wb == w_bit for wb in w_bits) assert all(gs == group_size for gs in group_sizes) - in_features = qweights[0].size(1) + in_features = qweights[0].size(0) device = qweights[0].device for w in qweights: - assert w.size(1) == in_features + assert w.size(0) == in_features assert w.device == device - out_features = sum(w.size(0) for w in qweights) + out_features = sum(s.size(1) for s in scales) - new_qweight = torch.cat(qweights, dim=0) - new_scales = torch.cat(scales, dim=0) - new_qzeros = torch.cat(qzeros, dim=0) + new_qweight = torch.cat(qweights, dim=1) + new_scales = torch.cat(scales, dim=1) + new_qzeros = torch.cat(qzeros, dim=1) new_bias = None if bias[0] is not None: assert all(b is not None for b in bias) @@ -267,7 +295,11 @@ def build_linear(mod: nn.Module, colwise=colwise, is_tp=is_tp) elif isinstance(mod, WQLinear_GEMM): - return AwqLinear(mod, adapter_infos) + return AwqLinear(mod, + adapter_infos, + ctx_mgr, + colwise=colwise, + is_tp=is_tp) elif isinstance(mod, QLinear): return W8A8Linear(mod, ctx_mgr, colwise, is_tp) elif isinstance(mod, LoRALinear): diff --git a/lmdeploy/pytorch/model_inputs.py b/lmdeploy/pytorch/model_inputs.py index eb16993f9..00141ef35 100644 --- a/lmdeploy/pytorch/model_inputs.py +++ b/lmdeploy/pytorch/model_inputs.py @@ -237,11 +237,10 @@ class StepContext: position_ids: torch.LongTensor q_start_loc: torch.LongTensor attention_mask: torch.LongTensor - history_lengths: torch.LongTensor - q_seq_length: torch.LongTensor - kv_seq_length: torch.LongTensor - max_q_seq_length: int - max_kv_seq_length: int + q_seqlens: torch.LongTensor + kv_seqlens: torch.LongTensor + max_q_seqlen: int + max_kv_seqlen: int kv_caches: List is_decoding: bool world_size: int = 1 @@ -267,42 +266,41 @@ def new( world_size (int): The distribution world size. device (str): The device of the tensors. """ - q_seq_length = inputs.seq_length - max_q_seq_length = inputs.max_q_seq_length - history_lengths = inputs.history_lengths + q_seqlens = inputs.seq_length + max_q_seqlen = inputs.max_q_seq_length + history_seqlens = inputs.history_lengths # for vlm input_embeddings, input_embedding_indexing = None, None if (inputs.vision_inputs is not None and inputs.vision_inputs.input_embeddings is not None): input_embeddings, input_embedding_indexing = \ - inputs.vision_inputs.get_inputs(history_lengths, q_seq_length) + inputs.vision_inputs.get_inputs(history_seqlens, q_seqlens) - batch_size = len(q_seq_length) - device = q_seq_length.device + batch_size = len(q_seqlens) + device = q_seqlens.device - # q_start_loc and kv_seq_length + # q_start_loc and kv_seqlens if inputs.is_decoding: q_start_loc = torch.arange(0, batch_size, device=device) - attention_mask = torch.ones_like(q_seq_length)[:, None] - position_ids = history_lengths.unsqueeze(-1) + attention_mask = torch.ones_like(q_seqlens)[:, None] + position_ids = history_seqlens.unsqueeze(-1) else: - q_start_loc = q_seq_length.cumsum(0) - q_seq_length - mask_range = torch.arange(max_q_seq_length, device=device)[None, :] - attention_mask = (mask_range < q_seq_length[:, None]).long() + q_start_loc = q_seqlens.cumsum(0) - q_seqlens + mask_range = torch.arange(max_q_seqlen, device=device)[None, :] + attention_mask = (mask_range < q_seqlens[:, None]).long() position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids += history_lengths.unsqueeze(-1) + position_ids += history_seqlens.unsqueeze(-1) # position ids 1d - position_ids = cls.get_position_ids_1d(position_ids, - q_seq_length)[None] + position_ids = cls.get_position_ids_1d(position_ids, q_seqlens)[None] # seq_len + history_length - kv_seq_length = q_seq_length + history_lengths - max_kv_seq_length = max_q_seq_length + inputs.max_history_length + kv_seqlens = q_seqlens + history_seqlens + max_kv_seqlen = max_q_seqlen + inputs.max_history_length window_size = getattr(cache_config, 'window_size', 0) if window_size > 0: - kv_seq_length -= inputs.num_ignored_history + kv_seqlens -= inputs.num_ignored_history adapter_params = None if inputs.adapter_info is not None: @@ -315,11 +313,10 @@ def new( input_embedding_indexing=input_embedding_indexing, attention_mask=attention_mask, q_start_loc=q_start_loc, - history_lengths=inputs.history_lengths, - q_seq_length=inputs.seq_length, - kv_seq_length=kv_seq_length, - max_q_seq_length=max_q_seq_length, - max_kv_seq_length=max_kv_seq_length, + q_seqlens=q_seqlens, + kv_seqlens=kv_seqlens, + max_q_seqlen=max_q_seqlen, + max_kv_seqlen=max_kv_seqlen, kv_caches=kv_caches, is_decoding=inputs.is_decoding, world_size=world_size, diff --git a/lmdeploy/pytorch/tools/make_inputs.py b/lmdeploy/pytorch/tools/make_inputs.py index 892a44b6d..82fa3b735 100644 --- a/lmdeploy/pytorch/tools/make_inputs.py +++ b/lmdeploy/pytorch/tools/make_inputs.py @@ -139,7 +139,6 @@ def __fill_kv_caches(kv_caches, past_key_values, block_offsets): return StepContext.new( inputs=model_inputs, world_size=world_size, - device=device, kv_caches=kv_caches, ) diff --git a/tests/pytorch/kernel/test_apply_rotary.py b/tests/pytorch/kernel/test_apply_rotary.py index e13c71d4e..0050a59ef 100644 --- a/tests/pytorch/kernel/test_apply_rotary.py +++ b/tests/pytorch/kernel/test_apply_rotary.py @@ -72,9 +72,15 @@ def cached_sin(self, max_seqlen, feature_dim, dtype): yield torch.rand(max_seqlen, feature_dim, dtype=dtype, device='cuda') @pytest.fixture - def gt(self, q_states, k_states, cached_cos, cached_sin, position_ids_1d): - cos = cached_cos[position_ids_1d, None, :] - sin = cached_sin[position_ids_1d, None, :] + def cos(self, cached_cos, position_ids_1d): + yield cached_cos[position_ids_1d, None, :] + + @pytest.fixture + def sin(self, cached_sin, position_ids_1d): + yield cached_sin[position_ids_1d, None, :] + + @pytest.fixture + def gt(self, q_states, k_states, cos, sin, position_ids_1d): q_embed = q_states * cos + _rotate_half(q_states) * sin k_embed = k_states * cos + _rotate_half(k_states) * sin @@ -86,11 +92,8 @@ def gt(self, q_states, k_states, cached_cos, cached_sin, position_ids_1d): indirect=True) @pytest.mark.parametrize(('num_heads_q', 'num_heads_k'), [(8, 8), (8, 4)], indirect=True) - def test_apply_rotary(self, q_states, k_states, cached_cos, cached_sin, - position_ids_1d, gt): - q_embed, k_embed = apply_rotary_pos_emb(q_states, k_states, cached_cos, - cached_sin, None, - position_ids_1d) + def test_apply_rotary(self, q_states, k_states, cos, sin, gt): + q_embed, k_embed = apply_rotary_pos_emb(q_states, k_states, cos, sin) q_gt, k_gt = gt rtol = None From 60df32f38024806783d12887d7b3ced4b7391a91 Mon Sep 17 00:00:00 2001 From: grimoire Date: Tue, 23 Jul 2024 18:04:19 +0800 Subject: [PATCH 07/89] add add_rms_norm kernel --- lmdeploy/pytorch/backends/cuda/awq_modules.py | 9 +- lmdeploy/pytorch/backends/cuda/norm.py | 9 +- .../pytorch/backends/default/awq_modules.py | 3 +- lmdeploy/pytorch/backends/norm.py | 2 +- lmdeploy/pytorch/kernels/cuda/rms_norm.py | 99 ++++++++++++++++--- 5 files changed, 95 insertions(+), 27 deletions(-) diff --git a/lmdeploy/pytorch/backends/cuda/awq_modules.py b/lmdeploy/pytorch/backends/cuda/awq_modules.py index 25c4023a7..a26fceb06 100644 --- a/lmdeploy/pytorch/backends/cuda/awq_modules.py +++ b/lmdeploy/pytorch/backends/cuda/awq_modules.py @@ -33,8 +33,8 @@ def wq_gemm_forward( event_def = torch.cuda.Event() event_def.record() event_def.wait(default_stream) - out = awq_ext.dequantize_weights_cuda(qweight, scales, qzeros, 0, - 0, 0, False) + out = awq_ext.dequantize_weights_cuda(qweight, scales, qzeros, 0, 0, 0, + False) event_def = torch.cuda.Event() event_def.record(default_stream) event_def.wait() @@ -56,6 +56,7 @@ def wq_gemm_forward( out = out.to(dtype=input_dtype) return out + class AwqLinearW4A16Impl(LinearW4A16Impl): def __init__(self, mod: nn.Module): @@ -74,8 +75,8 @@ def __init__(self, mod: nn.Module): def forward(self, x, all_reduce: bool = False): out_features = self.scales.size(1) out = wq_gemm_forward(x, self.qweight, self.qzeros, self.scales, - self.w_bit, self.group_size, self.bias, - out_features) + self.w_bit, self.group_size, self.bias, + out_features) if all_reduce: dist.all_reduce(out) return out diff --git a/lmdeploy/pytorch/backends/cuda/norm.py b/lmdeploy/pytorch/backends/cuda/norm.py index 83edcd748..bee9ed73a 100644 --- a/lmdeploy/pytorch/backends/cuda/norm.py +++ b/lmdeploy/pytorch/backends/cuda/norm.py @@ -15,13 +15,12 @@ def __init__(self, weight: torch.Tensor, eps: float = 1e-6): self.eps = eps def forward(self, x: torch.Tensor, residual: torch.Tensor = None): - if residual is not None: - x = x + residual - residual = x - x = rms_norm(x, self.weight, self.eps) if residual is None: + x = rms_norm(x, self.weight, self.eps) return x - return x, residual + else: + x, residual = rms_norm(x, self.weight, self.eps, residual=residual) + return x, residual class TritonRMSNormBuilder(RMSNormBuilder): diff --git a/lmdeploy/pytorch/backends/default/awq_modules.py b/lmdeploy/pytorch/backends/default/awq_modules.py index 320d4eaa4..1a6ed490a 100644 --- a/lmdeploy/pytorch/backends/default/awq_modules.py +++ b/lmdeploy/pytorch/backends/default/awq_modules.py @@ -28,7 +28,8 @@ def forward(self, x, all_reduce: bool = False): if input_dtype != torch.float16: x = x.half() - out = dequantize_gemm(self.qweight, self.qzeros, self.scales, self.w_bit, self.group_size) + out = dequantize_gemm(self.qweight, self.qzeros, self.scales, + self.w_bit, self.group_size) out = torch.matmul(x, out) out = out + self.bias if self.bias is not None else out diff --git a/lmdeploy/pytorch/backends/norm.py b/lmdeploy/pytorch/backends/norm.py index a0008fb85..21bf01a96 100644 --- a/lmdeploy/pytorch/backends/norm.py +++ b/lmdeploy/pytorch/backends/norm.py @@ -15,5 +15,5 @@ class RMSNormBuilder(ABC): @staticmethod @abstractmethod - def build(weight: torch.Tensor, eps: float = 1e-6): + def build(weight: torch.Tensor, eps: float = 1e-6, inplace: bool = False): raise NotImplementedError diff --git a/lmdeploy/pytorch/kernels/cuda/rms_norm.py b/lmdeploy/pytorch/kernels/cuda/rms_norm.py index b02f4f7c7..0a171f762 100644 --- a/lmdeploy/pytorch/kernels/cuda/rms_norm.py +++ b/lmdeploy/pytorch/kernels/cuda/rms_norm.py @@ -7,6 +7,17 @@ from .triton_utils import get_kernel_meta, wrap_jit_func +@triton.jit +def _compute_rms_norm(x, w, eps: tl.constexpr, N_COLS: tl.constexpr): + """compute rms norm.""" + xf = x.to(tl.float32) + + var = tl.sum(xf * xf, 0) * float(1.0 / N_COLS) + out = xf / tl.sqrt(var + eps) + out = (w * out).to(x.dtype) + return out + + @wrap_jit_func(type_hint=dict( input=Tensor, weight=Tensor, @@ -28,20 +39,56 @@ def rms_norm_kernel(input, weight, output, input_row_stride: tl.constexpr, x_ptr = input + prog_id * input_row_stride x = tl.load(x_ptr + offsets, mask=offsets < N_COLS) - xf = x.to(tl.float32) + out = _compute_rms_norm(x, w, eps, N_COLS) - var = tl.sum(xf * xf, 0) * float(1.0 / N_COLS) - out = xf / tl.sqrt(var + eps) - out = (w * out).to(x.dtype) + out_ptr = output + prog_id * input_row_stride + tl.store(out_ptr + offsets, out, mask=offsets < N_COLS) + + +@wrap_jit_func(type_hint=dict( + input=Tensor, + weight=Tensor, + residual=Tensor, + output=Tensor, + out_residual=Tensor, + input_row_stride=int, + residual_row_stride=int, + eps=float, + N_COLS=torch.int32, + BLOCK_N=torch.int32, +)) +@triton.jit +def add_rms_norm_kernel(input, weight, residual, output, out_residual, + input_row_stride: tl.constexpr, + residual_row_stride: tl.constexpr, eps: tl.constexpr, + N_COLS: tl.constexpr, BLOCK_N: tl.constexpr): + """rms norm kernel.""" + prog_id = tl.program_id(0) + offsets = tl.arange(0, BLOCK_N) + + w = tl.load(weight + offsets, mask=offsets < N_COLS) + + x_ptr = input + prog_id * input_row_stride + x = tl.load(x_ptr + offsets, mask=offsets < N_COLS) + + res_ptr = residual + prog_id * residual_row_stride + res = tl.load(res_ptr + offsets, mask=offsets < N_COLS) + + new_x = x + res + out = _compute_rms_norm(new_x, w, eps, N_COLS) out_ptr = output + prog_id * input_row_stride tl.store(out_ptr + offsets, out, mask=offsets < N_COLS) + out_res_ptr = out_residual + prog_id * residual_row_stride + tl.store(out_res_ptr + offsets, new_x, mask=offsets < N_COLS) def rms_norm(hidden_states: Tensor, weight: Tensor, eps: float = 1e-6, - out: Tensor = None): + residual: Tensor = None, + out: Tensor = None, + out_residual: Tensor = None): """rms norm.""" feat_size = weight.shape[0] @@ -55,18 +102,38 @@ def rms_norm(hidden_states: Tensor, kernel_meta = get_kernel_meta(hidden_states) grid = (seq_len, ) - rms_norm_kernel[grid](hidden_states, - weight, - out, - input_row_stride=input_stride, - eps=eps, - N_COLS=feat_size, - BLOCK_N=BLOCK_N, - num_warps=4, - num_stages=2, - **kernel_meta) - return out + if residual is None: + rms_norm_kernel[grid](hidden_states, + weight, + out, + input_row_stride=input_stride, + eps=eps, + N_COLS=feat_size, + BLOCK_N=BLOCK_N, + num_warps=4, + num_stages=2, + **kernel_meta) + return out + else: + if out_residual is None: + out_residual = torch.empty_like(hidden_states) + + res_stride = residual.stride(-2) + add_rms_norm_kernel[grid](hidden_states, + weight, + residual, + out, + out_residual, + input_row_stride=input_stride, + residual_row_stride=res_stride, + eps=eps, + N_COLS=feat_size, + BLOCK_N=BLOCK_N, + num_warps=4, + num_stages=2, + **kernel_meta) + return out, out_residual if __name__ == '__main__': From 67aba315cc060c4b9ccbfc85235b16c8013ae90d Mon Sep 17 00:00:00 2001 From: grimoire Date: Wed, 24 Jul 2024 13:35:04 +0800 Subject: [PATCH 08/89] optimize step context --- lmdeploy/pytorch/backends/attention.py | 2 -- lmdeploy/pytorch/backends/cuda/attention.py | 2 +- .../pytorch/backends/cuda/layer_backend.py | 8 ++--- lmdeploy/pytorch/engine/engine.py | 3 -- lmdeploy/pytorch/engine/model_agent.py | 7 ++-- lmdeploy/pytorch/model_inputs.py | 33 +++---------------- 6 files changed, 12 insertions(+), 43 deletions(-) diff --git a/lmdeploy/pytorch/backends/attention.py b/lmdeploy/pytorch/backends/attention.py index 2745f4fde..84e605c25 100644 --- a/lmdeploy/pytorch/backends/attention.py +++ b/lmdeploy/pytorch/backends/attention.py @@ -13,8 +13,6 @@ class AttentionMetadata: q_start_loc: torch.Tensor = None q_seqlens: torch.Tensor = None kv_seqlens: torch.Tensor = None - max_q_seqlen: int = 0 - max_kv_seqlen: int = 0 T = TypeVar('T', bound=AttentionMetadata) diff --git a/lmdeploy/pytorch/backends/cuda/attention.py b/lmdeploy/pytorch/backends/cuda/attention.py index 92d06c761..5751cbb74 100644 --- a/lmdeploy/pytorch/backends/cuda/attention.py +++ b/lmdeploy/pytorch/backends/cuda/attention.py @@ -52,7 +52,7 @@ def forward( q_start_loc = attn_metadata.q_start_loc q_seqlens = attn_metadata.q_seqlens kv_seqlens = attn_metadata.kv_seqlens - max_q_seqlen = attn_metadata.max_q_seqlen + max_q_seqlen = query.numel() // (query.size(-1) * query.size(-2)) # fill kv cache self.fill_kv_cache( diff --git a/lmdeploy/pytorch/backends/cuda/layer_backend.py b/lmdeploy/pytorch/backends/cuda/layer_backend.py index f7b561629..707153afa 100644 --- a/lmdeploy/pytorch/backends/cuda/layer_backend.py +++ b/lmdeploy/pytorch/backends/cuda/layer_backend.py @@ -89,14 +89,14 @@ def get_v_block_shape( def update_step_context(cls, step_context): """update step context.""" attn_meta_cls = cls.get_attention_metadata_cls() + q_seqlens = step_context.q_seqlens + q_start_loc = q_seqlens.cumsum(0) - q_seqlens attn_meta = attn_meta_cls( step_context.is_decoding, step_context.block_offsets, - q_start_loc=step_context.q_start_loc, - q_seqlens=step_context.q_seqlens, + q_start_loc=q_start_loc, + q_seqlens=q_seqlens, kv_seqlens=step_context.kv_seqlens, - max_q_seqlen=step_context.max_q_seqlen, - max_kv_seqlen=step_context.max_kv_seqlen, ) step_context.attn_meta = attn_meta diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 18f4325ad..b8ed66cf1 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -374,7 +374,6 @@ def create_model_inputs(self, messages: SeqList, adapters: AdapterList, else: seq_length = self._seq_length_buf[:batch_size] max_q_seq_length = seq_length.max().item() - max_history_length = history_lengths.max().item() # TODO: get block offsets is slow when block_size = 1 block_offsets = self.scheduler.get_block_tables(messages) @@ -456,8 +455,6 @@ def __get_vlm_embeddings(): seq_length=seq_length, history_lengths=history_lengths, block_offsets=block_offsets, - max_q_seq_length=max_q_seq_length, - max_history_length=max_history_length, is_decoding=is_decoding, num_ignored_history=num_ignored_history, local_adapter_ids=local_adapter_ids, diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 482d45d49..4afbb827b 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -16,7 +16,7 @@ get_loralinear_info, update_lora_linears) from ..config import CacheConfig, ModelConfig from ..devices import DeviceContext, get_device_manager -from ..model_inputs import ModelInputs, StepContext +from ..model_inputs import ModelInputs from ..models.patch import patch, update_model from ..utils import get_gpu_memory from ..weight_loader.model_weight_loader import load_model_weights @@ -140,13 +140,12 @@ def model_forward( with torch.cuda.stream(stream): # forward inputs = inputs.to_device('cuda') - context = StepContext.new( + ctx_mgr = model.ctx_mgr + context = ctx_mgr.build_context( inputs=inputs, world_size=world_size, kv_caches=cache_engine.gpu_cache, - cache_config=cache_engine.cache_config, ) - ctx_mgr = model.ctx_mgr with ctx_mgr.context(context): output = model( input_ids=inputs.input_ids, diff --git a/lmdeploy/pytorch/model_inputs.py b/lmdeploy/pytorch/model_inputs.py index 00141ef35..b51df7b0d 100644 --- a/lmdeploy/pytorch/model_inputs.py +++ b/lmdeploy/pytorch/model_inputs.py @@ -8,7 +8,6 @@ from lmdeploy.pytorch.backends import get_backend from .adapter.adapter import SchedulerAdapter -from .config import CacheConfig @dataclass @@ -146,8 +145,6 @@ class ModelInputs: seq_length: torch.LongTensor history_lengths: torch.LongTensor block_offsets: torch.LongTensor - max_q_seq_length: int - max_history_length: int is_decoding: bool num_ignored_history: torch.LongTensor local_adapter_ids: torch.LongTensor = None @@ -159,7 +156,6 @@ def update(self, input_ids: torch.LongTensor): """update input ids.""" assert self.is_decoding self.history_lengths = self.history_lengths + 1 - self.max_history_length = self.max_history_length + 1 if input_ids.dim() == 1: input_ids = input_ids[None, :] self.input_ids = input_ids @@ -194,8 +190,6 @@ def split(self, split_size: int, block_size: int): seq_length=input_ids.new_tensor([end - start]), block_offsets=block_offsets, history_lengths=self.history_lengths + start, - max_q_seq_length=end - start, - max_history_length=self.max_history_length + start, is_decoding=self.is_decoding, num_ignored_history=self.num_ignored_history, local_adapter_ids=self.local_adapter_ids, @@ -232,15 +226,11 @@ class StepContext: patched model might need extra information to perform inference. This dataclass provide these infos and tools. """ - inputs: ModelInputs block_offsets: torch.LongTensor position_ids: torch.LongTensor - q_start_loc: torch.LongTensor attention_mask: torch.LongTensor q_seqlens: torch.LongTensor kv_seqlens: torch.LongTensor - max_q_seqlen: int - max_kv_seqlen: int kv_caches: List is_decoding: bool world_size: int = 1 @@ -257,7 +247,6 @@ def new( inputs: ModelInputs, world_size: int = 1, kv_caches: List = None, - cache_config: CacheConfig = None, ): """build step context. @@ -267,8 +256,8 @@ def new( device (str): The device of the tensors. """ q_seqlens = inputs.seq_length - max_q_seqlen = inputs.max_q_seq_length history_seqlens = inputs.history_lengths + device = q_seqlens.device # for vlm input_embeddings, input_embedding_indexing = None, None @@ -277,16 +266,12 @@ def new( input_embeddings, input_embedding_indexing = \ inputs.vision_inputs.get_inputs(history_seqlens, q_seqlens) - batch_size = len(q_seqlens) - device = q_seqlens.device - # q_start_loc and kv_seqlens if inputs.is_decoding: - q_start_loc = torch.arange(0, batch_size, device=device) attention_mask = torch.ones_like(q_seqlens)[:, None] position_ids = history_seqlens.unsqueeze(-1) else: - q_start_loc = q_seqlens.cumsum(0) - q_seqlens + max_q_seqlen = q_seqlens.max().item() mask_range = torch.arange(max_q_seqlen, device=device)[None, :] attention_mask = (mask_range < q_seqlens[:, None]).long() position_ids = attention_mask.long().cumsum(-1) - 1 @@ -296,27 +281,19 @@ def new( position_ids = cls.get_position_ids_1d(position_ids, q_seqlens)[None] # seq_len + history_length kv_seqlens = q_seqlens + history_seqlens - max_kv_seqlen = max_q_seqlen + inputs.max_history_length - - window_size = getattr(cache_config, 'window_size', 0) - if window_size > 0: - kv_seqlens -= inputs.num_ignored_history + kv_seqlens -= inputs.num_ignored_history adapter_params = None if inputs.adapter_info is not None: adapter_params = inputs.adapter_info.split_by_targets() - ret = StepContext(inputs=inputs, - block_offsets=inputs.block_offsets, + ret = StepContext(block_offsets=inputs.block_offsets, position_ids=position_ids, input_embeddings=input_embeddings, input_embedding_indexing=input_embedding_indexing, attention_mask=attention_mask, - q_start_loc=q_start_loc, q_seqlens=q_seqlens, kv_seqlens=kv_seqlens, - max_q_seqlen=max_q_seqlen, - max_kv_seqlen=max_kv_seqlen, kv_caches=kv_caches, is_decoding=inputs.is_decoding, world_size=world_size, @@ -351,14 +328,12 @@ def build_context( inputs: ModelInputs, world_size: int = 1, kv_caches: List = None, - cache_config: CacheConfig = None, ): """build context.""" return StepContext.new( inputs, world_size, kv_caches, - cache_config, ) @contextmanager From 4312826a3609d848b612fd9b187a2d6c984ed73f Mon Sep 17 00:00:00 2001 From: grimoire Date: Wed, 24 Jul 2024 14:20:33 +0800 Subject: [PATCH 09/89] attn meta as input --- lmdeploy/pytorch/backends/cuda/layer_backend.py | 4 ++-- lmdeploy/pytorch/engine/model_agent.py | 2 +- lmdeploy/pytorch/model_inputs.py | 1 + lmdeploy/pytorch/models/llama.py | 12 +++++++++--- 4 files changed, 13 insertions(+), 6 deletions(-) diff --git a/lmdeploy/pytorch/backends/cuda/layer_backend.py b/lmdeploy/pytorch/backends/cuda/layer_backend.py index 707153afa..65016c398 100644 --- a/lmdeploy/pytorch/backends/cuda/layer_backend.py +++ b/lmdeploy/pytorch/backends/cuda/layer_backend.py @@ -91,7 +91,7 @@ def update_step_context(cls, step_context): attn_meta_cls = cls.get_attention_metadata_cls() q_seqlens = step_context.q_seqlens q_start_loc = q_seqlens.cumsum(0) - q_seqlens - attn_meta = attn_meta_cls( + attn_metadata = attn_meta_cls( step_context.is_decoding, step_context.block_offsets, q_start_loc=q_start_loc, @@ -99,5 +99,5 @@ def update_step_context(cls, step_context): kv_seqlens=step_context.kv_seqlens, ) - step_context.attn_meta = attn_meta + step_context.attn_metadata = attn_metadata return step_context diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 4afbb827b..2f969977b 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -151,7 +151,7 @@ def model_forward( input_ids=inputs.input_ids, position_ids=context.position_ids, past_key_values=cache_engine.gpu_cache, - context=context, + attn_metadata=context.attn_metadata, ) return dict(logits=output) diff --git a/lmdeploy/pytorch/model_inputs.py b/lmdeploy/pytorch/model_inputs.py index b51df7b0d..17037386a 100644 --- a/lmdeploy/pytorch/model_inputs.py +++ b/lmdeploy/pytorch/model_inputs.py @@ -238,6 +238,7 @@ class StepContext: adapter_params: Dict[str, AdapterInfo] = None input_embeddings: torch.Tensor = None input_embedding_indexing: torch.Tensor = None + attn_metadata = None _outputs: Dict = field(default_factory=dict) diff --git a/lmdeploy/pytorch/models/llama.py b/lmdeploy/pytorch/models/llama.py index a02acc77d..a1e00551a 100644 --- a/lmdeploy/pytorch/models/llama.py +++ b/lmdeploy/pytorch/models/llama.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union import torch import torch.distributed as dist @@ -78,10 +78,10 @@ def forward( hidden_states: torch.Tensor, rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], past_key_value: Optional[Tuple[torch.Tensor]] = None, + attn_metadata: Any = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Rewrite of LlamaAttention.forward.""" - context = self.ctx_mgr.current_context() qkv_states = self.qkv_proj(hidden_states) # (-1, heads, head_dim) qkv_states = qkv_states.flatten(0, -2) @@ -109,7 +109,7 @@ def forward( value_states, past_key_value[0], past_key_value[1], - context.attn_meta, + attn_metadata, inplace=True, ) attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1) @@ -197,6 +197,7 @@ def forward( rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], past_key_value: Optional[List[torch.FloatTensor]], residual: Optional[torch.Tensor] = None, + attn_metadata: Any = None, ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: if residual is None: @@ -211,6 +212,7 @@ def forward( hidden_states=hidden_states, rotary_pos_emb=rotary_pos_emb, past_key_value=past_key_value, + attn_metadata=attn_metadata, ) # Fully Connected @@ -260,6 +262,7 @@ def forward( input_ids: torch.LongTensor = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + attn_metadata: Any = None, inputs_embeds: Optional[torch.FloatTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: """Rewrite of LlamaModel.forward.""" @@ -289,6 +292,7 @@ def forward( rotary_pos_emb=rotary_pos_emb, past_key_value=past_key_value, residual=residual, + attn_metadata=attn_metadata, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -309,6 +313,7 @@ def forward( input_ids: torch.Tensor, position_ids: torch.Tensor, past_key_values: List[List[torch.Tensor]], + attn_metadata: Any = None, inputs_embeds: torch.Tensor = None, **kwargs, ): @@ -316,6 +321,7 @@ def forward( input_ids=input_ids, position_ids=position_ids, past_key_values=past_key_values, + attn_metadata=attn_metadata, inputs_embeds=inputs_embeds, ) From ef092e5297aee0377e7ef594d1b30dbb4391c5ce Mon Sep 17 00:00:00 2001 From: grimoire Date: Thu, 25 Jul 2024 14:20:28 +0800 Subject: [PATCH 10/89] add cuda graph support --- lmdeploy/messages.py | 1 + lmdeploy/pytorch/backends/base.py | 10 + lmdeploy/pytorch/backends/cuda/awq_modules.py | 2 +- .../pytorch/backends/cuda/graph_runner.py | 262 ++++++++++++++++++ .../pytorch/backends/cuda/layer_backend.py | 11 + lmdeploy/pytorch/backends/cuda/slora.py | 5 +- lmdeploy/pytorch/backends/graph_runner.py | 29 ++ lmdeploy/pytorch/config.py | 8 + lmdeploy/pytorch/engine/engine.py | 37 ++- lmdeploy/pytorch/engine/model_agent.py | 52 +++- lmdeploy/pytorch/model_inputs.py | 37 +-- lmdeploy/pytorch/models/llama.py | 43 ++- 12 files changed, 447 insertions(+), 50 deletions(-) create mode 100644 lmdeploy/pytorch/backends/cuda/graph_runner.py create mode 100644 lmdeploy/pytorch/backends/graph_runner.py diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index cd5f6caea..ec79e7906 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -212,6 +212,7 @@ class PytorchEngineConfig: thread_safe: bool = False enable_prefix_caching: bool = False device_type: str = 'cuda' + eager_mode: bool = False download_dir: str = None revision: str = None diff --git a/lmdeploy/pytorch/backends/base.py b/lmdeploy/pytorch/backends/base.py index 0bb8b94b9..e740474ec 100644 --- a/lmdeploy/pytorch/backends/base.py +++ b/lmdeploy/pytorch/backends/base.py @@ -7,6 +7,8 @@ import torch +from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig + class LayerType(Enum): Attention = auto() @@ -62,3 +64,11 @@ def get_v_block_shape( @classmethod def update_step_context(cls, step_context): return step_context + + @staticmethod + def build_graph_runner(model: torch.nn.Module, model_config: ModelConfig, + cache_config: CacheConfig, + backend_config: BackendConfig, + device: torch.device): + from .graph_runner import GraphRunner + return GraphRunner(model) diff --git a/lmdeploy/pytorch/backends/cuda/awq_modules.py b/lmdeploy/pytorch/backends/cuda/awq_modules.py index a26fceb06..ac0d5714d 100644 --- a/lmdeploy/pytorch/backends/cuda/awq_modules.py +++ b/lmdeploy/pytorch/backends/cuda/awq_modules.py @@ -25,7 +25,7 @@ def wq_gemm_forward( if input_dtype != torch.float16: x = x.half() - FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0] * x.shape[1] >= 1024 + FP16_MATMUL_HEURISTIC_CONDITION = x.size(0) * x.size(1) >= 1024 if FP16_MATMUL_HEURISTIC_CONDITION: # TODO: remove event wait if awq kernel set stream diff --git a/lmdeploy/pytorch/backends/cuda/graph_runner.py b/lmdeploy/pytorch/backends/cuda/graph_runner.py new file mode 100644 index 000000000..57da57321 --- /dev/null +++ b/lmdeploy/pytorch/backends/cuda/graph_runner.py @@ -0,0 +1,262 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Any, Dict, List + +import torch + +from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig +from lmdeploy.pytorch.model_inputs import StepContext +from lmdeploy.utils import get_logger + +from ..graph_runner import GraphRunner + +logger = get_logger('lmdeploy') + + +def next_power_of_2(n: int): + """Return the smallest power of 2 greater than or equal to n.""" + n -= 1 + n |= n >> 1 + n |= n >> 2 + n |= n >> 4 + n |= n >> 8 + n |= n >> 16 + n |= n >> 32 + n += 1 + return n + + +class CUDASingleGraphRunner: + """cuda single graph runner.""" + + def __init__( + self, + model: torch.nn.Module, + max_batches: int, + max_tokens: int, + num_blocks: int, + is_decoding: bool, + device: torch.device, + ): + self.model = model + self.ctx_mgr = model.ctx_mgr + self.device = device + self.max_batches = max_batches + self.max_tokens = max_tokens + self.num_blocks = num_blocks + self.is_decoding = is_decoding + self._graph: torch.cuda.CUDAGraph = None + + self.input_buffers = dict() + self.output_buffers = dict() + self.make_buffers() + + def make_buffers(self): + """make cache step context.""" + max_batches = self.max_batches + max_tokens = self.max_tokens + num_blocks = self.num_blocks + device = self.device + is_decoding = self.is_decoding + + if is_decoding: + self.input_buffers['input_ids'] = torch.zeros(1, + max_batches, + dtype=torch.int64, + device=device) + self.input_buffers['position_ids'] = torch.zeros((1, max_batches), + dtype=torch.int64, + device=device) + else: + self.input_buffers['input_ids'] = torch.zeros(1, + max_tokens, + dtype=torch.int64, + device=device) + self.input_buffers['position_ids'] = torch.zeros((1, max_tokens), + dtype=torch.int64, + device=device) + + self.input_buffers['block_offsets'] = torch.zeros( + (max_batches, num_blocks), dtype=torch.int64, device=device) + self.input_buffers['q_start_loc'] = torch.zeros(max_batches, + dtype=torch.int64, + device=device) + self.input_buffers['q_seqlens'] = torch.zeros(max_batches, + dtype=torch.int64, + device=device) + self.input_buffers['kv_seqlens'] = torch.zeros(max_batches, + dtype=torch.int64, + device=device) + self.input_buffers['local_adapter_ids'] = torch.zeros( + max_batches, dtype=torch.int64, device=device) - 1 + + def _fill_inputs(self, input_ids: torch.Tensor, position_ids: torch.Tensor, + past_key_values: List, attn_metadata: Any, + inputs_embeds: torch.Tensor, **kwargs): + """fill input.""" + is_decoding = self.is_decoding + block_offsets = attn_metadata.block_offsets + q_start_loc = attn_metadata.q_start_loc + q_seqlens = attn_metadata.q_seqlens + kv_seqlens = attn_metadata.kv_seqlens + + batch_size, num_blocks = block_offsets.size() + num_tokens = input_ids.size(-1) + + # fill buffer + self.input_buffers['input_ids'][:, :num_tokens] = input_ids + self.input_buffers['position_ids'][:, :num_tokens] = position_ids + self.input_buffers[ + 'block_offsets'][:batch_size, :num_blocks] = block_offsets + self.input_buffers['q_seqlens'][:batch_size] = q_seqlens + self.input_buffers['kv_seqlens'][:batch_size] = kv_seqlens + self.input_buffers['q_start_loc'][:batch_size] = q_start_loc + if inputs_embeds is not None: + emb_size = inputs_embeds.size(-1) + if 'inputs_embeds' not in self.input_buffers: + max_num_tokens = self.input_buffers['input_ids'].size(-1) + self.input_buffers['inputs_embeds'] = inputs_embeds.new_zeros( + 1, max_num_tokens, emb_size) + self.input_buffers['inputs_embeds'][:, :num_tokens] = inputs_embeds + + # create inputs + new_batch_size = next_power_of_2(batch_size) + attn_metadata.block_offsets = self.input_buffers[ + 'block_offsets'][:new_batch_size] + attn_metadata.q_start_loc = self.input_buffers[ + 'q_start_loc'][:new_batch_size] + attn_metadata.q_seqlens = self.input_buffers[ + 'q_seqlens'][:new_batch_size] + attn_metadata.kv_seqlens = self.input_buffers[ + 'kv_seqlens'][:new_batch_size] + + new_inputs = dict( + past_key_values=past_key_values, + attn_metadata=attn_metadata, + ) + + if is_decoding: + new_inputs['input_ids'] = self.input_buffers[ + 'input_ids'][:, :new_batch_size] + new_inputs['position_ids'] = self.input_buffers[ + 'position_ids'][:, :new_batch_size] + else: + new_inputs['input_ids'] = self.input_buffers['input_ids'] + new_inputs['position_ids'] = self.input_buffers['position_ids'] + + if inputs_embeds is not None: + if is_decoding: + new_inputs['inputs_embeds'] = self.input_buffers[ + 'inputs_embeds'][:, :new_batch_size] + else: + new_inputs['inputs_embeds'] = self.input_buffers[ + 'inputs_embeds'] + + new_inputs.update(kwargs) + return new_inputs + + def capture(self, **kwargs): + """capture graph.""" + padded_kwargs = self._fill_inputs(**kwargs) + + # warmup + output = self.model(**padded_kwargs) + + self._graph = torch.cuda.CUDAGraph() + current_stream = torch.cuda.current_stream() + with torch.cuda.graph(self._graph, stream=current_stream): + output = self.model(**padded_kwargs) + + self.output_buffers['logits'] = output + + def forward(self, **kwargs): + """forward.""" + num_tokens = kwargs['input_ids'].size(-1) + assert self._graph is not None + self._fill_inputs(**kwargs) + self._graph.replay() + + output = self.output_buffers['logits'][:, :num_tokens].clone() + return output + + +class CUDAGraphRunner(GraphRunner): + """cuda graph runner.""" + + def __init__(self, model: torch.nn.Module, model_config: ModelConfig, + cache_config: CacheConfig, backend_config: BackendConfig, + device: torch.device): + self.model = model + self.ctx_mgr = model.ctx_mgr + self.device = device + self.model_config = model_config + self.cache_config = cache_config + self.backend_config = backend_config + self.max_batches = cache_config.max_batches + self.max_tokens = cache_config.max_prefill_token_num + self.num_blocks = cache_config.num_gpu_blocks + + self.enable_graph = self.check_enable_graph() + + self._runner_map: Dict[Any, CUDASingleGraphRunner] = dict() + + def check_enable_graph(self): + """check enable graph.""" + if self.backend_config.eager_mode: + return False + + # do not support awq + hf_config = self.model_config.hf_config + quantization_config = getattr(hf_config, 'quantization_config', dict()) + quant_method = quantization_config.get('quant_method', None) + if quant_method == 'awq': + logger.warning('AWQ model does not support cuda graph yet.') + return False + return True + + def get_graph_key(self, input_ids: torch.Tensor, + position_ids: torch.Tensor, past_key_values: List, + attn_metadata: Any, inputs_embeds: torch.Tensor, + **kwargs): + """get graph key.""" + context = self.ctx_mgr.current_context() + is_decoding = context.is_decoding + block_offsets = attn_metadata.block_offsets + batch_size, _ = block_offsets.size() + new_batch_size = next_power_of_2(batch_size) + return (new_batch_size, is_decoding) + + def __call__(self, **kwargs): + """call.""" + if not self.enable_graph: + return self.model(**kwargs) + + graph_key = self.get_graph_key(**kwargs) + max_batches = graph_key[0] + is_decoding = graph_key[1] + if graph_key not in self._runner_map: + runner = CUDASingleGraphRunner(self.model, + max_batches=max_batches, + max_tokens=self.max_tokens, + num_blocks=self.num_blocks, + is_decoding=is_decoding, + device=self.device) + runner.capture(**kwargs) + self._runner_map[graph_key] = runner + else: + runner = self._runner_map[graph_key] + + output = runner.forward(**kwargs) + return output + + def prepare_inputs_for_generation( + self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: torch.Tensor = None, + context: StepContext = None, + ): + """prepare inputs.""" + return self.model.prepare_inputs_for_generation( + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + context=context, + ) diff --git a/lmdeploy/pytorch/backends/cuda/layer_backend.py b/lmdeploy/pytorch/backends/cuda/layer_backend.py index 65016c398..1ee3a7316 100644 --- a/lmdeploy/pytorch/backends/cuda/layer_backend.py +++ b/lmdeploy/pytorch/backends/cuda/layer_backend.py @@ -3,6 +3,7 @@ import torch +from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig from lmdeploy.utils import get_logger from ..base import LayerType @@ -101,3 +102,13 @@ def update_step_context(cls, step_context): step_context.attn_metadata = attn_metadata return step_context + + @staticmethod + def build_graph_runner(model: torch.nn.Module, model_config: ModelConfig, + cache_config: CacheConfig, + backend_config: BackendConfig, + device: torch.device): + """build graph runner.""" + from .graph_runner import CUDAGraphRunner + return CUDAGraphRunner(model, model_config, cache_config, + backend_config, device) diff --git a/lmdeploy/pytorch/backends/cuda/slora.py b/lmdeploy/pytorch/backends/cuda/slora.py index 8c397fe49..c5e75a17a 100644 --- a/lmdeploy/pytorch/backends/cuda/slora.py +++ b/lmdeploy/pytorch/backends/cuda/slora.py @@ -51,17 +51,18 @@ def _make_packed_lora_input(self, x, target_name: str, layer_idx: int): cache_len = k_cache.size(0) a_cache = k_cache.view(cache_len, -1) b_cache = v_cache.view(cache_len, -1) + max_q_seq_length = x.numel() // x.size(-1) return PackedLoRAInput(x=x.flatten(0, -2).contiguous(), a_cache=a_cache, b_cache=b_cache, q_start_loc=context.q_start_loc, - q_seqlens=context.q_seq_length, + q_seqlens=context.q_seqlens, adapter_ids=context.local_adapter_ids, scaling=scaling, rank_offset=rank_offset, ranks=ranks, - max_seq_len=context.max_q_seq_length, + max_seq_len=max_q_seq_length, max_rank=max_rank, is_decoding=context.is_decoding) diff --git a/lmdeploy/pytorch/backends/graph_runner.py b/lmdeploy/pytorch/backends/graph_runner.py new file mode 100644 index 000000000..21ef31f10 --- /dev/null +++ b/lmdeploy/pytorch/backends/graph_runner.py @@ -0,0 +1,29 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import torch + +from lmdeploy.pytorch.model_inputs import StepContext + + +class GraphRunner: + """graph runner.""" + + def __init__(self, model: torch.nn.Module, **kwargs): + self.model = model + + def __call__(self, **kwargs): + return self.model(**kwargs) + + def prepare_inputs_for_generation( + self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: torch.Tensor = None, + context: StepContext = None, + ): + """prepare inputs.""" + return self.model.prepare_inputs_for_generation( + past_key_values, + inputs_embeds, + context, + ) diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index 9359374a9..de437b249 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -28,6 +28,13 @@ def _update_torch_dtype(config: 'ModelConfig', default: str = 'float16'): return config +@dataclass +class BackendConfig: + """backend config.""" + eager_mode: bool = True + device_type: str = 'cuda' + + @dataclass class SchedulerConfig: """Config of scheduler.""" @@ -44,6 +51,7 @@ class SchedulerConfig: class CacheConfig: """Config of key value cache.""" + max_batches: int block_size: int num_cpu_blocks: int num_gpu_blocks: int diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index b8ed66cf1..68b31b73c 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -13,7 +13,7 @@ from ..adapter.adapter import AdapterManager, SchedulerAdapter from ..check_env import check_adapters, check_env, check_model -from ..config import CacheConfig, SchedulerConfig +from ..config import BackendConfig, CacheConfig, SchedulerConfig from ..devices import DeviceContext, get_device_manager from ..messages import (InputEmbeddingRangeType, InputEmbeddingType, MessageStatus, SchedulerSequence) @@ -125,6 +125,7 @@ def __init__(self, # block_size = 1 to enable unified paging adapters = engine_config.adapters cache_config = CacheConfig( + max_batches=engine_config.max_batch_size, block_size=engine_config.block_size, num_cpu_blocks=engine_config.num_cpu_blocks, num_gpu_blocks=engine_config.num_gpu_blocks, @@ -138,10 +139,20 @@ def __init__(self, engine_config.revision) self.model_path = model_path + backend_config = BackendConfig( + eager_mode=engine_config.eager_mode, + device_type=engine_config.device_type, + ) + if (engine_config.adapters is not None + and not backend_config.eager_mode): + logger.warning('LoRA adapter require eager_mode=True') + backend_config.eager_mode = True + with get_device_manager().context(self.device_context): self.model_agent = AutoModelAgent.from_pretrained( model_path, cache_config=cache_config, + backend_config=backend_config, trust_remote_code=trust_remote_code, adapters=adapters, tp=tp) @@ -158,6 +169,7 @@ def __init__(self, self.scheduler_config = scheduler_config self.cache_config = cache_config + self.backend_config = backend_config self.stream = torch.cuda.Stream() self.req_manager = self._bind_request_manager() @@ -359,8 +371,6 @@ def create_model_inputs(self, messages: SeqList, adapters: AdapterList, token_ids = [msg.token_ids for msg in messages] - meta = messages[0].meta - if isinstance(token_ids[0], int): token_ids = [token_ids] @@ -451,16 +461,17 @@ def __get_vlm_embeddings(): input_embedding_indexing=input_embedding_indexing, input_embedding_ranges=input_embedding_ranges) - return ModelInputs(input_ids=input_ids, - seq_length=seq_length, - history_lengths=history_lengths, - block_offsets=block_offsets, - is_decoding=is_decoding, - num_ignored_history=num_ignored_history, - local_adapter_ids=local_adapter_ids, - adapter_info=adapter_info, - vision_inputs=vision_embedding_inputs, - meta=meta) + return ModelInputs( + input_ids=input_ids, + seq_length=seq_length, + history_lengths=history_lengths, + block_offsets=block_offsets, + is_decoding=is_decoding, + num_ignored_history=num_ignored_history, + local_adapter_ids=local_adapter_ids, + adapter_info=adapter_info, + vision_inputs=vision_embedding_inputs, + ) def _batch_stopping_criteria(self, token_ids: torch.Tensor, stop_words: torch.Tensor, diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 2f969977b..a0a54e26e 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -14,7 +14,8 @@ from ..adapter.adapter import (AdapterWeightMap, get_indexed_lora_linears, get_loralinear_info, update_lora_linears) -from ..config import CacheConfig, ModelConfig +from ..backends import get_backend +from ..config import BackendConfig, CacheConfig, ModelConfig from ..devices import DeviceContext, get_device_manager from ..model_inputs import ModelInputs from ..models.patch import patch, update_model @@ -147,12 +148,11 @@ def model_forward( kv_caches=cache_engine.gpu_cache, ) with ctx_mgr.context(context): - output = model( - input_ids=inputs.input_ids, - position_ids=context.position_ids, + input_dict = model.prepare_inputs_for_generation( past_key_values=cache_engine.gpu_cache, - attn_metadata=context.attn_metadata, + context=context, ) + output = model(**input_dict) return dict(logits=output) @@ -267,12 +267,14 @@ def forward(self, inputs: ModelInputs, swap_in_map: SwapMap, def from_pretrained(cls, pretrained_model_name_or_path: str, cache_config: CacheConfig, + backend_config: BackendConfig, trust_remote_code: bool, adapters: Dict[str, str] = None, tp: int = 1): """from pretrained.""" return build_model_agent(pretrained_model_name_or_path, cache_config=cache_config, + backend_config=backend_config, trust_remote_code=trust_remote_code, adapters=adapters, tp=tp) @@ -294,19 +296,31 @@ def __init__(self, model_path: str, model_config: ModelConfig, cache_config: CacheConfig, + backend_config: BackendConfig, adapters: Dict[str, str] = None, trust_remote_code: bool = True): super().__init__(model_config=model_config, cache_config=cache_config) torch_dtype = model_config.dtype + device = 'cuda' + self.backend_config = backend_config self.patched_model = self._build_model( model_path, torch_dtype=torch_dtype, adapters=adapters, + device=device, trust_remote_code=trust_remote_code) _update_cache_config(model_config, cache_config) + backend = get_backend() + self.patched_model = backend.build_graph_runner( + self.patched_model, + model_config=model_config, + cache_config=cache_config, + backend_config=backend_config, + device=device) + self.cache_engine = CacheEngine(cache_config, model_config) self.stream = torch.cuda.Stream() @@ -314,9 +328,9 @@ def _build_model(self, model_path: str, torch_dtype: torch.dtype, adapters: Dict[str, str] = None, + device: torch.device = 'cuda', trust_remote_code: bool = True): """build patched model.""" - device = 'cuda' with LoadNoInit(), warnings.catch_warnings(): warnings.simplefilter('ignore') hf_model = self.model_config.auto_model_cls.from_pretrained( @@ -443,6 +457,7 @@ def _tp_build_model( model_path: str, model_config: ModelConfig, cache_config: CacheConfig, + backend_config: BackendConfig, adapters: Dict[str, str], world_size: int, trust_remote_code=True, @@ -514,6 +529,15 @@ def _broadcast_config(cache_config): cache_config, gpu_id=rank, world_size=world_size) + + backend = get_backend() + patched_model = backend.build_graph_runner( + patched_model, + model_config=model_config, + cache_config=cache_config, + backend_config=backend_config, + device='cuda') + cache_config = _broadcast_config(cache_config) cache_engine = CacheEngine(cache_config, model_config, @@ -580,6 +604,7 @@ def _tp_model_loop( model_path: str, model_config: ModelConfig, cache_config: CacheConfig, + backend_config: BackendConfig, adapters: Dict[str, str], world_size: int, trust_remote_code=True, @@ -602,6 +627,7 @@ def _tp_model_loop( model_path, model_config, cache_config, + backend_config, adapters, world_size=world_size, trust_remote_code=trust_remote_code) @@ -706,6 +732,7 @@ def __init__(self, model_path: str, model_config: ModelConfig, cache_config: CacheConfig, + backend_config: BackendConfig, world_size: int, adapters: Dict[str, str] = None, trust_remote_code: bool = True) -> None: @@ -728,10 +755,12 @@ def __signal_term_handler(sig, frame): self.mp_ctx = mp.get_context('spawn') self.world_size = world_size + self.backend_config = backend_config self._start_sub_process(model_path, model_config=model_config, cache_config=cache_config, + backend_config=backend_config, adapters=adapters, world_size=world_size, trust_remote_code=trust_remote_code) @@ -740,6 +769,7 @@ def __signal_term_handler(sig, frame): model_path=model_path, model_config=model_config, cache_config=cache_config, + backend_config=backend_config, adapters=adapters, world_size=world_size, trust_remote_code=trust_remote_code, @@ -750,7 +780,9 @@ def __signal_term_handler(sig, frame): self.stream = torch.cuda.Stream() def _start_sub_process(self, model_path: str, model_config: ModelConfig, - cache_config: CacheConfig, adapters: Dict[str, str], + cache_config: CacheConfig, + backend_config: BackendConfig, adapters: Dict[str, + str], world_size: int, trust_remote_code: bool): """Start tensor parallel sub process.""" port = _find_available_port() @@ -770,6 +802,7 @@ def _start_sub_process(self, model_path: str, model_config: ModelConfig, (model_path, ), dict(model_config=model_config, cache_config=cache_config, + backend_config=backend_config, adapters=adapters, world_size=world_size, trust_remote_code=trust_remote_code), @@ -800,6 +833,7 @@ def _build_model( model_path: str, model_config: ModelConfig, cache_config: CacheConfig, + backend_config: BackendConfig, adapters: Dict[str, str], world_size: int, trust_remote_code=True, @@ -812,6 +846,7 @@ def _build_model( model_path=model_path, model_config=model_config, cache_config=cache_config, + backend_config=backend_config, adapters=adapters, world_size=world_size, trust_remote_code=trust_remote_code, @@ -890,6 +925,7 @@ async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, def build_model_agent(model_path: str, cache_config: CacheConfig, + backend_config: BackendConfig, trust_remote_code: bool, adapters: Dict[str, str] = None, tp: int = 1): @@ -900,12 +936,14 @@ def build_model_agent(model_path: str, model_agent = BaseModelAgent(model_path, model_config=model_config, cache_config=cache_config, + backend_config=backend_config, adapters=adapters, trust_remote_code=trust_remote_code) else: model_agent = TPModelAgent(model_path, model_config=model_config, cache_config=cache_config, + backend_config=backend_config, world_size=tp, adapters=adapters, trust_remote_code=trust_remote_code) diff --git a/lmdeploy/pytorch/model_inputs.py b/lmdeploy/pytorch/model_inputs.py index 17037386a..63829f790 100644 --- a/lmdeploy/pytorch/model_inputs.py +++ b/lmdeploy/pytorch/model_inputs.py @@ -149,7 +149,6 @@ class ModelInputs: num_ignored_history: torch.LongTensor local_adapter_ids: torch.LongTensor = None adapter_info: AdapterInfo = None - meta: Any = None vision_inputs: VisionModelInputs = None def update(self, input_ids: torch.LongTensor): @@ -194,7 +193,6 @@ def split(self, split_size: int, block_size: int): num_ignored_history=self.num_ignored_history, local_adapter_ids=self.local_adapter_ids, adapter_info=self.adapter_info, - meta=self.meta, vision_inputs=self.vision_inputs, ) ret.append(inp) @@ -226,11 +224,13 @@ class StepContext: patched model might need extra information to perform inference. This dataclass provide these infos and tools. """ + input_ids: torch.LongTensor block_offsets: torch.LongTensor position_ids: torch.LongTensor attention_mask: torch.LongTensor q_seqlens: torch.LongTensor kv_seqlens: torch.LongTensor + q_start_loc: torch.LongTensor kv_caches: List is_decoding: bool world_size: int = 1 @@ -238,7 +238,7 @@ class StepContext: adapter_params: Dict[str, AdapterInfo] = None input_embeddings: torch.Tensor = None input_embedding_indexing: torch.Tensor = None - attn_metadata = None + attn_metadata: Any = None _outputs: Dict = field(default_factory=dict) @@ -267,7 +267,7 @@ def new( input_embeddings, input_embedding_indexing = \ inputs.vision_inputs.get_inputs(history_seqlens, q_seqlens) - # q_start_loc and kv_seqlens + # kv_seqlens if inputs.is_decoding: attention_mask = torch.ones_like(q_seqlens)[:, None] position_ids = history_seqlens.unsqueeze(-1) @@ -277,6 +277,7 @@ def new( attention_mask = (mask_range < q_seqlens[:, None]).long() position_ids = attention_mask.long().cumsum(-1) - 1 position_ids += history_seqlens.unsqueeze(-1) + q_start_loc = q_seqlens.cumsum(0) - q_seqlens # position ids 1d position_ids = cls.get_position_ids_1d(position_ids, q_seqlens)[None] @@ -288,18 +289,22 @@ def new( if inputs.adapter_info is not None: adapter_params = inputs.adapter_info.split_by_targets() - ret = StepContext(block_offsets=inputs.block_offsets, - position_ids=position_ids, - input_embeddings=input_embeddings, - input_embedding_indexing=input_embedding_indexing, - attention_mask=attention_mask, - q_seqlens=q_seqlens, - kv_seqlens=kv_seqlens, - kv_caches=kv_caches, - is_decoding=inputs.is_decoding, - world_size=world_size, - local_adapter_ids=inputs.local_adapter_ids, - adapter_params=adapter_params) + ret = StepContext( + input_ids=inputs.input_ids, + block_offsets=inputs.block_offsets, + position_ids=position_ids, + input_embeddings=input_embeddings, + input_embedding_indexing=input_embedding_indexing, + attention_mask=attention_mask, + q_seqlens=q_seqlens, + kv_seqlens=kv_seqlens, + q_start_loc=q_start_loc, + kv_caches=kv_caches, + is_decoding=inputs.is_decoding, + world_size=world_size, + local_adapter_ids=inputs.local_adapter_ids, + adapter_params=adapter_params, + ) ret = get_backend().update_step_context(ret) return ret diff --git a/lmdeploy/pytorch/models/llama.py b/lmdeploy/pytorch/models/llama.py index a1e00551a..0a8ea55d5 100644 --- a/lmdeploy/pytorch/models/llama.py +++ b/lmdeploy/pytorch/models/llama.py @@ -11,7 +11,7 @@ build_rotary_embedding) from lmdeploy.pytorch.layers.linear import (build_merged_colwise_linear, build_rowwise_linear) -from lmdeploy.pytorch.model_inputs import StepContextManager +from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager from ..weight_loader.dist_utils import (colwise_parallelize_linear, rowwise_parallelize_linear) @@ -266,19 +266,9 @@ def forward( inputs_embeds: Optional[torch.FloatTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: """Rewrite of LlamaModel.forward.""" - context = self.ctx_mgr.current_context() - # get inputs from context - vision_embeddings = context.input_embeddings - vision_embedding_indexing = context.input_embedding_indexing - if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if vision_embeddings is not None and len(vision_embeddings) > 0: - inputs_embeds[:, - vision_embedding_indexing, :] = vision_embeddings.to( - inputs_embeds) - hidden_states = inputs_embeds residual = None cos, sin = self.rotary_emb(hidden_states, position_ids) @@ -302,6 +292,8 @@ def forward( class LlamaForCausalLM(nn.Module): + support_cuda_graph = True + def __init__(self, origin: nn.Module, ctx_mgr: StepContextManager): super().__init__() self.ctx_mgr = ctx_mgr @@ -328,3 +320,32 @@ def forward( logits = self.lm_head(hidden_states) logits = logits.float() return logits + + def prepare_inputs_for_generation( + self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: torch.Tensor = None, + context: StepContext = None, + ): + """prepare input.""" + input_ids = context.input_ids + position_ids = context.position_ids + attn_metadata = context.attn_metadata + # get inputs from context + vision_embeddings = context.input_embeddings + vision_embedding_indexing = context.input_embedding_indexing + + if vision_embeddings is not None and len(vision_embeddings) > 0: + if inputs_embeds is None: + inputs_embeds = self.model.embed_tokens(input_ids) + inputs_embeds[:, + vision_embedding_indexing, :] = vision_embeddings.to( + inputs_embeds) + + return dict( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + ) From 9fefda5a2fe44a383daeeff23ae9735dcde90154 Mon Sep 17 00:00:00 2001 From: grimoire Date: Thu, 25 Jul 2024 15:22:35 +0800 Subject: [PATCH 11/89] disable one of mha kernel --- lmdeploy/pytorch/kernels/cuda/pagedattention.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lmdeploy/pytorch/kernels/cuda/pagedattention.py b/lmdeploy/pytorch/kernels/cuda/pagedattention.py index 64eba4b87..5adecdeb4 100644 --- a/lmdeploy/pytorch/kernels/cuda/pagedattention.py +++ b/lmdeploy/pytorch/kernels/cuda/pagedattention.py @@ -795,7 +795,8 @@ def _get_block_d(Lk): SPLIT_K = 4 block_per_cta = triton.cdiv(block_offsets.size(-1), SPLIT_K) acc = q.new_empty(batch, head, SPLIT_K, Lv + 2, dtype=torch.float32) - if kv_group_num <= 2 or shared_kv: + if False: + # TODO: make sure if we need to keep this kernel BLOCK_DMODEL = triton.next_power_of_2(Lk) if shared_kv: BLOCK_DV = BLOCK_DMODEL From 839f0be646077f7c2c4b9db3278454dfe740f499 Mon Sep 17 00:00:00 2001 From: grimoire Date: Fri, 26 Jul 2024 16:10:35 +0800 Subject: [PATCH 12/89] share graph pool --- .../pytorch/backends/cuda/graph_runner.py | 12 ++++++-- lmdeploy/pytorch/engine/model_agent.py | 29 +++++++++++++++++-- 2 files changed, 35 insertions(+), 6 deletions(-) diff --git a/lmdeploy/pytorch/backends/cuda/graph_runner.py b/lmdeploy/pytorch/backends/cuda/graph_runner.py index 57da57321..f7e958c40 100644 --- a/lmdeploy/pytorch/backends/cuda/graph_runner.py +++ b/lmdeploy/pytorch/backends/cuda/graph_runner.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Dict, List +from typing import Any, Dict, List, Tuple import torch @@ -35,6 +35,7 @@ def __init__( max_tokens: int, num_blocks: int, is_decoding: bool, + pool: Tuple[int, int], device: torch.device, ): self.model = model @@ -44,6 +45,7 @@ def __init__( self.max_tokens = max_tokens self.num_blocks = num_blocks self.is_decoding = is_decoding + self.pool = pool self._graph: torch.cuda.CUDAGraph = None self.input_buffers = dict() @@ -157,13 +159,15 @@ def _fill_inputs(self, input_ids: torch.Tensor, position_ids: torch.Tensor, def capture(self, **kwargs): """capture graph.""" padded_kwargs = self._fill_inputs(**kwargs) + current_stream = torch.cuda.current_stream() # warmup output = self.model(**padded_kwargs) self._graph = torch.cuda.CUDAGraph() - current_stream = torch.cuda.current_stream() - with torch.cuda.graph(self._graph, stream=current_stream): + with torch.cuda.graph(self._graph, + pool=self.pool, + stream=current_stream): output = self.model(**padded_kwargs) self.output_buffers['logits'] = output @@ -197,6 +201,7 @@ def __init__(self, model: torch.nn.Module, model_config: ModelConfig, self.enable_graph = self.check_enable_graph() + self.graph_pool_handle = torch.cuda.graph_pool_handle() self._runner_map: Dict[Any, CUDASingleGraphRunner] = dict() def check_enable_graph(self): @@ -239,6 +244,7 @@ def __call__(self, **kwargs): max_tokens=self.max_tokens, num_blocks=self.num_blocks, is_decoding=is_decoding, + pool=self.graph_pool_handle, device=self.device) runner.capture(**kwargs) self._runner_map[graph_key] = runner diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index a0a54e26e..12c7b8dc6 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import asyncio +import atexit import os import warnings from datetime import timedelta @@ -553,7 +554,7 @@ def _broadcast_inputs(rank: int, inputs: Any, stream: torch.cuda.Stream): """get input tensor parallel.""" # broadcast meta info if rank != 0: - inputs = [None, None, None] + inputs = [None, None, None, None] with torch.cuda.stream(stream): dist.broadcast_object_list(inputs) @@ -639,9 +640,12 @@ def _tp_model_loop( weight_map=None) while True: - inputs, swap_in_map, swap_out_map = _broadcast_inputs( + inputs, swap_in_map, swap_out_map, exit_flag = _broadcast_inputs( rank, None, stream) + if exit_flag: + break + cache_swapping(cache_engine, swap_in_map=swap_in_map, swap_out_map=swap_out_map) @@ -814,6 +818,8 @@ def _start_sub_process(self, model_path: str, model_config: ModelConfig, _check_context_alive(self.mp_context) rank = 0 + # Please see [Note Exit By Sending Exit Flag] + atexit.register(_exit_by_sending_exit_flag, rank, self) try: dist.init_process_group('nccl', rank=rank, @@ -877,7 +883,8 @@ def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, """forward impl.""" _check_context_alive(self.mp_context) rank = 0 - _broadcast_inputs(rank, [inputs, swap_in_map, swap_out_map], + exit_flag = False + _broadcast_inputs(rank, [inputs, swap_in_map, swap_out_map, exit_flag], self.stream) cache_swapping(self.cache_engine, swap_in_map=swap_in_map, @@ -923,6 +930,22 @@ async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, return output +# [Note] Exit By Sending Exit Flag +# the registration of this function in atexit should be called +# after importing torch.multiprocessing +def _exit_by_sending_exit_flag(rank: int, agent: TPModelAgent): + # send exit_flag to all subprocess relying on all subprocess are alive + # and wait at _broadcast_inputs + exit_flag = True + _broadcast_inputs(rank, [None, None, None, exit_flag], agent.stream) + agent.stream.synchronize() + + # Tricky, extra sleep for subprocess releasing resources + import time + time.sleep(1) + return + + def build_model_agent(model_path: str, cache_config: CacheConfig, backend_config: BackendConfig, From 3345181b4fc11a245ceefbb52ae2d7b4dda49d30 Mon Sep 17 00:00:00 2001 From: grimoire Date: Fri, 26 Jul 2024 16:39:27 +0800 Subject: [PATCH 13/89] del graph --- lmdeploy/pytorch/backends/cuda/graph_runner.py | 4 ++++ lmdeploy/pytorch/engine/model_agent.py | 2 ++ 2 files changed, 6 insertions(+) diff --git a/lmdeploy/pytorch/backends/cuda/graph_runner.py b/lmdeploy/pytorch/backends/cuda/graph_runner.py index f7e958c40..3adc09f45 100644 --- a/lmdeploy/pytorch/backends/cuda/graph_runner.py +++ b/lmdeploy/pytorch/backends/cuda/graph_runner.py @@ -182,6 +182,10 @@ def forward(self, **kwargs): output = self.output_buffers['logits'][:, :num_tokens].clone() return output + def __del__(self): + """del.""" + del self._graph + class CUDAGraphRunner(GraphRunner): """cuda graph runner.""" diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 12c7b8dc6..bd95dc4b1 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -940,6 +940,8 @@ def _exit_by_sending_exit_flag(rank: int, agent: TPModelAgent): _broadcast_inputs(rank, [None, None, None, exit_flag], agent.stream) agent.stream.synchronize() + del agent.patched_model + # Tricky, extra sleep for subprocess releasing resources import time time.sleep(1) From 6746e676b4345d44b0d0355916f331bbf8db73b9 Mon Sep 17 00:00:00 2001 From: grimoire Date: Fri, 26 Jul 2024 20:45:17 +0800 Subject: [PATCH 14/89] update docstring --- lmdeploy/pytorch/backends/activation.py | 4 ++++ lmdeploy/pytorch/backends/apply_rotary_emb.py | 13 ++++++++++++- lmdeploy/pytorch/backends/attention.py | 8 ++++++++ lmdeploy/pytorch/backends/awq_modules.py | 4 ++++ lmdeploy/pytorch/backends/base.py | 12 ++++++++++++ .../pytorch/backends/cuda/apply_rotary_emb.py | 14 ++++++++++++-- lmdeploy/pytorch/backends/cuda/attention.py | 8 ++++++++ lmdeploy/pytorch/backends/cuda/awq_modules.py | 4 ++++ lmdeploy/pytorch/backends/cuda/layer_backend.py | 6 ++++++ .../backends/cuda/multinomial_sampling.py | 2 ++ lmdeploy/pytorch/backends/cuda/norm.py | 4 ++++ lmdeploy/pytorch/backends/cuda/qmodules.py | 8 ++++++++ lmdeploy/pytorch/backends/cuda/slora.py | 6 ++++++ lmdeploy/pytorch/backends/default/activation.py | 4 ++++ .../pytorch/backends/default/apply_rotary_emb.py | 14 ++++++++++++-- lmdeploy/pytorch/backends/default/awq_modules.py | 4 ++++ .../pytorch/backends/default/layer_backend.py | 3 +++ lmdeploy/pytorch/backends/default/linear.py | 4 ++++ .../backends/default/multinomial_sampling.py | 4 ++++ lmdeploy/pytorch/backends/default/norm.py | 4 ++++ .../pytorch/backends/default/rotary_embedding.py | 5 +++++ lmdeploy/pytorch/backends/graph_runner.py | 1 + lmdeploy/pytorch/backends/linear.py | 4 ++++ .../pytorch/backends/multinomial_sampling.py | 4 ++++ lmdeploy/pytorch/backends/norm.py | 4 ++++ lmdeploy/pytorch/backends/qmodules.py | 8 ++++++++ lmdeploy/pytorch/backends/rotary_embedding.py | 5 +++++ lmdeploy/pytorch/backends/slora.py | 5 +++++ lmdeploy/pytorch/layers/activation.py | 6 ++++-- lmdeploy/pytorch/layers/apply_rotary_emb.py | 16 ---------------- lmdeploy/pytorch/layers/attention.py | 4 ++++ lmdeploy/pytorch/layers/linear.py | 10 ++++++++++ lmdeploy/pytorch/layers/multinomial_sampling.py | 1 + lmdeploy/pytorch/layers/norm.py | 2 ++ lmdeploy/pytorch/layers/rotary_embedding.py | 12 ++++++++++-- 35 files changed, 192 insertions(+), 25 deletions(-) delete mode 100644 lmdeploy/pytorch/layers/apply_rotary_emb.py diff --git a/lmdeploy/pytorch/backends/activation.py b/lmdeploy/pytorch/backends/activation.py index b6478705f..327dcbfcc 100644 --- a/lmdeploy/pytorch/backends/activation.py +++ b/lmdeploy/pytorch/backends/activation.py @@ -3,15 +3,19 @@ class SiluAndMulImpl(ABC): + """silu + multiple residual fused implementation.""" @abstractmethod def forward(self, x): + """forward.""" raise NotImplementedError class SiluAndMulBuilder(ABC): + """silu and mul implementation builder.""" @staticmethod @abstractmethod def build(inplace: bool = False): + """build.""" raise NotImplementedError diff --git a/lmdeploy/pytorch/backends/apply_rotary_emb.py b/lmdeploy/pytorch/backends/apply_rotary_emb.py index 9e6955bfd..bc3c7b664 100644 --- a/lmdeploy/pytorch/backends/apply_rotary_emb.py +++ b/lmdeploy/pytorch/backends/apply_rotary_emb.py @@ -1,17 +1,28 @@ # Copyright (c) OpenMMLab. All rights reserved. from abc import ABC, abstractmethod +from torch import Tensor + class ApplyRotaryEmbImpl(ABC): + """Apply rotary embedding implementation.""" @abstractmethod - def forward(self, query, key, cos, sin, inplace: bool = True): + def forward(self, + query: Tensor, + key: Tensor, + cos: Tensor, + sin: Tensor, + inplace: bool = True): + """forward.""" raise NotImplementedError class ApplyRotaryEmbBuilder(ABC): + """Apply rotary embedding implementation builder.""" @staticmethod @abstractmethod def build(): + """build implementation.""" raise NotImplementedError diff --git a/lmdeploy/pytorch/backends/attention.py b/lmdeploy/pytorch/backends/attention.py index 84e605c25..d7adf9a09 100644 --- a/lmdeploy/pytorch/backends/attention.py +++ b/lmdeploy/pytorch/backends/attention.py @@ -8,6 +8,7 @@ @dataclass class AttentionMetadata: + """Base Attention metadata.""" is_decoding: bool block_offsets: torch.Tensor q_start_loc: torch.Tensor = None @@ -19,6 +20,7 @@ class AttentionMetadata: class AttentionImpl(ABC, Generic[T]): + """Attention implementation.""" def __init__( self, @@ -29,6 +31,7 @@ def __init__( v_head_size: int = None, alibi_scale: float = None, sliding_window: int = None, + logical_softcapping: float = None, **kwargs, ) -> None: if scale is None: @@ -47,6 +50,7 @@ def __init__( self.v_head_size = v_head_size self.alibi_scale = alibi_scale self.sliding_window = sliding_window + self.logical_softcapping = logical_softcapping @abstractmethod def forward( @@ -58,10 +62,12 @@ def forward( v_cache: torch.Tensor, attn_metadata: T, ) -> torch.Tensor: + """forward.""" raise NotImplementedError class AttentionBuilder(ABC, Generic[T]): + """Attention implementation builder.""" @staticmethod @abstractmethod @@ -73,6 +79,8 @@ def build( v_head_size: int = None, alibi_scale: float = None, sliding_window: int = None, + logical_softcapping: float = None, **kwargs, ) -> AttentionImpl[T]: + """build.""" raise NotImplementedError diff --git a/lmdeploy/pytorch/backends/awq_modules.py b/lmdeploy/pytorch/backends/awq_modules.py index b82af3bd7..c62a72d20 100644 --- a/lmdeploy/pytorch/backends/awq_modules.py +++ b/lmdeploy/pytorch/backends/awq_modules.py @@ -7,15 +7,19 @@ class LinearW4A16Impl(ABC, nn.Module): + """w4a16 linear implementation.""" @abstractmethod def forward(self, x, all_reduce: bool = False): + """forward.""" raise NotImplementedError class LinearW4A16Builder(ABC): + """w4a16 linear implementation builder.""" @staticmethod @abstractmethod def build(mod: nn.Module, ctx_mgr: StepContextManager = None): + """build.""" raise NotImplementedError diff --git a/lmdeploy/pytorch/backends/base.py b/lmdeploy/pytorch/backends/base.py index e740474ec..66b27821d 100644 --- a/lmdeploy/pytorch/backends/base.py +++ b/lmdeploy/pytorch/backends/base.py @@ -11,6 +11,7 @@ class LayerType(Enum): + """Layer type enumerate.""" Attention = auto() Linear = auto() RotaryEmbedding = auto() @@ -25,20 +26,24 @@ class LayerType(Enum): class LayersBackend(ABC): + """Layer backend abstract.""" @staticmethod @abstractmethod def get_name() -> str: + """get backend name.""" raise NotImplementedError @classmethod @abstractmethod def get_layer_impl_builder(cls, layer_type: LayerType): + """get builder of given layer type.""" raise NotImplementedError @staticmethod @abstractmethod def get_attention_metadata_cls(): + """get attention metadata class.""" raise NotImplementedError @staticmethod @@ -49,6 +54,7 @@ def get_k_block_shape( head_size: int, dtype: torch.dtype, ) -> Tuple[int, ...]: + """get block shape of k.""" raise NotImplementedError @staticmethod @@ -59,10 +65,15 @@ def get_v_block_shape( head_size: int, dtype: torch.dtype, ) -> Tuple[int, ...]: + """get block shape of v.""" raise NotImplementedError @classmethod def update_step_context(cls, step_context): + """update StepContext for inference. + + attention meta should be built here. + """ return step_context @staticmethod @@ -70,5 +81,6 @@ def build_graph_runner(model: torch.nn.Module, model_config: ModelConfig, cache_config: CacheConfig, backend_config: BackendConfig, device: torch.device): + """build graph runner.""" from .graph_runner import GraphRunner return GraphRunner(model) diff --git a/lmdeploy/pytorch/backends/cuda/apply_rotary_emb.py b/lmdeploy/pytorch/backends/cuda/apply_rotary_emb.py index fd90ba82b..74e72c4ee 100644 --- a/lmdeploy/pytorch/backends/cuda/apply_rotary_emb.py +++ b/lmdeploy/pytorch/backends/cuda/apply_rotary_emb.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch +from torch import Tensor from lmdeploy.pytorch.kernels.cuda import apply_rotary_pos_emb @@ -7,8 +8,15 @@ class TritonApplyRotaryEmbImpl(ApplyRotaryEmbImpl): - - def forward(self, query, key, cos, sin, inplace: bool = True): + """Apply rotary embedding implementation.""" + + def forward(self, + query: Tensor, + key: Tensor, + cos: Tensor, + sin: Tensor, + inplace: bool = True): + """forward.""" if inplace: q_embed = query k_embed = key @@ -19,7 +27,9 @@ def forward(self, query, key, cos, sin, inplace: bool = True): class TritonApplyRotaryEmbBuilder(ApplyRotaryEmbBuilder): + """Apply rotary embedding implementation builder.""" @staticmethod def build(): + """build implementation.""" return TritonApplyRotaryEmbImpl() diff --git a/lmdeploy/pytorch/backends/cuda/attention.py b/lmdeploy/pytorch/backends/cuda/attention.py index 5751cbb74..459e5459f 100644 --- a/lmdeploy/pytorch/backends/cuda/attention.py +++ b/lmdeploy/pytorch/backends/cuda/attention.py @@ -5,10 +5,12 @@ class TritonAttentionMetadata(AttentionMetadata): + """triton attention metadata.""" pass class TritonAttentionImpl(AttentionImpl[TritonAttentionMetadata]): + """triton attention implementation.""" def __init__( self, @@ -19,6 +21,7 @@ def __init__( v_head_size: int = None, alibi_scale: float = None, sliding_window: int = None, + logical_softcapping: float = None, **kwargs, ): super().__init__( @@ -47,6 +50,7 @@ def forward( attn_metadata: TritonAttentionMetadata, inplace: bool = True, ) -> torch.Tensor: + """forward.""" block_offsets = attn_metadata.block_offsets q_start_loc = attn_metadata.q_start_loc @@ -92,6 +96,7 @@ def forward( class TritonAttentionBuilder(AttentionBuilder[TritonAttentionMetadata]): + """triton attention builder.""" @staticmethod def build( @@ -102,8 +107,10 @@ def build( v_head_size: int = None, alibi_scale: float = None, sliding_window: int = None, + logical_softcapping: float = None, **kwargs, ) -> TritonAttentionImpl: + """build.""" return TritonAttentionImpl(num_heads, head_size, scale=scale, @@ -111,4 +118,5 @@ def build( v_head_size=v_head_size, alibi_scale=alibi_scale, sliding_window=sliding_window, + logical_softcapping=logical_softcapping, **kwargs) diff --git a/lmdeploy/pytorch/backends/cuda/awq_modules.py b/lmdeploy/pytorch/backends/cuda/awq_modules.py index ac0d5714d..913f0a40d 100644 --- a/lmdeploy/pytorch/backends/cuda/awq_modules.py +++ b/lmdeploy/pytorch/backends/cuda/awq_modules.py @@ -58,6 +58,7 @@ def wq_gemm_forward( class AwqLinearW4A16Impl(LinearW4A16Impl): + """awq kernel linear.""" def __init__(self, mod: nn.Module): super().__init__() @@ -73,6 +74,7 @@ def __init__(self, mod: nn.Module): self.out_features = mod.out_features def forward(self, x, all_reduce: bool = False): + """forward.""" out_features = self.scales.size(1) out = wq_gemm_forward(x, self.qweight, self.qzeros, self.scales, self.w_bit, self.group_size, self.bias, @@ -83,7 +85,9 @@ def forward(self, x, all_reduce: bool = False): class AwqLinearW4A16Builder(LinearW4A16Builder): + """awq linear builder.""" @staticmethod def build(mod: nn.Module, ctx_mgr: StepContextManager = None): + """build.""" return AwqLinearW4A16Impl(mod) diff --git a/lmdeploy/pytorch/backends/cuda/layer_backend.py b/lmdeploy/pytorch/backends/cuda/layer_backend.py index 1ee3a7316..f7cbe5fd8 100644 --- a/lmdeploy/pytorch/backends/cuda/layer_backend.py +++ b/lmdeploy/pytorch/backends/cuda/layer_backend.py @@ -13,13 +13,16 @@ class CudaLayersBackend(DefaultLayersBackend): + """cuda layer backend.""" @staticmethod def get_name() -> str: + """backend name.""" raise 'cuda' @classmethod def get_layer_impl_builder(cls, layer_type: LayerType): + """get cuda layer builder.""" if layer_type == LayerType.Attention: from .attention import TritonAttentionBuilder return TritonAttentionBuilder @@ -57,6 +60,7 @@ def get_layer_impl_builder(cls, layer_type: LayerType): @staticmethod def get_attention_metadata_cls(): + """get attention metadata class.""" from .attention import TritonAttentionMetadata return TritonAttentionMetadata @@ -67,6 +71,7 @@ def get_k_block_shape( head_size: int, dtype: torch.dtype, ) -> Tuple[int, ...]: + """get k block shape.""" return ( block_size, num_heads, @@ -80,6 +85,7 @@ def get_v_block_shape( head_size: int, dtype: torch.dtype, ) -> Tuple[int, ...]: + """get v block shape.""" return ( block_size, num_heads, diff --git a/lmdeploy/pytorch/backends/cuda/multinomial_sampling.py b/lmdeploy/pytorch/backends/cuda/multinomial_sampling.py index b91637ce7..9429e4cc5 100644 --- a/lmdeploy/pytorch/backends/cuda/multinomial_sampling.py +++ b/lmdeploy/pytorch/backends/cuda/multinomial_sampling.py @@ -5,6 +5,8 @@ class TritonMultinomialSamplingBuilder(MultinomialSamplingBuilder): + """triton multinomial sampling builder.""" def build(): + """build.""" return multinomial_sampling diff --git a/lmdeploy/pytorch/backends/cuda/norm.py b/lmdeploy/pytorch/backends/cuda/norm.py index bee9ed73a..b35a2bef5 100644 --- a/lmdeploy/pytorch/backends/cuda/norm.py +++ b/lmdeploy/pytorch/backends/cuda/norm.py @@ -8,6 +8,7 @@ class TritonRMSNormImpl(RMSNormImpl, nn.Module): + """triton RMS norm implementation.""" def __init__(self, weight: torch.Tensor, eps: float = 1e-6): super().__init__() @@ -15,6 +16,7 @@ def __init__(self, weight: torch.Tensor, eps: float = 1e-6): self.eps = eps def forward(self, x: torch.Tensor, residual: torch.Tensor = None): + """forward.""" if residual is None: x = rms_norm(x, self.weight, self.eps) return x @@ -24,7 +26,9 @@ def forward(self, x: torch.Tensor, residual: torch.Tensor = None): class TritonRMSNormBuilder(RMSNormBuilder): + """triton RMS norm implementation builder.""" @staticmethod def build(weight: torch.Tensor, eps: float = 1e-6): + """build.""" return TritonRMSNormImpl(weight, eps) diff --git a/lmdeploy/pytorch/backends/cuda/qmodules.py b/lmdeploy/pytorch/backends/cuda/qmodules.py index c41fdf6e2..2dfb2bc2b 100644 --- a/lmdeploy/pytorch/backends/cuda/qmodules.py +++ b/lmdeploy/pytorch/backends/cuda/qmodules.py @@ -13,6 +13,7 @@ class TritonRMSNormW8A8Impl(RMSNormW8A8Impl, nn.Module): + """triton RMS norm w8a8 implementation api.""" def __init__(self, weight: torch.Tensor, eps: float = 1e-6): super().__init__() @@ -20,6 +21,7 @@ def __init__(self, weight: torch.Tensor, eps: float = 1e-6): self.eps = eps def forward(self, x: torch.Tensor, residual: torch.Tensor = None): + """forward.""" if residual is not None: x = x + residual residual = x @@ -32,13 +34,16 @@ def forward(self, x: torch.Tensor, residual: torch.Tensor = None): class TritonRMSNormBuilder(RMSNormW8A8Builder): + """triton RMS norm w8a8 implementation builder.""" @staticmethod def build(weight: torch.Tensor, eps: float = 1e-6): + """build.""" return TritonRMSNormW8A8Impl(weight, eps) class TritonLinearW8A8Impl(LinearW8A8Impl, nn.Module): + """triton linear w8a8 implementation.""" def __init__(self, mod: nn.Module): super().__init__() @@ -47,6 +52,7 @@ def __init__(self, mod: nn.Module): self.bias = mod.bias def forward(self, x, all_reduce: bool = False): + """forward.""" if isinstance(x, torch.Tensor): x = x.contiguous() input_quant, input_scale = per_token_quant_int8(x, 1e-7) @@ -67,7 +73,9 @@ def forward(self, x, all_reduce: bool = False): class TritonLinearW8A8Builder(LinearW8A8Builder): + """triton linear w8a8 implementation builder.""" @staticmethod def build(mod: nn.Module, ctx_mgr: StepContextManager = None): + """build.""" return TritonLinearW8A8Impl(mod) diff --git a/lmdeploy/pytorch/backends/cuda/slora.py b/lmdeploy/pytorch/backends/cuda/slora.py index c5e75a17a..b1a0ba14d 100644 --- a/lmdeploy/pytorch/backends/cuda/slora.py +++ b/lmdeploy/pytorch/backends/cuda/slora.py @@ -14,6 +14,7 @@ @dataclass class PackedLoRAInput: + """packed lora input.""" x: torch.Tensor a_cache: torch.Tensor b_cache: torch.Tensor @@ -29,6 +30,7 @@ class PackedLoRAInput: class TritonSLoRAImpl(SLoRAImpl): + """triton slora implementation.""" def __init__(self, adapter_info: AdapterInfo, @@ -39,6 +41,7 @@ def __init__(self, self.colwise = colwise def _make_packed_lora_input(self, x, target_name: str, layer_idx: int): + """make PackedLoRAInput.""" context = self.ctx_mgr.current_context() adapter_param = context.adapter_params[target_name] @@ -219,6 +222,7 @@ def forward(self, target_name: str, layer_idx: int, is_tp: bool = True): + """forward.""" lora_input = self._make_packed_lora_input(x, target_name, layer_idx) if self.colwise and is_tp: return self._forward_colwise(lora_input, base_output) @@ -227,9 +231,11 @@ def forward(self, class TritonSLoRABuilder(SLoRABuilder): + """triton slora layer builder.""" @staticmethod def build(adapter_info: AdapterInfo, ctx_mgr: StepContextManager, colwise: bool = True): + """build.""" return TritonSLoRAImpl(adapter_info, ctx_mgr, colwise) diff --git a/lmdeploy/pytorch/backends/default/activation.py b/lmdeploy/pytorch/backends/default/activation.py index 6d3c49af1..72c71f280 100644 --- a/lmdeploy/pytorch/backends/default/activation.py +++ b/lmdeploy/pytorch/backends/default/activation.py @@ -5,6 +5,7 @@ class DefaultSiluAndMulImpl(SiluAndMulImpl, nn.Module): + """silu + multiple residual fused implementation.""" def __init__(self, inplace: bool): super().__init__() @@ -12,12 +13,15 @@ def __init__(self, inplace: bool): self.silu = nn.SiLU(inplace) def forward(self, x): + """forward.""" gate, up = x.chunk(2, -1) return self.silu(gate) * up class DefaultSiluAndMulBuilder(SiluAndMulBuilder): + """silu and mul implementation builder.""" @staticmethod def build(inplace: bool = False): + """build.""" return DefaultSiluAndMulImpl(inplace) diff --git a/lmdeploy/pytorch/backends/default/apply_rotary_emb.py b/lmdeploy/pytorch/backends/default/apply_rotary_emb.py index fa41123f6..72723103b 100644 --- a/lmdeploy/pytorch/backends/default/apply_rotary_emb.py +++ b/lmdeploy/pytorch/backends/default/apply_rotary_emb.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch +from torch import Tensor from ..apply_rotary_emb import ApplyRotaryEmbBuilder, ApplyRotaryEmbImpl @@ -16,8 +17,15 @@ def rotate_half(x): class DefaultApplyRotaryEmbImpl(ApplyRotaryEmbImpl): - - def forward(self, query, key, cos, sin, inplace: bool = True): + """Apply rotary embedding implementation.""" + + def forward(self, + query: Tensor, + key: Tensor, + cos: Tensor, + sin: Tensor, + inplace: bool = True): + """forward.""" unsqueeze_dim = -2 cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) @@ -37,7 +45,9 @@ def forward(self, query, key, cos, sin, inplace: bool = True): class DefaultApplyRotaryEmbBuilder(ApplyRotaryEmbBuilder): + """Apply rotary embedding implementation builder.""" @staticmethod def build(): + """build implementation.""" return DefaultApplyRotaryEmbImpl() diff --git a/lmdeploy/pytorch/backends/default/awq_modules.py b/lmdeploy/pytorch/backends/default/awq_modules.py index 1a6ed490a..5709fe816 100644 --- a/lmdeploy/pytorch/backends/default/awq_modules.py +++ b/lmdeploy/pytorch/backends/default/awq_modules.py @@ -9,6 +9,7 @@ class DefaultLinearW4A16Impl(LinearW4A16Impl): + """w4a16 linear implementation.""" def __init__(self, mod: nn.Module): super().__init__() @@ -22,6 +23,7 @@ def __init__(self, mod: nn.Module): self.out_features = mod.out_features def forward(self, x, all_reduce: bool = False): + """forward.""" from awq.utils.packing_utils import dequantize_gemm out_shape = x.shape[:-1] + (self.out_features, ) input_dtype = x.dtype @@ -43,7 +45,9 @@ def forward(self, x, all_reduce: bool = False): class DefaultLinearW4A16Builder(LinearW4A16Builder): + """w4a16 linear implementation builder.""" @staticmethod def build(mod: nn.Module, ctx_mgr: StepContextManager = None): + """build.""" return DefaultLinearW4A16Impl(mod) diff --git a/lmdeploy/pytorch/backends/default/layer_backend.py b/lmdeploy/pytorch/backends/default/layer_backend.py index dc8619979..8b445db27 100644 --- a/lmdeploy/pytorch/backends/default/layer_backend.py +++ b/lmdeploy/pytorch/backends/default/layer_backend.py @@ -14,6 +14,7 @@ def get_name() -> str: @classmethod def get_layer_impl_builder(cls, layer_type: LayerType): + """get builder of given layer type.""" if layer_type == LayerType.Linear: from .linear import DefaultLinearBuilder return DefaultLinearBuilder @@ -45,6 +46,7 @@ def get_k_block_shape( head_size: int, dtype: torch.dtype, ) -> Tuple[int, ...]: + """get block shape of k.""" return ( block_size, num_heads, @@ -58,6 +60,7 @@ def get_v_block_shape( head_size: int, dtype: torch.dtype, ) -> Tuple[int, ...]: + """get block shape of v.""" return ( block_size, num_heads, diff --git a/lmdeploy/pytorch/backends/default/linear.py b/lmdeploy/pytorch/backends/default/linear.py index b3a8b84e1..9d69c96d3 100644 --- a/lmdeploy/pytorch/backends/default/linear.py +++ b/lmdeploy/pytorch/backends/default/linear.py @@ -8,12 +8,14 @@ class DefaultLinearImpl(LinearImpl): + """Linear implementation api.""" def __init__(self, mod: nn.Module): super().__init__() self.mod = mod def forward(self, x, all_reduce: bool = False): + """forward.""" out = self.mod(x) if all_reduce: dist.all_reduce(out) @@ -21,7 +23,9 @@ def forward(self, x, all_reduce: bool = False): class DefaultLinearBuilder(LinearBuilder): + """linear implementation builder.""" @staticmethod def build(mod: nn.Module, ctx_mgr: StepContextManager = None): + """build.""" return DefaultLinearImpl(mod) diff --git a/lmdeploy/pytorch/backends/default/multinomial_sampling.py b/lmdeploy/pytorch/backends/default/multinomial_sampling.py index 5cffe6c63..bfd616100 100644 --- a/lmdeploy/pytorch/backends/default/multinomial_sampling.py +++ b/lmdeploy/pytorch/backends/default/multinomial_sampling.py @@ -6,11 +6,13 @@ class DefaultMultinomialSamplingImpl(MultinomialSamplingImpl): + """multinomial sampling implementation api.""" def forward(scores: torch.Tensor, seeds: torch.LongTensor, offsets: torch.LongTensor, indices: torch.Tensor = None): + """forward.""" sampled_index = torch.multinomial(scores, num_samples=1, replacement=True) @@ -19,6 +21,8 @@ def forward(scores: torch.Tensor, class DefaultMultinomialSamplingBuilder(MultinomialSamplingBuilder): + """multinomial sampling implementation builder.""" def build(): + """build.""" return DefaultMultinomialSamplingImpl() diff --git a/lmdeploy/pytorch/backends/default/norm.py b/lmdeploy/pytorch/backends/default/norm.py index cb7f3b36f..5b57b526f 100644 --- a/lmdeploy/pytorch/backends/default/norm.py +++ b/lmdeploy/pytorch/backends/default/norm.py @@ -6,6 +6,7 @@ class DefaultRMSNormImpl(RMSNormImpl, nn.Module): + """RMS norm implementation api.""" def __init__(self, weight: torch.Tensor, eps: float = 1e-6): super().__init__() @@ -13,6 +14,7 @@ def __init__(self, weight: torch.Tensor, eps: float = 1e-6): self.eps = eps def forward(self, x: torch.Tensor, residual: torch.Tensor = None): + """forward.""" input_dtype = x.dtype if residual is not None: x = x + residual @@ -27,7 +29,9 @@ def forward(self, x: torch.Tensor, residual: torch.Tensor = None): class DefaultRMSNormBuilder(RMSNormBuilder): + """RMS norm implementation builder.""" @staticmethod def build(weight: torch.Tensor, eps: float = 1e-6): + """build.""" return DefaultRMSNormImpl(weight, eps) diff --git a/lmdeploy/pytorch/backends/default/rotary_embedding.py b/lmdeploy/pytorch/backends/default/rotary_embedding.py index 734c2f600..eb8a52f11 100644 --- a/lmdeploy/pytorch/backends/default/rotary_embedding.py +++ b/lmdeploy/pytorch/backends/default/rotary_embedding.py @@ -7,6 +7,7 @@ class RotaryEmbeddingImpl(RotaryEmbeddingImpl, nn.Module): + """base rotary embedding.""" def __init__(self, dim: int, @@ -21,6 +22,7 @@ def __init__(self, self.register_buffer('inv_freq', inv_freq, persistent=False) def forward(self, x, position_ids): + """forward.""" # x: [bs, num_attention_heads, seq_len, head_size] if self.inv_freq.device != x.device: self.inv_freq = self.inv_freq.to(x.device) @@ -57,6 +59,7 @@ def __init__(self, self.max_position_embeddings = max_position_embeddings def forward(self, x, position_ids): + """forward.""" seq_len = torch.max(position_ids) + 1 if seq_len > self.max_position_embeddings: base = self.base * ((self.scaling_factor * seq_len / @@ -73,6 +76,7 @@ def forward(self, x, position_ids): class DefaultRotaryEmbeddingBuilder(RotaryEmbeddingBuilder): + """rotary embedding builder.""" @staticmethod def build( @@ -82,6 +86,7 @@ def build( scaling_factor: float = 1.0, emb_type: EmbeddingType = EmbeddingType.Default, ): + """build.""" if emb_type in (EmbeddingType.Default, EmbeddingType.LinearScaling): return RotaryEmbeddingImpl(dim, base, scaling_factor) elif emb_type == EmbeddingType.DynamicNTKScaling: diff --git a/lmdeploy/pytorch/backends/graph_runner.py b/lmdeploy/pytorch/backends/graph_runner.py index 21ef31f10..14a471590 100644 --- a/lmdeploy/pytorch/backends/graph_runner.py +++ b/lmdeploy/pytorch/backends/graph_runner.py @@ -13,6 +13,7 @@ def __init__(self, model: torch.nn.Module, **kwargs): self.model = model def __call__(self, **kwargs): + """call graph runner forward.""" return self.model(**kwargs) def prepare_inputs_for_generation( diff --git a/lmdeploy/pytorch/backends/linear.py b/lmdeploy/pytorch/backends/linear.py index 5aa1710a7..6d5336815 100644 --- a/lmdeploy/pytorch/backends/linear.py +++ b/lmdeploy/pytorch/backends/linear.py @@ -7,15 +7,19 @@ class LinearImpl(ABC, nn.Module): + """Linear implementation api.""" @abstractmethod def forward(self, x, all_reduce: bool = False): + """forward.""" raise NotImplementedError class LinearBuilder(ABC): + """linear implementation builder.""" @staticmethod @abstractmethod def build(mod: nn.Module, ctx_mgr: StepContextManager = None): + """build.""" raise NotImplementedError diff --git a/lmdeploy/pytorch/backends/multinomial_sampling.py b/lmdeploy/pytorch/backends/multinomial_sampling.py index 35dcf091a..5e43e28f2 100644 --- a/lmdeploy/pytorch/backends/multinomial_sampling.py +++ b/lmdeploy/pytorch/backends/multinomial_sampling.py @@ -5,18 +5,22 @@ class MultinomialSamplingImpl(ABC): + """multinomial sampling implementation api.""" @abstractmethod def forward(scores: torch.Tensor, seeds: torch.LongTensor, offsets: torch.LongTensor, indices: torch.Tensor = None): + """forward.""" raise NotImplementedError class MultinomialSamplingBuilder(ABC): + """multinomial sampling implementation builder.""" @staticmethod @abstractmethod def build(): + """build.""" raise NotImplementedError diff --git a/lmdeploy/pytorch/backends/norm.py b/lmdeploy/pytorch/backends/norm.py index 21bf01a96..7635fdc8c 100644 --- a/lmdeploy/pytorch/backends/norm.py +++ b/lmdeploy/pytorch/backends/norm.py @@ -5,15 +5,19 @@ class RMSNormImpl(ABC): + """RMS norm implementation api.""" @abstractmethod def forward(self, x: torch.Tensor, residual: torch.Tensor = None): + """forward.""" raise NotImplementedError class RMSNormBuilder(ABC): + """RMS norm implementation builder.""" @staticmethod @abstractmethod def build(weight: torch.Tensor, eps: float = 1e-6, inplace: bool = False): + """build.""" raise NotImplementedError diff --git a/lmdeploy/pytorch/backends/qmodules.py b/lmdeploy/pytorch/backends/qmodules.py index 19bfc4e29..8183a9142 100644 --- a/lmdeploy/pytorch/backends/qmodules.py +++ b/lmdeploy/pytorch/backends/qmodules.py @@ -8,30 +8,38 @@ class RMSNormW8A8Impl(ABC, nn.Module): + """RMS norm w8a8 implementation api.""" @abstractmethod def forward(self, x: torch.Tensor, residual: torch.Tensor = None): + """forward.""" raise NotImplementedError class RMSNormW8A8Builder(ABC): + """RMS norm w8a8 implementation builder.""" @staticmethod @abstractmethod def build(weight: torch.Tensor, eps: float = 1e-6): + """build.""" raise NotImplementedError class LinearW8A8Impl(ABC, nn.Module): + """linear w8a8 implementation api.""" @abstractmethod def forward(self, x, all_reduce: bool = False): + """forward.""" raise NotImplementedError class LinearW8A8Builder(ABC): + """linear w8a8 implementation builder.""" @staticmethod @abstractmethod def build(mod: nn.Module, ctx_mgr: StepContextManager = None): + """build.""" raise NotImplementedError diff --git a/lmdeploy/pytorch/backends/rotary_embedding.py b/lmdeploy/pytorch/backends/rotary_embedding.py index 5bf5cada4..46b095685 100644 --- a/lmdeploy/pytorch/backends/rotary_embedding.py +++ b/lmdeploy/pytorch/backends/rotary_embedding.py @@ -4,19 +4,23 @@ class EmbeddingType(Enum): + """rotary embedding type.""" Default = auto() LinearScaling = auto() DynamicNTKScaling = auto() class RotaryEmbeddingImpl(ABC): + """rotary embedding implementation api.""" @abstractmethod def forward(self, x, position_ids): + """forward.""" raise NotImplementedError class RotaryEmbeddingBuilder(ABC): + """rotary embedding implementation builder.""" @staticmethod @abstractmethod @@ -27,4 +31,5 @@ def build( scaling_factor: float = 1.0, emb_type: EmbeddingType = EmbeddingType.Default, ): + """build.""" raise NotImplementedError diff --git a/lmdeploy/pytorch/backends/slora.py b/lmdeploy/pytorch/backends/slora.py index 6efd3f960..fc7681fd5 100644 --- a/lmdeploy/pytorch/backends/slora.py +++ b/lmdeploy/pytorch/backends/slora.py @@ -10,6 +10,7 @@ @dataclass class AdapterInfo: + """Adapter information.""" r: dict lora_A: nn.ModuleDict lora_B: nn.ModuleDict @@ -34,6 +35,7 @@ def from_lora_linear(mod: nn.Module, base_slice: slice = None): class SLoRAImpl(ABC): + """slora implementation api.""" @abstractmethod def forward(self, @@ -42,14 +44,17 @@ def forward(self, target_name: str, layer_idx: int, is_tp: bool = True): + """forward.""" raise NotImplementedError class SLoRABuilder(ABC): + """slora implementation builder.""" @staticmethod @abstractmethod def build(adapter_info: AdapterInfo, ctx_mgr: StepContextManager, colwise: bool = True): + """build.""" raise NotImplementedError diff --git a/lmdeploy/pytorch/layers/activation.py b/lmdeploy/pytorch/layers/activation.py index 9f7e67098..9e278d556 100644 --- a/lmdeploy/pytorch/layers/activation.py +++ b/lmdeploy/pytorch/layers/activation.py @@ -1,10 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. -from torch import nn +from torch import Tensor, nn from ..backends import LayerType, get_backend class SiluAndMul(nn.Module): + """Silu and elementwise multiple.""" def __init__(self, inplace: bool = True): super().__init__() @@ -12,5 +13,6 @@ def __init__(self, inplace: bool = True): builder = backend.get_layer_impl_builder(LayerType.SiluAndMul) self.impl = builder.build(inplace) - def forward(self, x): + def forward(self, x: Tensor): + """forward.""" return self.impl.forward(x) diff --git a/lmdeploy/pytorch/layers/apply_rotary_emb.py b/lmdeploy/pytorch/layers/apply_rotary_emb.py deleted file mode 100644 index 2e600ff96..000000000 --- a/lmdeploy/pytorch/layers/apply_rotary_emb.py +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from torch import nn - -from ..backends import LayerType, get_backend - - -class ApplyRotaryEmb(nn.Module): - - def __init__(self): - super().__init__() - backend = get_backend() - builder = backend.get_layer_impl_builder(LayerType.ApplyRotaryEmb) - self.impl = builder.build() - - def forward(self, query, key, cos, sin, inplace: bool = True): - return self.impl.forward(query, key, cos, sin, inplace) diff --git a/lmdeploy/pytorch/layers/attention.py b/lmdeploy/pytorch/layers/attention.py index 7f82fbfd2..e14f5ac8f 100644 --- a/lmdeploy/pytorch/layers/attention.py +++ b/lmdeploy/pytorch/layers/attention.py @@ -7,6 +7,7 @@ class Attention(nn.Module): + """Attention layer.""" def __init__( self, @@ -17,6 +18,7 @@ def __init__( v_head_size: int = None, alibi_scale: float = None, sliding_window: int = None, + logical_softcapping: float = None, **kwargs, ): super().__init__() @@ -32,6 +34,7 @@ def __init__( v_head_size, alibi_scale, sliding_window, + logical_softcapping, **kwargs, ) @@ -45,6 +48,7 @@ def forward( attn_metadata: AttentionMetadata, inplace: bool = True, ) -> torch.Tensor: + """forward.""" return self.impl.forward( query, key, diff --git a/lmdeploy/pytorch/layers/linear.py b/lmdeploy/pytorch/layers/linear.py index 7edc06f1b..64006580e 100644 --- a/lmdeploy/pytorch/layers/linear.py +++ b/lmdeploy/pytorch/layers/linear.py @@ -53,6 +53,7 @@ def _get_world_rank(): class SLoRA(nn.Module): + """SLoRA layer.""" def __init__(self, adapter_info: AdapterInfo, @@ -68,11 +69,13 @@ def __init__(self, self.is_tp = is_tp def forward(self, x, base_output=None): + """forward of loraA@loraB.""" return self.impl.forward(x, base_output, self.target_name, self.layer_idx, self.is_tp) class AwqLinear(nn.Module): + """w4a16 linear.""" def __init__(self, mod: nn.Module, @@ -95,6 +98,7 @@ def __init__(self, self.colwise = colwise def forward(self, x): + """w4a16 forward.""" if self.lora_adapters is None: is_tp = False if self.colwise else self.is_tp return self.impl.forward(x, is_tp) @@ -109,6 +113,7 @@ def forward(self, x): class W8A8Linear(nn.Module): + """w8a8 linear.""" def __init__(self, mod: nn.Module, @@ -123,11 +128,13 @@ def __init__(self, self.colwise = colwise def forward(self, x): + """forward of w8a8.""" is_tp = False if self.colwise else self.is_tp return self.impl.forward(x, is_tp) class BaseLinear(nn.Module): + """linear layer.""" def __init__(self, mod: nn.Module, @@ -149,6 +156,7 @@ def __init__(self, self.colwise = colwise def forward(self, x): + """forward of linear layer.""" if self.lora_adapters is None: is_tp = False if self.colwise else self.is_tp return self.impl.forward(x, is_tp) @@ -324,6 +332,7 @@ def build_colwise_linear(mod: nn.Module, adapter_infos: List[AdapterInfo] = None, ctx_mgr: Any = None, is_tp: bool = False) -> nn.Module: + """build columnwise parallel linear layer.""" return build_linear(mod, adapter_infos, ctx_mgr, colwise=True, is_tp=is_tp) @@ -331,6 +340,7 @@ def build_rowwise_linear(mod: nn.Module, adapter_infos: List[AdapterInfo] = None, ctx_mgr: Any = None, is_tp: bool = False) -> nn.Module: + """build rowwise parallel linear layer.""" return build_linear(mod, adapter_infos, ctx_mgr, diff --git a/lmdeploy/pytorch/layers/multinomial_sampling.py b/lmdeploy/pytorch/layers/multinomial_sampling.py index dd00e870c..562636bee 100644 --- a/lmdeploy/pytorch/layers/multinomial_sampling.py +++ b/lmdeploy/pytorch/layers/multinomial_sampling.py @@ -8,6 +8,7 @@ def multinomial_sampling(scores: torch.Tensor, seeds: torch.LongTensor, offsets: torch.LongTensor, indices: torch.Tensor = None): + """multinomial sampling op.""" impl_builder = get_backend().get_layer_impl_builder( LayerType.MultinomialSampling) return impl_builder.build()(scores, seeds, offsets, indices) diff --git a/lmdeploy/pytorch/layers/norm.py b/lmdeploy/pytorch/layers/norm.py index 1767794f1..7df0a810a 100644 --- a/lmdeploy/pytorch/layers/norm.py +++ b/lmdeploy/pytorch/layers/norm.py @@ -6,6 +6,7 @@ class RMSNorm(nn.Module): + """RMS Norm with add residual.""" def __init__(self, weight: torch.Tensor, @@ -20,4 +21,5 @@ def __init__(self, self.impl = builder.build(weight, eps) def forward(self, x: torch.Tensor, residual: torch.Tensor = None): + """forward.""" return self.impl.forward(x, residual) diff --git a/lmdeploy/pytorch/layers/rotary_embedding.py b/lmdeploy/pytorch/layers/rotary_embedding.py index 1db92cde3..d3614e6a5 100644 --- a/lmdeploy/pytorch/layers/rotary_embedding.py +++ b/lmdeploy/pytorch/layers/rotary_embedding.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from torch import nn +from torch import Tensor, nn from ..backends import LayerType, get_backend from ..backends.rotary_embedding import EmbeddingType @@ -11,6 +11,7 @@ def build_rotary_embedding( base: int = 10000, scaling_factor: float = 1.0, emb_type: EmbeddingType = EmbeddingType.Default) -> nn.Module: + """build rotary embedding op.""" backend = get_backend() builder = backend.get_layer_impl_builder(LayerType.RotaryEmbedding) @@ -19,6 +20,7 @@ def build_rotary_embedding( class ApplyRotaryEmb(nn.Module): + """apply rotary embedding.""" def __init__(self): super().__init__() @@ -26,5 +28,11 @@ def __init__(self): builder = backend.get_layer_impl_builder(LayerType.ApplyRotaryEmb) self.impl = builder.build() - def forward(self, query, key, cos, sin, inplace: bool = True): + def forward(self, + query: Tensor, + key: Tensor, + cos: Tensor, + sin: Tensor, + inplace: bool = True): + """forward.""" return self.impl.forward(query, key, cos, sin, inplace) From e5a790b75e633692e62501fdb122f2a2c2a1df47 Mon Sep 17 00:00:00 2001 From: grimoire Date: Tue, 30 Jul 2024 10:49:48 +0800 Subject: [PATCH 15/89] awq cudagraph --- lmdeploy/pytorch/backends/cuda/awq_modules.py | 12 +- .../pytorch/backends/cuda/graph_runner.py | 14 +-- lmdeploy/pytorch/kernels/cuda/awq_kernels.py | 106 ++++++++++++++++++ 3 files changed, 115 insertions(+), 17 deletions(-) create mode 100644 lmdeploy/pytorch/kernels/cuda/awq_kernels.py diff --git a/lmdeploy/pytorch/backends/cuda/awq_modules.py b/lmdeploy/pytorch/backends/cuda/awq_modules.py index 913f0a40d..367622a6e 100644 --- a/lmdeploy/pytorch/backends/cuda/awq_modules.py +++ b/lmdeploy/pytorch/backends/cuda/awq_modules.py @@ -20,6 +20,7 @@ def wq_gemm_forward( ): """wq gemm forward.""" from awq.modules.linear.gemm import awq_ext + from lmdeploy.pytorch.kernels.cuda.awq_kernels import awq_dequantize_weights out_shape = x.shape[:-1] + (out_features, ) input_dtype = x.dtype if input_dtype != torch.float16: @@ -28,16 +29,7 @@ def wq_gemm_forward( FP16_MATMUL_HEURISTIC_CONDITION = x.size(0) * x.size(1) >= 1024 if FP16_MATMUL_HEURISTIC_CONDITION: - # TODO: remove event wait if awq kernel set stream - default_stream = torch.cuda.default_stream() - event_def = torch.cuda.Event() - event_def.record() - event_def.wait(default_stream) - out = awq_ext.dequantize_weights_cuda(qweight, scales, qzeros, 0, 0, 0, - False) - event_def = torch.cuda.Event() - event_def.record(default_stream) - event_def.wait() + out = awq_dequantize_weights(qweight, scales, qzeros) out = torch.matmul(x, out) else: x = x.flatten(0, -2) diff --git a/lmdeploy/pytorch/backends/cuda/graph_runner.py b/lmdeploy/pytorch/backends/cuda/graph_runner.py index 3adc09f45..c59d09a8a 100644 --- a/lmdeploy/pytorch/backends/cuda/graph_runner.py +++ b/lmdeploy/pytorch/backends/cuda/graph_runner.py @@ -213,13 +213,13 @@ def check_enable_graph(self): if self.backend_config.eager_mode: return False - # do not support awq - hf_config = self.model_config.hf_config - quantization_config = getattr(hf_config, 'quantization_config', dict()) - quant_method = quantization_config.get('quant_method', None) - if quant_method == 'awq': - logger.warning('AWQ model does not support cuda graph yet.') - return False + # TODO: should we enable cudagraph awq now? + # hf_config = self.model_config.hf_config + # quantization_config = getattr(hf_config, 'quantization_config', dict()) + # quant_method = quantization_config.get('quant_method', None) + # if quant_method == 'awq': + # logger.warning('AWQ model does not support cuda graph yet.') + # return False return True def get_graph_key(self, input_ids: torch.Tensor, diff --git a/lmdeploy/pytorch/kernels/cuda/awq_kernels.py b/lmdeploy/pytorch/kernels/cuda/awq_kernels.py new file mode 100644 index 000000000..5ecd93746 --- /dev/null +++ b/lmdeploy/pytorch/kernels/cuda/awq_kernels.py @@ -0,0 +1,106 @@ +import triton +import triton.language as tl +from torch import Tensor +import torch +from .triton_utils import get_kernel_meta, wrap_jit_func + + +@triton.jit +def _get_unpacked_order(offs_n, elem_per_int): + """get unpacked order.""" + origin_order = offs_n % elem_per_int + unpacked_order = (origin_order % 2) * 4 + origin_order // 2 + return unpacked_order + +@triton.jit +def _unpack_weight(weight, order): + """unpack weight.""" + weight = weight >> (order * 4) + weight = weight & 0xf + return weight + + +@wrap_jit_func +@triton.jit +def _dequantize_weights_kernel( + QWeight, Scales, Qzeros, Out, + in_size: tl.constexpr, + out_size: tl.constexpr, + group_size: tl.constexpr, + elem_per_int: tl.constexpr, + stride_wk: tl.constexpr, + stride_wn: tl.constexpr, + stride_sk: tl.constexpr, + stride_sn: tl.constexpr, + stride_zk: tl.constexpr, + stride_zn: tl.constexpr, + stride_ok: tl.constexpr, + stride_on: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + """dequantize weight kernel.""" + k_block_id = tl.program_id(0) + n_block_id = tl.program_id(1) + + offs_k = tl.arange(0, BLOCK_K) + k_block_id * BLOCK_K + offs_n = tl.arange(0, BLOCK_N) + n_block_id * BLOCK_N + offs_wk = offs_k + offs_wn = offs_n // elem_per_int + offs_sk = offs_k // group_size + offs_sn = offs_n + offs_zk = offs_sk + offs_zn = offs_wn + mask_k = offs_k < in_size + mask_n = offs_n < out_size + mask = mask_k[:, None] and mask_n[None, :] + unpacked_order = _get_unpacked_order(offs_n, elem_per_int) + + qw_ptr = QWeight + offs_wk[:, None] * stride_wk + offs_wn[None, :] * stride_wn + s_ptr = Scales + offs_sk[:, None] * stride_sk + offs_sn[None, :] * stride_sn + qz_ptr = Qzeros + offs_zk[:, None] * stride_zk + offs_zn[None, :] * stride_zn + + out_dtype = Out.dtype.element_ty + qw = tl.load(qw_ptr, mask=mask) + s = tl.load(s_ptr, mask=mask).to(out_dtype) + qz = tl.load(qz_ptr, mask=mask) + + # unpack w, z + w = _unpack_weight(qw, unpacked_order) + z = _unpack_weight(qz, unpacked_order) + + o = (w - z).to(s.dtype) * s + o_ptr = Out + offs_k[:, None] * stride_ok + offs_n[None, :] * stride_on + tl.store(o_ptr, o, mask=mask) + + +def awq_dequantize_weights(qweight: Tensor, scales: Tensor, qzeros: Tensor): + """dequantize weights of awq.""" + in_size = qweight.size(0) + out_size = scales.size(1) + group_size = in_size // qzeros.size(0) + elem_per_int = out_size // qzeros.size(1) + + output = scales.new_empty(in_size, out_size, dtype=scales.dtype) + + BLOCK_K = 128 + BLOCK_N = 64 + kernel_meta = get_kernel_meta(qweight) + grid = (triton.cdiv(in_size, BLOCK_K), triton.cdiv(out_size, BLOCK_N)) + _dequantize_weights_kernel[grid]( + qweight, scales, qzeros, output, + in_size, out_size, group_size, elem_per_int, + qweight.stride(0), + qweight.stride(1), + scales.stride(0), + scales.stride(1), + qzeros.stride(0), + qzeros.stride(1), + output.stride(0), + output.stride(1), + BLOCK_K=BLOCK_K, + BLOCK_N=BLOCK_N, + **kernel_meta, + ) + + return output From fbc0912669ed135e7d150ace52a63061ade5d55b Mon Sep 17 00:00:00 2001 From: grimoire Date: Tue, 30 Jul 2024 11:47:22 +0800 Subject: [PATCH 16/89] merge main --- .github/ISSUE_TEMPLATE/1-bug-report.yml | 2 + .github/scripts/action_tools.py | 16 +- .github/scripts/eval_opencompass_config.py | 677 ++---- .github/scripts/test_triton_server.py | 73 - .github/workflows/benchmark.yml | 274 +-- .github/workflows/cuda11.8-whl-release.yml | 9 +- .github/workflows/daily_ete_test.yml | 215 +- .github/workflows/evaluate.yml | 20 +- .github/workflows/pypi.yml | 9 +- .github/workflows/windows-x64-gpu.yml | 9 +- CMakeLists.txt | 23 +- MANIFEST.in | 2 - README.md | 20 +- README_zh-CN.md | 18 +- autotest/config.yaml | 34 +- .../interface/pipeline/test_pipeline_func.py | 64 + .../pipeline/test_pipeline_longtext_func.py | 12 +- .../restful/test_restful_chat_func.py | 198 ++ .../restful/test_restful_completions_v1.py | 33 + .../chat/test_command_chat_hf_pytorch.py | 2 +- .../chat/test_command_chat_hf_turbomind.py | 2 +- .../pipeline/test_pipeline_chat_turbomind.py | 4 + .../restful/test_restful_chat_hf_turbomind.py | 14 +- autotest/utils/get_run_config.py | 5 +- autotest/utils/pipeline_chat.py | 38 +- autotest/utils/restful_return_check.py | 78 +- autotest/utils/run_client_chat.py | 15 +- autotest/utils/run_restful_chat.py | 9 +- benchmark/profile_generation.py | 60 +- benchmark/profile_serving.py | 260 --- benchmark/profile_throughput.py | 6 +- builder/windows/setup_cuda.ps1 | 103 + docs/en/benchmark/profile_triton_server.md | 58 - docs/en/get_started.md | 2 +- docs/en/index.rst | 1 - docs/en/inference/turbomind.md | 2 - docs/en/multi_modal/cogvlm.md | 2 +- docs/en/multi_modal/xcomposer2d5.md | 160 ++ docs/en/serving/api_server_tools.md | 100 +- docs/en/supported_models/supported_models.md | 122 +- docs/zh_cn/benchmark/profile_triton_server.md | 61 - docs/zh_cn/get_started.md | 2 +- docs/zh_cn/index.rst | 1 - docs/zh_cn/inference/turbomind.md | 2 - docs/zh_cn/multi_modal/cogvlm.md | 2 +- docs/zh_cn/multi_modal/xcomposer2d5.md | 160 ++ docs/zh_cn/serving/api_server_tools.md | 100 +- .../supported_models/supported_models.md | 122 +- lmdeploy/cli/serve.py | 25 - lmdeploy/lite/apis/auto_awq.py | 4 +- lmdeploy/lite/apis/calibrate.py | 4 + lmdeploy/lite/apis/smooth_quant.py | 2 + lmdeploy/lite/quantization/awq.py | 9 + lmdeploy/lite/quantization/calibration.py | 65 +- lmdeploy/lite/utils/batch_split.py | 2 + lmdeploy/messages.py | 2 +- lmdeploy/model.py | 299 +-- lmdeploy/pytorch/backends/cuda/awq_modules.py | 4 +- .../pytorch/backends/cuda/graph_runner.py | 3 +- lmdeploy/pytorch/backends/cuda/slora.py | 1 - lmdeploy/pytorch/backends/graph_runner.py | 4 + lmdeploy/pytorch/chat.py | 10 +- lmdeploy/pytorch/check_env/__init__.py | 22 + lmdeploy/pytorch/config.py | 9 + lmdeploy/pytorch/engine/model_agent.py | 35 +- lmdeploy/pytorch/kernels/cuda/awq_kernels.py | 32 +- .../pytorch/kernels/cuda/pagedattention.py | 22 + lmdeploy/pytorch/layers/linear.py | 4 +- lmdeploy/pytorch/models/awq_modules.py | 67 + lmdeploy/pytorch/models/gemma.py | 3 + lmdeploy/pytorch/models/module_map.py | 8 + lmdeploy/pytorch/models/peft.py | 16 +- lmdeploy/pytorch/weight_loader/dist_utils.py | 247 ++- lmdeploy/serve/async_engine.py | 27 + lmdeploy/serve/client.py | 69 - lmdeploy/serve/gradio/app.py | 24 +- .../serve/gradio/triton_server_backend.py | 161 -- lmdeploy/serve/openai/api_client.py | 2 + lmdeploy/serve/openai/api_server.py | 46 +- lmdeploy/serve/turbomind/chatbot.py | 687 ------ lmdeploy/serve/turbomind/service_docker_up.sh | 87 - .../triton_models/interactive/1/placeholder | 0 .../triton_models/interactive/config.pbtxt | 281 --- .../triton_models/postprocessing/1/model.py | 136 -- .../triton_models/postprocessing/config.pbtxt | 41 - .../triton_models/preprocessing/1/model.py | 151 -- .../triton_models/preprocessing/config.pbtxt | 37 - .../triton_models/tokenizer/placeholder | 0 .../triton_models/weights/config.ini | 0 lmdeploy/serve/turbomind/utils.py | 103 - lmdeploy/turbomind/chat.py | 9 +- lmdeploy/turbomind/deploy/converter.py | 101 +- .../turbomind/deploy/source_model/__init__.py | 1 + .../turbomind/deploy/source_model/glm4_awq.py | 93 + .../turbomind/deploy/source_model/internvl.py | 5 + .../turbomind/deploy/source_model/llama.py | 40 +- .../deploy/source_model/xcomposer2.py | 10 +- .../turbomind/deploy/target_model/base.py | 4 + .../turbomind/deploy/target_model/plora_w4.py | 2 + lmdeploy/turbomind/deploy/target_model/w4.py | 21 +- lmdeploy/turbomind/supported_models.py | 5 +- lmdeploy/turbomind/turbomind.py | 73 +- lmdeploy/version.py | 2 +- lmdeploy/vl/model/xcomposer2.py | 68 +- lmdeploy/vl/templates.py | 2 + lmdeploy/vl/tools/__init__.py | 1 + lmdeploy/vl/tools/merge_xcomposer2d5_task.py | 65 + requirements/runtime.txt | 6 +- src/turbomind/kernels/attention/array_ops.h | 28 +- .../kernels/attention/attention_params.h | 4 + .../kernels/attention/attention_universal.h | 3 + .../kernels/attention/kv_cache_utils_v2.cu | 42 +- .../kernels/attention/kv_cache_utils_v2.h | 12 + .../kernels/attention/test_attention.cu | 9 + src/turbomind/models/llama/llama_params.h | 16 +- .../models/llama/unified_attention_layer.cc | 9 +- src/turbomind/triton_backend/CMakeLists.txt | 247 --- .../triton_backend/libfastertransformer.cc | 1914 ----------------- .../libtriton_fastertransformer.ldscript | 30 - .../triton_backend/llama/LlamaTritonModel.cc | 5 + .../llama/LlamaTritonModelInstance.cc | 1 - src/turbomind/triton_backend/triton_utils.hpp | 57 - src/turbomind/utils/allocator.h | 2 + tests/test_lmdeploy/test_model.py | 86 +- 124 files changed, 2791 insertions(+), 6134 deletions(-) delete mode 100644 .github/scripts/test_triton_server.py delete mode 100644 benchmark/profile_serving.py create mode 100644 builder/windows/setup_cuda.ps1 delete mode 100644 docs/en/benchmark/profile_triton_server.md create mode 100644 docs/en/multi_modal/xcomposer2d5.md delete mode 100644 docs/zh_cn/benchmark/profile_triton_server.md create mode 100644 docs/zh_cn/multi_modal/xcomposer2d5.md create mode 100644 lmdeploy/pytorch/models/awq_modules.py delete mode 100644 lmdeploy/serve/client.py delete mode 100644 lmdeploy/serve/gradio/triton_server_backend.py delete mode 100644 lmdeploy/serve/turbomind/chatbot.py delete mode 100644 lmdeploy/serve/turbomind/service_docker_up.sh delete mode 100644 lmdeploy/serve/turbomind/triton_models/interactive/1/placeholder delete mode 100644 lmdeploy/serve/turbomind/triton_models/interactive/config.pbtxt delete mode 100644 lmdeploy/serve/turbomind/triton_models/postprocessing/1/model.py delete mode 100644 lmdeploy/serve/turbomind/triton_models/postprocessing/config.pbtxt delete mode 100644 lmdeploy/serve/turbomind/triton_models/preprocessing/1/model.py delete mode 100644 lmdeploy/serve/turbomind/triton_models/preprocessing/config.pbtxt delete mode 100644 lmdeploy/serve/turbomind/triton_models/tokenizer/placeholder delete mode 100644 lmdeploy/serve/turbomind/triton_models/weights/config.ini delete mode 100644 lmdeploy/serve/turbomind/utils.py create mode 100644 lmdeploy/turbomind/deploy/source_model/glm4_awq.py create mode 100644 lmdeploy/vl/tools/__init__.py create mode 100644 lmdeploy/vl/tools/merge_xcomposer2d5_task.py delete mode 100644 src/turbomind/triton_backend/libfastertransformer.cc delete mode 100644 src/turbomind/triton_backend/libtriton_fastertransformer.ldscript delete mode 100644 src/turbomind/triton_backend/triton_utils.hpp diff --git a/.github/ISSUE_TEMPLATE/1-bug-report.yml b/.github/ISSUE_TEMPLATE/1-bug-report.yml index d9e695673..3095e5b42 100644 --- a/.github/ISSUE_TEMPLATE/1-bug-report.yml +++ b/.github/ISSUE_TEMPLATE/1-bug-report.yml @@ -10,6 +10,7 @@ body: options: - label: 1. I have searched related issues but cannot get the expected help. - label: 2. The bug has not been fixed in the latest version. + - label: 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback. - type: textarea attributes: label: Describe the bug @@ -31,6 +32,7 @@ body: description: | 1. Please run `lmdeploy check_env` to collect necessary environment information and paste it here. 2. You may add addition that may be helpful for locating the problem, such as + - Which **model** are you using? - How you installed PyTorch \[e.g., pip, conda, source\] - Other environment variables that may be related (such as `$PATH`, `$LD_LIBRARY_PATH`, `$PYTHONPATH`, etc.) placeholder: Environment here. diff --git a/.github/scripts/action_tools.py b/.github/scripts/action_tools.py index 49226f02f..cc6519330 100644 --- a/.github/scripts/action_tools.py +++ b/.github/scripts/action_tools.py @@ -117,7 +117,7 @@ def evaluate(models: List[str], datasets: List[str], workspace: str): print(f'Start evaluating {idx+1}/{num_model} {ori_model} ...') model = ori_model.lower() model_, precision = model.rsplit('_', 1) - do_lite = precision in ['w4a16', 'w4kv8', 'w8a8'] + do_lite = precision in ['4bits', 'kvint4', 'kvint8'] if do_lite: model = model_ engine_type, model_ = model.split('_', 1) @@ -170,10 +170,12 @@ def evaluate(models: List[str], datasets: List[str], workspace: str): if ret != 0: continue csv_files = glob.glob(f'{work_dir}/*/summary/summary_*.csv') - if len(csv_files) != 1: + + if len(csv_files) < 1: logging.error(f'Did not find summary csv file {csv_files}') continue - csv_file = csv_files[0] + else: + csv_file = max(csv_files, key=os.path.getctime) # print csv_txt to screen csv_txt = csv_file.replace('.csv', '.txt') if os.path.exists(csv_txt): @@ -217,14 +219,14 @@ def evaluate(models: List[str], datasets: List[str], workspace: str): header = ','.join(['Model', 'Engine', 'Precision'] + dataset_names) f.write(header + '\n') - f.write(row + '\n') if hf_res_row: f.write(hf_res_row + '\n') + f.write(row + '\n') else: with open(output_csv, 'a') as f: - f.write(row + '\n') if hf_res_row: f.write(hf_res_row + '\n') + f.write(row + '\n') # write to github action summary _append_summary('## Evaluation Results') @@ -248,7 +250,7 @@ def create_model_links(src_dir: str, dst_dir: str): def generate_benchmark_report(report_path: str): # write to github action summary - _append_summary('## Evaluation Results Start') + _append_summary('## Benchmark Results Start') subfolders = [f.path for f in os.scandir(report_path) if f.is_dir()] for dir_path in subfolders: second_subfolders = [ @@ -309,7 +311,7 @@ def generate_benchmark_report(report_path: str): if 'generation' in benchmark_subfolder: add_summary(merged_csv_path) print(merged_df) - _append_summary('## Evaluation Results End') + _append_summary('## Benchmark Results End') if __name__ == '__main__': diff --git a/.github/scripts/eval_opencompass_config.py b/.github/scripts/eval_opencompass_config.py index 7b4bae62e..2057082ca 100644 --- a/.github/scripts/eval_opencompass_config.py +++ b/.github/scripts/eval_opencompass_config.py @@ -1,3 +1,5 @@ +from copy import deepcopy + from mmengine.config import read_base from opencompass.models import (LmdeployPytorchModel, TurboMindModel, TurboMindModelwithChatTemplate) @@ -147,15 +149,6 @@ MAX_SESSION_LEN = 2048 MAX_NEW_TOKENS = 1024 -tb_engine_config_template_max_bs_8 = dict(session_len=MAX_SESSION_LEN, - max_batch_size=8, - rope_scaling_factor=1.0) -tb_engine_config_template_max_bs_16 = dict(session_len=MAX_SESSION_LEN, - max_batch_size=16, - rope_scaling_factor=1.0) -tb_engine_config_template_max_bs_32 = dict(session_len=MAX_SESSION_LEN, - max_batch_size=32, - rope_scaling_factor=1.0) tb_engine_config_template_max_bs_128 = dict(session_len=MAX_SESSION_LEN, max_batch_size=128, rope_scaling_factor=1.0) @@ -164,8 +157,6 @@ tp=2, rope_scaling_factor=1.0) -pt_engine_config_template_max_bs_8 = dict(session_len=MAX_SESSION_LEN, - max_batch_size=8) pt_engine_config_template_max_bs_16 = dict(session_len=MAX_SESSION_LEN, max_batch_size=16) pt_engine_config_template_max_bs_32 = dict(session_len=MAX_SESSION_LEN, @@ -193,29 +184,12 @@ cache_max_entry_count=0.5, max_prefill_token_num=4096, max_batch_size=64) -pt_engine_config_template_max_bs_128_prefill = dict( - session_len=MAX_SESSION_LEN, - cache_max_entry_count=0.5, - max_prefill_token_num=4096, - max_batch_size=128) pt_engine_config_template_max_bs_8_prefill_tp2 = dict( session_len=MAX_SESSION_LEN, cache_max_entry_count=0.5, max_prefill_token_num=4096, max_batch_size=8, tp=2) -pt_engine_config_template_max_bs_64_prefill_tp2 = dict( - session_len=MAX_SESSION_LEN, - cache_max_entry_count=0.5, - max_prefill_token_num=4096, - max_batch_size=64, - tp=2) -pt_engine_config_template_max_bs_128_prefill_tp2 = dict( - session_len=MAX_SESSION_LEN, - cache_max_entry_count=0.5, - max_prefill_token_num=4096, - max_batch_size=128, - tp=2) tb_awq_engine_config_template_max_bs_8 = dict(session_len=MAX_SESSION_LEN, max_batch_size=8, model_format='awq', @@ -224,29 +198,6 @@ max_batch_size=32, model_format='awq', rope_scaling_factor=1.0) -tb_awq_engine_config_template_max_bs_128 = dict(session_len=MAX_SESSION_LEN, - max_batch_size=128, - model_format='awq', - rope_scaling_factor=1.0) - -tb_awq_engine_config_template_max_bs_128_tp2 = dict( - session_len=MAX_SESSION_LEN, - max_batch_size=128, - model_format='awq', - tp=2, - rope_scaling_factor=1.0) - -tb_kvint4_engine_config_template_max_bs_128 = dict(session_len=MAX_SESSION_LEN, - max_batch_size=128, - quant_policy=4, - rope_scaling_factor=1.0) - -tb_kvint4_engine_config_template_max_bs_128_tp2 = dict( - session_len=MAX_SESSION_LEN, - max_batch_size=128, - quant_policy=4, - tp=2, - rope_scaling_factor=1.0) gen_config_template = dict(top_k=1, top_p=0.8, @@ -269,33 +220,15 @@ engine_config_template_max_bs_128 = dict(session_len=MAX_SESSION_LEN, max_batch_size=128) -engine_config_template_max_bs_128_awq = dict(session_len=MAX_SESSION_LEN, - max_batch_size=128, - model_format='awq') -engine_config_template_max_bs_128_kvint4 = dict(session_len=MAX_SESSION_LEN, - max_batch_size=128, - quant_policy=4) -engine_config_template_max_bs_128_kvint8 = dict(session_len=MAX_SESSION_LEN, - max_batch_size=128, - quant_policy=8) engine_config_template_max_bs_128_tp2 = dict(session_len=MAX_SESSION_LEN, max_batch_size=128, tp=2) -engine_config_template_max_bs_128_awq_tp2 = dict(session_len=MAX_SESSION_LEN, - max_batch_size=128, - model_format='awq', - tp=2) -engine_config_template_max_bs_128_kvint4_tp2 = dict( - session_len=MAX_SESSION_LEN, max_batch_size=128, quant_policy=4, tp=2) -engine_config_template_max_bs_128_kvint8_tp2 = dict( - session_len=MAX_SESSION_LEN, max_batch_size=128, quant_policy=8, tp=2) # ===== Configs for internlm/internlm-chat-7b ===== -# config for internlm-chat-7b tb_internlm_chat_7b = dict(type=TurboMindModel, - abbr='internlm-chat-7b-turbomind', + abbr='tb_internlm_chat_7b', path='internlm/internlm-chat-7b', - engine_config=tb_engine_config_template_max_bs_32, + engine_config=tb_engine_config_template_max_bs_128, gen_config=gen_config_template, max_out_len=MAX_NEW_TOKENS, max_seq_len=MAX_SESSION_LEN, @@ -304,12 +237,12 @@ meta_template=internlm_meta_template, run_cfg=run_cfg_tp1_template, end_str='') +tb_internlm_chat_7b_4bits = deepcopy(tb_internlm_chat_7b) -# config for pt internlm-chat-7b pt_internlm_chat_7b = dict(type=LmdeployPytorchModel, - abbr='internlm-chat-7b-pytorch', + abbr='pt_internlm_chat_7b', path='internlm/internlm-chat-7b', - engine_config=pt_engine_config_template_max_bs_16, + engine_config=pt_engine_config_template_max_bs_128, gen_config=gen_config_template, max_out_len=MAX_NEW_TOKENS, max_seq_len=MAX_SESSION_LEN, @@ -319,65 +252,6 @@ run_cfg=run_cfg_tp1_template, end_str='') -tb_internlm_chat_7b_w4a16 = dict( - type=TurboMindModel, - abbr='internlm-chat-7b-4bits-turbomind', - path='internlm/internlm-chat-7b-inner-4bits', - engine_config=tb_awq_engine_config_template_max_bs_32, - gen_config=gen_config_template, - max_out_len=MAX_NEW_TOKENS, - max_seq_len=MAX_SESSION_LEN, - batch_size=32, - concurrency=32, - meta_template=internlm_meta_template, - run_cfg=run_cfg_tp1_template, - end_str='') - -# ===== Configs for internlm/internlm-chat-20b ===== -# config for internlm-chat-20b -tb_internlm_chat_20b = dict(type=TurboMindModel, - abbr='internlm-chat-20b-turbomind', - path='internlm/internlm-chat-20b', - engine_config=tb_engine_config_template_max_bs_8, - gen_config=gen_config_template, - max_out_len=MAX_NEW_TOKENS, - max_seq_len=MAX_SESSION_LEN, - batch_size=8, - concurrency=8, - meta_template=internlm_meta_template, - run_cfg=run_cfg_tp1_template, - end_str='') - -# config for internlm-chat-20b-w4 model -tb_internlm_chat_20b_w4a16 = dict( - type=TurboMindModel, - abbr='internlm-chat-20b-4bits-turbomind', - path='internlm/internlm-chat-20b-inner-4bits', - engine_config=tb_awq_engine_config_template_max_bs_8, - gen_config=gen_config_template, - max_out_len=MAX_NEW_TOKENS, - max_seq_len=MAX_SESSION_LEN, - batch_size=8, - concurrency=8, - meta_template=internlm_meta_template, - run_cfg=run_cfg_tp1_template, - end_str='') - -# config for internlm-chat-20b -pt_internlm_chat_20b = dict( - type=LmdeployPytorchModel, - abbr='internlm-chat-20b-pytorch', - path='internlm/internlm-chat-20b', - engine_config=pt_engine_config_template_max_bs_8_prefill, - gen_config=gen_config_template, - max_out_len=MAX_NEW_TOKENS, - max_seq_len=MAX_SESSION_LEN, - batch_size=8, - concurrency=8, - meta_template=internlm_meta_template, - run_cfg=run_cfg_tp1_template, - end_str='') - # ===== Configs for internlm/internlm2-chat-7b ===== tb_internlm2_chat_7b = dict( type=TurboMindModelwithChatTemplate, @@ -392,98 +266,12 @@ stop_words=['', '<|im_end|>'], ) -tb_internlm2_chat_7b_w4a16 = dict( - type=TurboMindModelwithChatTemplate, - abbr='tb_internlm2_chat_7b_w4a16', - path='internlm/internlm2-chat-7b-inner-4bits', - engine_config=engine_config_template_max_bs_128_awq, - gen_config=gen_config_template, - max_seq_len=MAX_SESSION_LEN, - max_out_len=MAX_NEW_TOKENS, - batch_size=128, - run_cfg=dict(num_gpus=1), - stop_words=['', '<|im_end|>'], -) - -tb_internlm2_chat_7b_kvint4 = dict( - type=TurboMindModelwithChatTemplate, - abbr='tb_internlm2_chat_7b_kvint4', - path='internlm/internlm2-chat-7b', - engine_config=engine_config_template_max_bs_128_kvint4, - gen_config=gen_config_template, - max_seq_len=MAX_SESSION_LEN, - max_out_len=MAX_NEW_TOKENS, - batch_size=128, - run_cfg=dict(num_gpus=1), - stop_words=['', '<|im_end|>'], -) -tb_internlm2_chat_7b_kvint8 = dict( - type=TurboMindModelwithChatTemplate, - abbr='tb_internlm2_chat_7b_kvint8', - path='internlm/internlm2-chat-7b', - engine_config=engine_config_template_max_bs_128_kvint8, - gen_config=gen_config_template, - max_seq_len=MAX_SESSION_LEN, - max_out_len=MAX_NEW_TOKENS, - batch_size=128, - run_cfg=dict(num_gpus=1), - stop_words=['', '<|im_end|>'], -) - -tb_internlm2_5_7b_chat = dict( - type=TurboMindModelwithChatTemplate, - abbr='tb_internlm2_5_7b_chat', - path='internlm/internlm2_5-7b-chat', - engine_config=engine_config_template_max_bs_128, - gen_config=gen_config_template, - max_seq_len=MAX_SESSION_LEN, - max_out_len=MAX_NEW_TOKENS, - batch_size=128, - run_cfg=dict(num_gpus=1), - stop_words=['', '<|im_end|>'], -) - -tb_internlm2_5_7b_chat_w4a16 = dict( - type=TurboMindModelwithChatTemplate, - abbr='tb_internlm2_5_7b_chat_w4a16', - path='internlm/internlm2_5-7b-chat-inner-4bits', - engine_config=engine_config_template_max_bs_128_awq, - gen_config=gen_config_template, - max_seq_len=MAX_SESSION_LEN, - max_out_len=MAX_NEW_TOKENS, - batch_size=128, - run_cfg=dict(num_gpus=1), - stop_words=['', '<|im_end|>'], -) +tb_internlm2_chat_7b_4bits = deepcopy(tb_internlm2_chat_7b) +tb_internlm2_chat_7b_kvint4 = deepcopy(tb_internlm2_chat_7b) +tb_internlm2_chat_7b_kvint8 = deepcopy(tb_internlm2_chat_7b) -tb_internlm2_5_chat_7b_kvint4 = dict( - type=TurboMindModelwithChatTemplate, - abbr='tb_internlm2_5_7b_chat_kvint4', - path='internlm/internlm2_5-7b-chat', - engine_config=engine_config_template_max_bs_128_kvint4, - gen_config=gen_config_template, - max_seq_len=MAX_SESSION_LEN, - max_out_len=MAX_NEW_TOKENS, - batch_size=128, - run_cfg=dict(num_gpus=1), - stop_words=['', '<|im_end|>'], -) -tb_internlm2_5_chat_7b_kvint8 = dict( - type=TurboMindModelwithChatTemplate, - abbr='tb_internlm2_5_7b_chat_kvint8', - path='internlm/internlm2_5-7b-chat', - engine_config=engine_config_template_max_bs_128_kvint8, - gen_config=gen_config_template, - max_seq_len=MAX_SESSION_LEN, - max_out_len=MAX_NEW_TOKENS, - batch_size=128, - run_cfg=dict(num_gpus=1), - stop_words=['', '<|im_end|>'], -) - -# config for pt internlm-chat-7b pt_internlm2_chat_7b = dict(type=LmdeployPytorchModel, - abbr='internlm2-chat-7b-pytorch', + abbr='pt_internlm2_chat_7b', path='internlm/internlm2-chat-7b', engine_config=pt_engine_config_template_max_bs_64, gen_config=gen_config_template, @@ -495,51 +283,44 @@ run_cfg=run_cfg_tp1_template, end_str='<|im_end|>') -# ===== Configs for internlm/internlm2-chat-20b ===== -tb_internlm2_chat_20b = dict( +# ===== Configs for internlm/internlm2_5_7b_chat ===== +tb_internlm2_5_7b_chat = dict( type=TurboMindModelwithChatTemplate, - abbr='tb_internlm2_chat_20b', - path='internlm/internlm2-chat-20b', - engine_config=engine_config_template_max_bs_128_tp2, + abbr='tb_internlm2_5_7b_chat', + path='internlm/internlm2_5-7b-chat', + engine_config=engine_config_template_max_bs_128, gen_config=gen_config_template, max_seq_len=MAX_SESSION_LEN, max_out_len=MAX_NEW_TOKENS, batch_size=128, - run_cfg=dict(num_gpus=2), + run_cfg=dict(num_gpus=1), stop_words=['', '<|im_end|>'], ) -tb_internlm2_chat_20b_w4a16 = dict( - type=TurboMindModelwithChatTemplate, - abbr='tb_internlm2_chat_20b_w4a16', - path='internlm/internlm2-chat-20b-inner-4bits', - engine_config=engine_config_template_max_bs_128_awq_tp2, - gen_config=gen_config_template, - max_seq_len=MAX_SESSION_LEN, - max_out_len=MAX_NEW_TOKENS, - batch_size=128, - run_cfg=dict(num_gpus=2), - stop_words=['', '<|im_end|>'], -) +tb_internlm2_5_7b_chat_4bits = deepcopy(tb_internlm2_5_7b_chat) +tb_internlm2_5_7b_chat_kvint4 = deepcopy(tb_internlm2_5_7b_chat) +tb_internlm2_5_7b_chat_kvint8 = deepcopy(tb_internlm2_5_7b_chat) -tb_internlm2_chat_20b_kvint4 = dict( - type=TurboMindModelwithChatTemplate, - abbr='tb_internlm2_chat_7b_kvint4', - path='internlm/internlm2-chat-20b', - engine_config=engine_config_template_max_bs_128_kvint4_tp2, +pt_internlm2_5_7b_chat = dict( + type=LmdeployPytorchModel, + abbr='pt_internlm2_5_7b_chat', + path='internlm/internlm2_5-7b-chat', + engine_config=pt_engine_config_template_max_bs_64, gen_config=gen_config_template, - max_seq_len=MAX_SESSION_LEN, max_out_len=MAX_NEW_TOKENS, - batch_size=128, - run_cfg=dict(num_gpus=2), - stop_words=['', '<|im_end|>'], -) + max_seq_len=MAX_SESSION_LEN, + batch_size=64, + concurrency=64, + meta_template=internlm2_meta_template, + run_cfg=run_cfg_tp1_template, + end_str='<|im_end|>') -tb_internlm2_chat_20b_kvint8 = dict( +# ===== Configs for internlm/internlm2_chat_20b ===== +tb_internlm2_chat_20b = dict( type=TurboMindModelwithChatTemplate, - abbr='tb_internlm2_chat_7b_kvint8', + abbr='tb_internlm2_chat_20b', path='internlm/internlm2-chat-20b', - engine_config=engine_config_template_max_bs_128_kvint8_tp2, + engine_config=engine_config_template_max_bs_128_tp2, gen_config=gen_config_template, max_seq_len=MAX_SESSION_LEN, max_out_len=MAX_NEW_TOKENS, @@ -548,10 +329,13 @@ stop_words=['', '<|im_end|>'], ) -# config for pt internlm-chat-20b +tb_internlm2_chat_20b_4bits = deepcopy(tb_internlm2_chat_20b) +tb_internlm2_chat_20b_kvint4 = deepcopy(tb_internlm2_chat_20b) +tb_internlm2_chat_20b_kvint8 = deepcopy(tb_internlm2_chat_20b) + pt_internlm2_chat_20b = dict( type=LmdeployPytorchModel, - abbr='internlm2-chat-20b-pytorch', + abbr='pt_internlm2_chat_20b', path='internlm/internlm2-chat-20b', engine_config=pt_engine_config_template_max_bs_64_prefill, gen_config=gen_config_template, @@ -563,10 +347,39 @@ run_cfg=run_cfg_tp1_template, end_str='<|im_end|>') +# ===== Configs for Qwen/Qwen1.5-7B-Chat ===== +tb_qwen1_5_7b_chat = dict( + type=TurboMindModelwithChatTemplate, + abbr='tb_qwen1_5_7b_chat', + path='Qwen/Qwen1.5-7B-Chat', + engine_config=engine_config_template_max_bs_128, + gen_config=gen_config_template, + max_seq_len=MAX_SESSION_LEN, + max_out_len=MAX_NEW_TOKENS, + batch_size=128, + run_cfg=dict(num_gpus=1), +) + +tb_qwen1_5_7b_chat_4bits = deepcopy(tb_qwen1_5_7b_chat) +tb_qwen1_5_7b_chat_kvint4 = deepcopy(tb_qwen1_5_7b_chat) +tb_qwen1_5_7b_chat_kvint8 = deepcopy(tb_qwen1_5_7b_chat) + +pt_qwen1_5_7b_chat = dict(type=LmdeployPytorchModel, + abbr='pt_qwen1_5_7b_chat', + path='Qwen/Qwen1.5-7B-Chat', + engine_config=pt_engine_config_template_max_bs_128, + gen_config=gen_config_template, + max_out_len=MAX_NEW_TOKENS, + max_seq_len=MAX_SESSION_LEN, + batch_size=128, + concurrency=128, + meta_template=qwen1_5_meta_template, + run_cfg=run_cfg_tp1_template, + end_str='<|im_end|>') + # ===== Configs for Qwen/Qwen-7B-Chat ===== -# config for qwen-chat-7b turbomind tb_qwen_chat_7b = dict(type=TurboMindModel, - abbr='qwen-7b-chat-turbomind', + abbr='tb_qwen_chat_7b', path='Qwen/Qwen-7B-Chat', engine_config=tb_engine_config_template_max_bs_128, gen_config=qwen_gen_config_template, @@ -578,23 +391,13 @@ run_cfg=run_cfg_tp1_template, end_str='<|im_end|>') -tb_qwen_chat_7b_w4a16 = dict( - type=TurboMindModel, - abbr='qwen-7b-chat-4bits-turbomind', - path='Qwen/Qwen-7B-Chat-inner-4bits', - engine_config=tb_awq_engine_config_template_max_bs_128, - gen_config=qwen_gen_config_template, - max_out_len=MAX_NEW_TOKENS, - max_seq_len=MAX_SESSION_LEN, - batch_size=128, - concurrency=128, - meta_template=qwen_meta_template, - run_cfg=run_cfg_tp1_template, - end_str='<|im_end|>') +tb_qwen_chat_7b_4bits = deepcopy(tb_qwen_chat_7b) +tb_qwen_chat_7b_kvint4 = deepcopy(tb_qwen_chat_7b) +tb_qwen_chat_7b_kvint8 = deepcopy(tb_qwen_chat_7b) # config for qwen-chat-7b pytorch pt_qwen_chat_7b = dict(type=LmdeployPytorchModel, - abbr='qwen-7b-chat-pytorch', + abbr='pt_qwen_chat_7b', path='Qwen/Qwen-7B-Chat', engine_config=pt_engine_config_template_max_bs_64, gen_config=qwen_gen_config_template, @@ -609,7 +412,7 @@ # ===== Configs for meta-llama/Llama-2-7b-chat-hf ===== # config for llama2-chat-7b turbomind tb_llama2_chat_7b = dict(type=TurboMindModel, - abbr='llama-2-7b-chat-turbomind', + abbr='tb_llama2_chat_7b', path='meta-llama/Llama-2-7b-chat-hf', engine_config=tb_engine_config_template_max_bs_128, gen_config=gen_config_template, @@ -621,24 +424,12 @@ run_cfg=run_cfg_tp1_template, end_str='[INST]') -# config for llama2-chat-7b-w4a16 turbomind -tb_llama2_chat_7b_wa416 = dict( - type=TurboMindModel, - abbr='llama-2-7b-chat-4bits-turbomind', - path='meta-llama/Llama-2-7b-chat-hf-inner-4bits', - engine_config=tb_awq_engine_config_template_max_bs_128, - gen_config=gen_config_template, - max_out_len=MAX_NEW_TOKENS, - max_seq_len=MAX_SESSION_LEN, - batch_size=128, - concurrency=128, - meta_template=llama2_meta_template, - run_cfg=run_cfg_tp1_template, - end_str='[INST]') +tb_llama2_chat_7b_4bits = deepcopy(tb_llama2_chat_7b) +tb_llama2_chat_7b_kvint4 = deepcopy(tb_llama2_chat_7b) +tb_llama2_chat_7b_kvint8 = deepcopy(tb_llama2_chat_7b) -# config for llama2-chat-7b pytorch pt_llama2_chat_7b = dict(type=LmdeployPytorchModel, - abbr='llama-2-7b-chat-pytorch', + abbr='pt_llama2_chat_7b', path='meta-llama/Llama-2-7b-chat-hf', engine_config=pt_engine_config_template_max_bs_128, gen_config=gen_config_template, @@ -650,95 +441,44 @@ run_cfg=run_cfg_tp1_template, end_str='[INST]') -# ===== Configs for baichuan-inc/Baichuan2-7B-Chat ===== -# config for baichuan2-chat-7b turbomind -tb_baichuan2_chat_7b = dict(type=TurboMindModel, - abbr='Baichuan2-7B-Chat-turbomind', - path='baichuan-inc/Baichuan2-7B-Chat', - engine_config=tb_engine_config_template_max_bs_16, - gen_config=gen_config_template, - max_out_len=MAX_NEW_TOKENS, - max_seq_len=MAX_SESSION_LEN, - batch_size=16, - concurrency=16, - meta_template=baichuan2_meta_template, - run_cfg=run_cfg_tp1_template) - -# config for baichuan2-chat-7b pytorch -pt_baichuan2_chat_7b = dict(type=LmdeployPytorchModel, - abbr='baichuan2-7b-chat-hf', - path='baichuan-inc/Baichuan2-7B-Chat', - engine_config=pt_engine_config_template_max_bs_16, - gen_config=gen_config_template, - max_out_len=MAX_NEW_TOKENS, - max_seq_len=MAX_SESSION_LEN, - batch_size=16, - concurrency=16, - meta_template=baichuan2_meta_template, - run_cfg=run_cfg_tp1_template, - end_str=None) - -# ===== Configs for mistralai/Mistral-7B-Instruct-v0.1 ===== -# config for pt Mistral-7B-Instruct-v0.1 -pt_mistral_chat_7b = dict(type=LmdeployPytorchModel, - abbr='mistral-7b-instruct-v0.1-pytorch', - path='mistralai/Mistral-7B-Instruct-v0.1', - engine_config=pt_engine_config_template_max_bs_16, - gen_config=gen_config_template, - max_out_len=MAX_NEW_TOKENS, - max_seq_len=MAX_SESSION_LEN, - batch_size=16, - concurrency=16, - meta_template=mistral_meta_template, - run_cfg=run_cfg_tp1_template, - end_str='') - -# ===== Configs for mistralai/Mixtral-8x7B-Instruct-v0.1 ===== -# config for pt Mixtral-8x7B-Instruct-v0.1 -pt_mixtral_chat_8x7b = dict( - type=LmdeployPytorchModel, - abbr='mixtral-8x7b-instruct-v0.1-pytorch', - path='mistralai/Mixtral-8x7B-Instruct-v0.1', - engine_config=pt_engine_config_template_max_bs_8_prefill_tp2, - gen_config=gen_config_template, - max_out_len=MAX_NEW_TOKENS, - max_seq_len=MAX_SESSION_LEN, - batch_size=8, - concurrency=8, - meta_template=mistral_meta_template, - run_cfg=run_cfg_tp2_template, - end_str='') - -# ===== Configs for Qwen/Qwen1.5-7B-Chat ===== -tb_qwen1_5_7b_chat = dict( +# ===== Configs for meta-llama/Meta-Llama-3-8B-Instruct ===== +tb_llama_3_8b_instruct = dict( type=TurboMindModelwithChatTemplate, - abbr='tb_qwen1_5_7b_chat', - path='Qwen/Qwen1.5-7B-Chat', + abbr='tb_llama_3_8b_instruct', + path='meta-llama/Meta-Llama-3-8B-Instruct', engine_config=engine_config_template_max_bs_128, gen_config=gen_config_template, max_seq_len=MAX_SESSION_LEN, max_out_len=MAX_NEW_TOKENS, batch_size=128, run_cfg=dict(num_gpus=1), + stop_words=['<|eot_id|>', '<|end_of_text|>'], ) -tb_qwen1_5_7b_chat_w4a16 = dict( - type=TurboMindModelwithChatTemplate, - abbr='tb_qwen1_5_7b_chat_w4a16', - path='Qwen/Qwen1.5-7B-Chat-inner-4bits', - engine_config=engine_config_template_max_bs_128_awq, +pt_llama_3_8b_instruct = dict( + type=LmdeployPytorchModel, + abbr='pt_llama_3_8b_instruct', + path='meta-llama/Meta-Llama-3-8B-Instruct', + engine_config=pt_engine_config_template_max_bs_128, gen_config=gen_config_template, - max_seq_len=MAX_SESSION_LEN, max_out_len=MAX_NEW_TOKENS, + max_seq_len=MAX_SESSION_LEN, batch_size=128, - run_cfg=dict(num_gpus=1), -) + concurrency=128, + meta_template=llama3_meta_template, + run_cfg=run_cfg_tp1_template, + end_str='[INST]') + +tb_llama_3_8b_instruct_4bits = deepcopy(tb_llama_3_8b_instruct) +tb_llama_3_8b_instruct_kvint4 = deepcopy(tb_llama_3_8b_instruct) +tb_llama_3_8b_instruct_kvint8 = deepcopy(tb_llama_3_8b_instruct) -tb_qwen1_5_7b_chat_kvint4 = dict( +# ===== Configs for Qwen/Qwen2-7B-Instruct ===== +tb_qwen2_7b_instruct = dict( type=TurboMindModelwithChatTemplate, - abbr='tb_qwen1_5_7b_chat_kvint4', - path='Qwen/Qwen1.5-7B-Chat', - engine_config=engine_config_template_max_bs_128_kvint4, + abbr='tb_qwen2_7b_instruct', + path='Qwen/Qwen2-7B-Instruct', + engine_config=engine_config_template_max_bs_128, gen_config=gen_config_template, max_seq_len=MAX_SESSION_LEN, max_out_len=MAX_NEW_TOKENS, @@ -746,18 +486,22 @@ run_cfg=dict(num_gpus=1), ) -pt_qwen1_5_7b_chat = dict(type=LmdeployPytorchModel, - abbr='qwen1.5-7b-chat-pytorch', - path='Qwen/Qwen1.5-7B-Chat', - engine_config=pt_engine_config_template_max_bs_128, - gen_config=gen_config_template, - max_out_len=MAX_NEW_TOKENS, - max_seq_len=MAX_SESSION_LEN, - batch_size=128, - concurrency=128, - meta_template=qwen1_5_meta_template, - run_cfg=run_cfg_tp1_template, - end_str='<|im_end|>') +tb_qwen2_7b_instruct_4bits = deepcopy(tb_qwen2_7b_instruct) +tb_qwen2_7b_instruct_kvint4 = deepcopy(tb_qwen2_7b_instruct) +tb_qwen2_7b_instruct_kvint8 = deepcopy(tb_qwen2_7b_instruct) + +for model in [v for k, v in locals().items() if k.endswith('_4bits')]: + model['engine_config']['model_format'] = 'awq' + model['abbr'] = model['abbr'] + '_4bits' + model['path'] = model['path'] + '-inner-4bits' + +for model in [v for k, v in locals().items() if k.endswith('_kvint4')]: + model['engine_config']['quant_policy'] = 4 + model['abbr'] = model['abbr'] + '_kvint4' + +for model in [v for k, v in locals().items() if k.endswith('_kvint8')]: + model['engine_config']['quant_policy'] = 8 + model['abbr'] = model['abbr'] + '_kvint8' pt_qwen1_5_moe_2_7b_chat = dict( type=LmdeployPytorchModel, @@ -773,46 +517,37 @@ run_cfg=run_cfg_tp1_template, end_str='<|im_end|>') -# ===== Configs for Qwen/Qwen2-7B-Instruct ===== -tb_qwen2_7b_instruct = dict( - type=TurboMindModelwithChatTemplate, - abbr='tb_qwen2_7b_instruct', - path='Qwen/Qwen2-7B-Instruct', - engine_config=engine_config_template_max_bs_128, - gen_config=gen_config_template, - max_seq_len=MAX_SESSION_LEN, - max_out_len=MAX_NEW_TOKENS, - batch_size=128, - run_cfg=dict(num_gpus=1), -) - -tb_qwen2_7b_instruct_w4a16 = dict( - type=TurboMindModelwithChatTemplate, - abbr='tb_qwen2_7b_instruct_w4a16', - path='Qwen/Qwen2-7B-Instruct-inner-4bits', - engine_config=engine_config_template_max_bs_128_awq, - gen_config=gen_config_template, - max_seq_len=MAX_SESSION_LEN, - max_out_len=MAX_NEW_TOKENS, - batch_size=128, - run_cfg=dict(num_gpus=1), -) +# ===== Configs for baichuan-inc/Baichuan2-7B-Chat ===== +# config for baichuan2-chat-7b turbomind +tb_baichuan2_chat_7b = dict(type=TurboMindModel, + abbr='Baichuan2-7B-Chat-turbomind', + path='baichuan-inc/Baichuan2-7B-Chat', + engine_config=tb_engine_config_template_max_bs_128, + gen_config=gen_config_template, + max_out_len=MAX_NEW_TOKENS, + max_seq_len=MAX_SESSION_LEN, + batch_size=16, + concurrency=16, + meta_template=baichuan2_meta_template, + run_cfg=run_cfg_tp1_template) -tb_qwen2_7b_instruct_kvint4 = dict( - type=TurboMindModelwithChatTemplate, - abbr='tb_qwen2_7b_instruct_kvint4', - path='Qwen/Qwen2-7B-Instruct', - engine_config=engine_config_template_max_bs_128_kvint4, - gen_config=gen_config_template, - max_seq_len=MAX_SESSION_LEN, - max_out_len=MAX_NEW_TOKENS, - batch_size=128, - run_cfg=dict(num_gpus=1), -) +# config for baichuan2-chat-7b pytorch +pt_baichuan2_chat_7b = dict(type=LmdeployPytorchModel, + abbr='baichuan2-7b-chat-hf', + path='baichuan-inc/Baichuan2-7B-Chat', + engine_config=pt_engine_config_template_max_bs_16, + gen_config=gen_config_template, + max_out_len=MAX_NEW_TOKENS, + max_seq_len=MAX_SESSION_LEN, + batch_size=16, + concurrency=16, + meta_template=baichuan2_meta_template, + run_cfg=run_cfg_tp1_template, + end_str=None) # ===== Configs for google/gemma-7b-it ===== pt_gemma_chat_7b = dict(type=LmdeployPytorchModel, - abbr='gemma-7b-it-pytorch', + abbr='pt_gemma_chat_7b', path='google/gemma-7b-it', engine_config=pt_engine_config_template_max_bs_16, gen_config=gen_config_template, @@ -824,108 +559,16 @@ run_cfg=run_cfg_tp1_template, end_str='') -# ===== Configs for meta-llama/Meta-Llama-3-8B-Instruct ===== -tb_llama_3_8b_instruct = dict( - type=TurboMindModelwithChatTemplate, - abbr='tb_llama_3_8b_instruct', - path='meta-llama/Meta-Llama-3-8B-Instruct', - engine_config=engine_config_template_max_bs_128, - gen_config=gen_config_template, - max_seq_len=MAX_SESSION_LEN, - max_out_len=MAX_NEW_TOKENS, - batch_size=128, - run_cfg=dict(num_gpus=1), - stop_words=['<|eot_id|>', '<|end_of_text|>'], -) - -tb_llama_3_8b_instruct_w4a16 = dict( - type=TurboMindModelwithChatTemplate, - abbr='tb_llama_3_8b_instruct_w4a16', - path='meta-llama/Meta-Llama-3-8B-Instruct-inner-4bits', - engine_config=engine_config_template_max_bs_128_awq, - gen_config=gen_config_template, - max_seq_len=MAX_SESSION_LEN, - max_out_len=MAX_NEW_TOKENS, - batch_size=128, - run_cfg=dict(num_gpus=1), - stop_words=['<|eot_id|>', '<|end_of_text|>'], -) - -tb_llama_3_8b_instruct_kvint4 = dict( - type=TurboMindModelwithChatTemplate, - abbr='tb_llama_3_8b_instruct_kvint4', - path='meta-llama/Meta-Llama-3-8B-Instruct', - engine_config=engine_config_template_max_bs_128_kvint4, - gen_config=gen_config_template, - max_seq_len=MAX_SESSION_LEN, - max_out_len=MAX_NEW_TOKENS, - batch_size=128, - run_cfg=dict(num_gpus=1), - stop_words=['<|eot_id|>', '<|end_of_text|>'], -) - -tb_llama_3_8b_instruct_kvint8 = dict( - type=TurboMindModelwithChatTemplate, - abbr='tb_llama_3_8b_instruct_kvint8', - path='meta-llama/Meta-Llama-3-8B-Instruct', - engine_config=engine_config_template_max_bs_128_kvint8, - gen_config=gen_config_template, - max_seq_len=MAX_SESSION_LEN, - max_out_len=MAX_NEW_TOKENS, - batch_size=128, - run_cfg=dict(num_gpus=1), - stop_words=['<|eot_id|>', '<|end_of_text|>'], -) - -# config for llama-3-8b-instruct pytorch -pt_llama_3_8b_instruct = dict( - type=LmdeployPytorchModel, - abbr='llama-3-8b-instruct-pytorch', - path='meta-llama/Meta-Llama-3-8B-Instruct', - engine_config=pt_engine_config_template_max_bs_128, - gen_config=gen_config_template, - max_out_len=MAX_NEW_TOKENS, - max_seq_len=MAX_SESSION_LEN, - batch_size=128, - concurrency=128, - meta_template=llama3_meta_template, - run_cfg=run_cfg_tp1_template, - end_str='[INST]') - -# config for code llama -tb_codellama_7b_chat = dict(type=TurboMindModel, - abbr='codellama-7b-chat-turbomind', - path='codellama/CodeLlama-7b-Instruct-hf', - engine_config=tb_engine_config_template_max_bs_128, - gen_config=gen_config_template, - max_out_len=MAX_NEW_TOKENS, - max_seq_len=MAX_SESSION_LEN, - batch_size=128, - concurrency=128, - run_cfg=dict(num_gpus=1, num_procs=1), - end_str='') - -tb_codellama_7b_chat_w4a16 = dict( - type=TurboMindModel, - abbr='codellama-7b-chat-4bits-turbomind', - path='codellama/CodeLlama-7b-Instruct-hf-inner-4bits', - engine_config=tb_awq_engine_config_template_max_bs_128, - gen_config=gen_config_template, - max_out_len=MAX_NEW_TOKENS, - max_seq_len=MAX_SESSION_LEN, - batch_size=128, - concurrency=128, - run_cfg=dict(num_gpus=1, num_procs=1), - end_str='') - -pt_codellama_7b_chat = dict(type=LmdeployPytorchModel, - abbr='codellama-7b-chat-pytorch', - path='codellama/CodeLlama-7b-Instruct-hf', - engine_config=pt_engine_config_template_max_bs_128, - gen_config=gen_config_template, - max_out_len=MAX_NEW_TOKENS, - max_seq_len=MAX_SESSION_LEN, - batch_size=128, - concurrency=128, - run_cfg=dict(num_gpus=1, num_procs=1), - end_str='') +# ===== Configs for google/gemma2-7b-it ===== +pt_gemma_2_9b_it = dict(type=LmdeployPytorchModel, + abbr='pt_gemma_2_9b_it', + path='google/gemma-2-9b-it', + engine_config=pt_engine_config_template_max_bs_16, + gen_config=gen_config_template, + max_out_len=MAX_NEW_TOKENS, + max_seq_len=MAX_SESSION_LEN, + batch_size=16, + concurrency=16, + meta_template=gemma_meta_template, + run_cfg=run_cfg_tp1_template, + end_str='') diff --git a/.github/scripts/test_triton_server.py b/.github/scripts/test_triton_server.py deleted file mode 100644 index d44fa3c3d..000000000 --- a/.github/scripts/test_triton_server.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import time -from subprocess import PIPE, Popen - -import fire - - -def parse_dialogue(inputs: str): - sep = 'double enter to end input >>>' - dialogues = inputs.strip() - if dialogues.endswith(sep): - dialogues = dialogues[:-len(sep)] - dialogues = dialogues.strip() - dialogues = dialogues.split(sep) - dialogues = [d.strip() for d in dialogues] - return dialogues[1:] - - -def test(port=33337): - cmd = [f'lmdeploy serve triton_client localhost:{port}'] - - test_cases = [ - dict( - prompts='Hello! Please answer in English.', - keywords=['Hello', 'hi'], - ), - dict( - prompts='您好! 请用中文回答。', - keywords=['您好', '你好'], - ), - dict( - prompts='How many days does a week have? ', - keywords=['seven', '7'], - ), - dict( - prompts='一周有多少天?请用中文回答。', - keywords=['七天', '7天'], - ), - ] - - sep = '\n\n' - end = sep + 'exit\n\n\n' - all_pass = True - for cases in test_cases: - quest = cases['prompts'] - keywords = [k.lower() for k in cases['keywords']] - inputs = quest + end - print(f'Test Input prompts: {quest}\nKey words: {keywords}') - time.sleep(5) - - with Popen(cmd, - stdin=PIPE, - stdout=PIPE, - stderr=PIPE, - shell=True, - text=True, - encoding='utf-8') as proc: - out, err = proc.communicate(input=inputs) - print(f'Output: {out}') - if proc.returncode == 0: - out = parse_dialogue(out)[0].lower() - success = any([k in out for k in keywords]) - if not success: - print(f'>>> Failed to output keywords: {out} {keywords}') - all_pass = False - else: - all_pass = False - print(f'Failed to get outputs: {out} {err}') - assert all_pass, 'Tests failed!' - - -if __name__ == '__main__': - fire.Fire(test) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index cef7425e1..00b146d53 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -15,9 +15,9 @@ on: default: 'main' benchmark_type: required: true - description: 'Set benchmark type. Default is "["generation", "throughtput", "api_server", "triton_server"]"' + description: 'Set benchmark type. Default is "["generation", "throughtput", "api_server"]"' type: string - default: "['generation', 'throughput', 'api_server', 'triton_server']" + default: "['generation', 'throughput', 'api_server']" backend: required: true description: 'Set backend testcase filter: turbomind or pytorch or turbomind, pytorch. Default is "["turbomind", "pytorch"]"' @@ -52,7 +52,7 @@ on: required: true description: 'Set models run benchmark' type: string - default: "['internlm/internlm2-chat-20b','internlm/internlm2-chat-20b-inner-4bits','meta-llama/Llama-2-7b-chat-hf','meta-llama/Llama-2-7b-chat-hf-inner-4bits','meta-llama/Meta-Llama-3-8B-Instruct','Qwen/Qwen1.5-32B-Chat']" + default: "['internlm/internlm2-chat-20b','internlm/internlm2-chat-20b-inner-4bits','internlm/internlm2_5-7b-chat','internlm/internlm2_5-7b-chat-inner-4bits','meta-llama/Llama-2-7b-chat-hf','meta-llama/Llama-2-7b-chat-hf-inner-4bits','meta-llama/Meta-Llama-3-8B-Instruct','meta-llama/Meta-Llama-3-8B-Instruct-inner-4bits','Qwen/Qwen2-72B-Instruct']" env: HOST_PIP_CACHE_DIR: /nvme/github-actions/pip-cache @@ -62,16 +62,15 @@ env: DATASET_FILE: /nvme/qa_test_models/datasets/ShareGPT_V3_unfiltered_cleaned_split.json TP_INFO: --tp 1 LOOP_NUM: 1 - TRITON_PTXAS_PATH: /usr/local/cuda/bin/ptxas ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true - + dependency_pkgs: ${{inputs.dependency_pkgs || 'packaging transformers_stream_generator transformers datasets matplotlib jmespath'}} jobs: linux-build: if: ${{github.event_name == 'schedule' || (!cancelled() && !inputs.offline_mode)}} strategy: matrix: - pyver: [py38] + pyver: [py310] runs-on: ubuntu-latest env: PYTHON_VERSION: ${{ matrix.pyver }} @@ -115,12 +114,13 @@ jobs: MODEL_PATH: /nvme/qa_test_models/${{matrix.model}} CUDA_VISIBLE_DEVICES: 6,7 container: - image: nvcr.io/nvidia/tritonserver:22.12-py3 + image: nvcr.io/nvidia/tritonserver:24.03-py3 options: "--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip --pull never" volumes: - /nvme/github-actions/pip-cache:/root/.cache/pip - /nvme/github-actions/packages:/root/packages - /nvme/qa_test_models:/nvme/qa_test_models + - /mnt/shared:/mnt/shared - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro steps: - name: Clone repository @@ -136,26 +136,29 @@ jobs: if: ${{github.event_name == 'schedule' || !inputs.offline_mode}} uses: actions/download-artifact@v4 with: - name: my-artifact-${{ github.run_id }}-py38 + name: my-artifact-${{ github.run_id }}-py310 - name: Install pytorch run: | python3 -m pip cache dir - python3 -m pip install torch==2.1.0 torchvision==0.16.0 --index-url https://download.pytorch.org/whl/cu118 + python3 -m pip install torch==2.2.2 torchvision==0.17.2 --index-url https://download.pytorch.org/whl/cu118 - name: Install lmdeploy - dependency run: | - python3 -m pip install ${{inputs.dependency_pkgs}} # manually install flash attn # the install packeage from. https://github.com/Dao-AILab/flash-attention/releases - python3 -m pip install /root/packages/flash_attn-2.3.6+cu118torch2.1cxx11abiFALSE-cp38-cp38-linux_x86_64.whl + python3 -m pip install /root/packages/flash_attn-2.5.7+cu118torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl + python3 -m pip install -U 'xformers<=0.0.26' --index-url https://download.pytorch.org/whl/cu118 + python3 -m pip install ${{env.dependency_pkgs}} - name: Install lmdeploy if: ${{github.event_name == 'schedule' || !inputs.offline_mode}} run: | python3 -m pip install lmdeploy-*.whl + python3 -m pip install triton==2.1.0 python3 -m pip install -r requirements/test.txt - name: Install lmdeploy - offline if: ${{inputs.offline_mode}} run: | - python3 -m pip install /nvme/qa_test_models/offline_pkg/py38/lmdeploy-*.whl + python3 -m pip install /nvme/qa_test_models/offline_pkg/py310/lmdeploy-*.whl + python3 -m pip install triton==2.1.0 python3 -m pip install -r requirements/test.txt - name: Check env run: | @@ -172,7 +175,7 @@ jobs: run: | rm -rf ${result_dir} mkdir ${result_dir} - python3 benchmark/profile_generation.py $MODEL_PATH $MAX_ENTRY_COUNT $MODEL_FORMAT $TP_INFO -c 8,256 -ct 128 128 2048 128 -pt 1 128 128 2048 --csv ${result_dir}/generation.csv > ${result_dir}/generation.log + python3 benchmark/profile_generation.py $MODEL_PATH $MAX_ENTRY_COUNT $MODEL_FORMAT $TP_INFO -c 8 256 -ct 128 128 2048 128 -pt 1 128 128 2048 --csv ${result_dir}/generation.csv > ${result_dir}/generation.log - name: Run generation benchmark - longtext turbomind if: contains(fromJSON(github.event.inputs.backend), 'turbomind') && contains(env.LONGTEXT_BENCHMARK, 'true') env: @@ -188,7 +191,7 @@ jobs: run: | rm -rf ${result_dir} mkdir ${result_dir} - python3 benchmark/profile_generation.py $MODEL_PATH $TP_INFO --backend pytorch -c 8,256 -ct 128 128 2048 128 -pt 1 128 128 2048 --csv ${result_dir}/generation.csv > ${result_dir}/generation.log + python3 benchmark/profile_generation.py $MODEL_PATH $TP_INFO --backend pytorch -c 8 256 -ct 128 128 2048 128 -pt 1 128 128 2048 --csv ${result_dir}/generation.csv > ${result_dir}/generation.log - name: Run generation benchmark - longtext pytorch if: (!contains(env.MODEL_FORMAT, 'awq') && contains(fromJSON(github.event.inputs.backend), 'pytorch') && contains(env.LONGTEXT_BENCHMARK, 'true')) env: @@ -227,13 +230,14 @@ jobs: MODEL_PATH: /nvme/qa_test_models/${{matrix.model}} CUDA_VISIBLE_DEVICES: 4,5 container: - image: nvcr.io/nvidia/tritonserver:22.12-py3 + image: nvcr.io/nvidia/tritonserver:24.03-py3 options: "--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip --pull never" volumes: - /nvme/github-actions/pip-cache:/root/.cache/pip - /nvme/github-actions/packages:/root/packages - /nvme/qa_test_models:/nvme/qa_test_models - /nvme/qa_test_models/datasets/:/nvme/qa_test_models/datasets/ + - /mnt/shared:/mnt/shared - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro steps: - name: Clone repository @@ -249,26 +253,29 @@ jobs: if: ${{github.event_name == 'schedule' || !inputs.offline_mode}} uses: actions/download-artifact@v4 with: - name: my-artifact-${{ github.run_id }}-py38 + name: my-artifact-${{ github.run_id }}-py310 - name: Install pytorch run: | python3 -m pip cache dir - python3 -m pip install torch==2.1.0 torchvision==0.16.0 --index-url https://download.pytorch.org/whl/cu118 + python3 -m pip install torch==2.2.2 torchvision==0.17.2 --index-url https://download.pytorch.org/whl/cu118 - name: Install lmdeploy - dependency run: | - python3 -m pip install ${{inputs.dependency_pkgs}} # manually install flash attn # the install packeage from. https://github.com/Dao-AILab/flash-attention/releases - python3 -m pip install /root/packages/flash_attn-2.3.6+cu118torch2.1cxx11abiFALSE-cp38-cp38-linux_x86_64.whl + python3 -m pip install /root/packages/flash_attn-2.5.7+cu118torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl + python3 -m pip install -U 'xformers<=0.0.26' --index-url https://download.pytorch.org/whl/cu118 + python3 -m pip install ${{env.dependency_pkgs}} - name: Install lmdeploy if: ${{github.event_name == 'schedule' || !inputs.offline_mode}} run: | python3 -m pip install lmdeploy-*.whl + python3 -m pip install triton==2.1.0 python3 -m pip install -r requirements/test.txt - name: Install lmdeploy - offline if: ${{inputs.offline_mode}} run: | - python3 -m pip install /nvme/qa_test_models/offline_pkg/py38/lmdeploy-*.whl + python3 -m pip install /nvme/qa_test_models/offline_pkg/py310/lmdeploy-*.whl + python3 -m pip install triton==2.1.0 python3 -m pip install -r requirements/test.txt - name: Check env run: | @@ -368,13 +375,14 @@ jobs: MODEL_PATH: /nvme/qa_test_models/${{matrix.model}} CUDA_VISIBLE_DEVICES: 6,7 container: - image: nvcr.io/nvidia/tritonserver:22.12-py3 + image: nvcr.io/nvidia/tritonserver:24.03-py3 options: "--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip --pull never" volumes: - /nvme/github-actions/pip-cache:/root/.cache/pip - /nvme/github-actions/packages:/root/packages - /nvme/qa_test_models:/nvme/qa_test_models - /nvme/qa_test_models/datasets/:/nvme/qa_test_models/datasets/ + - /mnt/shared:/mnt/shared - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro steps: - name: Clone repository @@ -390,26 +398,29 @@ jobs: if: ${{github.event_name == 'schedule' || !inputs.offline_mode}} uses: actions/download-artifact@v4 with: - name: my-artifact-${{ github.run_id }}-py38 + name: my-artifact-${{ github.run_id }}-py310 - name: Install pytorch run: | python3 -m pip cache dir - python3 -m pip install torch==2.1.0 torchvision==0.16.0 --index-url https://download.pytorch.org/whl/cu118 + python3 -m pip install torch==2.2.2 torchvision==0.17.2 --index-url https://download.pytorch.org/whl/cu118 - name: Install lmdeploy - dependency run: | - python3 -m pip install ${{inputs.dependency_pkgs}} # manually install flash attn # the install packeage from. https://github.com/Dao-AILab/flash-attention/releases - python3 -m pip install /root/packages/flash_attn-2.3.6+cu118torch2.1cxx11abiFALSE-cp38-cp38-linux_x86_64.whl + python3 -m pip install /root/packages/flash_attn-2.5.7+cu118torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl + python3 -m pip install -U 'xformers<=0.0.26' --index-url https://download.pytorch.org/whl/cu118 + python3 -m pip install ${{env.dependency_pkgs}} - name: Install lmdeploy if: ${{github.event_name == 'schedule' || !inputs.offline_mode}} run: | python3 -m pip install lmdeploy-*.whl + python3 -m pip install triton==2.1.0 python3 -m pip install -r requirements/test.txt - name: Install lmdeploy - offline if: ${{inputs.offline_mode}} run: | - python3 -m pip install /nvme/qa_test_models/offline_pkg/py38/lmdeploy-*.whl + python3 -m pip install /nvme/qa_test_models/offline_pkg/py310/lmdeploy-*.whl + python3 -m pip install triton==2.1.0 python3 -m pip install -r requirements/test.txt - name: Check env run: | @@ -445,16 +456,16 @@ jobs: if: contains(fromJSON(github.event.inputs.backend), 'turbomind') run: | kill -15 "$restful_pid" - - name: Start restful api turbomind - kvint4 - if: contains(fromJSON(github.event.inputs.backend), 'turbomind') && contains(fromJSON(github.event.inputs.kvint_quantization), 'kvint4') + - name: Start restful api turbomind - kvint8 + if: contains(fromJSON(github.event.inputs.backend), 'turbomind') && contains(fromJSON(github.event.inputs.kvint_quantization), 'kvint8') run: | - lmdeploy serve api_server $MODEL_PATH $MAX_ENTRY_COUNT $MAX_BATCH_SIZE $MODEL_FORMAT $TP_INFO --quant-policy 4 --log-level ${{inputs.log_level}} > turbomind_kvint4_run.log 2>&1 & + lmdeploy serve api_server $MODEL_PATH $MAX_ENTRY_COUNT $MAX_BATCH_SIZE $MODEL_FORMAT $TP_INFO --quant-policy 8 --log-level ${{inputs.log_level}} > turbomind_kvint8_run.log 2>&1 & echo "restful_pid=$!" >> "$GITHUB_ENV" sleep 180s - - name: Run restful benchmark -kvint4 - if: contains(fromJSON(github.event.inputs.backend), 'turbomind') && contains(fromJSON(github.event.inputs.kvint_quantization), 'kvint4') + - name: Run restful benchmark -kvint8 + if: contains(fromJSON(github.event.inputs.backend), 'turbomind') && contains(fromJSON(github.event.inputs.kvint_quantization), 'kvint8') env: - result_dir: benchmark-restful-turbomind-kvint4 + result_dir: benchmark-restful-turbomind-kvint8 run: | rm -rf ${result_dir} mkdir ${result_dir} @@ -466,20 +477,20 @@ jobs: python3 benchmark/profile_restful_api.py localhost:23333 $MODEL_PATH $DATASET_FILE --concurrency "$batch" --stream-output True --csv ${result_dir}/restful_csv_batch_"${batch}"_"${i}"th.csv &> ${result_dir}/restful_csv_batch_"${batch}"_"${i}"th.log done done - - name: Kill restful api turbomind - kvint4 - if: contains(fromJSON(github.event.inputs.backend), 'turbomind') && contains(fromJSON(github.event.inputs.kvint_quantization), 'kvint4') + - name: Kill restful api turbomind - kvint8 + if: contains(fromJSON(github.event.inputs.backend), 'turbomind') && contains(fromJSON(github.event.inputs.kvint_quantization), 'kvint8') run: | kill -15 "$restful_pid" - - name: Start restful api turbomind - kvint8 - if: contains(fromJSON(github.event.inputs.backend), 'turbomind') && contains(fromJSON(github.event.inputs.kvint_quantization), 'kvint8') + - name: Start restful api turbomind - kvint4 + if: contains(fromJSON(github.event.inputs.backend), 'turbomind') && contains(fromJSON(github.event.inputs.kvint_quantization), 'kvint4') run: | - lmdeploy serve api_server $MODEL_PATH $MAX_ENTRY_COUNT $MAX_BATCH_SIZE $MODEL_FORMAT $TP_INFO --quant-policy 8 --log-level ${{inputs.log_level}} > turbomind_kvint8_run.log 2>&1 & + lmdeploy serve api_server $MODEL_PATH $MAX_ENTRY_COUNT $MAX_BATCH_SIZE $MODEL_FORMAT $TP_INFO --quant-policy 4 --log-level ${{inputs.log_level}} > turbomind_kvint4_run.log 2>&1 & echo "restful_pid=$!" >> "$GITHUB_ENV" sleep 180s - - name: Run restful benchmark -kvint8 - if: contains(fromJSON(github.event.inputs.backend), 'turbomind') && contains(fromJSON(github.event.inputs.kvint_quantization), 'kvint8') + - name: Run restful benchmark -kvint4 + if: contains(fromJSON(github.event.inputs.backend), 'turbomind') && contains(fromJSON(github.event.inputs.kvint_quantization), 'kvint4') env: - result_dir: benchmark-restful-turbomind-kvint8 + result_dir: benchmark-restful-turbomind-kvint4 run: | rm -rf ${result_dir} mkdir ${result_dir} @@ -491,8 +502,8 @@ jobs: python3 benchmark/profile_restful_api.py localhost:23333 $MODEL_PATH $DATASET_FILE --concurrency "$batch" --stream-output True --csv ${result_dir}/restful_csv_batch_"${batch}"_"${i}"th.csv &> ${result_dir}/restful_csv_batch_"${batch}"_"${i}"th.log done done - - name: Kill restful api turbomind - kvint8 - if: contains(fromJSON(github.event.inputs.backend), 'turbomind') && contains(fromJSON(github.event.inputs.kvint_quantization), 'kvint8') + - name: Kill restful api turbomind - kvint4 + if: contains(fromJSON(github.event.inputs.backend), 'turbomind') && contains(fromJSON(github.event.inputs.kvint_quantization), 'kvint4') run: | kill -15 "$restful_pid" - name: Start restful api pytorch @@ -537,186 +548,9 @@ jobs: mkdir $workdir chmod -R 777 $workdir - - - triton_benchmark: - if: ${{github.event_name == 'schedule' || ((!cancelled() && contains(fromJSON(github.event.inputs.benchmark_type), 'triton_server')) && contains(fromJSON(github.event.inputs.backend), 'turbomind'))}} - runs-on: [self-hosted, linux-a100-2] - timeout-minutes: 120 - env: - WORKDIR: /nvme/qa_test_models/triton_workspace - OFFLINE_PKGS: /nvme/qa_test_models/offline_pkg - MODEL_PATH: /nvme/qa_test_models/autotest_model/workspace_${{matrix.model}} - DEVICE: device=4 - GRPC_PORT: 33337 - strategy: - fail-fast: false - matrix: - model: ${{fromJSON(github.event.inputs.models)}} - steps: - - name: Clone repository - uses: actions/checkout@v3 - if: ${{github.event_name == 'schedule' || !inputs.offline_mode}} - with: - repository: ${{ github.event.inputs.repo_org || 'InternLM/lmdeploy' }} - ref: ${{github.event.inputs.repo_ref || 'main'}} - - name: Set params - run: | - chmod +x .github/scripts/set_benchmark_param.sh - .github/scripts/set_benchmark_param.sh ${{matrix.model}} - - name: Set params - cuda allocate - if: contains( env.TP_INFO, '--tp 2') - run: | - echo 'DEVICE="device=4,5"' >> "$GITHUB_ENV" - - name: Create test container - run: | - export date_today="$(date +'%H%M%S')" - export CONTAINER_ID=$(docker create \ - --rm \ - --gpus=$DEVICE \ - --shm-size 16g \ - --cap-add=SYS_PTRACE \ - --cap-add=SYS_ADMIN \ - --security-opt seccomp=unconfined \ - --name "lmdeploy-ci-triton-$GITHUB_RUN_ID-$date_today" \ - --workdir /__w/lmdeploy/lmdeploy \ - --env NCCL_LAUNCH_MODE=GROUP \ - --pull never \ - -v $(pwd)/../../:/__w \ - -v ${MODEL_PATH}:${MODEL_PATH} \ - -v ${WORKDIR}:/root/workspace/workdir \ - -v ${OFFLINE_PKGS}:/root/workspace/offline_pkg \ - -v ${HOST_PIP_CACHE_DIR}:/root/.cache/pip \ - -v ${HOST_LOCALTIME}:/etc/localtime:ro \ - -v /nvme/qa_test_models/datasets/:/nvme/qa_test_models/datasets/ \ - openmmlab/lmdeploy:latest tail -f /dev/null \ - ) - docker start $CONTAINER_ID - echo "CONTAINER_ID=$CONTAINER_ID" - echo "CONTAINER_ID=$CONTAINER_ID" >> $GITHUB_ENV - - name: Build lmdeploy from source - if: ${{github.event_name == 'schedule' || !inputs.offline_mode}} - run: | - docker exec $CONTAINER_ID mkdir build - docker exec --workdir /__w/lmdeploy/lmdeploy/build \ - --env http_proxy=${{secrets.PROXY}} \ - --env https_proxy=${{secrets.PROXY}} \ - --env HTTP_PROXY=${{secrets.PROXY}} \ - --env HTTPS_PROXY=${{secrets.PROXY}} \ - --env no_proxy="localhost,127.0.0.1" \ - --env NO_PROXY="localhost,127.0.0.1" \ - $CONTAINER_ID cmake .. \ - -DCMAKE_BUILD_TYPE=RelWithDebInfo \ - -DCMAKE_EXPORT_COMPILE_COMMANDS=1 \ - -DCMAKE_INSTALL_PREFIX=/opt/tritonserver \ - -DBUILD_PY_FFI=ON \ - -DBUILD_MULTI_GPU=ON \ - -DCMAKE_CUDA_FLAGS="-lineinfo" \ - -DUSE_NVTX=ON \ - -DSM=80 \ - -DCMAKE_CUDA_ARCHITECTURES=80 \ - -DBUILD_TEST=OFF - docker exec --workdir /__w/lmdeploy/lmdeploy/build $CONTAINER_ID make -j$(nproc) - docker exec --workdir /__w/lmdeploy/lmdeploy/build $CONTAINER_ID make install \ - --env http_proxy=${{secrets.PROXY}} \ - --env https_proxy=${{secrets.PROXY}} \ - --env HTTP_PROXY=${{secrets.PROXY}} \ - --env HTTPS_PROXY=${{secrets.PROXY}} - - name: Install lmdeploy - if: ${{github.event_name == 'schedule' || !inputs.offline_mode}} - run: | - docker exec \ - --env http_proxy=${{secrets.PROXY}} \ - --env https_proxy=${{secrets.PROXY}} \ - $CONTAINER_ID python3 -m pip install tritonclient[grpc] protobuf - - docker exec \ - --env http_proxy=${{secrets.PROXY}} \ - --env https_proxy=${{secrets.PROXY}} \ - $CONTAINER_ID python3 -m pip install -r requirements/test.txt - - docker exec \ - --env http_proxy=${{secrets.PROXY}} \ - --env https_proxy=${{secrets.PROXY}} \ - $CONTAINER_ID python3 -m pip install . - - docker exec $CONTAINER_ID lmdeploy check_env - - name: Copy repository - offline - if: ${{inputs.offline_mode}} - run: | - docker exec --workdir /__w/lmdeploy $CONTAINER_ID \ - cp -r /root/workspace/offline_pkg/lmdeploy . - - name: Install lmdeploy - offline - if: ${{inputs.offline_mode}} - run: | - docker exec \ - --env http_proxy=${{secrets.PROXY}} \ - --env https_proxy=${{secrets.PROXY}} \ - $CONTAINER_ID python3 -m pip install tritonclient[grpc] protobuf - - docker exec --workdir /__w/lmdeploy/lmdeploy \ - --env http_proxy=${{secrets.PROXY}} \ - --env https_proxy=${{secrets.PROXY}} \ - $CONTAINER_ID python3 -m pip install -r requirements/test.txt - - docker exec --env http_proxy=${{secrets.PROXY}} \ - --env https_proxy=${{secrets.PROXY}} $CONTAINER_ID \ - python3 -m pip install /root/workspace/offline_pkg/py38/lmdeploy-latest-cp38-cp38-manylinux2014_x86_64.whl - - docker exec $CONTAINER_ID lmdeploy check_env - - name: Start triton server service - run: | - docker exec --detach $CONTAINER_ID bash -c \ - "tritonserver \ - --model-repository=${MODEL_PATH}/model_repository \ - --allow-http=0 \ - --allow-grpc=1 \ - --grpc-port=${GRPC_PORT} \ - --log-verbose=0 \ - --allow-metrics=1 > run.log 2>&1 ; touch finish.txt" - # wait for triton server to fully start up - sleep 360s - # print triton server log file - cat run.log - python3 -c 'import os; assert not os.path.exists("finish.txt"), "Failed to start tritonserver"' - - name: Run triton benchmark - env: - result_dir: benchmark-triton-turbomind - run: | - rm -rf ${result_dir} - mkdir ${result_dir} -p - batches=($BATCHES) - for batch in "${batches[@]}" - do - for ((i=1; i<=$LOOP_NUM; i++)) - do - docker exec \ - --env no_proxy="localhost,127.0.0.1" \ - --env NO_PROXY="localhost,127.0.0.1" \ - $CONTAINER_ID python3 benchmark/profile_serving.py localhost:${GRPC_PORT} $MODEL_PATH/triton_models/tokenizer $DATASET_FILE --concurrency "$batch" --csv ${result_dir}/triton_csv_batch_"${batch}"_"${i}"th.csv &> ${result_dir}/triton_csv_batch_"${batch}"_"${i}"th.log - done - done - - name: Save reports - if: always() - run: | - mkdir $REPORT_DIR -p && mkdir $REPORT_DIR/${{matrix.model}} -p && mkdir $REPORT_DIR/${{matrix.model}}/triton -p - cp -r benchmark-triton-* $REPORT_DIR/${{matrix.model}}/triton && rm -rf benchmark-triton-* - mv run.log $REPORT_DIR/${{matrix.model}}/triton - echo 'save report to $REPORT_DIR/${{matrix.model}}/triton' - - name: Clear workfile - if: always() - run: | - docker exec --workdir /__w/lmdeploy $CONTAINER_ID chmod -R 777 lmdeploy - docker stop $CONTAINER_ID - export workdir=$(pwd) - cd .. - rm -rf $workdir - mkdir $workdir - - get_result_overview: if: always() && !cancelled() - needs: [generation_benchmark, throughput_benchmark, restful_benchmark, triton_benchmark] + needs: [generation_benchmark, throughput_benchmark, restful_benchmark] timeout-minutes: 5 runs-on: [self-hosted, linux-a100] steps: diff --git a/.github/workflows/cuda11.8-whl-release.yml b/.github/workflows/cuda11.8-whl-release.yml index 7d8d3ab47..a762e32b7 100644 --- a/.github/workflows/cuda11.8-whl-release.yml +++ b/.github/workflows/cuda11.8-whl-release.yml @@ -67,11 +67,12 @@ jobs: - name: Install python packages run: | pip install pybind11 wheel - - uses: Jimver/cuda-toolkit@v0.2.11 + - name: Setup CUDA Toolkit id: cuda-toolkit - with: - cuda: '11.8.0' - use-github-cache: false + shell: pwsh + run: ./builder/windows/setup_cuda.ps1 + env: + INPUT_CUDA_VERSION: '11.8.0' - name: Build wheel run: | mkdir build diff --git a/.github/workflows/daily_ete_test.yml b/.github/workflows/daily_ete_test.yml index 019f128df..09768f86b 100644 --- a/.github/workflows/daily_ete_test.yml +++ b/.github/workflows/daily_ete_test.yml @@ -32,7 +32,7 @@ on: required: true description: 'Dependency packages, you can also set a specific version' type: string - default: 'packaging transformers_stream_generator transformers==4.41.2 datasets matplotlib openai attrdict timm modelscope jmespath' + default: 'packaging transformers_stream_generator transformers==4.41.2 datasets matplotlib openai attrdict timm modelscope jmespath decord' tools_regression: required: true description: 'Whether start a tool regression' @@ -43,24 +43,19 @@ on: description: 'Whether start a restful api regression' type: boolean default: true - triton_regression: - required: true - description: 'Whether start a triton server api regression' - type: boolean - default: true pipeline_regression: required: true description: 'Whether start an interface pipeline regression' type: boolean default: true schedule: - - cron: '00 20 * * 0-4' + - cron: '00 19 * * 0-4' env: HOST_PIP_CACHE_DIR: /nvme/github-actions/pip-cache + dependency_pkgs: ${{inputs.dependency_pkgs || 'packaging transformers_stream_generator transformers==4.41.2 datasets matplotlib openai attrdict timm modelscope jmespath decord'}} HOST_LOCALTIME: /usr/share/zoneinfo/Asia/Shanghai OUTPUT_FOLDER: cuda11.8_dist_${{ github.run_id }} - TRITON_PTXAS_PATH: /usr/local/cuda/bin/ptxas ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true jobs: @@ -68,7 +63,7 @@ jobs: if: ${{!cancelled() && (github.event_name == 'schedule' || !inputs.offline_mode)}} strategy: matrix: - pyver: [py38, py310] + pyver: [py310] runs-on: ubuntu-latest env: PYTHON_VERSION: ${{ matrix.pyver }} @@ -103,14 +98,14 @@ jobs: needs: linux-build if: ${{!cancelled() && (github.event_name == 'schedule' || inputs.tools_regression)}} runs-on: [self-hosted, linux-a100] - timeout-minutes: 300 + timeout-minutes: 360 env: REPORT_DIR: /nvme/qa_test_models/test-reports PYTHONPATH: /nvme/qa_test_models/offline_pkg/LLaVA MODELSCOPE_CACHE: /root/modelscope_hub MODELSCOPE_MODULES_CACHE: /root/modelscope_modules container: - image: nvcr.io/nvidia/tritonserver:22.12-py3 + image: nvcr.io/nvidia/tritonserver:24.03-py3 options: "--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip --pull never" volumes: - /nvme/github-actions/pip-cache:/root/.cache/pip @@ -119,6 +114,7 @@ jobs: - /nvme/github-actions/modelscope_modules:/root/modelscope_modules - /nvme/github-actions/resources/lora:/root/lora - /nvme/qa_test_models:/nvme/qa_test_models + - /mnt/shared:/mnt/shared - /nvme/qa_test_models/lmdeploy/autotest:/local_case - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro steps: @@ -135,28 +131,30 @@ jobs: if: ${{github.event_name == 'schedule' || !inputs.offline_mode}} uses: actions/download-artifact@v4 with: - name: my-artifact-${{ github.run_id }}-py38 + name: my-artifact-${{ github.run_id }}-py310 - name: Install pytorch run: | python3 -m pip cache dir python3 -m pip install torch==2.2.2 torchvision==0.17.2 --index-url https://download.pytorch.org/whl/cu118 - name: Install lmdeploy - dependency run: | - python3 -m pip install ${{inputs.dependency_pkgs || 'packaging transformers_stream_generator transformers==4.41.2 datasets matplotlib openai attrdict timm modelscope jmespath'}} # manually install flash attn # the install packeage from. https://github.com/Dao-AILab/flash-attention/releases - python3 -m pip install /root/packages/flash_attn-2.5.7+cu118torch2.2cxx11abiFALSE-cp38-cp38-linux_x86_64.whl + python3 -m pip install /root/packages/flash_attn-2.5.7+cu118torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl python3 -m pip install -U 'xformers<=0.0.26' --index-url https://download.pytorch.org/whl/cu118 + python3 -m pip install ${{env.dependency_pkgs}} - name: Install lmdeploy if: ${{github.event_name == 'schedule' || !inputs.offline_mode}} run: | python3 -m pip install lmdeploy-*.whl + python3 -m pip install triton==2.1.0 python3 -m pip install -r requirements/test.txt pip install /nvme/qa_test_models/offline_pkg/DeepSeek-VL --no-deps - name: Install lmdeploy - offline if: ${{inputs.offline_mode}} run: | - python3 -m pip install /nvme/qa_test_models/offline_pkg/py38/lmdeploy-*.whl + python3 -m pip install /nvme/qa_test_models/offline_pkg/py310/lmdeploy-*.whl + python3 -m pip install triton==2.1.0 python3 -m pip install -r requirements/test.txt pip install /nvme/qa_test_models/offline_pkg/DeepSeek-VL --no-deps - name: Check env @@ -263,161 +261,6 @@ jobs: mkdir $workdir chmod -R 777 $workdir - test_triton: - if: ${{!cancelled() && (github.event_name == 'schedule' || inputs.triton_regression)}} - runs-on: [self-hosted, linux-a100-2] - needs: test_tools - timeout-minutes: 30 - env: - HF_MODEL: /nvme/qa_test_models/internlm-chat-20b - WORKDIR: /nvme/qa_test_models/triton_workspace - OFFLINE_PKGS: /nvme/qa_test_models/offline_pkg - TB_MODEL: internlm-chat-20b-fp16-tp2 - GRPC_PORT: 33337 - steps: - - name: Clone repository - uses: actions/checkout@v2 - if: ${{github.event_name == 'schedule' || !inputs.offline_mode}} - with: - repository: ${{ github.event.inputs.repo_org || 'InternLM/lmdeploy' }} - ref: ${{github.event.inputs.repo_ref || 'main'}} - - name: Create test container - run: | - export CONTAINER_ID=$(docker create \ - --rm \ - --gpus='"device=4,5"' \ - --shm-size 16g \ - --cap-add=SYS_PTRACE \ - --cap-add=SYS_ADMIN \ - --security-opt seccomp=unconfined \ - --name "lmdeploy-ci-triton-$GITHUB_RUN_ID" \ - --workdir /__w/lmdeploy/lmdeploy \ - --env NCCL_LAUNCH_MODE=GROUP \ - --pull never \ - -v $(pwd)/../../:/__w \ - -v ${HF_MODEL}:/root/workspace/hf_model \ - -v ${WORKDIR}:/root/workspace/workdir \ - -v ${OFFLINE_PKGS}:/root/workspace/offline_pkg \ - -v ${HOST_PIP_CACHE_DIR}:/root/.cache/pip \ - -v ${HOST_LOCALTIME}:/etc/localtime:ro \ - openmmlab/lmdeploy:latest tail -f /dev/null \ - ) - docker start $CONTAINER_ID - echo "CONTAINER_ID=$CONTAINER_ID" - echo "CONTAINER_ID=$CONTAINER_ID" >> $GITHUB_ENV - - name: Build lmdeploy from source - if: ${{github.event_name == 'schedule' || !inputs.offline_mode}} - run: | - docker exec $CONTAINER_ID mkdir build - docker exec --workdir /__w/lmdeploy/lmdeploy/build \ - --env http_proxy=${{secrets.PROXY}} \ - --env https_proxy=${{secrets.PROXY}} \ - --env HTTP_PROXY=${{secrets.PROXY}} \ - --env HTTPS_PROXY=${{secrets.PROXY}} \ - --env no_proxy="localhost,127.0.0.1" \ - --env NO_PROXY="localhost,127.0.0.1" \ - $CONTAINER_ID cmake .. \ - -DCMAKE_BUILD_TYPE=RelWithDebInfo \ - -DCMAKE_EXPORT_COMPILE_COMMANDS=1 \ - -DCMAKE_INSTALL_PREFIX=/opt/tritonserver \ - -DBUILD_PY_FFI=ON \ - -DBUILD_MULTI_GPU=ON \ - -DCMAKE_CUDA_FLAGS="-lineinfo" \ - -DUSE_NVTX=ON \ - -DSM=80 \ - -DCMAKE_CUDA_ARCHITECTURES=80 \ - -DBUILD_TEST=OFF - docker exec --workdir /__w/lmdeploy/lmdeploy/build $CONTAINER_ID make -j$(nproc) - docker exec --workdir /__w/lmdeploy/lmdeploy/build $CONTAINER_ID make install \ - --env http_proxy=${{secrets.PROXY}} \ - --env https_proxy=${{secrets.PROXY}} \ - --env HTTP_PROXY=${{secrets.PROXY}} \ - --env HTTPS_PROXY=${{secrets.PROXY}} - - name: Install lmdeploy - if: ${{github.event_name == 'schedule' || !inputs.offline_mode}} - run: | - docker exec \ - --env http_proxy=${{secrets.PROXY}} \ - --env https_proxy=${{secrets.PROXY}} \ - $CONTAINER_ID python3 -m pip install tritonclient[grpc] protobuf - - docker exec \ - --env http_proxy=${{secrets.PROXY}} \ - --env https_proxy=${{secrets.PROXY}} \ - $CONTAINER_ID python3 -m pip install -r requirements/test.txt - - docker exec \ - --env http_proxy=${{secrets.PROXY}} \ - --env https_proxy=${{secrets.PROXY}} \ - $CONTAINER_ID python3 -m pip install . - - docker exec $CONTAINER_ID lmdeploy check_env - - name: Copy repository - offline - if: ${{inputs.offline_mode}} - run: | - docker exec --workdir /__w/lmdeploy $CONTAINER_ID \ - cp -r /root/workspace/offline_pkg/lmdeploy . - - name: Install lmdeploy - offline - if: ${{inputs.offline_mode}} - run: | - docker exec \ - --env http_proxy=${{secrets.PROXY}} \ - --env https_proxy=${{secrets.PROXY}} \ - $CONTAINER_ID python3 -m pip install tritonclient[grpc] protobuf - - docker exec --workdir /__w/lmdeploy/lmdeploy \ - --env http_proxy=${{secrets.PROXY}} \ - --env https_proxy=${{secrets.PROXY}} \ - $CONTAINER_ID python3 -m pip install -r requirements/test.txt - - docker exec --env http_proxy=${{secrets.PROXY}} \ - --env https_proxy=${{secrets.PROXY}} $CONTAINER_ID \ - python3 -m pip install /root/workspace/offline_pkg/py38/lmdeploy-latest-cp38-cp38-manylinux2014_x86_64.whl - - docker exec $CONTAINER_ID lmdeploy check_env - - name: Convert to turbomind model - run: | - docker exec $CONTAINER_ID \ - lmdeploy convert \ - internlm-chat-20b \ - /root/workspace/hf_model \ - --tp 2 \ - --trust-remote-code \ - --dst-path /root/workspace/workdir/${TB_MODEL} - - name: Start triton server service - run: | - docker exec --detach $CONTAINER_ID bash -c \ - "tritonserver \ - --model-repository=/root/workspace/workdir/${TB_MODEL}/model_repository \ - --allow-http=0 \ - --allow-grpc=1 \ - --grpc-port=${GRPC_PORT} \ - --log-verbose=0 \ - --allow-metrics=1 > run.log 2>&1 ; touch finish.txt" - # wait for triton server to fully start up - sleep 180s - # print triton server log file - cat run.log - python3 -c 'import os; assert not os.path.exists("finish.txt"), "Failed to start tritonserver"' - - name: Test triton server - run: | - docker exec \ - --env no_proxy="localhost,127.0.0.1" \ - --env NO_PROXY="localhost,127.0.0.1" \ - $CONTAINER_ID python3 .github/scripts/test_triton_server.py --port ${GRPC_PORT} - # print triton server log file - cat run.log - - name: Clear workfile - if: always() - run: | - docker exec --workdir /__w/lmdeploy $CONTAINER_ID rm -rf /root/workspace/workdir/${TB_MODEL} - docker exec --workdir /__w/lmdeploy $CONTAINER_ID chmod -R 777 lmdeploy - docker stop $CONTAINER_ID - export workdir=$(pwd) - cd .. - rm -rf $workdir - mkdir $workdir - test_restful: if: ${{!cancelled() && (github.event_name == 'schedule' || inputs.restful_regression)}} runs-on: [self-hosted, linux-a100] @@ -430,7 +273,7 @@ jobs: env: REPORT_DIR: /nvme/qa_test_models/test-reports container: - image: nvcr.io/nvidia/tritonserver:22.12-py3 + image: nvcr.io/nvidia/tritonserver:24.03-py3 options: "--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip --pull never" volumes: - /nvme/github-actions/pip-cache:/root/.cache/pip @@ -451,26 +294,28 @@ jobs: if: ${{github.event_name == 'schedule' || !inputs.offline_mode}} uses: actions/download-artifact@v4 with: - name: my-artifact-${{ github.run_id }}-py38 + name: my-artifact-${{ github.run_id }}-py310 - name: Install pytorch run: | python3 -m pip cache dir python3 -m pip install torch==2.2.2 torchvision==0.17.2 --index-url https://download.pytorch.org/whl/cu118 - name: Install lmdeploy - dependency run: | - python3 -m pip install ${{inputs.dependency_pkgs || 'packaging transformers_stream_generator transformers==4.41.2 datasets matplotlib openai attrdict timm modelscope jmespath'}} # manually install flash attn # the install packeage from. https://github.com/Dao-AILab/flash-attention/releases - python3 -m pip install /root/packages/flash_attn-2.5.7+cu118torch2.2cxx11abiFALSE-cp38-cp38-linux_x86_64.whl + python3 -m pip install /root/packages/flash_attn-2.5.7+cu118torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl + python3 -m pip install ${{env.dependency_pkgs}} - name: Install lmdeploy if: ${{github.event_name == 'schedule' || !inputs.offline_mode}} run: | python3 -m pip install lmdeploy-*.whl + python3 -m pip install triton==2.1.0 python3 -m pip install -r requirements/test.txt - name: Install lmdeploy - offline if: ${{inputs.offline_mode}} run: | - python3 -m pip install /nvme/qa_test_models/offline_pkg/py38/lmdeploy-*.whl + python3 -m pip install /nvme/qa_test_models/offline_pkg/py310/lmdeploy-*.whl + python3 -m pip install triton==2.1.0 python3 -m pip install -r requirements/test.txt - name: Check env run: | @@ -492,7 +337,7 @@ jobs: - name: Test lmdeploy - restful api timeout-minutes: 75 run: | - pytest autotest/interface/restful/test_restful_chat_func.py -n 20 --alluredir=allure-results + pytest autotest/interface/restful/test_restful_chat_func.py -n 20 -m 'not not_${{matrix.backend}}' --alluredir=allure-results - name: Kill api server if: always() run: | @@ -542,7 +387,7 @@ jobs: env: REPORT_DIR: /nvme/qa_test_models/test-reports container: - image: nvcr.io/nvidia/tritonserver:22.12-py3 + image: nvcr.io/nvidia/tritonserver:24.03-py3 options: "--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip --pull never" volumes: - /nvme/github-actions/pip-cache:/root/.cache/pip @@ -563,26 +408,28 @@ jobs: if: ${{github.event_name == 'schedule' || !inputs.offline_mode}} uses: actions/download-artifact@v4 with: - name: my-artifact-${{ github.run_id }}-py38 + name: my-artifact-${{ github.run_id }}-py310 - name: Install pytorch run: | python3 -m pip cache dir python3 -m pip install torch==2.2.2 torchvision==0.17.2 --index-url https://download.pytorch.org/whl/cu118 - name: Install lmdeploy - dependency run: | - python3 -m pip install ${{inputs.dependency_pkgs || 'packaging transformers_stream_generator transformers==4.41.2 datasets matplotlib openai attrdict timm modelscope jmespath'}} # manually install flash attn # the install packeage from. https://github.com/Dao-AILab/flash-attention/releases - python3 -m pip install /root/packages/flash_attn-2.5.7+cu118torch2.2cxx11abiFALSE-cp38-cp38-linux_x86_64.whl + python3 -m pip install /root/packages/flash_attn-2.5.7+cu118torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl + python3 -m pip install ${{env.dependency_pkgs}} - name: Install lmdeploy if: ${{github.event_name == 'schedule' || !inputs.offline_mode}} run: | python3 -m pip install lmdeploy-*.whl + python3 -m pip install triton==2.1.0 python3 -m pip install -r requirements/test.txt - name: Install lmdeploy - offline if: ${{inputs.offline_mode}} run: | - python3 -m pip install /nvme/qa_test_models/offline_pkg/py38/lmdeploy-*.whl + python3 -m pip install /nvme/qa_test_models/offline_pkg/py310/lmdeploy-*.whl + python3 -m pip install triton==2.1.0 python3 -m pip install -r requirements/test.txt - name: Check env run: | @@ -591,7 +438,7 @@ jobs: rm -rf allure-results - name: Test lmdeploy - interface pipeline case run: | - pytest autotest/interface/pipeline/test_pipeline_func.py -m 'not pr_test' -n 4 --alluredir=allure-results || true + pytest autotest/interface/pipeline/test_pipeline_func.py -m 'not pr_test' -n 4 --rerun 0 --alluredir=allure-results || true pytest autotest/interface/pipeline/test_pipeline_longtext_func.py -m 'gpu_num_1 and not pr_test' -n 8 --alluredir=allure-results || true pytest autotest/interface/pipeline/test_pipeline_longtext_func.py -m 'gpu_num_2 and not pr_test' -n 4 --alluredir=allure-results || true pytest autotest/interface/pipeline/test_pipeline_longtext_func.py -m 'gpu_num_4 and not pr_test' -n 2 --alluredir=allure-results @@ -615,7 +462,7 @@ jobs: notify_to_feishu: if: always() && !cancelled() && (github.ref_name == 'develop' || github.ref_name == 'main') - needs: [test_tools, test_triton, test_restful] + needs: [test_tools, test_restful, test_pipeline] timeout-minutes: 5 runs-on: [self-hosted, linux-a100] steps: @@ -624,6 +471,6 @@ jobs: run: | curl -X POST -H "Content-Type: application/json" -d '{"msg_type":"post","content":{"post":{"zh_cn":{"title":"Lmdeploy- Daily test failed!!!","content":[[{"tag":"text","text":"branch: ${{github.ref_name}}, run action: ${{github.workflow}} failed. "},{"tag":"a","text":"Please click here for details ","href":"https://github.com/'${{ github.repository }}'/actions/runs/'${GITHUB_RUN_ID}'"},{"tag":"at","user_id":"'${{ secrets.FEISHU_USER_ID }}'"}]]}}}}' ${{ secrets.FEISHU_WEBHOOK_URL }} - name: success notify - if: needs.test_tools.result=='success' && needs.test_triton.result=='success' && needs.test_restful.result=='success' + if: needs.test_tools.result=='success' && needs.test_restful.result=='success' && needs.test_pipeline.result=='success' run: | curl -X POST -H "Content-Type: application/json" -d '{"msg_type":"post","content":{"post":{"zh_cn":{"title":"Lmdeploy- Daily test success","content":[[{"tag":"text","text":"branch: ${{github.ref_name}}, run action: ${{github.workflow}} failed. "},{"tag":"a","text":"Please click here for details ","href":"https://github.com/'${{ github.repository }}'/actions/runs/'${GITHUB_RUN_ID}'"}]]}}}}' ${{ secrets.FEISHU_WEBHOOK_URL }} diff --git a/.github/workflows/evaluate.yml b/.github/workflows/evaluate.yml index 0421d25b6..711c24fd2 100644 --- a/.github/workflows/evaluate.yml +++ b/.github/workflows/evaluate.yml @@ -17,7 +17,7 @@ on: required: true description: 'Tested TurboMind models list. eg. [internlm_chat_7b,internlm_chat_7b_w8a16]' type: string - default: '[tb_internlm2_chat_7b,tb_internlm2_chat_7b_w4a16,tb_internlm2_chat_7b_kvint4,pt_internlm2_chat_7b,tb_internlm2_chat_20b,tb_internlm2_chat_20b_w4a16,tb_internlm2_chat_20b_kvint4,pt_internlm2_chat_20b,tb_llama_3_8b_instruct,tb_llama_3_8b_instruct_w4a16,tb_llama_3_8b_instruct_kvint4,tb_qwen1_5_7b_chat,tb_qwen1_5_7b_chat_kvint4,tb_qwen1_5_7b_chat_w4a16,pt_qwen1_5_7b_chat,pt_qwen1_5_moe_2_7b_chat,tb_qwen2_7b_instruct,tb_qwen2_7b_instruct_w4a16,tb_qwen2_7b_instruct_kvint4]' + default: '[tb_internlm2_chat_7b, tb_internlm2_chat_7b_4bits, tb_internlm2_chat_7b_kvint4, tb_internlm2_chat_7b_kvint8, pt_internlm2_chat_7b, tb_internlm2_5_7b_chat, tb_internlm2_5_7b_chat_4bits, tb_internlm2_5_7b_chat_kvint4, tb_internlm2_5_7b_chat_kvint8, pt_internlm2_5_7b_chat, tb_internlm2_chat_20b, tb_internlm2_chat_20b_4bits, tb_internlm2_chat_20b_kvint4, tb_internlm2_chat_20b_kvint8, pt_internlm2_chat_20b, tb_qwen1_5_7b_chat, tb_qwen1_5_7b_chat_4bits, tb_qwen1_5_7b_chat_kvint4, tb_qwen1_5_7b_chat_kvint8, pt_qwen1_5_7b_chat, tb_llama_3_8b_instruct, pt_llama_3_8b_instruct, tb_llama_3_8b_instruct_4bits, tb_llama_3_8b_instruct_kvint4, tb_llama_3_8b_instruct_kvint8, tb_qwen2_7b_instruct, tb_qwen2_7b_instruct_4bits, tb_qwen2_7b_instruct_kvint4, tb_qwen2_7b_instruct_kvint8, pt_qwen1_5_moe_2_7b_chat, pt_gemma_2_9b_it]' datasets: required: true description: 'Tested datasets list. eg. [*bbh_datasets,*ceval_datasets,*cmmlu_datasets,*GaokaoBench_datasets,*gpqa_datasets,*gsm8k_datasets,*hellaswag_datasets,*humaneval_datasets,*ifeval_datasets,*math_datasets,*sanitized_mbpp_datasets,*mmlu_datasets,*nq_datasets,*race_datasets,*TheoremQA_datasets,*triviaqa_datasets,*winogrande_datasets,*crowspairs_datasets]' @@ -41,13 +41,14 @@ on: env: ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true + dependency_pkgs: ${{inputs.dependency_pkgs || 'pynvml packaging protobuf transformers_stream_generator transformers'}} jobs: linux-build: if: ${{github.event_name == 'schedule' || (!cancelled() && !inputs.offline_mode)}} strategy: matrix: - pyver: [py38] + pyver: [py310] runs-on: ubuntu-latest env: PYTHON_VERSION: ${{ matrix.pyver }} @@ -84,7 +85,7 @@ jobs: runs-on: [self-hosted, linux-a100] timeout-minutes: 4320 # 72hours container: - image: nvcr.io/nvidia/tritonserver:22.12-py3 + image: nvcr.io/nvidia/tritonserver:24.03-py3 options: "--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip --pull never" volumes: - /nvme/github-actions/pip-cache:/root/.cache/pip @@ -93,6 +94,7 @@ jobs: - /nvme/github-actions/opencompass-data:/root/opencompass-data - /nvme/qa_test_models/evaluation-reports:/root/evaluation-reports - /nvme/qa_test_models:/root/models + - /mnt/shared:/mnt/shared - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro steps: - name: Setup systems @@ -107,17 +109,18 @@ jobs: - name: Download Artifacts uses: actions/download-artifact@v4 with: - name: my-artifact-${{ github.run_id }}-py38 + name: my-artifact-${{ github.run_id }}-py310 - name: Install pytorch run: | python3 -m pip cache dir - python3 -m pip install torch==2.1.0 torchvision==0.16.0 --index-url https://download.pytorch.org/whl/cu118 + python3 -m pip install torch==2.2.2 torchvision==0.17.2 --index-url https://download.pytorch.org/whl/cu118 - name: Install lmdeploy - dependency run: | - python3 -m pip install ${{inputs.dependency_pkgs }} # manually install flash attn # the install packeage from. https://github.com/Dao-AILab/flash-attention/releases - python3 -m pip install /root/packages/flash_attn-2.3.6+cu118torch2.1cxx11abiFALSE-cp38-cp38-linux_x86_64.whl + python3 -m pip install /root/packages/flash_attn-2.5.7+cu118torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl + python3 -m pip install -U 'xformers<=0.0.26' --index-url https://download.pytorch.org/whl/cu118 + python3 -m pip install ${{env.dependency_pkgs}} - name: Install lmdeploy run: | python3 -m pip install lmdeploy-*.whl @@ -126,7 +129,8 @@ jobs: run: | git clone --depth=1 https://github.com/open-compass/opencompass.git cd opencompass - python3 -m pip install opencompass==0.2.5 + python3 -m pip install -e . + python3 -m pip install triton==2.1.0 echo "OPENCOMPASS_DIR=$(pwd)" >> $GITHUB_ENV - name: Check env run: | diff --git a/.github/workflows/pypi.yml b/.github/workflows/pypi.yml index 03db66a46..8b610d70b 100644 --- a/.github/workflows/pypi.yml +++ b/.github/workflows/pypi.yml @@ -66,11 +66,12 @@ jobs: - name: Install python packages run: | pip install pybind11 wheel - - uses: Jimver/cuda-toolkit@v0.2.11 + - name: Setup CUDA Toolkit id: cuda-toolkit - with: - cuda: '12.1.0' - use-github-cache: false + shell: pwsh + run: ./builder/windows/setup_cuda.ps1 + env: + INPUT_CUDA_VERSION: '12.1.0' - name: Build wheel run: | mkdir build diff --git a/.github/workflows/windows-x64-gpu.yml b/.github/workflows/windows-x64-gpu.yml index 084719461..cf43cc1e2 100644 --- a/.github/workflows/windows-x64-gpu.yml +++ b/.github/workflows/windows-x64-gpu.yml @@ -41,11 +41,12 @@ jobs: - name: Install python packages run: | pip install pybind11 wheel - - uses: Jimver/cuda-toolkit@v0.2.11 + - name: Setup CUDA Toolkit id: cuda-toolkit - with: - cuda: ${{ matrix.cudaver }} - use-github-cache: false + shell: pwsh + run: ./builder/windows/setup_cuda.ps1 + env: + INPUT_CUDA_VERSION: ${{ matrix.cudaver }} - name: Build wheel run: | $env:BUILD_TEST="ON" diff --git a/CMakeLists.txt b/CMakeLists.txt index 8a2026c6e..173a689f4 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -43,20 +43,21 @@ option(BUILD_PY_FFI "Build python ffi" ON) option(BUILD_TEST "Build tests" OFF) include(FetchContent) +if (BUILD_TEST) + FetchContent_Declare( + repo-cutlass + GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git + GIT_TAG 6f47420213f757831fae65c686aa471749fa8d60 + GIT_SHALLOW ON + ) -FetchContent_Declare( - repo-cutlass - GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git - GIT_TAG 6f47420213f757831fae65c686aa471749fa8d60 - GIT_SHALLOW ON -) - -set(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") + set(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") -FetchContent_MakeAvailable(repo-cutlass) + FetchContent_MakeAvailable(repo-cutlass) -set(CUTLASS_HEADER_DIR ${PROJECT_SOURCE_DIR}/3rdparty/cutlass/include) -set(CUTLASS_EXTENSIONS_DIR ${PROJECT_SOURCE_DIR}/src/turbomind/cutlass_extensions/include) + set(CUTLASS_HEADER_DIR ${PROJECT_SOURCE_DIR}/3rdparty/cutlass/include) + set(CUTLASS_EXTENSIONS_DIR ${PROJECT_SOURCE_DIR}/src/turbomind/cutlass_extensions/include) +endif() option(SPARSITY_SUPPORT "Build project with Ampere sparsity feature support" OFF) diff --git a/MANIFEST.in b/MANIFEST.in index ab9b0b57b..32e432f49 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -4,5 +4,3 @@ include lmdeploy/lib/*.so* include lmdeploy/lib/*.dll include lmdeploy/lib/*.pyd include lmdeploy/bin/* -include lmdeploy/serve/turbomind/service_docker_up.sh -recursive-include lmdeploy/serve/turbomind/triton_models * diff --git a/README.md b/README.md index dfc5718ec..abcd46356 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,8 @@ ______________________________________________________________________
2024 +- \[2024/07\] 🎉🎉 Support Llama3.1 8B, 70B and its TOOLS CALLING +- \[2024/07\] Support [InternVL2](https://huggingface.co/collections/OpenGVLab/internvl-20-667d3961ab5eb12c7ed1463e) full-series models, [InternLM-XComposer2.5](docs/en/multi_modal/xcomposer2d5.md) and [function call](docs/en/serving/api_server_tools.md) of InternLM2.5 - \[2024/06\] PyTorch engine support DeepSeek-V2 and several VLMs, such as CogVLM2, Mini-InternVL, LlaVA-Next - \[2024/05\] Balance vision model when deploying VLMs with multiple GPUs - \[2024/05\] Support 4-bits weight-only quantization and inference on VLMs, such as InternVL v1.5, LLaVa, InternLMXComposer2 @@ -109,18 +111,20 @@ For detailed inference benchmarks in more devices and more settings, please refe
  • Llama (7B - 65B)
  • Llama2 (7B - 70B)
  • Llama3 (8B, 70B)
  • +
  • Llama3.1 (8B, 70B)
  • InternLM (7B - 20B)
  • InternLM2 (7B - 20B)
  • InternLM2.5 (7B)
  • -
  • QWen (1.8B - 72B)
  • -
  • QWen1.5 (0.5B - 110B)
  • -
  • QWen1.5 - MoE (0.5B - 72B)
  • -
  • QWen2 (0.5B - 72B)
  • +
  • Qwen (1.8B - 72B)
  • +
  • Qwen1.5 (0.5B - 110B)
  • +
  • Qwen1.5 - MoE (0.5B - 72B)
  • +
  • Qwen2 (0.5B - 72B)
  • Baichuan (7B)
  • Baichuan2 (7B-13B)
  • Code Llama (7B - 34B)
  • ChatGLM2 (6B)
  • GLM4 (9B)
  • +
  • CodeGeeX4 (9B)
  • Falcon (7B - 180B)
  • YI (6B-34B)
  • Mistral (7B)
  • @@ -137,9 +141,11 @@ For detailed inference benchmarks in more devices and more settings, please refe
    • LLaVA(1.5,1.6) (7B-34B)
    • InternLM-XComposer2 (7B, 4khd-7B)
    • -
    • QWen-VL (7B)
    • +
    • InternLM-XComposer2.5 (7B)
    • +
    • Qwen-VL (7B)
    • DeepSeek-VL (7B)
    • InternVL-Chat (v1.1-v1.5)
    • +
    • InternVL2 (1B-76B)
    • MiniGeminiLlama (7B)
    • CogVLM-Chat (17B)
    • CogVLM2-Chat (19B)
    • @@ -169,7 +175,7 @@ pip install lmdeploy Since v0.3.0, The default prebuilt package is compiled on **CUDA 12**. However, if CUDA 11+ is required, you can install lmdeploy by: ```shell -export LMDEPLOY_VERSION=0.5.0 +export LMDEPLOY_VERSION=0.5.2 export PYTHON_VERSION=38 pip install https://github.com/InternLM/lmdeploy/releases/download/v${LMDEPLOY_VERSION}/lmdeploy-${LMDEPLOY_VERSION}+cu118-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux2014_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cu118 ``` @@ -192,7 +198,7 @@ For more information about inference pipeline, please refer to [here](./docs/en/ # Tutorials -Please overview [getting_started](./docs/en/get_started.md) section for the basic usage of LMDeploy. +Please review [getting_started](./docs/en/get_started.md) section for the basic usage of LMDeploy. For detailed user guides and advanced guides, please refer to our [tutorials](https://lmdeploy.readthedocs.io/en/latest/): diff --git a/README_zh-CN.md b/README_zh-CN.md index ef1f806a0..28f965fb1 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -26,6 +26,8 @@ ______________________________________________________________________
      2024 +- \[2024/07\] 🎉🎉 支持 Llama3.1 8B 和 70B 模型,以及工具调用功能 +- \[2024/07\] 支持 [InternVL2](https://huggingface.co/collections/OpenGVLab/internvl-20-667d3961ab5eb12c7ed1463e) 全系列模型,[InternLM-XComposer2.5](docs/zh_cn/multi_modal/xcomposer2d5.md) 模型和 InternLM2.5 的 [function call 功能](docs/zh_cn/serving/api_server_tools.md) - \[2024/06\] PyTorch engine 支持了 DeepSeek-V2 和若干 VLM 模型推理, 比如 CogVLM2,Mini-InternVL,LlaVA-Next - \[2024/05\] 在多 GPU 上部署 VLM 模型时,支持把视觉部分的模型均分到多卡上 - \[2024/05\] 支持InternVL v1.5, LLaVa, InternLMXComposer2 等 VLMs 模型的 4bit 权重量化和推理 @@ -110,18 +112,20 @@ LMDeploy TurboMind 引擎拥有卓越的推理能力,在各种规模的模型
    • Llama (7B - 65B)
    • Llama2 (7B - 70B)
    • Llama3 (8B, 70B)
    • +
    • Llama3.1 (8B, 70B)
    • InternLM (7B - 20B)
    • InternLM2 (7B - 20B)
    • InternLM2.5 (7B)
    • -
    • QWen (1.8B - 72B)
    • -
    • QWen1.5 (0.5B - 110B)
    • -
    • QWen1.5 - MoE (0.5B - 72B)
    • -
    • QWen2 (0.5B - 72B)
    • +
    • Qwen (1.8B - 72B)
    • +
    • Qwen1.5 (0.5B - 110B)
    • +
    • Qwen1.5 - MoE (0.5B - 72B)
    • +
    • Qwen2 (0.5B - 72B)
    • Baichuan (7B)
    • Baichuan2 (7B-13B)
    • Code Llama (7B - 34B)
    • ChatGLM2 (6B)
    • GLM4 (9B)
    • +
    • CodeGeeX4 (9B)
    • Falcon (7B - 180B)
    • YI (6B-34B)
    • Mistral (7B)
    • @@ -138,9 +142,11 @@ LMDeploy TurboMind 引擎拥有卓越的推理能力,在各种规模的模型
      • LLaVA(1.5,1.6) (7B-34B)
      • InternLM-XComposer2 (7B, 4khd-7B)
      • -
      • QWen-VL (7B)
      • +
      • InternLM-XComposer2.5 (7B)
      • +
      • Qwen-VL (7B)
      • DeepSeek-VL (7B)
      • InternVL-Chat (v1.1-v1.5)
      • +
      • InternVL2 (1B-76B)
      • MiniGeminiLlama (7B)
      • CogVLM-Chat (17B)
      • CogVLM2-Chat (19B)
      • @@ -170,7 +176,7 @@ pip install lmdeploy 自 v0.3.0 起,LMDeploy 预编译包默认基于 CUDA 12 编译。如果需要在 CUDA 11+ 下安装 LMDeploy,请执行以下命令: ```shell -export LMDEPLOY_VERSION=0.5.0 +export LMDEPLOY_VERSION=0.5.2 export PYTHON_VERSION=38 pip install https://github.com/InternLM/lmdeploy/releases/download/v${LMDEPLOY_VERSION}/lmdeploy-${LMDEPLOY_VERSION}+cu118-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux2014_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cu118 ``` diff --git a/autotest/config.yaml b/autotest/config.yaml index 291506e36..ccb6a5ca7 100644 --- a/autotest/config.yaml +++ b/autotest/config.yaml @@ -12,21 +12,26 @@ tp_config: Qwen-VL-Chat: 2 llava-v1.5-13b: 2 Qwen2-7B-Instruct: 2 + InternVL2-40B: 2 + turbomind_chat_model: - meta-llama/Meta-Llama-3-8B-Instruct - meta-llama/Llama-2-7b-chat-hf - internlm/internlm2_5-7b-chat - internlm/internlm2_5-7b-chat-1m - - internlm/internlm2-chat-1_8b - internlm/internlm2-chat-7b - internlm/internlm2-chat-20b - internlm/internlm2-chat-7b-4bits - internlm/internlm2-chat-20b-4bits - internlm/internlm-chat-20b - internlm/internlm-xcomposer2-vl-7b - - internlm/internlm-xcomposer2-7b - internlm/internlm-xcomposer2-4khd-7b + - internlm/internlm-xcomposer2d5-7b + - OpenGVLab/InternVL2-2B + - OpenGVLab/InternVL2-8B + - OpenGVLab/InternVL2-26B + - OpenGVLab/InternVL2-40B - OpenGVLab/InternVL-Chat-V1-5 - OpenGVLab/Mini-InternVL-Chat-2B-V1-5 - Qwen/Qwen2-7B-Instruct @@ -50,7 +55,6 @@ turbomind_chat_model: - codellama/CodeLlama-7b-Instruct-hf - THUDM/glm-4-9b-chat - pytorch_chat_model: - meta-llama/Meta-Llama-3-8B-Instruct - meta-llama/Llama-2-7b-chat-hf @@ -59,6 +63,11 @@ pytorch_chat_model: - internlm/internlm2-chat-7b - internlm/internlm2-chat-20b - internlm/internlm-chat-20b + - OpenGVLab/InternVL2-2B + - OpenGVLab/InternVL2-4B + - OpenGVLab/InternVL2-8B + - OpenGVLab/InternVL2-26B + - OpenGVLab/InternVL2-40B - baichuan-inc/Baichuan2-7B-Chat - baichuan-inc/Baichuan2-13B-Chat - 01-ai/Yi-6B-Chat @@ -73,11 +82,13 @@ pytorch_chat_model: - mistralai/Mistral-7B-Instruct-v0.2 - mistralai/Mixtral-8x7B-Instruct-v0.1 - google/gemma-7b-it + - google/gemma-2-9b-it - deepseek-ai/deepseek-moe-16b-chat - deepseek-ai/deepseek-coder-1.3b-instruct - THUDM/chatglm2-6b - - THUDM/cogvlm-chat-hf - THUDM/cogvlm2-llama3-chinese-chat-19B + - THUDM/glm-4v-9b + - THUDM/glm-4-9b-chat - microsoft/Phi-3-mini-4k-instruct - microsoft/Phi-3-vision-128k-instruct - bigcode/starcoder2-7b @@ -101,11 +112,17 @@ vl_model: - deepseek-ai/deepseek-vl-1.3b-chat - OpenGVLab/InternVL-Chat-V1-5 - OpenGVLab/Mini-InternVL-Chat-2B-V1-5 + - OpenGVLab/InternVL2-2B + - OpenGVLab/InternVL2-4B + - OpenGVLab/InternVL2-8B + - OpenGVLab/InternVL2-26B + - OpenGVLab/InternVL2-40B - internlm/internlm-xcomposer2-vl-7b - - internlm/internlm-xcomposer2-7b + - internlm/internlm-xcomposer2d5-7b - internlm/internlm-xcomposer2-4khd-7b - THUDM/cogvlm-chat-hf - THUDM/cogvlm2-llama3-chinese-chat-19B + - THUDM/glm-4v-9b - microsoft/Phi-3-vision-128k-instruct - openbmb/MiniCPM-Llama3-V-2_5 @@ -120,10 +137,14 @@ quatization_case_config: - internlm/internlm2-20b - internlm/internlm-chat-20b - internlm/internlm-xcomposer2-vl-7b - - internlm/internlm-xcomposer2-7b - internlm/internlm-xcomposer2-4khd-7b + - internlm/internlm-xcomposer2d5-7b - OpenGVLab/InternVL-Chat-V1-5 - OpenGVLab/Mini-InternVL-Chat-2B-V1-5 + - OpenGVLab/InternVL2-2B + - OpenGVLab/InternVL2-8B + - OpenGVLab/InternVL2-26B + - OpenGVLab/InternVL2-40B - Qwen/Qwen-7B-Chat - Qwen/Qwen1.5-7B-Chat - Qwen/Qwen2-7B-Instruct @@ -170,7 +191,6 @@ quatization_case_config: - deepseek-ai/deepseek-vl-1.3b-chat - deepseek-ai/deepseek-coder-1.3b-instruct - codellama/CodeLlama-7b-Instruct-hf - - THUDM/glm-4-9b-chat w8a8: - meta-llama/Meta-Llama-3-8B-Instruct - meta-llama/Llama-2-7b-chat-hf diff --git a/autotest/interface/pipeline/test_pipeline_func.py b/autotest/interface/pipeline/test_pipeline_func.py index 2867997b2..445fb63c7 100644 --- a/autotest/interface/pipeline/test_pipeline_func.py +++ b/autotest/interface/pipeline/test_pipeline_func.py @@ -279,6 +279,70 @@ def run_pipeline_testcase(config, model, backend, file_name): del os.environ['CUDA_VISIBLE_DEVICES'] +@pytest.mark.parametrize('model', ['internlm/internlm2-chat-20b']) +@pytest.mark.parametrize('backend', [TurbomindEngineConfig]) +def test_return_check_logprobs(config, model, backend, worker_id): + + def run_pipeline_testcase(config, model, backend, file_name): + + model_path = '/'.join([config.get('model_path'), model]) + backend_config = backend(tp=2) + pipe = pipeline(model_path, backend_config=backend_config) + gen_config = GenerationConfig(logprobs=10, max_new_tokens=5, top_k=40) + response = pipe('Hi, pls intro yourself', gen_config=gen_config) + result, msg = assert_pipeline_single_return(response, logprobs_num=10) + save_pipeline_common_log(config, file_name, result, response, msg) + del pipe + torch.cuda.empty_cache() + + file_name = f'pipeline_log_{worker_id}.txt' + if 'gw' in worker_id: + os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id, + tp_num=2) + p = Process(target=run_pipeline_testcase, + args=(config, model, backend, file_name)) + + p.start() + p.join() + assert_pipeline_common_log(config, file_name) + if 'gw' in worker_id: + del os.environ['CUDA_VISIBLE_DEVICES'] + + +@pytest.mark.parametrize('model', ['internlm/internlm2-chat-20b']) +@pytest.mark.parametrize('backend', [TurbomindEngineConfig]) +def test_return_check_logprobs_stream(config, model, backend, worker_id): + + def run_pipeline_testcase(config, model, backend, file_name): + + model_path = '/'.join([config.get('model_path'), model]) + backend_config = backend(tp=2) + pipe = pipeline(model_path, backend_config=backend_config) + gen_config = GenerationConfig(logprobs=10, max_new_tokens=5, top_k=40) + response = [] + for item in pipe.stream_infer('Hi, pls intro yourself', + gen_config=gen_config): + response.append(item) + result, msg = assert_pipeline_single_stream_return(response, + logprobs_num=10) + save_pipeline_common_log(config, file_name, result, response, msg) + del pipe + torch.cuda.empty_cache() + + file_name = f'pipeline_log_{worker_id}.txt' + if 'gw' in worker_id: + os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id, + tp_num=2) + p = Process(target=run_pipeline_testcase, + args=(config, model, backend, file_name)) + + p.start() + p.join() + assert_pipeline_common_log(config, file_name) + if 'gw' in worker_id: + del os.environ['CUDA_VISIBLE_DEVICES'] + + @pytest.mark.parametrize('model', ['internlm/internlm2-chat-20b']) @pytest.mark.parametrize('backend', [TurbomindEngineConfig, PytorchEngineConfig]) diff --git a/autotest/interface/pipeline/test_pipeline_longtext_func.py b/autotest/interface/pipeline/test_pipeline_longtext_func.py index 03fbfcebd..42c8b5741 100644 --- a/autotest/interface/pipeline/test_pipeline_longtext_func.py +++ b/autotest/interface/pipeline/test_pipeline_longtext_func.py @@ -18,7 +18,7 @@ @pytest.mark.gpu_num_1 @pytest.mark.parametrize('model', [ - 'internlm/internlm2-chat-7b', 'internlm/internlm2-7b', + 'internlm/internlm2-chat-7b', 'internlm/internlm2_5-7b', 'internlm/internlm2-chat-1_8b', 'internlm/internlm2-1_8b' ]) def test_history_issue_tp1(config, model, worker_id): @@ -151,8 +151,14 @@ def passkey_retrival(config, use_logn_attn=True, tp=tp_num) else: - backend_config = PytorchEngineConfig(session_len=session_len, - tp=tp_num) + if 'internlm2_5' in model and '-1m' in model: + backend_config = PytorchEngineConfig(session_len=session_len, + max_batch_size=1, + cache_max_entry_count=0.7, + tp=tp_num) + else: + backend_config = PytorchEngineConfig(session_len=session_len, + tp=tp_num) pipe = pipeline(model_path, backend_config=backend_config) diff --git a/autotest/interface/restful/test_restful_chat_func.py b/autotest/interface/restful/test_restful_chat_func.py index 5a8244161..5bd99662a 100644 --- a/autotest/interface/restful/test_restful_chat_func.py +++ b/autotest/interface/restful/test_restful_chat_func.py @@ -3,6 +3,7 @@ from random import randint import pytest +from openai import OpenAI from tqdm import tqdm from utils.restful_return_check import (assert_chat_completions_batch_return, assert_chat_completions_stream_return, @@ -693,6 +694,58 @@ def test_max_tokens_streaming(self): 'finish_reason') == 'length' assert length == 5 or length == 6 + @pytest.mark.not_pytorch + def test_logprobs(self): + api_client = APIClient(BASE_URL) + model_name = api_client.available_models[0] + for output in api_client.chat_completions_v1( + model=model_name, + messages='Hi, pls intro yourself', + max_tokens=5, + temperature=0.01, + logprobs=True, + top_logprobs=10): + continue + assert_chat_completions_batch_return(output, + model_name, + check_logprobs=True, + logprobs_num=10) + assert output.get('choices')[0].get('finish_reason') == 'length' + assert output.get('usage').get('completion_tokens') == 6 or output.get( + 'usage').get('completion_tokens') == 5 + + @pytest.mark.not_pytorch + def test_logprobs_streaming(self): + api_client = APIClient(BASE_URL) + model_name = api_client.available_models[0] + outputList = [] + for output in api_client.chat_completions_v1( + model=model_name, + messages='Hi, pls intro yourself', + stream=True, + max_tokens=5, + temperature=0.01, + logprobs=True, + top_logprobs=10): + outputList.append(output) + assert_chat_completions_stream_return(outputList[-1], + model_name, + True, + check_logprobs=True, + logprobs_num=10) + response = '' + for index in range(0, len(outputList) - 1): + assert_chat_completions_stream_return(outputList[index], + model_name, + check_logprobs=True, + logprobs_num=10) + response += outputList[index].get('choices')[0].get('delta').get( + 'content') + length = api_client.encode(response, add_bos=False)[1] + assert outputList[-1].get('choices')[0].get( + 'finish_reason') == 'length' + assert length == 5 or length == 6 + @pytest.mark.order(8) @pytest.mark.turbomind @@ -1242,3 +1295,148 @@ def test_input_validation_streaming(self): continue assert output.get('code') is None assert 'Input should be a valid integer' in str(output) + + +@pytest.mark.order(8) +@pytest.mark.turbomind +@pytest.mark.pytorch +@pytest.mark.flaky(reruns=2) +class TestRestfulSeverTools: + + def test_one_round_prompt(self): + tools = [{ + 'type': 'function', + 'function': { + 'name': 'get_current_weather', + 'description': 'Get the current weather in a given location', + 'parameters': { + 'type': 'object', + 'properties': { + 'location': { + 'type': + 'string', + 'description': + 'The city and state, e.g. San Francisco, CA', + }, + 'unit': { + 'type': 'string', + 'enum': ['celsius', 'fahrenheit'] + }, + }, + 'required': ['location'], + }, + } + }] + messages = [{ + 'role': 'user', + 'content': "What's the weather like in Boston today?" + }] + + client = OpenAI(api_key='YOUR_API_KEY', base_url=BASE_URL + '/v1') + model_name = client.models.list().data[0].id + response = client.chat.completions.create(model=model_name, + messages=messages, + temperature=0.01, + stream=False, + tools=tools) + print(response) + assert response.choices[0].finish_reason == 'tool_calls' + assert response.choices[0].message.tool_calls[ + 0].function.name == 'get_current_weather' + assert 'Boston' in response.choices[0].message.tool_calls[ + 0].function.arguments + assert response.choices[0].message.tool_calls[0].type == 'function' + + def test_multiple_round_prompt(self): + + def add(a: int, b: int): + return a + b + + def mul(a: int, b: int): + return a * b + + tools = [{ + 'type': 'function', + 'function': { + 'name': 'add', + 'description': 'Compute the sum of two numbers', + 'parameters': { + 'type': 'object', + 'properties': { + 'a': { + 'type': 'int', + 'description': 'A number', + }, + 'b': { + 'type': 'int', + 'description': 'A number', + }, + }, + 'required': ['a', 'b'], + }, + } + }, { + 'type': 'function', + 'function': { + 'name': 'mul', + 'description': 'Calculate the product of two numbers', + 'parameters': { + 'type': 'object', + 'properties': { + 'a': { + 'type': 'int', + 'description': 'A number', + }, + 'b': { + 'type': 'int', + 'description': 'A number', + }, + }, + 'required': ['a', 'b'], + }, + } + }] + messages = [{'role': 'user', 'content': 'Compute (3+5)*2'}] + + client = OpenAI(api_key='YOUR_API_KEY', base_url=BASE_URL + '/v1') + model_name = client.models.list().data[0].id + response = client.chat.completions.create(model=model_name, + messages=messages, + temperature=0.01, + stream=False, + tools=tools) + func1_name = response.choices[0].message.tool_calls[0].function.name + func1_args = response.choices[0].message.tool_calls[ + 0].function.arguments + func1_out = eval(f'{func1_name}(**{func1_args})') + assert response.choices[0].finish_reason == 'tool_calls' + assert func1_name == 'add' + assert func1_args == '{"a": 3, "b": 5}' + assert func1_out == 8 + assert response.choices[0].message.tool_calls[0].type == 'function' + + messages.append({ + 'role': 'assistant', + 'content': response.choices[0].message.content + }) + messages.append({ + 'role': 'environment', + 'content': f'3+5={func1_out}', + 'name': 'plugin' + }) + response = client.chat.completions.create(model=model_name, + messages=messages, + temperature=0.8, + top_p=0.8, + stream=False, + tools=tools) + print(response) + func2_name = response.choices[0].message.tool_calls[0].function.name + func2_args = response.choices[0].message.tool_calls[ + 0].function.arguments + func2_out = eval(f'{func2_name}(**{func2_args})') + assert response.choices[0].finish_reason == 'tool_calls' + assert func2_name == 'mul' + assert func2_args == '{"a": 8, "b": 2}' + assert func2_out == 16 + assert response.choices[0].message.tool_calls[0].type == 'function' diff --git a/autotest/interface/restful/test_restful_completions_v1.py b/autotest/interface/restful/test_restful_completions_v1.py index 6acde9b64..95fe3c784 100644 --- a/autotest/interface/restful/test_restful_completions_v1.py +++ b/autotest/interface/restful/test_restful_completions_v1.py @@ -1,3 +1,6 @@ +from utils.restful_return_check import (assert_completions_batch_return, + assert_completions_stream_return) + from lmdeploy.serve.openai.api_client import APIClient BASE_HTTP_URL = 'http://localhost' @@ -37,6 +40,36 @@ def test_encode(self): assert length5 == length2 * 100 assert input_ids5 == input_ids2 * 100 + def test_return(self): + api_client = APIClient(BASE_URL) + model_name = api_client.available_models[0] + for item in api_client.completions_v1( + model=model_name, + prompt='Hi, pls intro yourself', + max_tokens=16, + temperature=0.01, + ): + completion_tokens = item['usage']['completion_tokens'] + assert completion_tokens > 0 + assert completion_tokens <= 17 + assert completion_tokens >= 16 + assert item.get('choices')[0].get('finish_reason') in ['length'] + assert_completions_batch_return(item, model_name) + + def test_return_streaming(self): + api_client = APIClient(BASE_URL) + model_name = api_client.available_models[0] + outputList = [] + for item in api_client.completions_v1(model=model_name, + prompt='Hi, pls intro yourself', + max_tokens=16, + stream=True, + temperature=0.01): + outputList.append(item) + assert_completions_stream_return(outputList[-1], model_name, True) + for index in range(0, len(outputList) - 1): + assert_completions_stream_return(outputList[index], model_name) + def test_max_tokens(self): api_client = APIClient(BASE_URL) model_name = api_client.available_models[0] diff --git a/autotest/tools/chat/test_command_chat_hf_pytorch.py b/autotest/tools/chat/test_command_chat_hf_pytorch.py index e9a8079ae..dc09bdc94 100644 --- a/autotest/tools/chat/test_command_chat_hf_pytorch.py +++ b/autotest/tools/chat/test_command_chat_hf_pytorch.py @@ -14,7 +14,7 @@ @pytest.mark.parametrize('model', get_torch_model_list(tp_num=1)) def test_hf_pytorch_chat_tp1(config, model, cli_case_config, worker_id): usercase = 'chat_testcase' - if 'deepseek-coder' in model: + if 'coder' in model: usercase = 'code_testcase' result, chat_log, msg = hf_command_line_test( config, diff --git a/autotest/tools/chat/test_command_chat_hf_turbomind.py b/autotest/tools/chat/test_command_chat_hf_turbomind.py index 4533b6964..c659b9ae6 100644 --- a/autotest/tools/chat/test_command_chat_hf_turbomind.py +++ b/autotest/tools/chat/test_command_chat_hf_turbomind.py @@ -14,7 +14,7 @@ @pytest.mark.parametrize('model', get_turbomind_model_list(tp_num=1)) def test_hf_turbomind_chat_tp1(config, model, cli_case_config, worker_id): usercase = 'chat_testcase' - if 'deepseek-coder' in model: + if 'coder' in model: usercase = 'code_testcase' result, chat_log, msg = hf_command_line_test( config, diff --git a/autotest/tools/pipeline/test_pipeline_chat_turbomind.py b/autotest/tools/pipeline/test_pipeline_chat_turbomind.py index 4b767b38c..a84dace34 100644 --- a/autotest/tools/pipeline/test_pipeline_chat_turbomind.py +++ b/autotest/tools/pipeline/test_pipeline_chat_turbomind.py @@ -54,6 +54,8 @@ def test_pipeline_chat_tp2(config, common_case_config, model, worker_id): @pytest.mark.parametrize('quant_policy', (4, 8)) def test_pipeline_chat_kvint_tp1(config, common_case_config, model, quant_policy, worker_id): + if quant_policy == 4 and 'Qwen2' in model: + return # kvint4 for qwen2 is not support if 'gw' in worker_id: os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id) p = Process(target=run_pipeline_chat_test, @@ -76,6 +78,8 @@ def test_pipeline_chat_kvint_tp1(config, common_case_config, model, @pytest.mark.parametrize('quant_policy', (4, 8)) def test_pipeline_chat_kvint_tp2(config, common_case_config, model, quant_policy, worker_id): + if quant_policy == 4 and 'Qwen2' in model: + return # kvint4 for qwen2 is not support if 'gw' in worker_id: os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id, tp_num=2) diff --git a/autotest/tools/restful/test_restful_chat_hf_turbomind.py b/autotest/tools/restful/test_restful_chat_hf_turbomind.py index a49db9b56..4d86c60d9 100644 --- a/autotest/tools/restful/test_restful_chat_hf_turbomind.py +++ b/autotest/tools/restful/test_restful_chat_hf_turbomind.py @@ -67,12 +67,13 @@ def getKvintModelList(tp_num): 'cuda_prefix': None, 'tp_num': tp_num, 'extra': '--quant-policy 4' - } for item in get_kvint_model_list(tp_num)] + [{ - 'model': item, - 'cuda_prefix': None, - 'tp_num': tp_num, - 'extra': '--quant-policy 8' - } for item in get_kvint_model_list(tp_num)] + } for item in get_kvint_model_list(tp_num) + if 'qwen2' not in item.lower()] + [{ + 'model': item, + 'cuda_prefix': None, + 'tp_num': tp_num, + 'extra': '--quant-policy 8' + } for item in get_kvint_model_list(tp_num)] @pytest.mark.order(7) @@ -132,7 +133,6 @@ def test_restful_chat_pr(config, common_case_config): @pytest.mark.usefixtures('common_case_config') @pytest.mark.restful_api @pytest.mark.gpu_num_1 -@pytest.mark.tmp @pytest.mark.parametrize('prepare_environment', [{ 'model': 'Qwen/Qwen-7B-Chat', 'cuda_prefix': None, diff --git a/autotest/utils/get_run_config.py b/autotest/utils/get_run_config.py index d1a6c652e..57ad54693 100644 --- a/autotest/utils/get_run_config.py +++ b/autotest/utils/get_run_config.py @@ -89,13 +89,12 @@ def get_model_name(model): return 'yi-vl' if ('qwen' in model_name): return 'qwen' - if ('internvl') in model_name: return 'internvl-internlm2' if ('internlm2') in model_name: return 'internlm2' - if ('internlm-xcomposer2-4khd-7b') in model_name: - return 'internlm-xcomposer2-4khd' + if ('internlm-xcomposer2d5') in model_name: + return 'internlm-xcomposer2d5' if ('internlm-xcomposer2') in model_name: return 'internlm-xcomposer2' if ('glm-4') in model_name: diff --git a/autotest/utils/pipeline_chat.py b/autotest/utils/pipeline_chat.py index a404f7e33..ec24f9521 100644 --- a/autotest/utils/pipeline_chat.py +++ b/autotest/utils/pipeline_chat.py @@ -77,7 +77,7 @@ def run_pipeline_chat_test(config, file.close for case in cases_info.keys(): - if ('deepseek-coder' in model_case + if ('coder' in model_case or 'CodeLlama' in model_case) and 'code' not in case: continue case_info = cases_info.get(case) @@ -126,7 +126,7 @@ def assert_pipeline_chat_log(config, allure.attach.file(config_log, attachment_type=allure.attachment_type.TEXT) for case in cases_info.keys(): - if ('deepseek-coder' in model_case + if ('coder' in model_case or 'CodeLlama' in model_case) and 'code' not in case: continue msg = 'result is empty, please check again' @@ -200,8 +200,10 @@ def assert_pipeline_common_log(config, log_name): assert result, msg -def assert_pipeline_single_return(output): - result = assert_pipeline_single_element(output, is_last=True) +def assert_pipeline_single_return(output, logprobs_num: int = 0): + result = assert_pipeline_single_element(output, + is_last=True, + logprobs_num=logprobs_num) if not result: return result, 'single_stream_element is wrong' return result & (len(output.token_ids) == output.generate_token_len @@ -219,13 +221,14 @@ def assert_pipeline_batch_return(output, size: int = 1): return True, '' -def assert_pipeline_single_stream_return(output): - print(output) +def assert_pipeline_single_stream_return(output, logprobs_num: int = 0): for i in range(0, len(output) - 1): - if assert_pipeline_single_element(output[i], is_stream=True) is False: + if assert_pipeline_single_element( + output[i], is_stream=True, logprobs_num=logprobs_num) is False: return False, f'single_stream_element is false, index is {i}' - if assert_pipeline_single_element(output[-1], is_stream=True, - is_last=True) is False: + if assert_pipeline_single_element( + output[-1], is_stream=True, is_last=True, + logprobs_num=logprobs_num) is False: return False, 'last single_stream_element is false' return True, '' @@ -241,7 +244,8 @@ def assert_pipeline_batch_stream_return(output, size: int = 1): def assert_pipeline_single_element(output, is_stream: bool = False, - is_last: bool = False): + is_last: bool = False, + logprobs_num: int = 0): result = True result &= output.generate_token_len > 0 result &= output.input_token_len > 0 @@ -257,7 +261,19 @@ def assert_pipeline_single_element(output, result &= len(output.text) > 0 result &= output.finish_reason is None result &= len(output.token_ids) > 0 - result &= output.logprobs is None + if logprobs_num == 0 or is_last: + result &= output.logprobs is None + else: + if is_stream: + result &= len(output.logprobs) == 1 + else: + result &= len(output.logprobs) == output.generate_token_len or len( + output.logprobs) == output.generate_token_len + 1 + if result: + for content in output.logprobs: + result &= len(content.keys()) <= logprobs_num + for key in content.keys(): + result &= type(content.get(key)) == float return result diff --git a/autotest/utils/restful_return_check.py b/autotest/utils/restful_return_check.py index a6da145cb..b243f4840 100644 --- a/autotest/utils/restful_return_check.py +++ b/autotest/utils/restful_return_check.py @@ -2,12 +2,7 @@ def assert_chat_completions_batch_return(output, model_name, check_logprobs: bool = False, logprobs_num: int = 5): - assert output.get('usage').get('prompt_tokens') > 0 - assert output.get('usage').get('total_tokens') > 0 - assert output.get('usage').get('completion_tokens') > 0 - assert output.get('usage').get('completion_tokens') + output.get( - 'usage').get('prompt_tokens') == output.get('usage').get( - 'total_tokens') + assert_usage(output.get('usage')) assert output.get('id') is not None assert output.get('object') == 'chat.completion' assert output.get('model') == model_name @@ -19,18 +14,46 @@ def assert_chat_completions_batch_return(output, assert len(message.get('message').get('content')) > 0 assert message.get('message').get('role') == 'assistant' if check_logprobs: - print(message.get('logprobs')) len(message.get('logprobs').get('content')) == output.get( 'usage').get('completion_tokens') for logprob in message.get('logprobs').get('content'): assert_logprobs(logprob, logprobs_num) +def assert_completions_batch_return(output, + model_name, + check_logprobs: bool = False, + logprobs_num: int = 5): + assert_usage(output.get('usage')) + assert output.get('id') is not None + assert output.get('object') == 'text_completion' + assert output.get('model') == model_name + output_message = output.get('choices') + assert len(output_message) == 1 + for message in output_message: + assert message.get('finish_reason') in ['stop', 'length'] + assert message.get('index') == 0 + assert len(message.get('text')) > 0 + if check_logprobs: + len(message.get('logprobs').get('content')) == output.get( + 'usage').get('completion_tokens') + for logprob in message.get('logprobs').get('content'): + assert_logprobs(logprob, logprobs_num) + + +def assert_usage(usage): + assert usage.get('prompt_tokens') > 0 + assert usage.get('total_tokens') > 0 + assert usage.get('completion_tokens') > 0 + assert usage.get('completion_tokens') + usage.get( + 'prompt_tokens') == usage.get('total_tokens') + + def assert_logprobs(logprobs, logprobs_num): assert_logprob_element(logprobs) - assert len(logprobs.get('top_logprobs')) > 0 and type( - logprobs.get('top_logprobs')) == list and len( - logprobs.get('top_logprobs')) <= logprobs_num + assert len(logprobs.get('top_logprobs')) >= 0 + assert type(logprobs.get('top_logprobs')) == list + assert len(logprobs.get('top_logprobs')) <= logprobs_num for logprob_element in logprobs.get('top_logprobs'): assert_logprob_element(logprob_element) @@ -38,8 +61,7 @@ def assert_logprobs(logprobs, logprobs_num): def assert_logprob_element(logprob): assert len(logprob.get('token')) > 0 and type(logprob.get('token')) == str assert len(logprob.get('bytes')) > 0 and type(logprob.get('bytes')) == list - assert len(logprob.get('logprob')) > 0 and type( - logprob.get('logprob')) == float + assert type(logprob.get('logprob')) == float def assert_chat_completions_stream_return(output, @@ -59,7 +81,9 @@ def assert_chat_completions_stream_return(output, if is_last is False: assert message.get('finish_reason') is None if check_logprobs: - assert_logprobs(message.get('logprobs'), logprobs_num) + assert (len(message.get('logprobs').get('content')) == 1) + assert_logprobs( + message.get('logprobs').get('content')[0], logprobs_num) if is_last is True: assert len(message.get('delta').get('content')) == 0 @@ -68,6 +92,33 @@ def assert_chat_completions_stream_return(output, assert message.get('logprobs') is None +def assert_completions_stream_return(output, + model_name, + is_last: bool = False, + check_logprobs: bool = False, + logprobs_num: int = 5): + assert output.get('id') is not None + assert output.get('object') == 'text_completion' + assert output.get('model') == model_name + output_message = output.get('choices') + assert len(output_message) == 1 + for message in output_message: + assert message.get('index') == 0 + assert len(message.get('text')) >= 0 + if is_last is False: + assert message.get('finish_reason') is None + if check_logprobs: + assert (len(message.get('logprobs').get('content')) == 1) + assert_logprobs( + message.get('logprobs').get('content')[0], logprobs_num) + + if is_last is True: + assert len(message.get('text')) == 0 + assert message.get('finish_reason') in ['stop', 'length'] + if check_logprobs is True: + assert message.get('logprobs') is None + + def assert_chat_interactive_batch_return(output): assert output.get('input_tokens') > 0 assert output.get('tokens') > 0 @@ -80,6 +131,7 @@ def assert_chat_interactive_stream_return(output, is_last: bool = False, index: int = None): assert output.get('input_tokens') > 0 + print(index, output) if index is not None: assert output.get('tokens') >= index and output.get( 'tokens') <= index + 6 diff --git a/autotest/utils/run_client_chat.py b/autotest/utils/run_client_chat.py index 79c537978..51e6c0cf0 100644 --- a/autotest/utils/run_client_chat.py +++ b/autotest/utils/run_client_chat.py @@ -96,6 +96,10 @@ def command_test(config, file.writelines('reproduce command chat: ' + ' '.join(cmd) + '\n') spliter = '\n\n' + if 'CodeLlama' in model and 'api_client' not in cmd: + if 'workspace' in ' '.join(cmd): + spliter = '\n!!\n' + # join prompt together prompt = '' for item in case_info: prompt += list(item.keys())[0] + spliter @@ -110,7 +114,7 @@ def command_test(config, shell=True, text=True, encoding='utf-8') as proc: - # file.writelines('prompt:' + prompt + '\n') + file.writelines('prompt:' + prompt + '\n') outputs, errors = proc.communicate(input=prompt) returncode = proc.returncode @@ -119,7 +123,7 @@ def command_test(config, result = False return result, chat_log, errors - outputDialogs = parse_dialogue(outputs, model) + outputDialogs = parse_dialogue(outputs, model, spliter) file.writelines('answersize:' + str(len(outputDialogs)) + '\n') # 结果判断 @@ -150,9 +154,12 @@ def command_test(config, # 从输出中解析模型输出的对话内容 -def parse_dialogue(inputs: str, model: str): +def parse_dialogue(inputs: str, model: str, spliter: str): dialogues = inputs.strip() - sep = 'double enter to end input >>>' + if '!!' in spliter: + sep = 'enter !! to end the input >>>' + else: + sep = 'double enter to end input >>>' dialogues = dialogues.strip() dialogues = dialogues.split(sep) dialogues = [d.strip() for d in dialogues] diff --git a/autotest/utils/run_restful_chat.py b/autotest/utils/run_restful_chat.py index 6e7c1ec04..c659ea772 100644 --- a/autotest/utils/run_restful_chat.py +++ b/autotest/utils/run_restful_chat.py @@ -110,8 +110,7 @@ def run_all_step(config, if model is None: assert False, 'server not start correctly' for case in cases_info.keys(): - if ('deepseek-coder' in model - or 'codellama' in model) and 'code' not in case: + if ('coder' in model or 'codellama' in model) and 'code' not in case: continue case_info = cases_info.get(case) @@ -167,7 +166,8 @@ def open_chat_test(config, case, case_info, model, url, worker_id: str = ''): for output in api_client.chat_completions_v1(model=model_name, messages=messages, - top_k=1): + top_k=1, + max_tokens=256): output_message = output.get('choices')[0].get('message') messages.append(output_message) @@ -215,7 +215,8 @@ def interactive_test(config, case, case_info, model, url, worker_id: str = ''): for output in api_client.chat_interactive_v1(prompt=prompt, interactive_mode=True, session_id=random_chars, - top_k=1): + top_k=1, + request_output_len=256): output_content = output.get('text') file.writelines('output:' + output_content + '\n') diff --git a/benchmark/profile_generation.py b/benchmark/profile_generation.py index 4162e4c9d..b145ed9d7 100644 --- a/benchmark/profile_generation.py +++ b/benchmark/profile_generation.py @@ -18,6 +18,10 @@ from lmdeploy.cli.utils import ArgumentHelper, DefaultsAndTypesHelpFormatter from lmdeploy.messages import (EngineGenerationConfig, PytorchEngineConfig, TurbomindEngineConfig) +from lmdeploy.utils import get_logger + +get_logger('lmdeploy').setLevel('ERROR') +os.environ['TM_LOG_LEVEL'] = 'ERROR' def infer(model, session_id: int, input_ids: List, @@ -165,15 +169,23 @@ def profile_throughput(model_path: str, concurrency: int, input_seqlen: int, 3) token_latency_ave = np.round(np.mean(np.sum(token_latency_stats, axis=1)), 3) - # sort token_latency without the first token's latency - sorted_token_latency = np.sort(token_latency_stats[:, 1:].flatten()) - percentiles = [ - np.round( - sorted_token_latency[int(percent * len(sorted_token_latency))], 3) - for percent in [0.5, 0.75, 0.95, 0.99] - ] - - throughput = np.round(token_latency_stats.size / elapsed_time, 2) + if output_seqlen > 1: + # sort token_latency without the first token's latency + sorted_token_latency = np.sort(token_latency_stats[:, 1:].flatten()) + percentiles = [ + np.round( + sorted_token_latency[int(percent * len(sorted_token_latency))], + 3) for percent in [0.5, 0.75, 0.95, 0.99] + ] + else: + percentiles = [ + first_token_latency_ave, + ] * 4 + + out_token_throughput = np.round(token_latency_stats.size / elapsed_time, 2) + total_token_throughput = np.round( + concurrency * test_round * (input_seqlen + output_seqlen) / + elapsed_time, 2) print(f'\n{"-" * 50}\ntotal time: {elapsed_time:.2f}s\n' f'concurrency: {concurrency}, test_round: {test_round}\n' f'input_tokens: {input_seqlen}, output_tokens: {output_seqlen}\n' @@ -183,11 +195,13 @@ def profile_throughput(model_path: str, concurrency: int, input_seqlen: int, f'{token_latency_min}s, {token_latency_max}s, ' f'{token_latency_ave}s\n' f'token_latency percentiles(50%,75%,95%,99%)(s): {percentiles}\n' - f'throughput: {throughput} token/s\n{"-" * 50}') + f'throughput(output): {out_token_throughput} token/s\n' + f'throughput(total): {total_token_throughput} token/s\n{"-" * 50}') return tm_model.model_name, \ [first_token_latency_min, first_token_latency_max, first_token_latency_ave], \ - percentiles, throughput, tm_model.gpu_count + percentiles, out_token_throughput, total_token_throughput, \ + tm_model.gpu_count class MemoryMonitor: @@ -273,11 +287,9 @@ class ProfileResult: completion_tokens: int first_token_latency: List percentiles: List - throughput_per_proc: float - throughput_per_node: float - mem_per_proc: float + output_throughput: float + total_throughput: float mem_per_gpu: float - mem_per_node: float def parse_args(): @@ -329,7 +341,6 @@ def parse_args(): ArgumentHelper.top_p(parser) ArgumentHelper.temperature(parser) ArgumentHelper.top_k(parser) - ArgumentHelper.log_level(parser) ArgumentHelper.backend(parser) # pytorch engine args pt_group = parser.add_argument_group('PyTorch engine arguments') @@ -385,7 +396,6 @@ def main(): f'mismatched size between `prompt-tokens` and `completion-tokenes`' \ f', {len(args.prompt_tokens)} vs {len(args.completion_tokens)}' - os.environ['TM_LOG_LEVEL'] = args.log_level results: List[ProfileResult] = [] MemoryMonitor.init() @@ -434,10 +444,9 @@ def main(): ) output = _process_map(profile_target, (args.model_path, )) model_name, first_token_latency, percentiles, \ - throughput_per_proc, tp = output + output_throughput, total_throughput, tp = output time.sleep(5) # wait a while for releasing GPU mem memory = MemoryMonitor.terminate() - device_count = MemoryMonitor.device_count.value results.append( ProfileResult(model_name=model_name, batch=batch, @@ -445,12 +454,9 @@ def main(): completion_tokens=completion_tokens, first_token_latency=first_token_latency, percentiles=percentiles, - throughput_per_proc=throughput_per_proc, - throughput_per_node=throughput_per_proc / tp * - device_count, - mem_per_proc=memory, - mem_per_gpu=memory / tp, - mem_per_node=memory / tp * device_count)) + output_throughput=output_throughput, + total_throughput=total_throughput, + mem_per_gpu=memory / tp)) if args.csv: with open(args.csv, 'w') as csvfile: writer = csv.writer(csvfile) @@ -458,6 +464,7 @@ def main(): 'batch', 'prompt_tokens', 'completion_tokens', + 'throughput(total tok/s)', 'throughput(out tok/s)', 'mem(GB)', 'FTL(ave)(s)', @@ -471,7 +478,8 @@ def main(): for re in results: writer.writerow([ re.batch, re.prompt_tokens, re.completion_tokens, - f'{re.throughput_per_proc:.2f}', f'{re.mem_per_gpu:.2f}', + f'{re.total_throughput:.2f}', + f'{re.output_throughput:.2f}', f'{re.mem_per_gpu:.2f}', re.first_token_latency[2], re.first_token_latency[0], re.first_token_latency[1], re.percentiles[0], re.percentiles[1], re.percentiles[2], re.percentiles[3] diff --git a/benchmark/profile_serving.py b/benchmark/profile_serving.py deleted file mode 100644 index 60058d290..000000000 --- a/benchmark/profile_serving.py +++ /dev/null @@ -1,260 +0,0 @@ -import csv -import json -import random -import time -from queue import Queue -from threading import Thread -from typing import List, Tuple - -import fire -import numpy as np -from tqdm import tqdm - -from lmdeploy.serve.turbomind.chatbot import Chatbot -from lmdeploy.tokenizer import Tokenizer - - -def sample_requests( - dataset_path: str, - num_requests: int, - tokenizer: Tokenizer, -) -> List[Tuple[str, int, int]]: - # Load the dataset. - with open(dataset_path) as f: - dataset = json.load(f) - # Filter out the conversations with less than 2 turns. - dataset = [data for data in dataset if len(data['conversations']) >= 2] - # Only keep the first two turns of each conversation. - dataset = [(data['conversations'][0]['value'], - data['conversations'][1]['value']) for data in dataset] - - # pre-sample to avoid go through all the dataset - dataset = random.sample(dataset, max(int(num_requests * 1.2), 1000)) - - # Tokenize the prompts and completions. - prompts = [prompt for prompt, _ in dataset] - prompt_token_ids = tokenizer(prompts).input_ids - completions = [completion for _, completion in dataset] - completion_token_ids = tokenizer(completions).input_ids - tokenized_dataset = [] - for i in range(len(dataset)): - output_len = len(completion_token_ids[i]) - tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len)) - - # Filter out too long sequences. - filtered_dataset: List[Tuple[str, int, int]] = [] - for prompt, prompt_token_ids, output_len in tokenized_dataset: - prompt_len = len(prompt_token_ids) - if prompt_len < 4 or output_len < 4: - # Prune too short sequences. - continue - if prompt_len > 1024 or prompt_len + output_len > 2048: - # Prune too long sequences. - continue - filtered_dataset.append((prompt, prompt_len, output_len)) - - # Sample the requests. - sampled_requests = random.sample(filtered_dataset, num_requests) - return sampled_requests - - -class Engine: - - def __init__(self, - server_addr: str, - tokenzier_path: str, - temperature: float = 0.8, - top_k: int = 1, - top_p: float = 1.0, - csv: str = '', - log_level: str = 'ERROR', - **kwargs): - self.server_addr = server_addr - self.tokenizer = Tokenizer(tokenzier_path) - self.temperature = temperature - self.top_k = top_k - self.top_p = top_p - self.csv = csv - self.log_level = log_level - self.pbar = None - - def _inference(self, req_queue: Queue, res_queue: Queue, session_id: int, - stream_output: bool): - - chatbot = Chatbot(self.server_addr, - ignore_eos=True, - top_k=self.top_k, - top_p=self.top_p, - temperature=self.temperature, - capability='completion', - log_level=self.log_level) - stats = [] - for prompt, input_seqlen, output_seqlen in iter( - req_queue.get, [None, None, None]): - timestamps = [] - tokens = [] - timestamps.append(time.perf_counter()) - for _, _, n_token in chatbot.stream_infer( - session_id, - prompt, - request_output_len=output_seqlen, - sequence_start=True, - sequence_end=True): - timestamps.append(time.perf_counter()) - tokens.append(n_token) - first_token_latency = np.round(timestamps[1] - timestamps[0], 3) - token_latency = np.round(timestamps[-1] - timestamps[0], 3) - completion_tokens = tokens[-1] - assert output_seqlen <= completion_tokens <= output_seqlen + 1, \ - f'Error. session_id({session_id}) request {output_seqlen} ' \ - f'tokens, but generate {completion_tokens} tokens.\n' \ - f'prompt: {prompt}' - total_tokens = tokens[-1] + input_seqlen - stats.append([ - first_token_latency, completion_tokens, output_seqlen, - total_tokens, token_latency - ]) - self.pbar.update(1) - res_queue.put((session_id, stats)) - - def process_request(self, - requests, - concurrency: int = 1, - stream_output: bool = True): - res_queue = Queue() - req_queue = Queue() - threads = [] - - self.pbar = tqdm(total=len(requests)) - - # feed request to q - for req in requests: - req_queue.put(req) - for i in range(concurrency): - req_queue.put([None, None, None]) - - start = time.time() - - # start threads - for i in range(concurrency): - t = Thread(target=self._inference, - args=(req_queue, res_queue, i, stream_output)) - t.start() - threads.append(t) - - # wait for finish - for t in threads: - t.join() - - elapsed_time = time.time() - start - - stats = [] - while not res_queue.empty(): - session_id, _stats = res_queue.get() - # print(f'\n{"-" * 50}\n' - # f'session {session_id} stats: \n{_stats}\n{"-" * 50}\n') - if len(_stats) != 0: - stats.append(np.array(_stats)) - - stats = np.concatenate(stats).reshape(-1, 5) - - first_token_latency_min = np.min(stats[:, 0], axis=0) - first_token_latency_max = np.max(stats[:, 0], axis=0) - first_token_latency_ave = np.mean(stats[:, 0], axis=0) - completion_tokens = np.sum(stats[:, 1], axis=0) - request_output_tokens = np.sum(stats[:, 2], axis=0) - total_tokens = np.sum(stats[:, 3], axis=0) - prompt_tokens = total_tokens - completion_tokens - completion_token_throughput = completion_tokens / elapsed_time - total_token_throughput = total_tokens / elapsed_time - rps = len(requests) / elapsed_time - rpm = rps * 60 - - if (np.abs(stats[:, 1] - stats[:, 2]) <= 1).min() is False: - print(f'Did not generate requested number of tokens. ' - f'Request {request_output_tokens:.0f}, ' - f'but got {completion_tokens:.0f}') - - print(f'\n{"-" * 50}\nconcurrency: {concurrency}\n' - f'elapsed_time: {elapsed_time:.3f}s\n') - if stream_output: - print(f'first_token latency(min, max, ave): ' - f'{first_token_latency_min:.3f}s, ' - f'{first_token_latency_max:.3f}s, ' - f'{first_token_latency_ave:.3f}s\n') - print( - f'number of prompt tokens: {prompt_tokens:.0f}\n' - f'number of completion tokens: {completion_tokens:.0f}\n' - f'token throughput (completion token): {completion_token_throughput:.3f} token/s\n' # noqa - f'token throughput (prompt + completion token): {total_token_throughput:.3f} token/s\n' # noqa - f'RPS (request per second): {rps:.3f} req/s\n' - f'RPM (request per minute): {rpm:.3f} req/min\n' - f'{"-" * 50}\n') - - if self.csv: - with open(self.csv, 'w') as csvfile: - writer = csv.writer(csvfile) - writer.writerow([ - 'batch', 'num_prompts', 'RPS', 'RPM', 'FTL(ave)(s)', - 'FTL(min)(s)', 'FTL(max)(s)', 'throughput(out tok/s)', - 'throughput(total tok/s)' - ]) - writer.writerow([ - concurrency, - len(requests), f'{rps:.3f}', f'{rpm:.3f}', - f'{first_token_latency_ave:.3f}' if stream_output else '-', - f'{first_token_latency_min:.3f}' if stream_output else '-', - f'{first_token_latency_max:.3f}' if stream_output else '-', - f'{completion_token_throughput:.3f}', - f'{total_token_throughput:.3f}' - ]) - - -def main(server_addr: str, - tokenizer_path: str, - dataset: str, - concurrency: int = 32, - num_prompts: int = 1000, - top_k: int = 1, - top_p: float = 1.0, - temperature: float = 1.0, - stream_output: bool = True, - csv: str = './profile_tis.csv', - seed: int = 0): - """Benchmark the request througput of the triton inference server. - - Args: - server_addr (str): Address of the triton inference server with format 0.0.0.0:0 - tokenizer_path (str): Path to the tokenizer model in localhost - dataset (str): Path to the dataset - concurrency (int, optional): Number of working threads to process the sampled prompts. - Defaults to 32. - num_prompts (int, optional): Number of prompts to process. Defaults to 1000. - top_k (int, optional): The number of highest probability vocabulary tokens - to keep for top-k-filtering. Defaults to 1. - top_p (float, optional): the set of most probable tokens with - probabilities that add up to top_p or higher - are kept for generation. Defaults to 1.0. - temperature (float, optional): The value used to modulate the next token probabilities. - Defaults to 1.0. - stream_output (bool, optional): Indicator for streaming output. Defaults to True. - seed (int, optional): Seed used in sampling prompts from dataset. Defaults to 0. - """ # noqa - - random.seed(seed) - - engine = Engine(server_addr, - tokenizer_path, - top_k=top_k, - top_p=top_p, - temperature=temperature, - log_level='ERROR', - csv=csv) - - requests = sample_requests(dataset, num_prompts, engine.tokenizer) - - engine.process_request(requests, concurrency, stream_output) - - -if __name__ == '__main__': - fire.Fire(main) diff --git a/benchmark/profile_throughput.py b/benchmark/profile_throughput.py index d1e898b39..be7c1035e 100644 --- a/benchmark/profile_throughput.py +++ b/benchmark/profile_throughput.py @@ -17,6 +17,10 @@ TurbomindEngineConfig) from lmdeploy.pytorch.engine import EngineInstance from lmdeploy.tokenizer import DetokenizeState, Tokenizer +from lmdeploy.utils import get_logger + +get_logger('lmdeploy').setLevel('ERROR') +os.environ['TM_LOG_LEVEL'] = 'ERROR' def sample_requests( @@ -274,7 +278,6 @@ def parse_args(): ArgumentHelper.top_p(parser) ArgumentHelper.temperature(parser) ArgumentHelper.top_k(parser) - ArgumentHelper.log_level(parser) ArgumentHelper.backend(parser) # pytorch engine args @@ -304,7 +307,6 @@ def parse_args(): def main(): args = parse_args() random.seed(args.seed) - os.environ['TM_LOG_LEVEL'] = args.log_level if args.backend == 'turbomind': engine_config = TurbomindEngineConfig( session_len=args.session_len, diff --git a/builder/windows/setup_cuda.ps1 b/builder/windows/setup_cuda.ps1 new file mode 100644 index 000000000..b573198ce --- /dev/null +++ b/builder/windows/setup_cuda.ps1 @@ -0,0 +1,103 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Adapted from https://github.com/thewh1teagle/vibe/blob/5d7b75568ca65ab635bdf0ce912bbc975a043066/scripts/setup_cuda.ps1 + +$CUDA_VERSION_FULL = $env:INPUT_CUDA_VERSION # v12.1.0 or v11.8.0 + +# Make sure CUDA_VERSION_FULL is set and valid, otherwise error. +# Validate CUDA version, extracting components via regex +$cuda_ver_matched = $CUDA_VERSION_FULL -match "^(?[1-9][0-9]*)\.(?[0-9]+)\.(?[0-9]+)$" +if(-not $cuda_ver_matched){ + Write-Output "Invalid CUDA version specified, .. required. '$CUDA_VERSION_FULL'." + exit 1 +} +$CUDA_MAJOR=$Matches.major +$CUDA_MINOR=$Matches.minor +$CUDA_PATCH=$Matches.patch + +Write-Output "Selected CUDA version: $CUDA_VERSION_FULL" + +$src = "cuda" +$dst = "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v$($CUDA_MAJOR).$($CUDA_MINOR)" +$installer = "cuda.exe" + +if ($CUDA_VERSION_FULL -eq "12.1.0") { + $downloadUrl = "https://developer.download.nvidia.com/compute/cuda/12.1.0/local_installers/cuda_12.1.0_531.14_windows.exe" +} elseif ($CUDA_VERSION_FULL -eq "11.8.0") { + $downloadUrl = "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_522.06_windows.exe" +} else { + Write-Output "Unsupported CUDA version specified" + exit 1 +} + +# Download cuda +Write-Output "Downloading CUDA from: $downloadUrl" +if (-not (Test-Path -Path $installer)) { + Write-Output "Downloading CUDA installer..." + # If the file does not exist, download it + & "C:\msys64\usr\bin\wget" $downloadUrl -O $installer -q +} + +# Extract cuda +if (-not (Test-Path -Path $src -Type Container)) { + # Extract CUDA using 7-Zip + Write-Output "Extracting CUDA using 7-Zip..." + mkdir "$src" + & 'C:\Program Files\7-Zip\7z' x $installer -o"$src" +} + +# Create destination directory if it doesn't exist +if (-Not (Test-Path -Path $dst)) { + Write-Output "Creating destination directory: $dst" + New-Item -Path $dst -ItemType Directory +} + +# Get directories to process from the source path +$directories = Get-ChildItem -Directory -Path $src +$whitelist = @("CUDA_Toolkit_Release_Notes.txt", "DOCS", "EULA.txt", "LICENSE", "README", "version.json") + +foreach ($dir in $directories) { + # Get all subdirectories and files in the current directory + $items = Get-ChildItem -Path (Join-Path $src $dir.Name) + + foreach ($item in $items) { + if ($item.PSIsContainer) { + # If the item is a directory, copy its contents + Write-Output "Copying contents of directory $($item.FullName) to $dst" + Copy-Item -Path "$($item.FullName)\*" -Destination $dst -Recurse -Force + } else { + if ($whitelist -contains $item.Name) { + Write-Output "Copying file $($item.FullName) to $dst" + Copy-Item -Path $item.FullName -Destination $dst -Force + } + } + } +} + +# Add msbuild cuda extensions +$msBuildExtensions = (Get-ChildItem "$src\visual_studio_integration\CUDAVisualStudioIntegration\extras\visual_studio_integration\MSBuildExtensions").fullname +(Get-ChildItem 'C:\Program Files\Microsoft Visual Studio\2022\*\MSBuild\Microsoft\VC\*\BuildCustomizations').FullName | ForEach-Object { + $destination = $_ + $msBuildExtensions | ForEach-Object { + $extension = $_ + Copy-Item $extension -Destination $destination -Force + Write-Output "Copied $extension to $destination" + } +} + +# Add to Github env +Write-Output "Setting environment variables for GitHub Actions..." + +Write-Output "CUDA_PATH=$dst" +Write-Output "CUDA_PATH_V$($CUDA_MAJOR)_$($CUDA_MINOR)=$dst" +Write-Output "CUDA_PATH_VX_Y=CUDA_PATH_V$($CUDA_MAJOR)_$($CUDA_MINOR)" +Write-Output "CUDA_VERSION=$CUDA_VERSION_FULL" + +Write-Output "CUDA_PATH=$dst" >> $env:GITHUB_ENV +Write-Output "CUDA_PATH_V$($CUDA_MAJOR)_$($CUDA_MINOR)=$dst" >> $env:GITHUB_ENV +Write-Output "CUDA_PATH_VX_Y=CUDA_PATH_V$($CUDA_MAJOR)_$($CUDA_MINOR)" >> $env:GITHUB_ENV +Write-Output "CudaToolkitDir=$dst" >> $env:GITHUB_ENV +Write-Output "CMAKE_CUDA_COMPILER=$dst\bin\nvcc.exe" >> $env:GITHUB_ENV +Write-Output "NVCC_APPEND_FLAGS=-allow-unsupported-compiler" >> $env:GITHUB_ENV + +Write-Output "CUDA_VERSION=$CUDA_VERSION_FULL" >> $env:GITHUB_ENV +Write-Output "Setup completed." diff --git a/docs/en/benchmark/profile_triton_server.md b/docs/en/benchmark/profile_triton_server.md deleted file mode 100644 index 1af1f5a07..000000000 --- a/docs/en/benchmark/profile_triton_server.md +++ /dev/null @@ -1,58 +0,0 @@ -# Profile Triton Inference Server - -Triton Inference Server (TIS) is another serving method supported by LMDeploy besides `api_server`. Its performance testing methods and metrics are similar to those of [api_server](./profile_api_server.md). - -The profiling script is `profile_serving.py`. Before running it, please install the lmdeploy precompiled package, download the profiling script and the test dataset: - -```shell -pip install 'lmdeploy[serve]' -git clone --depth=1 https://github.com/InternLM/lmdeploy -cd lmdeploy/benchmark -wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json -``` - -## Metrics - -LMDeploy records the performance metrics like first token latency, token throughput (tokens/s) and request throughput (RPM) - -`first_token_latency` is only reported in the case of streaming inference. - -The formula for calculating `token throughput` is: - -$$ -TokenThroughput=Number\\ of\\ generated\\ tokens/TotalTime -$$ - -And the formula for calculating `request throughput` is: - -$$ -RPM(request\\ per\\ minute)=Number\\ of\\ prompts/TotalTime * 60 -$$ - -Total time includes prefill time. - -## Profile - -In this section, we take [internlm/internlm-7b](https://huggingface.co/internlm/internlm-7b) as an example to show the benchmark procedure. - -### Launch triton inference server - -Before launching the server, the LLM model must be converted to the turbomind format in advance. - -```shell -lmdeploy convert internlm internlm/internlm-7b --dst-path ./internlm-7b --trust-remote-code -``` - -Then, the triton inference server can be launched by: - -```shell -bash ./internlm-7b/service_docker_up.sh -``` - -### Profile - -```shell -python3 profile_serving.py 0.0.0.0:33337 ./internlm-7b/triton_models/tokenizer ./ShareGPT_V3_unfiltered_cleaned_split.json -``` - -For detailed argument specification of `profile_serving.py`, such as request concurrency, sampling parameters an so on, please run the help command `python3 profile_serving.py -h`. diff --git a/docs/en/get_started.md b/docs/en/get_started.md index 9f8639734..f390cc547 100644 --- a/docs/en/get_started.md +++ b/docs/en/get_started.md @@ -13,7 +13,7 @@ pip install lmdeploy The default prebuilt package is compiled on **CUDA 12**. However, if CUDA 11+ is required, you can install lmdeploy by: ```shell -export LMDEPLOY_VERSION=0.5.0 +export LMDEPLOY_VERSION=0.5.2 export PYTHON_VERSION=38 pip install https://github.com/InternLM/lmdeploy/releases/download/v${LMDEPLOY_VERSION}/lmdeploy-${LMDEPLOY_VERSION}+cu118-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux2014_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cu118 ``` diff --git a/docs/en/index.rst b/docs/en/index.rst index 821fd05e1..d79b4b60a 100644 --- a/docs/en/index.rst +++ b/docs/en/index.rst @@ -58,7 +58,6 @@ Documentation benchmark/profile_generation.md benchmark/profile_throughput.md benchmark/profile_api_server.md - benchmark/profile_triton_server.md benchmark/evaluate_with_opencompass.md .. _supported_models: diff --git a/docs/en/inference/turbomind.md b/docs/en/inference/turbomind.md index fc197f85d..2b438723d 100644 --- a/docs/en/inference/turbomind.md +++ b/docs/en/inference/turbomind.md @@ -57,8 +57,6 @@ Our implementation of the LLaMa family models is modified from Gpt-NeoX model in TurboMind supports a Python API that enables streaming output and tensor parallel mode. -The ability to use [tritonserver](https://github.com/triton-inference-server/server) for serving is also inherited from FasterTransformer. However, to support submitting concurrent requests into our persistent batch model, we no longer use sequence batching or dynamic batching as FasterTransformer does. The bookkeeping of request and sequence states are managed by TurboMind instead. - ## Difference between FasterTransformer and TurboMind Apart of the features described above, there are still many minor differences that we don't cover in this document. Notably, many capabilities of FT are dropped in TurboMind because of the difference in objectives (e.g. prefix prompt, beam search, context embedding, sparse GEMM, GPT/T5/other model families, etc) diff --git a/docs/en/multi_modal/cogvlm.md b/docs/en/multi_modal/cogvlm.md index 07d266001..ad0045da2 100644 --- a/docs/en/multi_modal/cogvlm.md +++ b/docs/en/multi_modal/cogvlm.md @@ -22,7 +22,7 @@ Install LMDeploy with pip (Python 3.8+). Refer to [Installation](https://lmdeplo ```shell # cuda 11.8 # to get the latest version, run: pip index versions lmdeploy -export LMDEPLOY_VERSION=0.5.0 +export LMDEPLOY_VERSION=0.5.2 export PYTHON_VERSION=38 pip install https://github.com/InternLM/lmdeploy/releases/download/v${LMDEPLOY_VERSION}/lmdeploy-${LMDEPLOY_VERSION}+cu118-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux2014_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cu118 # cuda 12.1 diff --git a/docs/en/multi_modal/xcomposer2d5.md b/docs/en/multi_modal/xcomposer2d5.md new file mode 100644 index 000000000..81903c128 --- /dev/null +++ b/docs/en/multi_modal/xcomposer2d5.md @@ -0,0 +1,160 @@ +# InternLM-XComposer-2.5 + +## Introduction + +[InternLM-XComposer-2.5](https://github.com/InternLM/InternLM-XComposer) excels in various text-image comprehension and composition applications, achieving GPT-4V level capabilities with merely 7B LLM backend. IXC-2.5 is trained with 24K interleaved image-text contexts, it can seamlessly extend to 96K long contexts via RoPE extrapolation. This long-context capability allows IXC-2.5 to perform exceptionally well in tasks requiring extensive input and output contexts. LMDeploy supports model [internlm/internlm-xcomposer2d5-7b](https://huggingface.co/internlm/internlm-xcomposer2d5-7b) in TurboMind engine. + +## Quick Start + +### Installation + +Install LMDeploy with pip (Python 3.8+). Refer to [Installation](https://lmdeploy.readthedocs.io/en/latest/get_started.html#installation) for more. + +```shell +pip install lmdeploy + +# install other packages that InternLM-XComposer-2.5 needs +pip install decord +``` + +### Offline inference pipeline + +The following sample code shows the basic usage of VLM pipeline. For more examples, please refer to [VLM Offline Inference Pipeline](https://lmdeploy.readthedocs.io/en/latest/inference/vl_pipeline.html#vlm-offline-inference-pipeline) + +```python +from lmdeploy import pipeline +from lmdeploy.vl import load_image +from lmdeploy.vl.constants import IMAGE_TOKEN + +pipe = pipeline('internlm/internlm-xcomposer2d5-7b') + +image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg') +response = pipe((f'describe this image', image)) +print(response) +``` + +## Lora Model + +InternLM-XComposer-2.5 trained the LoRA weights for webpage creation and article writing. As TurboMind backend doesn't support slora, only one LoRA model can be deployed at a time, and the LoRA weights need to be merged when deploying the model. LMDeploy provides the corresponding conversion script, which is used as follows: + +``` +export HF_MODEL=internlm/internlm-xcomposer2d5-7b +export WORK_DIR=internlm/internlm-xcomposer2d5-7b-web +export TASK=web +python -m lmdeploy.vl.tools.merge_xcomposer2d5_task $HF_MODEL $WORK_DIR --task $TASK +``` + +## Quantization + +The following takes the base model as an example to show the quantization method. If you want to use the LoRA model, please merge the LoRA model according to the previous section. + +```shell + +export HF_MODEL=internlm/internlm-xcomposer2d5-7b +export WORK_DIR=internlm/internlm-xcomposer2d5-7b-4bit + +lmdeploy lite auto_awq \ + $HF_MODEL \ + --work-dir $WORK_DIR +``` + +## More examples + +
        + + Video Understanding + + +The following uses the `pipeline.chat` interface api as an example to demonstrate its usage. Other interfaces apis also support inference but require manually splicing of conversation content. + +```python +from lmdeploy import pipeline, GenerationConfig +from transformers.dynamic_module_utils import get_class_from_dynamic_module + +HF_MODEL = 'internlm/internlm-xcomposer2d5-7b' +load_video = get_class_from_dynamic_module('ixc_utils.load_video', HF_MODEL) +frame2img = get_class_from_dynamic_module('ixc_utils.frame2img', HF_MODEL) +Video_transform = get_class_from_dynamic_module('ixc_utils.Video_transform', HF_MODEL) +get_font = get_class_from_dynamic_module('ixc_utils.get_font', HF_MODEL) + +video = load_video('liuxiang.mp4') # https://github.com/InternLM/InternLM-XComposer/raw/main/examples/liuxiang.mp4 +img = frame2img(video, get_font()) +img = Video_transform(img) + +pipe = pipeline(HF_MODEL) +gen_config = GenerationConfig(top_k=50, top_p=0.8, temperature=1.0) +query = 'Here are some frames of a video. Describe this video in detail' +sess = pipe.chat((query, img), gen_config=gen_config) +print(sess.response.text) + +query = 'tell me the athlete code of Liu Xiang' +sess = pipe.chat(query, session=sess, gen_config=gen_config) +print(sess.response.text) +``` + +
        + +
        + + Multi-Image + + +```python +from lmdeploy import pipeline, GenerationConfig +from lmdeploy.vl.constants import IMAGE_TOKEN +from lmdeploy.vl import load_image + +query = f'Image1 {IMAGE_TOKEN}; Image2 {IMAGE_TOKEN}; Image3 {IMAGE_TOKEN}; I want to buy a car from the three given cars, analyze their advantages and weaknesses one by one' + +urls = ['https://raw.githubusercontent.com/InternLM/InternLM-XComposer/main/examples/cars1.jpg', + 'https://raw.githubusercontent.com/InternLM/InternLM-XComposer/main/examples/cars2.jpg', + 'https://raw.githubusercontent.com/InternLM/InternLM-XComposer/main/examples/cars3.jpg'] +images = [load_image(url) for url in urls] + +pipe = pipeline('internlm/internlm-xcomposer2d5-7b', log_level='INFO') +output = pipe((query, images), gen_config=GenerationConfig(top_k=0, top_p=0.8, random_seed=89247526689433939)) +``` + +Since LMDeploy does not support beam search, the generated results will be quite different from those using beam search with transformers. It is recommended to turn off top_k or use a larger top_k sampling to increase diversity. + +
        + +
        + + Instruction to Webpage + + +Please first convert the web model using the instructions above. + +```python +from lmdeploy import pipeline, GenerationConfig + +pipe = pipeline('/nvme/shared/internlm-xcomposer2d5-7b-web', log_level='INFO') +pipe.chat_template.meta_instruction = None + +query = 'A website for Research institutions. The name is Shanghai AI lab. Top Navigation Bar is blue.Below left, an image shows the logo of the lab. In the right, there is a passage of text below that describes the mission of the laboratory.There are several images to show the research projects of Shanghai AI lab.' +output = pipe(query, gen_config=GenerationConfig(max_new_tokens=2048)) +``` + +When using transformers for testing, it is found that if repetition_penalty is set, there is a high probability that the decode phase will not stop if `num_beams` is set to 1. As LMDeploy does not support beam search, it is recommended to turn off repetition_penalty when using LMDeploy for inference. + +
        + +
        + + Write Article + + +Please first convert the write model using the instructions above. + +```python +from lmdeploy import pipeline, GenerationConfig + +pipe = pipeline('/nvme/shared/internlm-xcomposer2d5-7b-write', log_level='INFO') +pipe.chat_template.meta_instruction = None + +query = 'Please write a blog based on the title: French Pastries: A Sweet Indulgence' +output = pipe(query, gen_config=GenerationConfig(max_new_tokens=8192)) +``` + +
        diff --git a/docs/en/serving/api_server_tools.md b/docs/en/serving/api_server_tools.md index 162964093..379d4942b 100644 --- a/docs/en/serving/api_server_tools.md +++ b/docs/en/serving/api_server_tools.md @@ -1,8 +1,10 @@ # Tools +LMDeploy supports tools for InternLM2, InternLM2.5 and llama3.1 models. + ## Single Round Invocation -Currently, LMDeploy supports tools only for InternLM2 or InternLM2.5 models. Please start the service of models before running the following example. +Please start the service of models before running the following example. ```python from openai import OpenAI @@ -43,6 +45,8 @@ print(response) ## Multiple Round Invocation +### InternLM + A complete toolchain invocation process can be demonstrated through the following example. ```python @@ -146,3 +150,97 @@ ChatCompletion(id='1', choices=[Choice(finish_reason='tool_calls', index=0, logp ChatCompletion(id='2', choices=[Choice(finish_reason='tool_calls', index=0, logprobs=None, message=ChatCompletionMessage(content='', role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='2', function=Function(arguments={'a': 8, 'b': 2}, name='mul'), type='function')]))], created=1719369987, model='/nvme/shared_data/InternLM/internlm2-chat-7b', object='chat.completion', system_fingerprint=None, usage=CompletionUsage(completion_tokens=25, prompt_tokens=282, total_tokens=307)) 16 ``` + +### Llama 3.1 + +Meta announces in [Llama3's official user guide](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1) that, + +```{text} +There are three built-in tools (brave_search, wolfram_alpha, and code interpreter) can be turned on using the system prompt: + +1. Brave Search: Tool call to perform web searches. +2. Wolfram Alpha: Tool call to perform complex mathematical calculations. +3. Code Interpreter: Enables the model to output python code. +``` + +Additionally, it cautions: "**Note:** We recommend using Llama 70B-instruct or Llama 405B-instruct for applications that combine conversation and tool calling. Llama 8B-Instruct can not reliably maintain a conversation alongside tool calling definitions. It can be used for zero-shot tool calling, but tool instructions should be removed for regular conversations between the model and the user." + +Therefore, we utilize [Meta-Llama-3.1-70B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B-Instruct) to show how to invoke the tool calling by LMDeploy `api_server`. + +On a A100-SXM-80G node, you can start the service as follows: + +```shell +lmdeploy serve api_server /the/path/of/Meta-Llama-3.1-70B-Instruct/model --tp 4 +``` + +For an in-depth understanding of the api_server, please refer to the detailed documentation available [here](./api_server.md). + +The following code snippet demonstrates how to utilize the 'Wolfram Alpha' tool. It is assumed that you have already registered on the [Wolfram Alpha](https://www.wolframalpha.com) website and obtained an API key. Please ensure that you have a valid API key to access the services provided by Wolfram Alpha + +```python +from openai import OpenAI +import requests + + +def request_llama3_1_service(messages): + client = OpenAI(api_key='YOUR_API_KEY', + base_url='http://0.0.0.0:23333/v1') + model_name = client.models.list().data[0].id + response = client.chat.completions.create( + model=model_name, + messages=messages, + temperature=0.8, + top_p=0.8, + stream=False) + return response.choices[0].message.content + + +# The role of "system" MUST be specified, including the required tools +messages = [ + { + "role": "system", + "content": "Environment: ipython\nTools: wolfram_alpha\n\n Cutting Knowledge Date: December 2023\nToday Date: 23 Jul 2024\n\nYou are a helpful Assistant." # noqa + }, + { + "role": "user", + "content": "Can you help me solve this equation: x^3 - 4x^2 + 6x - 24 = 0" # noqa + } +] + +# send request to the api_server of llama3.1-70b and get the response +# the "assistant_response" is supposed to be: +# <|python_tag|>wolfram_alpha.call(query="solve x^3 - 4x^2 + 6x - 24 = 0") +assistant_response = request_llama3_1_service(messages) +print(assistant_response) + +# Call the API of Wolfram Alpha with the query generated by the model +app_id = 'YOUR-Wolfram-Alpha-API-KEY' +params = { + "input": assistant_response, + "appid": app_id, + "format": "plaintext", + "output": "json", +} + +wolframalpha_response = requests.get( + "https://api.wolframalpha.com/v2/query", + params=params +) +wolframalpha_response = wolframalpha_response.json() + +# Append the contents obtained by the model and the wolframalpha's API +# to "messages", and send it again to the api_server +messages += [ + { + "role": "assistant", + "content": assistant_response + }, + { + "role": "ipython", + "content": wolframalpha_response + } +] + +assistant_response = request_llama3_1_service(messages) +print(assistant_response) +``` diff --git a/docs/en/supported_models/supported_models.md b/docs/en/supported_models/supported_models.md index b30776339..e7c8f03a8 100644 --- a/docs/en/supported_models/supported_models.md +++ b/docs/en/supported_models/supported_models.md @@ -2,68 +2,76 @@ ## Models supported by TurboMind -| Model | Size | FP16/BF16 | KV INT8 | KV INT4 | W4A16 | -| :-----------------: | :----------: | :-------: | :-----: | :-----: | :---: | -| Llama | 7B - 65B | Yes | Yes | Yes | Yes | -| Llama2 | 7B - 70B | Yes | Yes | Yes | Yes | -| Llama3 | 8B, 70B | Yes | Yes | Yes | Yes | -| InternLM | 7B - 20B | Yes | Yes | Yes | Yes | -| InternLM2 | 7B - 20B | Yes | Yes | Yes | Yes | -| InternLM2.5 | 7B | Yes | Yes | Yes | Yes | -| InternLM-XComposer | 7B | Yes | Yes | Yes | Yes | -| InternLM-XComposer2 | 7B, 4khd-7B | Yes | Yes | Yes | Yes | -| QWen | 1.8B - 72B | Yes | Yes | Yes | Yes | -| QWen1.5 | 1.8B - 110B | Yes | Yes | Yes | Yes | -| QWen2 | 1.5B - 72B | Yes | Yes | Yes | Yes | -| Mistral | 7B | Yes | Yes | Yes | No | -| QWen-VL | 7B | Yes | Yes | Yes | Yes | -| DeepSeek-VL | 7B | Yes | Yes | Yes | Yes | -| Baichuan | 7B | Yes | Yes | Yes | Yes | -| Baichuan2 | 7B | Yes | Yes | Yes | Yes | -| Code Llama | 7B - 34B | Yes | Yes | Yes | No | -| YI | 6B - 34B | Yes | Yes | Yes | No | -| LLaVA(1.5,1.6) | 7B - 34B | Yes | Yes | Yes | Yes | -| InternVL-Chat | v1.1- v1.5 | Yes | Yes | Yes | Yes | -| MiniCPM | Llama3-V-2_5 | Yes | Yes | Yes | Yes | -| MiniGeminiLlama | 7B | Yes | No | No | Yes | -| GLM4 | 9B | Yes | Yes | Yes | No | +| Model | Size | FP16/BF16 | KV INT8 | KV INT4 | W4A16 | +| :-------------------: | :----------: | :-------: | :-----: | :-----: | :---: | +| Llama | 7B - 65B | Yes | Yes | Yes | Yes | +| Llama2 | 7B - 70B | Yes | Yes | Yes | Yes | +| Llama3 | 8B, 70B | Yes | Yes | Yes | Yes | +| Llama3.1 | 8B, 70B | Yes | Yes | Yes | Yes | +| InternLM | 7B - 20B | Yes | Yes | Yes | Yes | +| InternLM2 | 7B - 20B | Yes | Yes | Yes | Yes | +| InternLM2.5 | 7B | Yes | Yes | Yes | Yes | +| InternLM-XComposer | 7B | Yes | Yes | Yes | Yes | +| InternLM-XComposer2 | 7B, 4khd-7B | Yes | Yes | Yes | Yes | +| InternLM-XComposer2.5 | 7B | Yes | Yes | Yes | Yes | +| Qwen | 1.8B - 72B | Yes | Yes | Yes | Yes | +| Qwen1.5 | 1.8B - 110B | Yes | Yes | Yes | Yes | +| Qwen2 | 1.5B - 72B | Yes | Yes | Yes | Yes | +| Mistral | 7B | Yes | Yes | Yes | No | +| Qwen-VL | 7B | Yes | Yes | Yes | Yes | +| DeepSeek-VL | 7B | Yes | Yes | Yes | Yes | +| Baichuan | 7B | Yes | Yes | Yes | Yes | +| Baichuan2 | 7B | Yes | Yes | Yes | Yes | +| Code Llama | 7B - 34B | Yes | Yes | Yes | No | +| YI | 6B - 34B | Yes | Yes | Yes | No | +| LLaVA(1.5,1.6) | 7B - 34B | Yes | Yes | Yes | Yes | +| InternVL-Chat | v1.1- v1.5 | Yes | Yes | Yes | Yes | +| InternVL2 | 2B-76B | Yes | Yes | Yes | Yes | +| MiniCPM | Llama3-V-2_5 | Yes | Yes | Yes | Yes | +| MiniGeminiLlama | 7B | Yes | No | No | Yes | +| GLM4 | 9B | Yes | Yes | Yes | No | +| CodeGeeX4 | 9B | Yes | Yes | Yes | No | "-" means not verified yet. ```{note} -The TurboMind engine doesn't support window attention. Therefore, for models that have applied window attention and have the corresponding switch "use_sliding_window" enabled, such as Mistral, QWen1.5 and etc., please choose the PyTorch engine for inference. +The TurboMind engine doesn't support window attention. Therefore, for models that have applied window attention and have the corresponding switch "use_sliding_window" enabled, such as Mistral, Qwen1.5 and etc., please choose the PyTorch engine for inference. ``` ## Models supported by PyTorch -| Model | Size | FP16/BF16 | KV INT8 | W8A8 | -| :-----------------: | :---------: | :-------: | :-----: | :--: | -| Llama | 7B - 65B | Yes | No | Yes | -| Llama2 | 7B - 70B | Yes | No | Yes | -| Llama3 | 8B, 70B | Yes | No | Yes | -| InternLM | 7B - 20B | Yes | No | Yes | -| InternLM2 | 7B - 20B | Yes | No | - | -| InternLM2.5 | 7B | Yes | No | - | -| Baichuan2 | 7B - 13B | Yes | No | Yes | -| ChatGLM2 | 6B | Yes | No | No | -| Falcon | 7B - 180B | Yes | No | No | -| YI | 6B - 34B | Yes | No | No | -| Mistral | 7B | Yes | No | No | -| Mixtral | 8x7B | Yes | No | No | -| QWen | 1.8B - 72B | Yes | No | No | -| QWen1.5 | 0.5B - 110B | Yes | No | No | -| QWen1.5-MoE | A2.7B | Yes | No | No | -| QWen2 | 0.5B - 72B | Yes | No | No | -| DeepSeek-MoE | 16B | Yes | No | No | -| DeepSeek-V2 | 16B, 236B | Yes | No | No | -| Gemma | 2B-7B | Yes | No | No | -| Dbrx | 132B | Yes | No | No | -| StarCoder2 | 3B-15B | Yes | No | No | -| Phi-3-mini | 3.8B | Yes | No | No | -| Phi-3-vision | 4.2B | Yes | No | No | -| CogVLM-Chat | 17B | Yes | No | No | -| CogVLM2-Chat | 19B | Yes | No | No | -| LLaVA(1.5,1.6) | 7B-34B | Yes | No | No | -| InternVL-Chat(v1.5) | 2B-26B | Yes | No | No | -| Gemma2 | 9B-27B | Yes | No | No | -| GLM4 | 9B | Yes | No | No | +| Model | Size | FP16/BF16 | KV INT8 | W8A8 | W4A16 | +| :-----------------: | :---------: | :-------: | :-----: | :--: | :---: | +| Llama | 7B - 65B | Yes | No | Yes | Yes | +| Llama2 | 7B - 70B | Yes | No | Yes | Yes | +| Llama3 | 8B, 70B | Yes | No | Yes | Yes | +| Llama3.1 | 8B, 70B | Yes | No | No | - | +| InternLM | 7B - 20B | Yes | No | Yes | - | +| InternLM2 | 7B - 20B | Yes | No | Yes | Yes | +| InternLM2.5 | 7B | Yes | No | Yes | Yes | +| Baichuan2 | 7B | Yes | No | Yes | Yes | +| Baichuan2 | 13B | Yes | No | No | No | +| ChatGLM2 | 6B | Yes | No | No | No | +| Falcon | 7B - 180B | Yes | No | No | No | +| YI | 6B - 34B | Yes | No | No | Yes | +| Mistral | 7B | Yes | No | No | No | +| Mixtral | 8x7B | Yes | No | No | No | +| QWen | 1.8B - 72B | Yes | No | No | Yes | +| QWen1.5 | 0.5B - 110B | Yes | No | No | Yes | +| QWen1.5-MoE | A2.7B | Yes | No | No | No | +| QWen2 | 0.5B - 72B | Yes | No | No | Yes | +| DeepSeek-MoE | 16B | Yes | No | No | No | +| DeepSeek-V2 | 16B, 236B | Yes | No | No | No | +| Gemma | 2B-7B | Yes | No | No | No | +| Dbrx | 132B | Yes | No | No | No | +| StarCoder2 | 3B-15B | Yes | No | No | No | +| Phi-3-mini | 3.8B | Yes | No | No | No | +| Phi-3-vision | 4.2B | Yes | No | No | - | +| CogVLM-Chat | 17B | Yes | No | No | - | +| CogVLM2-Chat | 19B | Yes | No | No | - | +| LLaVA(1.5,1.6) | 7B-34B | Yes | No | No | - | +| InternVL-Chat(v1.5) | 2B-26B | Yes | No | No | - | +| InternVL2 | 1B-40B | Yes | No | No | - | +| Gemma2 | 9B-27B | Yes | No | No | - | +| GLM4 | 9B | Yes | No | No | No | +| CodeGeeX4 | 9B | Yes | No | No | - | diff --git a/docs/zh_cn/benchmark/profile_triton_server.md b/docs/zh_cn/benchmark/profile_triton_server.md deleted file mode 100644 index f17a2964b..000000000 --- a/docs/zh_cn/benchmark/profile_triton_server.md +++ /dev/null @@ -1,61 +0,0 @@ -# Triton Inference Server 性能测试 - -Triton Inference Server(TIS) 是 LMDeploy 支持的除了 api_server 之外的另一种 serving 方式。它的性能测试方式和测试指标和 [api_server](./profile_api_server.md) 的测试方式类似。 - -```{note} -LMDeploy 尚未实现 Triton Inference Server 的 ensemble 推理模式,所以推理性能要比 api_server 弱。对于追求性能的用户,我们推荐使用 api_server 部署服务。 -``` - -TIS 性能测试脚本是 `profile_serving.py`。测试之前,请安装 lmdeploy 预编译包,并下载评测脚本和测试数据集。 - -```shell -pip install 'lmdeploy[serve]' -git clone --depth=1 https://github.com/InternLM/lmdeploy -wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json -``` - -## 测量指标 - -LMDeploy 统计首token延时(first_token_latency)、token吞吐量(tokens/s)和请求吞吐量(RPM)。 - -`first_token_latency` 只有在流式推理的情况下才会输出。 - -token吞吐量的计算公式为: - -$$ -吞吐量 = 生成的token数量 / 总时间 -$$ - -请求吞吐量的计算公式为: - -$$ -吞吐量 = 请求数量 / 总时间 -$$ - -总时间包括 prefill 时间 - -## 测量方法 - -我们以 [internlm/internlm-7b](https://huggingface.co/internlm/internlm-7b) 为例,展示 triton inference server 的性能测试流程 - -### 启动服务 - -启动服务之前,必须先把模型转换为 turbomind 模型格式: - -```shell -lmdeploy convert internlm internlm/internlm-7b --dst-path ./internlm-7b --trust-remote-code -``` - -然后,执行如下命令,启动服务: - -```shell -bash ./internlm-7b/service_docker_up.sh -``` - -### 测速 - -```shell -python3 profile_serving.py 0.0.0.0:33337 ./internlm-7b/triton_models/tokenizer ./ShareGPT_V3_unfiltered_cleaned_split.json -``` - -关于 `profile_serving.py` 脚本中的参数,比如请求并发数、采样参数等等,可以通过运行命令 `python3 profile_serving.py -h` 查阅。 diff --git a/docs/zh_cn/get_started.md b/docs/zh_cn/get_started.md index a2e3370e4..8577f4834 100644 --- a/docs/zh_cn/get_started.md +++ b/docs/zh_cn/get_started.md @@ -13,7 +13,7 @@ pip install lmdeploy LMDeploy的预编译包默认是基于 CUDA 12 编译的。如果需要在 CUDA 11+ 下安装 LMDeploy,请执行以下命令: ```shell -export LMDEPLOY_VERSION=0.5.0 +export LMDEPLOY_VERSION=0.5.2 export PYTHON_VERSION=38 pip install https://github.com/InternLM/lmdeploy/releases/download/v${LMDEPLOY_VERSION}/lmdeploy-${LMDEPLOY_VERSION}+cu118-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux2014_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cu118 ``` diff --git a/docs/zh_cn/index.rst b/docs/zh_cn/index.rst index 5541082f8..5a6df1fe4 100644 --- a/docs/zh_cn/index.rst +++ b/docs/zh_cn/index.rst @@ -58,7 +58,6 @@ LMDeploy 工具箱提供以下核心功能: benchmark/profile_generation.md benchmark/profile_throughput.md benchmark/profile_api_server.md - benchmark/profile_triton_server.md benchmark/evaluate_with_opencompass.md .. _支持的模型: diff --git a/docs/zh_cn/inference/turbomind.md b/docs/zh_cn/inference/turbomind.md index c179eaa1a..78dba19ab 100644 --- a/docs/zh_cn/inference/turbomind.md +++ b/docs/zh_cn/inference/turbomind.md @@ -57,8 +57,6 @@ TurboMind 的 [KV 缓存管理器](https://github.com/InternLM/lmdeploy/blob/mai TurboMind 的 Python API 支持流式结果返回和张量并行模式。 -同时 TurboMind 也继承了 FasterTransformer 能够注册为 [Triton Inference Server](https://github.com/triton-inference-server/server) 推理后端的能力。但是为了支持 persistent batch 中的并发请求,我们不再像 FasterTransformer 那样使用 sequence batching 或者 dynamic batching 。相反,TurboMind 负责记录和管理请求序列的状态。 - ## TurboMind 和 FasterTransformer 的区别 除了上文中提到的功能外,TurboMind 相较于 FasterTransformer 还有不少差别。譬如不少 FasterTransformer 的功能在 TurboMind 中都被去掉了,这其中包括前缀提示词、 beam search 、上下文 embedding、稀疏化 GEMM 操作和对应 GPT 或 T5 等结构的模型的支持等等。 diff --git a/docs/zh_cn/multi_modal/cogvlm.md b/docs/zh_cn/multi_modal/cogvlm.md index fc05cba74..d932a0831 100644 --- a/docs/zh_cn/multi_modal/cogvlm.md +++ b/docs/zh_cn/multi_modal/cogvlm.md @@ -21,7 +21,7 @@ pip install torch==2.2.2 torchvision==0.17.2 xformers==0.0.26 --index-url https: ```shell # cuda 11.8 -export LMDEPLOY_VERSION=0.5.0 +export LMDEPLOY_VERSION=0.5.2 export PYTHON_VERSION=38 pip install https://github.com/InternLM/lmdeploy/releases/download/v${LMDEPLOY_VERSION}/lmdeploy-${LMDEPLOY_VERSION}+cu118-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux2014_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cu118 # cuda 12.1 diff --git a/docs/zh_cn/multi_modal/xcomposer2d5.md b/docs/zh_cn/multi_modal/xcomposer2d5.md new file mode 100644 index 000000000..551ac5f02 --- /dev/null +++ b/docs/zh_cn/multi_modal/xcomposer2d5.md @@ -0,0 +1,160 @@ +# InternLM-XComposer-2.5 + +## 简介 + +[InternLM-XComposer-2.5](https://github.com/InternLM/InternLM-XComposer) 是基于书生·浦语2大语言模型研发的突破性的图文多模态大模型,仅使用 7B LLM 后端就达到了 GPT-4V 级别的能力。浦语·灵笔2.5使用24K交错的图像-文本上下文进行训练,通过RoPE外推可以无缝扩展到96K长的上下文。这种长上下文能力使浦语·灵笔2.5在需要广泛输入和输出上下文的任务中表现出色。 LMDeploy 支持了 [internlm/internlm-xcomposer2d5-7b](https://huggingface.co/internlm/internlm-xcomposer2d5-7b) 模型,通过 TurboMind 引擎推理。 + +## 快速开始 + +### 安装 + +使用 pip(Python 3.8+) 安装 LMDeploy,更多安装方式参考 [安装](https://lmdeploy.readthedocs.io/zh-cn/latest/get_started.html#id2)。 + +```shell +pip install lmdeploy + +# install other packages that InternLM-XComposer-2.5 needs +pip install decord +``` + +### 离线推理 pipeline + +以下是使用pipeline进行离线推理的示例,更多用法参考[VLM离线推理 pipeline](https://lmdeploy.readthedocs.io/zh-cn/latest/inference/vl_pipeline.html#vlm-pipeline) + +```python +from lmdeploy import pipeline +from lmdeploy.vl import load_image +from lmdeploy.vl.constants import IMAGE_TOKEN + +pipe = pipeline('internlm/internlm-xcomposer2d5-7b') + +image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg') +response = pipe((f'describe this image', image)) +print(response) +``` + +## Lora 模型 + +InternLM-XComposer-2.5 针对网页制作和文章创作训练了 LoRA 模型,由于 TurboMind 不支持 slora 特性,所以需要同时只能部署一个 LoRA 模型,需要先对权重进行合并。LMDeploy 提供相关的转换脚本,使用方式为: + +``` +export HF_MODEL=internlm/internlm-xcomposer2d5-7b +export WORK_DIR=internlm/internlm-xcomposer2d5-7b-web +export TASK=web +python -m lmdeploy.vl.tools.merge_xcomposer2d5_task $HF_MODEL $WORK_DIR --task $TASK +``` + +## 量化 + +下面以 base 模型为例,展示量化的方式,若要使用 LoRA 模型,请先按照上一章节提取 LoRA 模型。 + +```shell + +export HF_MODEL=internlm/internlm-xcomposer2d5-7b +export WORK_DIR=internlm/internlm-xcomposer2d5-7b-4bit + +lmdeploy lite auto_awq \ + $HF_MODEL \ + --work-dir $WORK_DIR +``` + +## 更多使用例子 + +
        + + Video Understanding + + +下面以 `pipeline.chat` 为例展示用法,其它接口同样支持推理,需要手动拼接对话内容。 + +```python +from lmdeploy import pipeline, GenerationConfig +from transformers.dynamic_module_utils import get_class_from_dynamic_module + +HF_MODEL = 'internlm/internlm-xcomposer2d5-7b' +load_video = get_class_from_dynamic_module('ixc_utils.load_video', HF_MODEL) +frame2img = get_class_from_dynamic_module('ixc_utils.frame2img', HF_MODEL) +Video_transform = get_class_from_dynamic_module('ixc_utils.Video_transform', HF_MODEL) +get_font = get_class_from_dynamic_module('ixc_utils.get_font', HF_MODEL) + +video = load_video('liuxiang.mp4') # https://github.com/InternLM/InternLM-XComposer/raw/main/examples/liuxiang.mp4 +img = frame2img(video, get_font()) +img = Video_transform(img) + +pipe = pipeline(HF_MODEL) +gen_config = GenerationConfig(top_k=50, top_p=0.8, temperature=1.0) +query = 'Here are some frames of a video. Describe this video in detail' +sess = pipe.chat((query, img), gen_config=gen_config) +print(sess.response.text) + +query = 'tell me the athlete code of Liu Xiang' +sess = pipe.chat(query, session=sess, gen_config=gen_config) +print(sess.response.text) +``` + +
        + +
        + + Multi-Image + + +```python +from lmdeploy import pipeline, GenerationConfig +from lmdeploy.vl.constants import IMAGE_TOKEN +from lmdeploy.vl import load_image + +query = f'Image1 {IMAGE_TOKEN}; Image2 {IMAGE_TOKEN}; Image3 {IMAGE_TOKEN}; I want to buy a car from the three given cars, analyze their advantages and weaknesses one by one' + +urls = ['https://raw.githubusercontent.com/InternLM/InternLM-XComposer/main/examples/cars1.jpg', + 'https://raw.githubusercontent.com/InternLM/InternLM-XComposer/main/examples/cars2.jpg', + 'https://raw.githubusercontent.com/InternLM/InternLM-XComposer/main/examples/cars3.jpg'] +images = [load_image(url) for url in urls] + +pipe = pipeline('internlm/internlm-xcomposer2d5-7b', log_level='INFO') +output = pipe((query, images), gen_config=GenerationConfig(top_k=0, top_p=0.8, random_seed=89247526689433939)) +``` + +由于 LMDeploy 不支持 beam search,生成的结果与使用 transformers 的 beam search 相比,会有较大的差异,建议关闭 top_k 或者使用较大的 top_k 采样来增加多样性。 + +
        + +
        + + Instruction to Webpage + + +请先使用使用上述说明,转化 web 模型。 + +```python +from lmdeploy import pipeline, GenerationConfig + +pipe = pipeline('/nvme/shared/internlm-xcomposer2d5-7b-web', log_level='INFO') +pipe.chat_template.meta_instruction = None + +query = 'A website for Research institutions. The name is Shanghai AI lab. Top Navigation Bar is blue.Below left, an image shows the logo of the lab. In the right, there is a passage of text below that describes the mission of the laboratory.There are several images to show the research projects of Shanghai AI lab.' +output = pipe(query, gen_config=GenerationConfig(max_new_tokens=2048)) +``` + +使用 transformers 测试时,发现如果设置了 repetition_penalty,beam search 为1时有较大概率停不下来,因为 LMDeploy 不支持 beam search,建议使用 LMDeploy 推理时关闭 repetition_penalty。 + +
        + +
        + + Write Article + + +请先使用使用上述说明,转化 write 模型。 + +```python +from lmdeploy import pipeline, GenerationConfig + +pipe = pipeline('/nvme/shared/internlm-xcomposer2d5-7b-write', log_level='INFO') +pipe.chat_template.meta_instruction = None + +query = 'Please write a blog based on the title: French Pastries: A Sweet Indulgence' +output = pipe(query, gen_config=GenerationConfig(max_new_tokens=8192)) +``` + +
        diff --git a/docs/zh_cn/serving/api_server_tools.md b/docs/zh_cn/serving/api_server_tools.md index 8a4d88f4b..8c44c0e43 100644 --- a/docs/zh_cn/serving/api_server_tools.md +++ b/docs/zh_cn/serving/api_server_tools.md @@ -1,8 +1,10 @@ # Tools +LMDeploy 支持 InternLM2, InternLM2.5 和 Llama3.1 模型的工具调用。 + ## 单轮调用 -目前的 LMDeploy 只支持 InternLM2 或者 InternLM2.5 模型的工具调用。启动好模型的服务后,运行下面 demo 即可。 +启动好模型的服务后,运行下面 demo 即可。 ```python from openai import OpenAI @@ -43,6 +45,8 @@ print(response) ## 多轮调用 +### InternLM + 一个完整的工具链调用过程可以通过下面的例子展示。 ```python @@ -146,3 +150,97 @@ ChatCompletion(id='1', choices=[Choice(finish_reason='tool_calls', index=0, logp ChatCompletion(id='2', choices=[Choice(finish_reason='tool_calls', index=0, logprobs=None, message=ChatCompletionMessage(content='', role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='2', function=Function(arguments={'a': 8, 'b': 2}, name='mul'), type='function')]))], created=1719369987, model='/nvme/shared_data/InternLM/internlm2-chat-7b', object='chat.completion', system_fingerprint=None, usage=CompletionUsage(completion_tokens=25, prompt_tokens=282, total_tokens=307)) 16 ``` + +### Llama3.1 + +Meta 在 [Llama3 的官方用户指南](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1)中宣布(注:下文为原文的中文翻译): + +```{text} +有三个内置工具(brave_search、wolfram_alpha 和 code interpreter)可以使用系统提示词打开: + +1. Brave Search:执行网络搜索的工具调用。 +2. Wolfram Alpha:执行复杂数学计算的工具调用。 +3. Code Interpreter:使模型能够输出 Python 代码的功能。 +``` + +此外,它还警告说:“注意: 我们建议使用 Llama 70B-instruct 或 Llama 405B-instruct 用于结合对话和工具调用的应用。Llama 8B-Instruct 无法可靠地在工具调用定义的同时维持对话。它可以用于零样本工具调用,但在模型和用户之间的常规对话中,应移除工具指令。”(注:引号中内容为原文的中文翻译) + +因此,我们使用 [Meta-Llama-3.1-70B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B-Instruct) 来展示如何通过 LMDeploy的`api_server`调用模型的工具能力. + +在 A100-SXM-80G 节点上,可以按照以下方式启动服务: + +```shell +lmdeploy serve api_server /the/path/of/Meta-Llama-3.1-70B-Instruct/model --tp 4 +``` + +有关 api_server 的详细介绍,请参考[此处](./api_server.md)的详细文档。 + +以下代码示例展示了如何使用 "Wolfram Alpha" 工具。假设你已经在[Wolfram Alpha](https://www.wolframalpha.com) 网站上注册并获取了 API 密钥。请确保拥有一个有效的 API 密钥,以便访问 Wolfram Alpha 提供的服务。 + +```python +from openai import OpenAI +import requests + + +def request_llama3_1_service(messages): + client = OpenAI(api_key='YOUR_API_KEY', + base_url='http://0.0.0.0:23333/v1') + model_name = client.models.list().data[0].id + response = client.chat.completions.create( + model=model_name, + messages=messages, + temperature=0.8, + top_p=0.8, + stream=False) + return response.choices[0].message.content + + +# The role of "system" MUST be specified, including the required tools +messages = [ + { + "role": "system", + "content": "Environment: ipython\nTools: wolfram_alpha\n\n Cutting Knowledge Date: December 2023\nToday Date: 23 Jul 2024\n\nYou are a helpful Assistant." # noqa + }, + { + "role": "user", + "content": "Can you help me solve this equation: x^3 - 4x^2 + 6x - 24 = 0" # noqa + } +] + +# send request to the api_server of llama3.1-70b and get the response +# the "assistant_response" is supposed to be: +# <|python_tag|>wolfram_alpha.call(query="solve x^3 - 4x^2 + 6x - 24 = 0") +assistant_response = request_llama3_1_service(messages) +print(assistant_response) + +# Call the API of Wolfram Alpha with the query generated by the model +app_id = 'YOUR-Wolfram-Alpha-API-KEY' +params = { + "input": assistant_response, + "appid": app_id, + "format": "plaintext", + "output": "json", +} + +wolframalpha_response = requests.get( + "https://api.wolframalpha.com/v2/query", + params=params +) +wolframalpha_response = wolframalpha_response.json() + +# Append the contents obtained by the model and the wolframalpha's API +# to "messages", and send it again to the api_server +messages += [ + { + "role": "assistant", + "content": assistant_response + }, + { + "role": "ipython", + "content": wolframalpha_response + } +] + +assistant_response = request_llama3_1_service(messages) +print(assistant_response) +``` diff --git a/docs/zh_cn/supported_models/supported_models.md b/docs/zh_cn/supported_models/supported_models.md index cca1ce8f3..b8c48b101 100644 --- a/docs/zh_cn/supported_models/supported_models.md +++ b/docs/zh_cn/supported_models/supported_models.md @@ -2,68 +2,76 @@ ## TurboMind 支持的模型 -| 模型 | 模型规模 | FP16/BF16 | KV INT8 | KV INT4 | W4A16 | -| :-----------------: | :----------: | :-------: | :-----: | :-----: | :---: | -| Llama | 7B - 65B | Yes | Yes | Yes | Yes | -| Llama2 | 7B - 70B | Yes | Yes | Yes | Yes | -| Llama3 | 8B, 70B | Yes | Yes | Yes | Yes | -| InternLM | 7B - 20B | Yes | Yes | Yes | Yes | -| InternLM2 | 7B - 20B | Yes | Yes | Yes | Yes | -| InternLM2.5 | 7B | Yes | Yes | Yes | Yes | -| InternLM-XComposer | 7B | Yes | Yes | Yes | Yes | -| InternLM-XComposer2 | 7B, 4khd-7B | Yes | Yes | Yes | Yes | -| QWen | 1.8B - 72B | Yes | Yes | Yes | Yes | -| QWen1.5 | 1.8B - 110B | Yes | Yes | Yes | Yes | -| QWen2 | 1.5B - 72B | Yes | Yes | Yes | Yes | -| Mistral | 7B | Yes | Yes | Yes | No | -| QWen-VL | 7B | Yes | Yes | Yes | Yes | -| DeepSeek-VL | 7B | Yes | Yes | Yes | Yes | -| Baichuan | 7B | Yes | Yes | Yes | Yes | -| Baichuan2 | 7B | Yes | Yes | Yes | Yes | -| Code Llama | 7B - 34B | Yes | Yes | Yes | No | -| YI | 6B - 34B | Yes | Yes | Yes | No | -| LLaVA(1.5,1.6) | 7B - 34B | Yes | Yes | Yes | Yes | -| InternVL-Chat | v1.1- v1.5 | Yes | Yes | Yes | Yes | -| MiniCPM | Llama3-V-2_5 | Yes | Yes | Yes | Yes | -| MiniGeminiLlama | 7B | Yes | No | No | Yes | -| GLM4 | 9B | Yes | Yes | Yes | No | +| 模型 | 模型规模 | FP16/BF16 | KV INT8 | KV INT4 | W4A16 | +| :-------------------: | :----------: | :-------: | :-----: | :-----: | :---: | +| Llama | 7B - 65B | Yes | Yes | Yes | Yes | +| Llama2 | 7B - 70B | Yes | Yes | Yes | Yes | +| Llama3 | 8B, 70B | Yes | Yes | Yes | Yes | +| Llama3.1 | 8B, 70B | Yes | Yes | Yes | Yes | +| InternLM | 7B - 20B | Yes | Yes | Yes | Yes | +| InternLM2 | 7B - 20B | Yes | Yes | Yes | Yes | +| InternLM2.5 | 7B | Yes | Yes | Yes | Yes | +| InternLM-XComposer | 7B | Yes | Yes | Yes | Yes | +| InternLM-XComposer2 | 7B, 4khd-7B | Yes | Yes | Yes | Yes | +| InternLM-XComposer2.5 | 7B | Yes | Yes | Yes | Yes | +| Qwen | 1.8B - 72B | Yes | Yes | Yes | Yes | +| Qwen1.5 | 1.8B - 110B | Yes | Yes | Yes | Yes | +| Qwen2 | 1.5B - 72B | Yes | Yes | Yes | Yes | +| Mistral | 7B | Yes | Yes | Yes | No | +| Qwen-VL | 7B | Yes | Yes | Yes | Yes | +| DeepSeek-VL | 7B | Yes | Yes | Yes | Yes | +| Baichuan | 7B | Yes | Yes | Yes | Yes | +| Baichuan2 | 7B | Yes | Yes | Yes | Yes | +| Code Llama | 7B - 34B | Yes | Yes | Yes | No | +| YI | 6B - 34B | Yes | Yes | Yes | No | +| LLaVA(1.5,1.6) | 7B - 34B | Yes | Yes | Yes | Yes | +| InternVL-Chat | v1.1- v1.5 | Yes | Yes | Yes | Yes | +| InternVL2 | 2B-76B | Yes | Yes | Yes | Yes | +| MiniCPM | Llama3-V-2_5 | Yes | Yes | Yes | Yes | +| MiniGeminiLlama | 7B | Yes | No | No | Yes | +| GLM4 | 9B | Yes | Yes | Yes | No | +| CodeGeeX4 | 9B | Yes | Yes | Yes | No | “-” 表示还没有验证。 ```{note} -turbomind 引擎不支持 window attention。所以,对于应用了 window attention,并开启了对应的开关"use_sliding_window"的模型,比如 Mistral、QWen1.5 等,在推理时,请选择 pytorch engine +turbomind 引擎不支持 window attention。所以,对于应用了 window attention,并开启了对应的开关"use_sliding_window"的模型,比如 Mistral、Qwen1.5 等,在推理时,请选择 pytorch engine ``` ### PyTorch 支持的模型 -| 模型 | 模型规模 | FP16/BF16 | KV INT8 | W8A8 | -| :-----------------: | :---------: | :-------: | :-----: | :--: | -| Llama | 7B - 65B | Yes | No | Yes | -| Llama2 | 7B - 70B | Yes | No | Yes | -| Llama3 | 8B, 70B | Yes | No | Yes | -| InternLM | 7B - 20B | Yes | No | Yes | -| InternLM2 | 7B - 20B | Yes | No | - | -| InternLM2.5 | 7B | Yes | No | - | -| Baichuan2 | 7B - 13B | Yes | No | Yes | -| ChatGLM2 | 6B | Yes | No | No | -| Falcon | 7B - 180B | Yes | No | No | -| YI | 6B - 34B | Yes | No | No | -| Mistral | 7B | Yes | No | No | -| Mixtral | 8x7B | Yes | No | No | -| QWen | 1.8B - 72B | Yes | No | No | -| QWen1.5 | 0.5B - 110B | Yes | No | No | -| QWen2 | 0.5B - 72B | Yes | No | No | -| QWen1.5-MoE | A2.7B | Yes | No | No | -| DeepSeek-MoE | 16B | Yes | No | No | -| DeepSeek-V2 | 16B, 236B | Yes | No | No | -| Gemma | 2B-7B | Yes | No | No | -| Dbrx | 132B | Yes | No | No | -| StarCoder2 | 3B-15B | Yes | No | No | -| Phi-3-mini | 3.8B | Yes | No | No | -| Phi-3-vision | 4.2B | Yes | No | No | -| CogVLM-Chat | 17B | Yes | No | No | -| CogVLM2-Chat | 19B | Yes | No | No | -| LLaVA(1.5,1.6) | 7B-34B | Yes | No | No | -| InternVL-Chat(v1.5) | 2B-26B | Yes | No | No | -| Gemma2 | 9B-27B | Yes | No | No | -| GLM4 | 9B | Yes | No | No | +| Model | Size | FP16/BF16 | KV INT8 | W8A8 | W4A16 | +| :-----------------: | :---------: | :-------: | :-----: | :--: | :---: | +| Llama | 7B - 65B | Yes | No | Yes | Yes | +| Llama2 | 7B - 70B | Yes | No | Yes | Yes | +| Llama3 | 8B, 70B | Yes | No | Yes | Yes | +| Llama3.1 | 8B, 70B | Yes | No | No | - | +| InternLM | 7B - 20B | Yes | No | Yes | - | +| InternLM2 | 7B - 20B | Yes | No | Yes | Yes | +| InternLM2.5 | 7B | Yes | No | Yes | Yes | +| Baichuan2 | 7B | Yes | No | Yes | Yes | +| Baichuan2 | 13B | Yes | No | No | No | +| ChatGLM2 | 6B | Yes | No | No | No | +| Falcon | 7B - 180B | Yes | No | No | No | +| YI | 6B - 34B | Yes | No | No | Yes | +| Mistral | 7B | Yes | No | No | No | +| Mixtral | 8x7B | Yes | No | No | No | +| QWen | 1.8B - 72B | Yes | No | No | Yes | +| QWen1.5 | 0.5B - 110B | Yes | No | No | Yes | +| QWen1.5-MoE | A2.7B | Yes | No | No | No | +| QWen2 | 0.5B - 72B | Yes | No | No | Yes | +| DeepSeek-MoE | 16B | Yes | No | No | No | +| DeepSeek-V2 | 16B, 236B | Yes | No | No | No | +| Gemma | 2B-7B | Yes | No | No | No | +| Dbrx | 132B | Yes | No | No | No | +| StarCoder2 | 3B-15B | Yes | No | No | No | +| Phi-3-mini | 3.8B | Yes | No | No | No | +| Phi-3-vision | 4.2B | Yes | No | No | - | +| CogVLM-Chat | 17B | Yes | No | No | - | +| CogVLM2-Chat | 19B | Yes | No | No | - | +| LLaVA(1.5,1.6) | 7B-34B | Yes | No | No | - | +| InternVL-Chat(v1.5) | 2B-26B | Yes | No | No | - | +| InternVL2 | 1B-40B | Yes | No | No | - | +| Gemma2 | 9B-27B | Yes | No | No | - | +| GLM4 | 9B | Yes | No | No | No | +| CodeGeeX4 | 9B | Yes | No | No | - | diff --git a/lmdeploy/cli/serve.py b/lmdeploy/cli/serve.py index 0c62b7abf..b5e2f4d39 100644 --- a/lmdeploy/cli/serve.py +++ b/lmdeploy/cli/serve.py @@ -200,23 +200,6 @@ def add_parser_api_client(): 'api key will be used') ArgumentHelper.session_id(parser) - @staticmethod - def add_parser_triton_client(): - """Add parser for triton_client command.""" - parser = SubCliServe.subparsers.add_parser( - 'triton_client', - formatter_class=DefaultsAndTypesHelpFormatter, - description=SubCliServe.triton_client.__doc__, - help=SubCliServe.triton_client.__doc__) - parser.set_defaults(run=SubCliServe.triton_client) - parser.add_argument( - 'tritonserver_addr', - type=str, - help='The address in format "ip:port" of triton inference server') - ArgumentHelper.session_id(parser) - ArgumentHelper.cap(parser) - ArgumentHelper.stream_output(parser) - @staticmethod def gradio(args): """Serve LLMs with web UI using gradio.""" @@ -336,16 +319,8 @@ def api_client(args): kwargs = convert_args(args) run_api_client(**kwargs) - @staticmethod - def triton_client(args): - """Interact with Triton Server using gRPC protocol.""" - from lmdeploy.serve.client import main as run_triton_client - kwargs = convert_args(args) - run_triton_client(**kwargs) - @staticmethod def add_parsers(): SubCliServe.add_parser_gradio() SubCliServe.add_parser_api_server() SubCliServe.add_parser_api_client() - SubCliServe.add_parser_triton_client() diff --git a/lmdeploy/lite/apis/auto_awq.py b/lmdeploy/lite/apis/auto_awq.py index 33f621e83..0da7c53dc 100644 --- a/lmdeploy/lite/apis/auto_awq.py +++ b/lmdeploy/lite/apis/auto_awq.py @@ -25,6 +25,7 @@ 'LlavaLlamaForCausalLM': 'LlamaDecoderLayer', 'MGMLlamaForCausalLM': 'LlamaDecoderLayer', # mini gemini 'InternLMXComposer2ForCausalLM': 'InternLM2DecoderLayer', + 'ChatGLMForConditionalGeneration': 'GLMBlock', } NORM_TYPE_MAP = { 'InternLMForCausalLM': 'InternLMRMSNorm', @@ -37,6 +38,7 @@ 'LlavaLlamaForCausalLM': 'LlamaRMSNorm', 'MGMLlamaForCausalLM': 'LlamaRMSNorm', # mini gemini 'InternLMXComposer2ForCausalLM': 'InternLM2RMSNorm', + 'ChatGLMForConditionalGeneration': 'RMSNorm', } @@ -136,7 +138,7 @@ def auto_awq(model: str, w_sym, w_group_size, device, - skip_if_contains='Plora') # TODO quant lora weight + skip_if_contains='lora') # TODO quant lora weight quantization_config = dict(quant_method='awq', version='gemm', bits=w_bits, diff --git a/lmdeploy/lite/apis/calibrate.py b/lmdeploy/lite/apis/calibrate.py index 692d5cae1..2f2c26cd5 100644 --- a/lmdeploy/lite/apis/calibrate.py +++ b/lmdeploy/lite/apis/calibrate.py @@ -23,6 +23,7 @@ 'LlavaLlamaForCausalLM': 'LlamaDecoderLayer', 'MGMLlamaForCausalLM': 'LlamaDecoderLayer', # mini gemini 'InternLMXComposer2ForCausalLM': 'InternLM2DecoderLayer', + 'ChatGLMForConditionalGeneration': 'GLMBlock' } NORM_TYPE_MAP = { @@ -36,6 +37,7 @@ 'LlavaLlamaForCausalLM': 'LlamaRMSNorm', 'MGMLlamaForCausalLM': 'LlamaRMSNorm', # mini gemini 'InternLMXComposer2ForCausalLM': 'InternLM2RMSNorm', + 'ChatGLMForConditionalGeneration': 'RMSNorm' } HEAD_NAME_MAP = { @@ -49,6 +51,7 @@ 'LlavaLlamaForCausalLM': 'lm_head', 'MGMLlamaForCausalLM': 'lm_head', # mini gemini 'InternLMXComposer2ForCausalLM': 'output', + 'ChatGLMForConditionalGeneration': 'output_layer' } @@ -183,6 +186,7 @@ def calibrate(model: str, from lmdeploy.vl.model.builder import vl_model_with_tokenizer vl_model, model, tokenizer = vl_model_with_tokenizer(model_path=model) + model.config.use_cache = False model_type = type(model).__name__ if model_type not in LAYER_TYPE_MAP or model_type not in NORM_TYPE_MAP: raise RuntimeError( diff --git a/lmdeploy/lite/apis/smooth_quant.py b/lmdeploy/lite/apis/smooth_quant.py index 7b432d558..f805fc6ea 100644 --- a/lmdeploy/lite/apis/smooth_quant.py +++ b/lmdeploy/lite/apis/smooth_quant.py @@ -20,6 +20,7 @@ 'QWenLMHeadModel': 'QWenBlock', 'BaiChuanForCausalLM': 'DecoderLayer', 'LlamaForCausalLM': 'LlamaDecoderLayer', + 'ChatGLMForConditionalGeneration': 'GLMBlock', } NORM_TYPE_MAP = { 'InternLMForCausalLM': 'InternLMRMSNorm', @@ -27,6 +28,7 @@ 'QWenLMHeadModel': 'RMSNorm', 'BaiChuanForCausalLM': 'RMSNorm', 'LlamaForCausalLM': 'LlamaRMSNorm', + 'ChatGLMForConditionalGeneration': 'RMSNorm', } LMDEPLOY_ROOT = lmdeploy.__path__[0] diff --git a/lmdeploy/lite/quantization/awq.py b/lmdeploy/lite/quantization/awq.py index 508a26ce1..a37e73c68 100644 --- a/lmdeploy/lite/quantization/awq.py +++ b/lmdeploy/lite/quantization/awq.py @@ -31,6 +31,10 @@ 'DecoderLayer': { 'input_layernorm': ['self_attn.W_pack'], 'post_attention_layernorm': ['mlp.gate_proj', 'mlp.up_proj'] + }, + 'GLMBlock': { + 'input_layernorm': ['self_attention.query_key_value'], + 'post_attention_layernorm': ['mlp.dense_h_to_4h'] } } @@ -57,6 +61,10 @@ 'DecoderLayer': { 'self_attn.W_pack': ['self_attn.o_proj'], 'mlp.up_proj': ['mlp.down_proj'] + }, + 'GLMBlock': { + # 'self_attention.query_key_value': ['self_attention.dense'] + # 'mlp.dense_h_to_4h': ['mlp.dense_4h_to_h'] } } @@ -163,6 +171,7 @@ def smooth_fc_fcs(pre_fc: torch.nn.Module, if getattr(pre_fc, 'bias', None) is not None: pre_fc.bias[-size_a:].div_(scales) else: + pre_fc.weight.div_(scales.view(-1, 1)) if getattr(pre_fc, 'bias', None) is not None: diff --git a/lmdeploy/lite/quantization/calibration.py b/lmdeploy/lite/quantization/calibration.py index 08ea4c5f9..4ae21e3f4 100644 --- a/lmdeploy/lite/quantization/calibration.py +++ b/lmdeploy/lite/quantization/calibration.py @@ -3,8 +3,6 @@ from typing import Union import torch -import transformers -from mmengine import digit_version from torch import nn from transformers import PreTrainedTokenizer @@ -161,60 +159,10 @@ def _forward(mod, *args, **kwargs): samples = len(batch_args) m_name = self.mod2name[mod] - k_obs = KVCacheObserver.find(m_name, group=self.key_obs_group) - v_obs = KVCacheObserver.find(m_name, group=self.value_obs_group) for i in range(len(batch_args)): - - if k_obs and v_obs: - batch_kwargs[i]['use_cache'] = True - version = digit_version(transformers.__version__) - use_new_cache = type(mod).__name__ in ('LlamaDecoderLayer', - 'Qwen2DecoderLayer') - if type(mod).__name__ == 'InternLM2DecoderLayer': - use_new_cache = hasattr(mod.attention, 'layer_idx') - if version > digit_version('4.36.0') and use_new_cache: - from transformers.cache_utils import DynamicCache - batch_kwargs[i]['past_key_value'] = DynamicCache() - - if hasattr(mod, 'self_attn'): - self_attn = mod.self_attn - elif hasattr(mod, 'attention'): - self_attn = mod.attention - else: - raise RuntimeError('Attention layer not found') - - ori_idx = self_attn.layer_idx - self_attn.layer_idx = 0 - - out = self._ori_forwards[mod](*batch_args[i], - **batch_kwargs[i]) - self_attn.layer_idx = ori_idx - - out = list(out) - cache = out.pop(-1) - - key = cache.key_cache.pop(-1) - value = cache.value_cache.pop(-1) - - k_obs.observe(key) - v_obs.observe(value) - - else: - out = self._ori_forwards[mod](*batch_args[i], - **batch_kwargs[i]) - out = list(out) - key, value = out.pop(-1) - - k_obs.observe(key) - v_obs.observe(value) - - del key, value - torch.cuda.empty_cache() - batch_outputs.append(tuple(out)) - else: - batch_outputs.append(self._ori_forwards[mod]( - *batch_args[i], **batch_kwargs[i])) + batch_outputs.append(self._ori_forwards[mod]( + *batch_args[i], **batch_kwargs[i])) outputs = concat_decoder_layer_outputs(batch_outputs) @@ -309,14 +257,11 @@ def export(self, out_dir): out_stats = self.collect_outputs_stats() torch.save(out_stats, out_dir / 'outputs_stats.pth') - key_stats, value_stats = self.collect_kv_stats() - torch.save(key_stats, out_dir / 'key_stats.pth') - torch.save(value_stats, out_dir / 'value_stats.pth') - def calibrate(self, data): """Forward pass through the model in inference mode with given data.""" - if type(self.model).__name__ == 'QWenLMHeadModel': + if type(self.model).__name__ in ('QWenLMHeadModel', + 'ChatGLMForConditionalGeneration'): model = self.model.transformer else: model = self.model.model @@ -520,7 +465,7 @@ def _forward(mod, *args, **kwargs): self.w_group_size, obs_group, mod_name) ActivationObserver.enable() for key, item in obs_group.items(): - if key.startswith(f'{mod_name}.'): + if key.startswith(f'{mod_name}.') and item.value is not None: item.value.cpu() del item.value diff --git a/lmdeploy/lite/utils/batch_split.py b/lmdeploy/lite/utils/batch_split.py index 0aba1f929..3bd208f60 100644 --- a/lmdeploy/lite/utils/batch_split.py +++ b/lmdeploy/lite/utils/batch_split.py @@ -96,6 +96,8 @@ def is_past_key_value(data: Any) -> bool: key = torch.cat([out[i][0] for out in batch_outputs]) value = torch.cat([out[i][1] for out in batch_outputs]) out_i = (key, value) + elif batch_outputs[0][i] is None: # glm4 + out_i = None else: # If it's not a past key-value pair, concatenate directly. out_i = torch.cat([out[i] for out in batch_outputs]) diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index ec79e7906..a4054449d 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -14,7 +14,7 @@ class GenerationConfig: Args: n (int): Define how many chat completion choices to generate for each - input message + input message. **Only 1** is supported now. max_new_tokens (int): The maximum number of tokens that can be generated in the chat completion top_p (float): An alternative to sampling with temperature, called diff --git a/lmdeploy/model.py b/lmdeploy/model.py index 6755a03c6..d2e0271dc 100644 --- a/lmdeploy/model.py +++ b/lmdeploy/model.py @@ -104,12 +104,7 @@ def from_json(cls, file_or_string): class BaseModel: """Base model.""" - def __init__(self, - session_len=2048, - capability='chat', - stop_words=None, - **kwargs): - self.session_len = session_len + def __init__(self, capability='chat', stop_words=None, **kwargs): self.stop_words = stop_words self.capability = capability @@ -361,8 +356,8 @@ def match(cls, model_path: str) -> Optional[str]: class MiniGemini(Vicuna): """Chat template of vicuna model.""" - def __init__(self, session_len=4096, **kwargs): - super().__init__(session_len=session_len, **kwargs) + def __init__(self, **kwargs): + super().__init__(**kwargs) def get_prompt(self, prompt, sequence_start=True): return super().get_prompt(prompt, sequence_start)[:-1] @@ -384,8 +379,6 @@ def match(cls, model_path: str) -> Optional[str]: return 'mini-gemini-vicuna' -@MODELS.register_module(name='internlm-chat') -@MODELS.register_module(name='internlm-chat-7b') @MODELS.register_module(name='internlm') class InternLMChat7B(BaseChatTemplate): """Chat template of InternLM model.""" @@ -429,48 +422,11 @@ def match(cls, model_path: str) -> Optional[str]: return 'internlm' -@MODELS.register_module(name='internlm-chat-20b') -@MODELS.register_module(name='internlm-chat-7b-8k') -class InternLMChat7B8K(InternLMChat7B): - """Chat template and generation parameters of InternLM-Chat-7B-8K and - InternLM-Chat-20B models.""" - - def __init__(self, session_len=8192, **kwargs): - super(InternLMChat7B8K, self).__init__(**kwargs) - self.session_len = session_len - - -@MODELS.register_module(name='internlm-20b') -class InternLMBaseModel20B(BaseChatTemplate): - """Generation parameters of InternLM-20B-Base model.""" - - def __init__(self, session_len=4096, capability='completion', **kwargs): - super().__init__(session_len=session_len, - capability=capability, - **kwargs) - - -@MODELS.register_module( - name=['internlm2-1_8b', 'internlm2-7b', 'internlm2-20b']) -class InternLM2BaseModel7B(BaseChatTemplate): - """Generation parameters of InternLM2-7B-Base model.""" - - def __init__(self, session_len=32768, capability='completion', **kwargs): - super().__init__(session_len=session_len, - capability=capability, - **kwargs) - - -@MODELS.register_module(name=[ - 'internlm2-chat', 'internlm2-chat-1_8b', 'internlm2-chat-7b', - 'internlm2-chat-20b' -]) @MODELS.register_module(name='internlm2') class InternLM2Chat7B(InternLMChat7B): """Chat template and generation parameters of InternLM2-Chat-7B.""" def __init__(self, - session_len=32768, system='<|im_start|>system\n', user='<|im_start|>user\n', assistant='<|im_start|>assistant\n', @@ -488,8 +444,7 @@ def __init__(self, self.interpreter = interpreter self.environment = environment self.eoenv = eoenv - super(InternLM2Chat7B, self).__init__(session_len=session_len, - system=system, + super(InternLM2Chat7B, self).__init__(system=system, user=user, assistant=assistant, eosys=eosys, @@ -607,13 +562,12 @@ def match(cls, model_path: str) -> Optional[str]: return 'internvl2-internlm2' -@MODELS.register_module(name='internlm-xcomposer2') +@MODELS.register_module(name=['internlm-xcomposer2', 'internlm-xcomposer2d5']) class InternLMXComposer2Chat7B(InternLMChat7B): """Chat template and generation parameters of InternLM-XComposer2-7b.""" def __init__( self, - session_len=4096, system='[UNUSED_TOKEN_146]system\n', meta_instruction="""You are an AI assistant whose name is InternLM-XComposer (浦语·灵笔). - InternLM-XComposer (浦语·灵笔) is a multi-modality conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless. @@ -627,8 +581,7 @@ def __init__( separator='\n', stop_words=['[UNUSED_TOKEN_145]'], **kwargs): - super().__init__(session_len=session_len, - system=system, + super().__init__(system=system, meta_instruction=meta_instruction, user=user, assistant=assistant, @@ -647,42 +600,14 @@ def match(cls, model_path: str) -> Optional[str]: model_path (str): the model path used for matching. """ path = model_path.lower() - if 'internlm' in path and 'xcomposer2' in path and '4khd' not in path: + if 'internlm' in path and 'xcomposer2' in path: + if '2d5' in path: + return 'internlm-xcomposer2d5' return 'internlm-xcomposer2' -@MODELS.register_module(name='internlm-xcomposer2-4khd') -class InternLMXComposer24khdChat7B(InternLMXComposer2Chat7B): - """Chat template and generation parameters of InternLM- - XComposer2-4khd-7b.""" - - def __init__(self, session_len=16384, **kwargs): - super().__init__(session_len=session_len, **kwargs) - - @classmethod - def match(cls, model_path: str) -> Optional[str]: - """Return the model_name that was registered to MODELS. - - Args: - model_path (str): the model path used for matching. - """ - path = model_path.lower() - if 'internlm' in path and 'xcomposer2' in path and '4khd' in path: - return 'internlm-xcomposer2-4khd' - - -@MODELS.register_module(name='baichuan-7b') -@MODELS.register_module(name='baichuan-base') -class Baichuan7B(BaseChatTemplate): - """Generation parameters of Baichuan-7B base model.""" - - def __init__(self, **kwargs): - super().__init__(**kwargs) - - -@MODELS.register_module(name='baichuan2-7b') @MODELS.register_module(name='baichuan2') -class Baichuan2_7B(BaseChatTemplate): +class Baichuan2(BaseChatTemplate): """Chat template and generation parameters of Baichuan2-7B-Base and Baichuan2-7B-Chat models.""" @@ -740,7 +665,7 @@ def match(cls, model_path: str) -> Optional[str]: return 'puyu' -@MODELS.register_module(name=['llama2', 'llama-2', 'llama-2-chat']) +@MODELS.register_module(name='llama2') class Llama2(BaseChatTemplate): """Chat template of LLaMA2 model.""" @@ -790,7 +715,6 @@ def __init__(self, user='<|start_header_id|>user<|end_header_id|>\n\n', eoh='<|eot_id|>', stop_words=['<|eot_id|>', '<|end_of_text|>'], - session_len=8192, **kwargs): super().__init__(system=system, meta_instruction=meta_instruction, @@ -800,7 +724,6 @@ def __init__(self, user=user, eoh=eoh, stop_words=stop_words, - session_len=session_len, **kwargs) def get_prompt(self, prompt, sequence_start=True): @@ -826,14 +749,123 @@ def match(cls, model_path: str) -> Optional[str]: return 'llama3' -@MODELS.register_module(name='qwen-14b') -@MODELS.register_module(name='qwen-7b') +@MODELS.register_module(name='llama3_1') +class Llama3_1(Llama3): + """Chat template of LLaMA3.1 model.""" + + def __init__( + self, + tools="""# Tool Instructions +- Always execute python code in messages that you share. +- When looking for real time information use relevant functions if available else fallback to brave_search + + + +You have access to the following functions: + +""", # noqa + eotools=""" + +If a you choose to call a function ONLY reply in the following format: +<{start_tag}={function_name}>{parameters}{end_tag} +where + +start_tag => ` a JSON dict with the function argument name as key and function argument value as value. +end_tag => `` + +Here is an example, +{"example_name": "example_value"} + +Reminder: +- Function calls MUST follow the specified format +- Required parameters MUST be specified +- Only call one function at a time +- Put the entire function call reply on one line" +- Always add your sources when using search results to answer the user query\n\n""", # noqa + knowledge='Cutting Knowledge Date: December 2023\nToday Date: 23 Jul 2024\n\n', + meta_instruction='You are a helpful assistant.', + ipython='<|start_header_id|>ipython<|end_header_id|>\n\n', + eoi='<|eot_id|>', + stop_words=['<|eot_id|>', '<|end_of_text|>', '<|eom_id|>'], + **kwargs): + super().__init__(meta_instruction=meta_instruction, + stop_words=stop_words, + **kwargs) + self.ipython = ipython + self.eoi = eoi + self.tools = tools + self.eotools = eotools + self.knowledge = knowledge + + def messages2prompt(self, + messages, + sequence_start=True, + tools=None, + **kwargs): + """Return the prompt that is concatenated with other elements in the + chat template. + + Args: + messages (str | List): user's input prompt + Returns: + str: the concatenated prompt + """ + if isinstance(messages, str): + return self.get_prompt(messages, sequence_start) + box_map = dict(user=self.user, + ipython=self.ipython, + assistant=self.assistant, + system=self.system) + eox_map = dict(user=self.eoh, + ipython=self.eoi, + assistant=self.eoa + self.separator, + system=self.eosys) + ret = '' + tool_prompt = '' + if tools is not None: + for tool in tools: + tool_prompt += "Use the function '{}' to: {}\n{}\n".format( + tool['name'], tool['description'], + json.dumps(tool, ensure_ascii=False)) + if self.meta_instruction is not None and sequence_start: + if len(messages) and messages[0]['role'] != 'system': + if tools is None: + ret += f'{self.system}{self.knowledge}{self.meta_instruction}{self.eosys}' + else: + ret += f'{self.system}{self.knowledge}{self.tools}{tool_prompt}{self.eotools}{self.meta_instruction}{self.eosys}' + for message in messages: + role = message['role'] + content = message['content'] + if role == 'assistant' and ('<|python_tag|>' in content + or '' in content): + ret += f'{box_map[role]}{content}<|eom_id|>' + elif role == 'system' and tools is not None: + ret += f'{box_map[role]}{self.tools}{tool_prompt}{self.eotools}{content}{eox_map[role]}' + else: + ret += f'{box_map[role]}{content}{eox_map[role]}' + ret += f'{self.assistant}' + if sequence_start and not isinstance(messages, str): + ret = '<|begin_of_text|>' + ret + return ret + + @classmethod + def match(cls, model_path: str) -> Optional[str]: + """Return the model_name that was registered to MODELS. + + Args: + model_path (str): the model path used for matching. + """ + if 'llama-3.1-' in model_path.lower( + ) or 'llama3.1-' in model_path.lower(): + return 'llama3_1' + + @MODELS.register_module(name='qwen') class Qwen7BChat(BaseChatTemplate): """Chat template for Qwen-7B-Chat.""" def __init__(self, - session_len=8192, system='<|im_start|>system\n', meta_instruction='You are a helpful assistant.', eosys='<|im_end|>\n', @@ -853,7 +885,6 @@ def __init__(self, eoa=eoa, separator=separator, stop_words=stop_words, - session_len=session_len, **kwargs) @classmethod @@ -872,12 +903,10 @@ class CodeLlama(Llama2): def __init__(self, meta_instruction='', - session_len=4096, suffix_first=False, stop_words=None, **kwargs): super().__init__(meta_instruction=meta_instruction, - session_len=session_len, stop_words=stop_words, **kwargs) caps = ['completion', 'infilling', 'chat', 'python'] @@ -885,7 +914,6 @@ def __init__(self, f'{self.capability} is not supported. ' \ f'The supported capabilities are: {caps}' self.meta_instruction = meta_instruction - self.session_len = session_len self.suffix_first = suffix_first self.stop_words = stop_words if self.capability == 'infilling': @@ -938,7 +966,6 @@ def match(cls, model_path: str) -> Optional[str]: return 'falcon' -@MODELS.register_module(name='chatglm2-6b') @MODELS.register_module(name='chatglm') class ChatGLM2(BaseModel): @@ -996,7 +1023,7 @@ def match(cls, model_path: str) -> Optional[str]: return 'chatglm' -@MODELS.register_module(name=['solar', 'solar-70b']) +@MODELS.register_module(name='solar') class SOLAR(BaseChatTemplate): """Chat template of SOLAR model. @@ -1010,7 +1037,6 @@ def __init__(self, eoh='\n\n', assistant='### Assistant:\n', meta_instruction='', - session_len=2048, **kwargs): super().__init__(**kwargs) self.system = system @@ -1019,7 +1045,6 @@ def __init__(self, self.eoh = eoh self.assistant = assistant self.meta_instruction = meta_instruction - self.session_len = session_len @classmethod def match(cls, model_path: str) -> Optional[str]: @@ -1032,8 +1057,7 @@ def match(cls, model_path: str) -> Optional[str]: return 'solar' -@MODELS.register_module(name='ultracm') -@MODELS.register_module(name='ultralm') +@MODELS.register_module(name=['ultracm', 'ultralm']) class UltraChat(BaseChatTemplate): """Template of UltraCM and UltraLM models. @@ -1052,7 +1076,6 @@ def __init__( eoa='', separator='\n', stop_words=[''], - session_len=2048, **kwargs): super().__init__(system=system, meta_instruction=meta_instruction, @@ -1063,7 +1086,6 @@ def __init__( eoa=eoa, separator=separator, stop_words=stop_words, - session_len=session_len, **kwargs) @classmethod @@ -1079,7 +1101,7 @@ def match(cls, model_path: str) -> Optional[str]: return 'ultralm' -@MODELS.register_module(name=['yi', 'yi-chat', 'yi-200k', 'yi-34b']) +@MODELS.register_module(name=['yi']) class Yi(BaseChatTemplate): """Chat template of Yi model.""" @@ -1118,7 +1140,6 @@ def match(cls, model_path: str) -> Optional[str]: @MODELS.register_module(name=['mistral', 'mixtral']) -@MODELS.register_module(name=['Mistral-7B-Instruct', 'Mixtral-8x7B-Instruct']) class MistralChat(BaseChatTemplate): """Template of Mistral and Mixtral Instruct models. @@ -1126,17 +1147,8 @@ class MistralChat(BaseChatTemplate): `https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1` """ - def __init__(self, - user='[INST] ', - eoh=' [/INST]', - eoa='', - session_len=2048, - **kwargs): - super().__init__(user=user, - eoh=eoh, - eoa=eoa, - session_len=session_len, - **kwargs) + def __init__(self, user='[INST] ', eoh=' [/INST]', eoa='', **kwargs): + super().__init__(user=user, eoh=eoh, eoa=eoa, **kwargs) @classmethod def match(cls, model_path: str) -> Optional[str]: @@ -1185,7 +1197,6 @@ def match(cls, model_path: str) -> Optional[str]: return 'gemma' -@MODELS.register_module(name=['deepseek-chat']) @MODELS.register_module(name=['deepseek']) class Deepseek(BaseChatTemplate): @@ -1227,13 +1238,11 @@ def __init__(self, eoh=' ', assistant=': ', eoa='', - session_len=4096, **kwargs): super().__init__(user=user, eoh=eoh, assistant=assistant, eoa=eoa, - session_len=session_len, **kwargs) def get_prompt(self, prompt, sequence_start=True): @@ -1265,7 +1274,6 @@ def __init__( eoh='\n\n', assistant='Assistant: ', eoa='<|end▁of▁sentence|>', - session_len=16384, **kwargs): super().__init__(meta_instruction=meta_instruction, eosys=eosys, @@ -1273,7 +1281,6 @@ def __init__( eoh=eoh, assistant=assistant, eoa=eoa, - session_len=session_len, **kwargs) def get_prompt(self, prompt, sequence_start=True): @@ -1300,7 +1307,6 @@ class DeepSeek(BaseChatTemplate): def __init__( self, - session_len=4096, system='', meta_instruction="""You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\n""", # noqa: E501 eosys='', @@ -1311,8 +1317,7 @@ def __init__( separator='\n', stop_words=['<|EOT|>'], **kwargs): - super().__init__(session_len=session_len, - system=system, + super().__init__(system=system, meta_instruction=meta_instruction, eosys=eosys, user=user, @@ -1424,8 +1429,7 @@ def match(cls, model_path: str) -> Optional[str]: return 'dbrx' -@MODELS.register_module(name=['internvl-zh-hermes2']) -@MODELS.register_module(name=['llava-chatml']) +@MODELS.register_module(name=['llava-chatml', 'internvl-zh-hermes2']) class ChatmlDirect(BaseChatTemplate): def __init__(self, @@ -1437,7 +1441,6 @@ def __init__(self, assistant='<|im_start|>assistant\n', eoa='<|im_end|>', separator='', - session_len=4096, **kwargs): super().__init__(system, meta_instruction=meta_instruction, @@ -1447,7 +1450,6 @@ def __init__(self, assistant=assistant, eoa=eoa, separator=separator, - session_len=session_len, **kwargs) @classmethod @@ -1600,10 +1602,12 @@ def __init__(self, system='<|system|>\n', user='<|user|>\n', assistant='<|assistant|>\n', + stop_words=['<|user|>', '<|endoftext|>', '<|observation|>'], **kwargs): super().__init__(system=system, user=user, assistant=assistant, + stop_words=stop_words, **kwargs) self.start = '[gMASK]' @@ -1619,6 +1623,45 @@ def match(cls, model_path: str) -> Optional[str]: return 'glm4' +@MODELS.register_module(name='codegeex4') +class CodeGeeX4Chat(BaseChatTemplate): + """Chat template of THUDM/codegeex4-all-9b model.""" + + def __init__( + self, + system='<|system|>\n', + meta_instruction='你是一位智能编程助手,你叫CodeGeeX。你会为用户回答关于编程、代码、计算机方面的任何问题,并提供格式规范、可以执行、准确安全的代码,并在必要时提供详细的解释。', + eosys='', + user='<|user|>\n', + eoh='', + assistant='<|assistant|>\n', + eoa='', + separator='', + stop_words=['<|endoftext|>', '<|user|>', '<|observation|>'], + **kwargs): + super().__init__(system=system, + meta_instruction=meta_instruction, + eosys=eosys, + user=user, + eoh=eoh, + assistant=assistant, + eoa=eoa, + separator=separator, + stop_words=stop_words, + **kwargs) + + @classmethod + def match(cls, model_path: str) -> Optional[str]: + """Return the model_name that was registered to MODELS. + + Args: + model_path (str): the model path used for matching. + """ + path = model_path.lower() + if 'codegeex4' in path: + return 'codegeex4' + + @MODELS.register_module(name='internvl-phi3') class InternVLPhi3(Phi3Instruct): """Chat template of InternVL Chat 4B model.""" diff --git a/lmdeploy/pytorch/backends/cuda/awq_modules.py b/lmdeploy/pytorch/backends/cuda/awq_modules.py index 367622a6e..f46dce6e9 100644 --- a/lmdeploy/pytorch/backends/cuda/awq_modules.py +++ b/lmdeploy/pytorch/backends/cuda/awq_modules.py @@ -20,7 +20,9 @@ def wq_gemm_forward( ): """wq gemm forward.""" from awq.modules.linear.gemm import awq_ext - from lmdeploy.pytorch.kernels.cuda.awq_kernels import awq_dequantize_weights + + from lmdeploy.pytorch.kernels.cuda.awq_kernels import \ + awq_dequantize_weights out_shape = x.shape[:-1] + (out_features, ) input_dtype = x.dtype if input_dtype != torch.float16: diff --git a/lmdeploy/pytorch/backends/cuda/graph_runner.py b/lmdeploy/pytorch/backends/cuda/graph_runner.py index c59d09a8a..8da9a83cb 100644 --- a/lmdeploy/pytorch/backends/cuda/graph_runner.py +++ b/lmdeploy/pytorch/backends/cuda/graph_runner.py @@ -215,7 +215,8 @@ def check_enable_graph(self): # TODO: should we enable cudagraph awq now? # hf_config = self.model_config.hf_config - # quantization_config = getattr(hf_config, 'quantization_config', dict()) + # quantization_config = getattr(hf_config, + # 'quantization_config', dict()) # quant_method = quantization_config.get('quant_method', None) # if quant_method == 'awq': # logger.warning('AWQ model does not support cuda graph yet.') diff --git a/lmdeploy/pytorch/backends/cuda/slora.py b/lmdeploy/pytorch/backends/cuda/slora.py index b1a0ba14d..d313d6cbd 100644 --- a/lmdeploy/pytorch/backends/cuda/slora.py +++ b/lmdeploy/pytorch/backends/cuda/slora.py @@ -213,7 +213,6 @@ def __gather_xa(xa): lora_out = lora_out.reshape(sliced_base.shape) sliced_base.add_(lora_out) output = base_output - return output def forward(self, diff --git a/lmdeploy/pytorch/backends/graph_runner.py b/lmdeploy/pytorch/backends/graph_runner.py index 14a471590..67796a2e0 100644 --- a/lmdeploy/pytorch/backends/graph_runner.py +++ b/lmdeploy/pytorch/backends/graph_runner.py @@ -16,6 +16,10 @@ def __call__(self, **kwargs): """call graph runner forward.""" return self.model(**kwargs) + def get_model(self): + """get model.""" + return self.model + def prepare_inputs_for_generation( self, past_key_values: List[List[torch.Tensor]], diff --git a/lmdeploy/pytorch/chat.py b/lmdeploy/pytorch/chat.py index eae278b6f..3b3f85b8d 100644 --- a/lmdeploy/pytorch/chat.py +++ b/lmdeploy/pytorch/chat.py @@ -4,9 +4,11 @@ import random from typing import List, Optional +from lmdeploy.archs import get_model_arch from lmdeploy.messages import EngineGenerationConfig, PytorchEngineConfig from lmdeploy.model import MODELS, ChatTemplateConfig, best_match_model from lmdeploy.tokenizer import DetokenizeState, Tokenizer +from lmdeploy.utils import _get_and_verify_max_len os.environ['TM_LOG_LEVEL'] = 'ERROR' @@ -93,6 +95,9 @@ def run_chat(model_path: str, model = MODELS.get(model_name)() stop_words = _stop_words(model.stop_words, tokenizer) + _, model_config = get_model_arch(model_path) + session_len = _get_and_verify_max_len(model_config, None) + while True: prompt = input_prompt(model_name) if prompt == 'exit': @@ -105,9 +110,6 @@ def run_chat(model_path: str, else: prompt = model.get_prompt(prompt, nth_round == 1) input_ids = tokenizer.encode(prompt, nth_round == 1) - session_len = model.session_len - if session_len is None: - session_len = tm_model.session_len if step >= session_len: print('WARNING: exceed session max length.' ' Please end the session.') @@ -143,7 +145,6 @@ def main(model_path: str, temperature: float = 0.8, repetition_penalty: float = 1.0, tp: int = 1, - stream_output: bool = True, adapter: str = None, trust_remote_code: bool = True, chat_template: str = None): @@ -159,7 +160,6 @@ def main(model_path: str, temperature (float): sampling temperature. repetition_penalty (float): parameter to penalize repetition tp (int): GPU number used in tensor parallelism - stream_output (bool): indicator for streaming output or not adapter (str): path to lora adapter. trust_remote_code (bool): Trust remote code. chat_template (str): A JSON file or string that specifies the diff --git a/lmdeploy/pytorch/check_env/__init__.py b/lmdeploy/pytorch/check_env/__init__.py index 1c49fecdf..7cfe82b16 100644 --- a/lmdeploy/pytorch/check_env/__init__.py +++ b/lmdeploy/pytorch/check_env/__init__.py @@ -87,6 +87,27 @@ def check_env(device_type: str): MAX_TRANSFORMERS_VERSION = '4.41.2' +def check_awq(hf_config): + """check awq support.""" + logger = get_logger('lmdeploy') + quantization_config = getattr(hf_config, 'quantization_config', dict()) + quant_method = quantization_config.get('quant_method', None) + if quant_method != 'awq': + return + try: + import awq # noqa + except Exception as e: + _handle_exception(e, 'autoawq', logger) + + try: + import awq_ext # noqa + except Exception: + logger.debug('Exception:', exc_info=1) + logger.warning('Failed to import `awq_ext`. ' + 'Try reinstall it from source: ' + 'https://github.com/casper-hansen/AutoAWQ_kernels') + + def check_transformers_version(model_path: str, trust_remote_code: bool = True): """check transformers version.""" @@ -172,6 +193,7 @@ def __check_model_dtype_support(config): config = __check_config(trans_version) __check_model_transformers_version(config, trans_version) __check_model_dtype_support(config) + check_awq(config) def check_model(model_path: str, trust_remote_code: bool = True): diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index de437b249..d0d7ea368 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -16,6 +16,15 @@ def _update_torch_dtype(config: 'ModelConfig', default: str = 'float16'): from lmdeploy.utils import get_logger logger = get_logger('lmdeploy') + quantization_config = getattr(config.hf_config, 'quantization_config', + dict()) + quant_method = quantization_config.get('quant_method', None) + if quant_method == 'awq': + logger.debug('set torch_dtype to float16 for awq.') + config.hf_config.torch_dtype = 'float16' + config.dtype = torch.float16 + return config + torch_dtype = getattr(config.hf_config, 'torch_dtype', None) if torch_dtype is None: logger.warning('Model config does not have `torch_dtype`,' diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index bd95dc4b1..cc9d29e93 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -157,6 +157,20 @@ def model_forward( return dict(logits=output) +def _get_indexed_lora_linears(model): + """get indexed lora linears.""" + if hasattr(model, 'get_model'): + model = model.get_model() + return get_indexed_lora_linears(model) + + +def _get_loralinear_info(model): + """get lora linear info.""" + if hasattr(model, 'get_model'): + model = model.get_model() + return get_loralinear_info(model) + + def _load_adapters(hf_model: torch.nn.Module, adapters: Dict[str, str], device_map: str = 'cpu'): @@ -357,7 +371,7 @@ def _build_model(self, def get_loralinear_info(self): """get lora linear info.""" - return get_loralinear_info(self.patched_model) + return _get_loralinear_info(self.patched_model) def get_block_numel(self): """get block nelement.""" @@ -367,7 +381,7 @@ def get_block_numel(self): def paging_adapters(self, weight_maps: List[AdapterWeightMap]): """paging adapter.""" logger.info('paging adapters.') - lora_linears = get_indexed_lora_linears(self.patched_model) + lora_linears = _get_indexed_lora_linears(self.patched_model) cpu_caches = self.cache_engine.cpu_cache num_blocks = self.cache_engine.num_cpu_blocks cpu_caches = [(kcache.view(num_blocks, @@ -493,13 +507,12 @@ def _broadcast_config(cache_config): return config_list[0] try: - config = model_config.hf_config torch_dtype = model_config.dtype device_map = None with init_empty_weights(), warnings.catch_warnings(): warnings.simplefilter('ignore') - model = model_config.auto_model_cls.from_config( - config, + model = model_config.auto_model_cls.from_pretrained( + model_path, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code, **model_config.init_kwargs) @@ -507,10 +520,10 @@ def _broadcast_config(cache_config): _remove_unused_modules(model, model_config) if rank == 0: device_map = _create_device_map(model, world_size) - _add_adapters(model, adapters) - if rank == 0: - # adapter would remove weight of linear. - device_map = _create_device_map(model, world_size, device_map) + _add_adapters(model, adapters) + if rank == 0: + # adapter would remove weight of linear. + device_map = _create_device_map(model, world_size, device_map) model.eval() model.config.use_cache = True @@ -582,7 +595,7 @@ def __get_weight_map(): def __paging(weight_maps): """paging.""" - lora_linears = get_indexed_lora_linears(patched_model) + lora_linears = _get_indexed_lora_linears(patched_model) cpu_caches = cache_engine.cpu_cache num_blocks = cache_engine.num_cpu_blocks cpu_caches = [(kcache.view(num_blocks, @@ -862,7 +875,7 @@ def _build_model( def get_loralinear_info(self): """get lora linear info.""" - return get_loralinear_info(self.patched_model) + return _get_loralinear_info(self.patched_model) def get_block_numel(self): """get block nelement.""" diff --git a/lmdeploy/pytorch/kernels/cuda/awq_kernels.py b/lmdeploy/pytorch/kernels/cuda/awq_kernels.py index 5ecd93746..9475e93c6 100644 --- a/lmdeploy/pytorch/kernels/cuda/awq_kernels.py +++ b/lmdeploy/pytorch/kernels/cuda/awq_kernels.py @@ -1,7 +1,8 @@ +# Copyright (c) OpenMMLab. All rights reserved. import triton import triton.language as tl from torch import Tensor -import torch + from .triton_utils import get_kernel_meta, wrap_jit_func @@ -12,6 +13,7 @@ def _get_unpacked_order(offs_n, elem_per_int): unpacked_order = (origin_order % 2) * 4 + origin_order // 2 return unpacked_order + @triton.jit def _unpack_weight(weight, order): """unpack weight.""" @@ -23,7 +25,10 @@ def _unpack_weight(weight, order): @wrap_jit_func @triton.jit def _dequantize_weights_kernel( - QWeight, Scales, Qzeros, Out, + QWeight, + Scales, + Qzeros, + Out, in_size: tl.constexpr, out_size: tl.constexpr, group_size: tl.constexpr, @@ -38,7 +43,7 @@ def _dequantize_weights_kernel( stride_on: tl.constexpr, BLOCK_K: tl.constexpr, BLOCK_N: tl.constexpr, - ): +): """dequantize weight kernel.""" k_block_id = tl.program_id(0) n_block_id = tl.program_id(1) @@ -56,10 +61,13 @@ def _dequantize_weights_kernel( mask = mask_k[:, None] and mask_n[None, :] unpacked_order = _get_unpacked_order(offs_n, elem_per_int) - qw_ptr = QWeight + offs_wk[:, None] * stride_wk + offs_wn[None, :] * stride_wn - s_ptr = Scales + offs_sk[:, None] * stride_sk + offs_sn[None, :] * stride_sn - qz_ptr = Qzeros + offs_zk[:, None] * stride_zk + offs_zn[None, :] * stride_zn - + qw_ptr = QWeight + offs_wk[:, + None] * stride_wk + offs_wn[None, :] * stride_wn + s_ptr = Scales + offs_sk[:, + None] * stride_sk + offs_sn[None, :] * stride_sn + qz_ptr = Qzeros + offs_zk[:, + None] * stride_zk + offs_zn[None, :] * stride_zn + out_dtype = Out.dtype.element_ty qw = tl.load(qw_ptr, mask=mask) s = tl.load(s_ptr, mask=mask).to(out_dtype) @@ -88,8 +96,14 @@ def awq_dequantize_weights(qweight: Tensor, scales: Tensor, qzeros: Tensor): kernel_meta = get_kernel_meta(qweight) grid = (triton.cdiv(in_size, BLOCK_K), triton.cdiv(out_size, BLOCK_N)) _dequantize_weights_kernel[grid]( - qweight, scales, qzeros, output, - in_size, out_size, group_size, elem_per_int, + qweight, + scales, + qzeros, + output, + in_size, + out_size, + group_size, + elem_per_int, qweight.stride(0), qweight.stride(1), scales.stride(0), diff --git a/lmdeploy/pytorch/kernels/cuda/pagedattention.py b/lmdeploy/pytorch/kernels/cuda/pagedattention.py index 5adecdeb4..722eae6fc 100644 --- a/lmdeploy/pytorch/kernels/cuda/pagedattention.py +++ b/lmdeploy/pytorch/kernels/cuda/pagedattention.py @@ -91,6 +91,7 @@ def _fwd_split_kernel( head_size: tl.constexpr, head_size_v: tl.constexpr, shared_kv: tl.constexpr, + logit_softcapping: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_DV: tl.constexpr, BLOCK_N: tl.constexpr, @@ -172,6 +173,10 @@ def _fwd_split_kernel( qk = tl.sum(q[None, :] * k, 1) qk *= sm_scale + if logit_softcapping > 0.0: + qk = qk / logit_softcapping + qk = tl.math.tanh(qk) + qk = qk * logit_softcapping # NOTE: inf - inf = nan, and nan will leads to error qk_mask = history_len >= (start_n + offs_n) if window_size > 0: @@ -280,6 +285,7 @@ def _fwd_grouped_split_kernel( head_size_v: tl.constexpr, num_heads_q: tl.constexpr, shared_kv: tl.constexpr, + logit_softcapping: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_DV: tl.constexpr, BLOCK_N: tl.constexpr, @@ -386,6 +392,10 @@ def _fwd_grouped_split_kernel( if BLOCK_DMODEL1 != 0: qk += tl.dot(q1, k1) qk *= sm_scale + if logit_softcapping > 0.0: + qk = qk / logit_softcapping + qk = tl.math.tanh(qk) + qk = qk * logit_softcapping # NOTE: inf - inf = nan, and nan will leads to error qk_mask = history_len >= (start_n + offs_n) if window_size > 0: @@ -545,6 +555,7 @@ def _fwd_kernel( head_size: tl.constexpr, head_size_v: tl.constexpr, shared_kv: tl.constexpr, + logit_softcapping: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_DV: tl.constexpr, @@ -645,6 +656,10 @@ def _fwd_kernel( if BLOCK_DMODEL1 != 0: qk += tl.dot(q1, k1) qk *= sm_scale + if logit_softcapping > 0.0: + qk = qk / logit_softcapping + qk = tl.math.tanh(qk) + qk = qk * logit_softcapping # NOTE: inf - inf = nan, and nan will leads to error qk_mask = (history_len + offs_m[:, None]) >= (start_n + offs_n[None, :]) @@ -695,6 +710,7 @@ def paged_attention_fwd( max_seqlen: int, window_size: int = None, sm_scale: float = None, + logit_softcapping: float = None, shared_kv: bool = False, ): """Paged Attention forward. @@ -719,6 +735,9 @@ def paged_attention_fwd( if window_size is None: window_size = -1 + if logit_softcapping is None: + logit_softcapping = -1.0 + def _get_block_d(Lk): """get block d.""" BLOCK_DMODEL = triton.next_power_of_2(Lk) @@ -783,6 +802,7 @@ def _get_block_d(Lk): head_size=Lk, head_size_v=Lv, shared_kv=shared_kv, + logit_softcapping=logit_softcapping, BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_DV=BLOCK_DV, @@ -830,6 +850,7 @@ def _get_block_d(Lk): head_size=Lk, head_size_v=Lv, shared_kv=shared_kv, + logit_softcapping=logit_softcapping, BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_DV=BLOCK_DV, BLOCK_N=BLOCK, @@ -869,6 +890,7 @@ def _get_block_d(Lk): head_size_v=Lv, num_heads_q=head, shared_kv=shared_kv, + logit_softcapping=logit_softcapping, BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_DV=BLOCK_DV, BLOCK_N=BLOCK, diff --git a/lmdeploy/pytorch/layers/linear.py b/lmdeploy/pytorch/layers/linear.py index 64006580e..c7a35c4a4 100644 --- a/lmdeploy/pytorch/layers/linear.py +++ b/lmdeploy/pytorch/layers/linear.py @@ -157,15 +157,15 @@ def __init__(self, def forward(self, x): """forward of linear layer.""" + is_tp = False if self.colwise else self.is_tp if self.lora_adapters is None: - is_tp = False if self.colwise else self.is_tp return self.impl.forward(x, is_tp) out = self.impl.forward(x, False) if self.lora_adapters is not None: for lora_adapter in self.lora_adapters: out = lora_adapter(x, out) - if self.is_tp: + if is_tp: dist.all_reduce(out) return out diff --git a/lmdeploy/pytorch/models/awq_modules.py b/lmdeploy/pytorch/models/awq_modules.py new file mode 100644 index 000000000..19f800bc9 --- /dev/null +++ b/lmdeploy/pytorch/models/awq_modules.py @@ -0,0 +1,67 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch import nn + + +def wq_gemm_forward( + x, + qweight, + qzeros, + scales, + w_bit=4, + group_size=128, + bias=None, + out_features=0, +): + """wq gemm forward.""" + from awq.modules.linear.gemm import AWQ_INSTALLED, dequantize_gemm + out_shape = x.shape[:-1] + (out_features, ) + input_dtype = x.dtype + if input_dtype != torch.float16: + x = x.half() + + if AWQ_INSTALLED: + from awq.modules.linear.gemm import awq_ext + FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0] * x.shape[1] >= 1024 + + if FP16_MATMUL_HEURISTIC_CONDITION: + # TODO: remove event wait if awq kernel set stream + default_stream = torch.cuda.default_stream() + event_def = torch.cuda.Event() + event_def.record() + event_def.wait(default_stream) + out = awq_ext.dequantize_weights_cuda(qweight, scales, qzeros, 0, + 0, 0, False) + event_def = torch.cuda.Event() + event_def.record(default_stream) + event_def.wait() + out = torch.matmul(x, out) + else: + x = x.flatten(0, -2) + if not x.is_contiguous(): + x = x.contiguous() + out = awq_ext.gemm_forward_cuda(x, qweight, scales, qzeros, 8) + else: + out = dequantize_gemm(qweight, qzeros, scales, w_bit, group_size) + out = torch.matmul(x, out) + + out = out + bias if bias is not None else out + out = out.reshape(out_shape) + + # always want 3D tensor if tensor is 2D + if len(out.shape) == 2: + out = out.unsqueeze(0) + + if input_dtype != torch.float16: + out = out.to(dtype=input_dtype) + return out + + +class PatchedWQLinear_GEMM(nn.Module): + + def forward(self, x): + """forward.""" + out_features = self.scales.size(1) + return wq_gemm_forward(x, self.qweight, self.qzeros, self.scales, + self.w_bit, self.group_size, self.bias, + out_features) diff --git a/lmdeploy/pytorch/models/gemma.py b/lmdeploy/pytorch/models/gemma.py index 9bf78b87a..c7bbbf1af 100644 --- a/lmdeploy/pytorch/models/gemma.py +++ b/lmdeploy/pytorch/models/gemma.py @@ -124,6 +124,8 @@ def __rotary_emb_fn(query_states, key_states, value_states): block_offsets=block_offsets, ) + logit_softcapping = getattr(self.config, 'attn_logit_softcapping', + None) window_size = getattr(self, 'sliding_window', None) sm_scale = getattr(self, 'scaling', None) attn_output = query_states @@ -139,6 +141,7 @@ def __rotary_emb_fn(query_states, key_states, value_states): max_seqlen=max_q_seq_length, window_size=window_size, sm_scale=sm_scale, + logit_softcapping=logit_softcapping, ) attn_output = attn_output.reshape(*hidden_states.shape[:-1], hidden_size) diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index bc9d32541..91ea7071e 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -217,6 +217,8 @@ # peft MODULE_MAP.update({ 'peft.tuners.lora.layer.Linear': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.peft.LoRALinear', + 'peft.tuners.lora.awq.AwqLoraLinear': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.peft.LoRALinear' }) @@ -338,6 +340,12 @@ f'{LMDEPLOY_PYTORCH_MODEL_PATH}.internvl.PatchedInternVLChatModel' }) +# awq +MODULE_MAP.update({ + 'awq.modules.linear.gemm.WQLinear_GEMM': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.awq_modules.PatchedWQLinear_GEMM' +}) + # phi3 vision MODULE_MAP.update({ 'modeling_phi3_v.Phi3Attention': diff --git a/lmdeploy/pytorch/models/peft.py b/lmdeploy/pytorch/models/peft.py index cf283ed2d..ed414c179 100644 --- a/lmdeploy/pytorch/models/peft.py +++ b/lmdeploy/pytorch/models/peft.py @@ -60,7 +60,10 @@ def _lora_forward_local(self, x): lora_input = self._make_packed_lora_input(x) - out_size = self.base_layer.weight.size(0) + if hasattr(self.base_layer, 'weight'): + out_size = self.base_layer.weight.size(0) + else: + out_size = self.base_layer.scales.size(1) if not lora_input.is_decoding: xa = mbgmm_a(lora_input.x, lora_input.a_cache, @@ -110,7 +113,11 @@ def _lora_forward_tp_rowwise(self, x): lora_input = self._make_packed_lora_input(x) rank = dist.get_rank() world_size = dist.get_world_size() - out_size = self.base_layer.weight.size(0) // world_size + if hasattr(self.base_layer, 'weight'): + out_size = self.base_layer.weight.size(0) + else: + out_size = self.base_layer.scales.size(1) + out_size = out_size // world_size if not lora_input.is_decoding: xa = mbgmm_a(lora_input.x, lora_input.a_cache, @@ -172,7 +179,10 @@ def __gather_xa(xa): lora_input = self._make_packed_lora_input(x) world_size = dist.get_world_size() - out_size = self.base_layer.weight.size(0) + if hasattr(self.base_layer, 'weight'): + out_size = self.base_layer.weight.size(0) + else: + out_size = self.base_layer.scales.size(1) if not lora_input.is_decoding: xa = mbgmm_a(lora_input.x, lora_input.a_cache, diff --git a/lmdeploy/pytorch/weight_loader/dist_utils.py b/lmdeploy/pytorch/weight_loader/dist_utils.py index 27177ff23..bad054e00 100644 --- a/lmdeploy/pytorch/weight_loader/dist_utils.py +++ b/lmdeploy/pytorch/weight_loader/dist_utils.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import math from typing import List import torch @@ -8,15 +9,43 @@ from .model_weight_loader import ModelWeightLoader +logger = get_logger('lmdeploy') + try: from peft.tuners.lora import Linear as LoRALinear except ImportError: + logger.debug('load peft.tuner.lora.Linear failed.') class LoRALinear: pass -logger = get_logger('lmdeploy') +try: + from peft.tuners.lora.awq import AwqLoraLinear +except ImportError: + logger.debug('load peft.tuners.lora.awq.AwqLoraLinear failed.') + + class AwqLoraLinear: + pass + + +try: + from awq.modules.linear.gemm import WQLinear_GEMM +except ImportError: + logger.debug('load awq.modules.linear.gemm.WQLinearGEMM failed.') + + class WQLinear_GEMM: + pass + + +def _div_up(a, b): + """div up.""" + return (a + b - 1) // b + + +def _math_lcm(*args): + """lcm.""" + return int(math.prod(args) / math.gcd(*args)) def get_prefixed_name(name: str, prefix: str): @@ -64,11 +93,15 @@ def colwise_parallelize_loralinear(module: torch.nn.Module, world_size: int, prefix: str = ''): """colwise parallelize loralinear.""" - colwise_parallelize_linear_naive(module.base_layer, - loader, - rank=rank, - world_size=world_size, - prefix=prefix) + if isinstance(module.base_layer, WQLinear_GEMM): + parallel_base_func = colwise_parallelize_wqlinear + else: + parallel_base_func = colwise_parallelize_linear_naive + parallel_base_func(module.base_layer, + loader, + rank=rank, + world_size=world_size, + prefix=prefix) for key, mod in module.lora_A.items(): ada_loader = loader.adapter(key) colwise_parallelize_linear_naive(mod, @@ -88,6 +121,57 @@ def colwise_parallelize_loralinear(module: torch.nn.Module, module._tp_mode = 'colwise' +def _get_split_size_with_align(size: int, align: int, num_chunk: int): + """get split size with align.""" + assert size % align == 0 + num_aligned = size // align + split_size = _div_up(num_aligned, num_chunk) * align + return split_size + + +def colwise_parallelize_wqlinear(mod: torch.nn.Module, + loader: ModelWeightLoader, + rank: int, + world_size: int, + prefix: str = ''): + """colwise parallelize wqlinear.""" + elem_per_word = 32 // mod.w_bit + group_size = mod.group_size + lcm = _math_lcm(elem_per_word, group_size) + num_out = mod.scales.size(1) + + split_size = _get_split_size_with_align(num_out, lcm, world_size) + qsplit_size = split_size // elem_per_word + + def __update_param(name, param): + """update_param.""" + dtype = param.dtype + prefixed_name = get_prefixed_name(name, prefix) + if name == 'bias': + ssize = split_size + dim = 0 + elif name == 'scales': + ssize = split_size + dim = 1 + else: + ssize = qsplit_size + dim = 1 + param = loader.pop(prefixed_name) + param = param.split(ssize, dim)[rank] + param = cast_dtype(param, dtype) + return param + + for name, param in mod.named_parameters(): + param = __update_param(name, param) + param = torch.nn.Parameter(param, requires_grad=False) + mod.register_parameter(name, param) + for name, param in mod.named_buffers(): + param = __update_param(name, param) + mod.register_buffer(name, param) + mod.in_features = mod.qweight.size(0) + mod.out_features = mod.scales.size(1) + + def colwise_parallelize_linear(module: torch.nn.Module, loader: ModelWeightLoader, rank: int, @@ -100,12 +184,18 @@ def colwise_parallelize_linear(module: torch.nn.Module, rank=rank, world_size=world_size, prefix=prefix) - elif isinstance(module, LoRALinear): + elif isinstance(module, (LoRALinear, AwqLoraLinear)): return colwise_parallelize_loralinear(module, loader, rank=rank, world_size=world_size, prefix=prefix) + elif isinstance(module, WQLinear_GEMM): + return colwise_parallelize_wqlinear(module, + loader, + rank=rank, + world_size=world_size, + prefix=prefix) else: raise TypeError(f'Unsupported module: {type(module)}') @@ -144,11 +234,15 @@ def rowwise_parallelize_loralinear(module: LoRALinear, world_size: int, prefix: str = ''): """colwise parallelize loralinear.""" - rowwise_parallelize_linear_naive(module.base_layer, - loader, - rank=rank, - world_size=world_size, - prefix=prefix) + if isinstance(module.base_layer, WQLinear_GEMM): + parallel_base_func = rowwise_parallelize_wqlinear + else: + parallel_base_func = rowwise_parallelize_linear_naive + parallel_base_func(module.base_layer, + loader, + rank=rank, + world_size=world_size, + prefix=prefix) for key, mod in module.lora_A.items(): ada_loader = loader.adapter(key) rowwise_parallelize_linear_naive(mod, @@ -168,6 +262,45 @@ def rowwise_parallelize_loralinear(module: LoRALinear, module._tp_mode = 'colwise' +def rowwise_parallelize_wqlinear(mod: torch.nn.Module, + loader: ModelWeightLoader, + rank: int, + world_size: int, + prefix: str = ''): + """rowwise parallelize linear.""" + elem_per_word = 32 // mod.w_bit + group_size = mod.group_size + lcm = _math_lcm(elem_per_word, group_size) + num_in = mod.qweight.size(0) + + split_size = _get_split_size_with_align(num_in, lcm, world_size) + qsplit_size = split_size // group_size + + def __update_param(name: str, param: torch.Tensor): + """update_param.""" + dtype = param.dtype + prefixed_name = get_prefixed_name(name, prefix) + param = loader.pop(prefixed_name) + if name == 'bias': + param /= world_size + elif name == 'qweight': + param = param.split(split_size)[rank] + else: + param = param.split(qsplit_size)[rank] + param = cast_dtype(param, dtype) + return param + + for name, param in mod.named_parameters(): + param = __update_param(name, param) + param = torch.nn.Parameter(param, requires_grad=False) + mod.register_parameter(name, param) + for name, param in mod.named_buffers(): + param = __update_param(name, param) + mod.register_buffer(name, param) + mod.in_features = mod.qweight.size(0) + mod.out_features = mod.scales.size(1) + + def rowwise_parallelize_linear(module: torch.nn.Module, loader: ModelWeightLoader, rank: int, @@ -180,12 +313,18 @@ def rowwise_parallelize_linear(module: torch.nn.Module, rank=rank, world_size=world_size, prefix=prefix) - elif isinstance(module, LoRALinear): + elif isinstance(module, (LoRALinear, AwqLoraLinear)): return rowwise_parallelize_loralinear(module, loader, rank=rank, world_size=world_size, prefix=prefix) + elif isinstance(module, WQLinear_GEMM): + return rowwise_parallelize_wqlinear(module, + loader, + rank=rank, + world_size=world_size, + prefix=prefix) else: raise TypeError(f'Unsupported module: {type(module)}') @@ -227,12 +366,16 @@ def colwise_split_parallelize_loralinear(module: LoRALinear, world_size: int, prefix: str = ''): """colwise split loralinear.""" - colwise_split_parallelize_linear_naive(module.base_layer, - sections, - loader, - rank=rank, - world_size=world_size, - prefix=prefix) + if isinstance(module.base_layer, WQLinear_GEMM): + parallel_base_func = colwise_split_parallelize_wqlinear + else: + parallel_base_func = colwise_split_parallelize_linear_naive + parallel_base_func(module.base_layer, + sections, + loader, + rank=rank, + world_size=world_size, + prefix=prefix) for key, mod in module.lora_A.items(): ada_loader = loader.adapter(key) colwise_parallelize_linear_naive(mod, @@ -253,6 +396,59 @@ def colwise_split_parallelize_loralinear(module: LoRALinear, module._tp_mode = 'colwise' +def colwise_split_parallelize_wqlinear(module: torch.nn.Module, + sections: List[int], + loader: ModelWeightLoader, + rank: int, + world_size: int, + prefix: str = ''): + """colwise split wqlinear.""" + elem_per_word = 32 // module.w_bit + group_size = module.group_size + lcm = _math_lcm(elem_per_word, group_size) + + for s in sections: + assert s % lcm == 0 + + def __update_param(name: str, param: torch.Tensor): + dtype = param.dtype + prefixed_name = get_prefixed_name(name, prefix) + param = loader.pop(prefixed_name) + if name == 'bias': + dim = 0 + sec = sections + elif name == 'scales': + dim = 1 + sec = sections + else: + dim = 1 + sec = [s // elem_per_word for s in sections] + splited_param = param.split(sec, dim=dim) + updated_param = [] + for p in splited_param: + if name == 'bias': + p = p.chunk(world_size)[rank] + else: + p = p.chunk(world_size, 1)[rank] + p = cast_dtype(p, dtype) + updated_param.append(p) + if name == 'bias': + param = torch.cat(updated_param) + else: + param = torch.cat(updated_param, 1) + return param + + for name, param in module.named_parameters(): + param = __update_param(name, param) + param = torch.nn.Parameter(param, requires_grad=False) + module.register_parameter(name, param) + for name, param in module.named_buffers(): + param = __update_param(name, param) + module.register_buffer(name, param) + module.in_features = module.qweight.size(0) + module.out_features = module.scales.size(1) + + def colwise_split_parallelize_linear(module: torch.nn.Module, sections: List[int], loader: ModelWeightLoader, @@ -267,13 +463,20 @@ def colwise_split_parallelize_linear(module: torch.nn.Module, rank=rank, world_size=world_size, prefix=prefix) - elif isinstance(module, LoRALinear): + elif isinstance(module, (LoRALinear, AwqLoraLinear)): return colwise_split_parallelize_loralinear(module, sections, loader, rank=rank, world_size=world_size, prefix=prefix) + elif isinstance(module, WQLinear_GEMM): + return colwise_split_parallelize_wqlinear(module, + sections, + loader, + rank=rank, + world_size=world_size, + prefix=prefix) else: raise TypeError(f'Unsupported module: {type(module)}') @@ -315,9 +518,9 @@ def default_load_linear(module: torch.nn.Module, rank: int = 0, prefix: str = ''): """default load linear.""" - if isinstance(module, (torch.nn.Linear, QLinear)): + if isinstance(module, (torch.nn.Linear, QLinear, WQLinear_GEMM)): load_no_recursive(module, loader, rank=rank, prefix=prefix) - elif isinstance(module, LoRALinear): + elif isinstance(module, (LoRALinear, AwqLoraLinear)): raise NotImplementedError('Not implemented, please contact us.') else: raise TypeError(f'Unsupported module: {type(module)}') diff --git a/lmdeploy/serve/async_engine.py b/lmdeploy/serve/async_engine.py index 831ea29f7..0dea3d17c 100644 --- a/lmdeploy/serve/async_engine.py +++ b/lmdeploy/serve/async_engine.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import asyncio import dataclasses +import json import os import random from contextlib import asynccontextmanager @@ -566,6 +567,10 @@ async def generate( # set random if it is not set and sequence_start is True if gen_config.random_seed is None and sequence_start: gen_config.random_seed = random.getrandbits(64) + if gen_config.n > 1: + logger.warning(f"n({gen_config.n}) > 1 hasn't been supported yet. " + f'Fallback to 1') + gen_config.n = 1 prompt = messages prompt_input = await self._get_prompt_input(prompt, @@ -661,6 +666,28 @@ async def generate( if self.backend == 'pytorch' and sequence_end: await self.end_session(session_id) + def parse_tool_response(self, text, tools, **kwargs): + """Parse model response containing tool information. + + Args: + text(str): model response in string format + tools(List): tools from user request + """ + if '<|plugin|>' in text: # internlm2 + text, action = text.split('<|action_start|><|plugin|>') + action = action.split('<|action_end|>'.strip())[0] + action = action[action.find('{'):] + action = json.loads(action) + name, parameters = action['name'], json.dumps(action['parameters']) + elif '') + parameters = action[action.find('{'):] + name = action.split('{')[0] + else: + raise RuntimeError(f'Unexpected model response: {text}') + action_id = [tool.function.name for tool in tools].index(name) + return text, action_id, name, parameters + def chat(self, prompt: str, session=None, diff --git a/lmdeploy/serve/client.py b/lmdeploy/serve/client.py deleted file mode 100644 index 424e83143..000000000 --- a/lmdeploy/serve/client.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import os - -from lmdeploy.serve.turbomind.chatbot import Chatbot - - -def input_prompt(model_name): - """Input a prompt in the consolo interface.""" - if model_name == 'codellama': - print('\nenter !! to end the input >>>\n', end='') - sentinel = '!!' - else: - print('\ndouble enter to end input >>> ', end='') - sentinel = '' # ends when this string is seen - return '\n'.join(iter(input, sentinel)) - - -def main(tritonserver_addr: str, - session_id: int = 1, - cap: str = 'chat', - stream_output: bool = True, - **kwargs): - """An example to communicate with inference server through the command line - interface. - - Args: - tritonserver_addr (str): the address in format "ip:port" of - triton inference server - session_id (int): the identical id of a session - cap (str): the capability of a model. For example, codellama has - the ability among ['completion', 'infill', 'instruct', 'python'] - stream_output (bool): indicator for streaming output or not - **kwargs (dict): other arguments for initializing model's chat template - """ - log_level = os.environ.get('SERVICE_LOG_LEVEL', 'WARNING') - kwargs.update(capability=cap) - chatbot = Chatbot(tritonserver_addr, - log_level=log_level, - display=stream_output, - **kwargs) - nth_round = 1 - while True: - prompt = input_prompt(chatbot.model_name) - if prompt == 'exit': - exit(0) - elif prompt == 'end': - chatbot.end(session_id) - else: - request_id = f'{session_id}-{nth_round}' - if stream_output: - for status, res, n_token in chatbot.stream_infer( - session_id, - prompt, - request_id=request_id, - request_output_len=512): - continue - else: - status, res, n_token = chatbot.infer(session_id, - prompt, - request_id=request_id, - request_output_len=512) - print(res) - nth_round += 1 - - -if __name__ == '__main__': - import fire - - fire.Fire(main) diff --git a/lmdeploy/serve/gradio/app.py b/lmdeploy/serve/gradio/app.py index a3dff287c..a3cd00351 100644 --- a/lmdeploy/serve/gradio/app.py +++ b/lmdeploy/serve/gradio/app.py @@ -21,10 +21,9 @@ def run(model_path_or_server: str, """chat with AI assistant through web ui. Args: - model_path_or_server (str): the path of the deployed model or the - tritonserver URL or restful api URL. For example: - - ./workspace - - 0.0.0.0:23333 + model_path_or_server (str): the path of the deployed model or + restful api URL. For example: + - huggingface hub repo_id - http://0.0.0.0:23333 server_name (str): the ip address of gradio server server_port (int): the port of gradio server @@ -39,17 +38,12 @@ def run(model_path_or_server: str, share (bool): whether to create a publicly shareable link for the app """ if ':' in model_path_or_server: - if 'http:' in model_path_or_server: - from lmdeploy.serve.gradio.api_server_backend import run_api_server - run_api_server(model_path_or_server, - server_name, - server_port, - batch_size, - share=share) - else: - from lmdeploy.serve.gradio.triton_server_backend import \ - run_triton_server - run_triton_server(model_path_or_server, server_name, server_port) + from lmdeploy.serve.gradio.api_server_backend import run_api_server + run_api_server(model_path_or_server, + server_name, + server_port, + batch_size, + share=share) else: pipeline_type, _ = get_task(model_path_or_server) if pipeline_type == 'vlm': diff --git a/lmdeploy/serve/gradio/triton_server_backend.py b/lmdeploy/serve/gradio/triton_server_backend.py deleted file mode 100644 index ef7ed8c3f..000000000 --- a/lmdeploy/serve/gradio/triton_server_backend.py +++ /dev/null @@ -1,161 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import os -from functools import partial -from threading import Lock -from typing import Sequence - -import gradio as gr - -from lmdeploy.serve.gradio.constants import CSS, THEME, disable_btn, enable_btn -from lmdeploy.serve.turbomind.chatbot import Chatbot - - -class InterFace: - global_session_id: int = 0 - lock = Lock() - - -def chat_stream(state_chatbot: Sequence, llama_chatbot: Chatbot, - cancel_btn: gr.Button, reset_btn: gr.Button, session_id: int, - top_p: float, temperature: float, request_output_len: int): - """Chat with AI assistant. - - Args: - instruction (str): user's prompt - state_chatbot (Sequence): the chatting history - llama_chatbot (Chatbot): the instance of a chatbot - cancel_btn (bool): enable the cancel button or not - reset_btn (bool): enable the reset button or not - session_id (int): the session id - """ - instruction = state_chatbot[-1][0] - - bot_response = llama_chatbot.stream_infer( - session_id, - instruction, - f'{session_id}-{len(state_chatbot)}', - request_output_len=request_output_len, - top_p=top_p, - temperature=temperature) - - for status, tokens, _ in bot_response: - state_chatbot[-1] = (state_chatbot[-1][0], tokens) - yield (state_chatbot, state_chatbot, enable_btn, disable_btn) - - yield (state_chatbot, state_chatbot, disable_btn, enable_btn) - - -def reset_all_func(instruction_txtbox: gr.Textbox, state_chatbot: gr.State, - llama_chatbot: gr.State, triton_server_addr: str, - model_name: str): - """reset the session.""" - state_chatbot = [] - log_level = os.environ.get('SERVICE_LOG_LEVEL', 'INFO') - llama_chatbot = Chatbot(triton_server_addr, - model_name, - log_level=log_level, - display=True) - - return ( - llama_chatbot, - state_chatbot, - state_chatbot, - gr.Textbox.update(value=''), - ) - - -def cancel_func( - state_chatbot: gr.State, - llama_chatbot: gr.State, - cancel_btn: gr.Button, - reset_btn: gr.Button, -): - """cancel the session.""" - yield (llama_chatbot, state_chatbot, disable_btn, disable_btn) - session_id = llama_chatbot._session.session_id - llama_chatbot.cancel(session_id) - - yield (llama_chatbot, state_chatbot, disable_btn, enable_btn) - - -def add_instruction(instruction, state_chatbot): - state_chatbot = state_chatbot + [(instruction, None)] - return ('', state_chatbot) - - -def run_triton_server(triton_server_addr: str, - server_name: str = 'localhost', - server_port: int = 6006): - """chat with AI assistant through web ui. - - Args: - triton_server_addr (str): the communication address of inference server - server_name (str): the ip address of gradio server - server_port (int): the port of gradio server - """ - with gr.Blocks(css=CSS, theme=THEME) as demo: - log_level = os.environ.get('SERVICE_LOG_LEVEL', 'INFO') - llama_chatbot = gr.State( - Chatbot(triton_server_addr, log_level=log_level, display=True)) - state_chatbot = gr.State([]) - state_session_id = gr.State(0) - model_name = llama_chatbot.value.model_name - reset_all = partial(reset_all_func, - model_name=model_name, - triton_server_addr=triton_server_addr) - - with gr.Column(elem_id='container'): - gr.Markdown('## LMDeploy Playground') - - chatbot = gr.Chatbot(elem_id='chatbot', label=model_name) - instruction_txtbox = gr.Textbox( - placeholder='Please input the instruction', - label='Instruction') - with gr.Row(): - cancel_btn = gr.Button(value='Cancel', interactive=False) - reset_btn = gr.Button(value='Reset') - with gr.Row(): - request_output_len = gr.Slider(1, - 2048, - value=512, - step=1, - label='Maximum new tokens') - top_p = gr.Slider(0.01, 1, value=0.8, step=0.01, label='Top_p') - temperature = gr.Slider(0.01, - 1.5, - value=0.7, - step=0.01, - label='Temperature') - - send_event = instruction_txtbox.submit( - add_instruction, [instruction_txtbox, state_chatbot], - [instruction_txtbox, state_chatbot]).then(chat_stream, [ - state_chatbot, llama_chatbot, cancel_btn, reset_btn, - state_session_id, top_p, temperature, request_output_len - ], [state_chatbot, chatbot, cancel_btn, reset_btn]) - - cancel_btn.click(cancel_func, - [state_chatbot, llama_chatbot, cancel_btn, reset_btn], - [llama_chatbot, chatbot, cancel_btn, reset_btn], - cancels=[send_event]) - - reset_btn.click( - reset_all, [instruction_txtbox, state_chatbot, llama_chatbot], - [llama_chatbot, state_chatbot, chatbot, instruction_txtbox], - cancels=[send_event]) - - def init(): - with InterFace.lock: - InterFace.global_session_id += 1 - new_session_id = InterFace.global_session_id - return new_session_id - - demo.load(init, inputs=None, outputs=[state_session_id]) - - print(f'server is gonna mount on: http://{server_name}:{server_port}') - demo.queue(concurrency_count=4, max_size=100, api_open=True).launch( - max_threads=10, - share=True, - server_port=server_port, - server_name=server_name, - ) diff --git a/lmdeploy/serve/openai/api_client.py b/lmdeploy/serve/openai/api_client.py index d5b25f0f2..08942f32f 100644 --- a/lmdeploy/serve/openai/api_client.py +++ b/lmdeploy/serve/openai/api_client.py @@ -96,6 +96,8 @@ def chat_completions_v1(self, messages: Union[str, List[Dict[str, str]]], temperature: Optional[float] = 0.7, top_p: Optional[float] = 1.0, + logprobs: Optional[bool] = False, + top_logprobs: Optional[int] = 0, n: Optional[int] = 1, max_tokens: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index 59068afae..08f4de164 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -1,7 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import asyncio import copy -import json import os import time from http import HTTPStatus @@ -256,7 +255,7 @@ async def chat_completions_v1_qos(request: ChatCompletionRequestQos, probable tokens with probabilities that add up to top_p or higher are kept for generation. - n (int): How many chat completion choices to generate for each input - message. Only support one here. + message. **Only support one here**. - stream: whether to stream the results or not. Default to false. - max_tokens (int): output token nums - repetition_penalty (float): The parameter for repetition penalty. @@ -386,7 +385,7 @@ async def chat_completions_v1(request: ChatCompletionRequest, probable tokens with probabilities that add up to top_p or higher are kept for generation. - n (int): How many chat completion choices to generate for each input - message. Only support one here. + message. **Only support one here**. - stream: whether to stream the results or not. Default to false. - max_tokens (int | None): output token nums. Default to None. - repetition_penalty (float): The parameter for repetition penalty. @@ -416,11 +415,17 @@ async def chat_completions_v1(request: ChatCompletionRequest, - presence_penalty (replaced with repetition_penalty) - frequency_penalty (replaced with repetition_penalty) """ - VariableInterface.session_id += 1 - request.session_id = VariableInterface.session_id + if request.session_id == -1: + VariableInterface.session_id += 1 + request.session_id = VariableInterface.session_id error_check_ret = await check_request(request) if error_check_ret is not None: return error_check_ret + if VariableInterface.async_engine.id2step.get(str(request.session_id), + 0) != 0: + return create_error_response( + HTTPStatus.BAD_REQUEST, + f'The session_id `{request.session_id}` is occupied.') model_name = request.model adapter_name = None @@ -535,22 +540,17 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: final_logprobs.extend(res.logprobs) tool_calls = None - if request.tool_choice != 'none' and '<|plugin|>' in text: + if request.tool_choice != 'none' and ('<|plugin|>' in text + or '<|plugin|>') - action = action.split('<|action_end|>'.strip())[0] - action = action[action.find('{'):] try: # TODO add json_schema guidance to turbomind - action = json.loads(action) - action_id = [tool.function.name - for tool in request.tools].index(action['name']) + text, action_id, name, parameters = VariableInterface.async_engine.parse_tool_response( # noqa + text, request.tools) tool_calls = [ ToolCall(id=str(action_id), - function=FunctionResponse(name=action['name'], - arguments=json.dumps( - action['parameters']))) + function=FunctionResponse(name=name, + arguments=parameters)) ] except Exception as e: logger.error(f'Exception: {e}') @@ -614,7 +614,7 @@ async def completions_v1_qos(request: CompletionRequestQos, probable tokens with probabilities that add up to top_p or higher are kept for generation. - n (int): How many chat completion choices to generate for each input - message. Only support one here. + message. **Only support one here**. - stream: whether to stream the results or not. Default to false. - repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty @@ -765,7 +765,7 @@ async def completions_v1(request: CompletionRequest, probable tokens with probabilities that add up to top_p or higher are kept for generation. - n (int): How many chat completion choices to generate for each input - message. Only support one here. + message. **Only support one here**. - stream: whether to stream the results or not. Default to false. - repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty @@ -785,11 +785,17 @@ async def completions_v1(request: CompletionRequest, - presence_penalty (replaced with repetition_penalty) - frequency_penalty (replaced with repetition_penalty) """ - VariableInterface.session_id += 1 - request.session_id = VariableInterface.session_id + if request.session_id == -1: + VariableInterface.session_id += 1 + request.session_id = VariableInterface.session_id error_check_ret = await check_request(request) if error_check_ret is not None: return error_check_ret + if VariableInterface.async_engine.id2step.get(str(request.session_id), + 0) != 0: + return create_error_response( + HTTPStatus.BAD_REQUEST, + f'The session_id `{request.session_id}` is occupied.') model_name = request.model adapter_name = None diff --git a/lmdeploy/serve/turbomind/chatbot.py b/lmdeploy/serve/turbomind/chatbot.py deleted file mode 100644 index 68326a1ea..000000000 --- a/lmdeploy/serve/turbomind/chatbot.py +++ /dev/null @@ -1,687 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import logging -import queue -import random -import threading -from dataclasses import dataclass -from enum import Enum -from functools import partial -from typing import List, Union - -import mmengine -import numpy as np -import tritonclient.grpc as grpcclient - -from lmdeploy.model import MODELS -from lmdeploy.serve.turbomind.utils import (Postprocessor, Preprocessor, - prepare_tensor) -from lmdeploy.utils import filter_suffix, get_logger - - -@dataclass -class Session: - session_id: Union[int, str] - request_id: str = '' - histories: str = '' # history conversations of the session - sequence_length: int = 0 # the total generated token number in the session - prompt: str = '' - response: str = '' - status: int = None # status of the session - - -class StatusCode(Enum): - TRITON_STREAM_END = 0 # end of streaming - TRITON_STREAM_ING = 1 # response is in streaming - TRITON_SESSION_READY = 2 # session is ready for inference - TRITON_SERVER_ERR = -1 # triton server's error - TRITON_SESSION_CLOSED = -2 # session has been closed - TRITON_SESSION_OUT_OF_LIMIT = -3 # request length out of limit - TRITON_SESSION_INVALID_ARG = -4 # invalid argument - - -def stream_callback(que, result, error): - """callback function invoked by triton client.""" - que.put((result, error)) - - -class Chatbot: - """Chatbot for LLaMA series models with turbomind as inference engine. - - Args: - tritonserver_addr (str): communicating address ':' of - triton inference server - model_name (str): name of the to-be-deployed mode - log_level (int): the level of the log - display (bool): display the generated text on consolo or not - """ - - def __init__(self, - tritonserver_addr: str, - model_name: str = '', - ignore_eos: bool = False, - log_level: int = logging.INFO, - display: bool = False, - top_p: float = 1.0, - top_k: int = 1, - temperature: float = 0.8, - repetition_penalty: float = 1.0, - **model_kwargs): - self.tritonserver_addr = tritonserver_addr - self.model_name = model_name - if self.model_name == '': - self.model_name = self._get_model_name() - assert self.model_name in MODELS.module_dict.keys(), \ - f"'{self.model_name}' is not supported. " \ - f'The supported models are: {MODELS.module_dict.keys()}' - self.model = MODELS.get(self.model_name)(**model_kwargs) - self._session = None - self.preprocess = Preprocessor(tritonserver_addr) - self.postprocess = Postprocessor(tritonserver_addr) - self.bos_id = self._get_bos() - self.eos_id = self._get_eos() - stop_words = self._stop_words(self.model.stop_words) - bad_words = None - if ignore_eos: - stop_words = None - bad_words = np.array([[[self.eos_id], [1]]], dtype=np.int32) - self.eos_id = -1 - self.cfg = mmengine.Config( - dict(session_len=self.model.session_len, - top_p=top_p, - top_k=top_k, - temperature=temperature, - repetition_penalty=repetition_penalty, - stop_words=stop_words, - bad_words=bad_words)) - self.log_level = log_level - self.display = display - - def stream_infer(self, - session_id: int, - prompt: str, - request_id: str = '', - request_output_len: int = None, - sequence_start: bool = False, - sequence_end: bool = False, - skip_special_tokens: bool = True, - *args, - **kwargs): - """Start a new round conversion of a session. - - Args: - session_id (int): the identical id of a session - prompt (str): user's prompt in this round conversation - request_id (str): the identical id of this round conversation - request_output_len (int): the expected generated token numbers - sequence_start (bool): start flag of a session - sequence_end (bool): end flag of a session - skip_special_tokens (bool): Whether or not to remove special tokens - in the decoding. Default to be True. - Returns: - iterator: The generated content by chatbot - """ - assert isinstance(session_id, int), \ - f'INT session id is required, but got {type(session_id)}' - - logger = get_logger('service.ft', log_level=self.log_level) - logger.info(f'session {session_id}, request_id {request_id}, ' - f'request_output_len {request_output_len}') - - if self._session is None: - sequence_start = True - self._session = Session(session_id=session_id) - elif self._session.status == 0: - logger.error(f'session {session_id} has been ended. Please set ' - f'`sequence_start` be True if you want to restart it') - yield StatusCode.TRITON_SESSION_CLOSED, '', 0 - return - - self._session.status = 1 - self._session.request_id = request_id - self._session.response = '' - self.cfg.update(**kwargs) - - self._session.prompt = self._get_prompt(prompt, sequence_start) - for status, res, tokens in self._stream_infer( - self._session, - self._session.prompt, - request_output_len, - sequence_start, - sequence_end, - skip_special_tokens=skip_special_tokens): - if status == StatusCode.TRITON_STREAM_END: # remove stop_words - res = filter_suffix(res, self.model.stop_words) - if status.value < 0: - break - else: - yield status, res, tokens - if status.value == 0: - self._session.histories = \ - self._session.histories + self._session.prompt + \ - self._session.response - else: - yield status, res, tokens - - def end(self, session_id: int, *args, **kwargs): - """end a session. Triton inference server will release the session's - occupied resource when it is ended. - - Args: - session_id (int): the identical id of a session - - Returns: - int: 0: success, -1: session not found - """ - assert isinstance(session_id, int), \ - f'INT session id is required, but got {type(session_id)}' - - logger = get_logger('service.ft', log_level=self.log_level) - logger.info(f'end session: {session_id}') - - if self._session is None: - logger.error( - f"session {session_id} doesn't exist. It cannot be ended") - return StatusCode.TRITON_SESSION_INVALID_ARG - if self._session.session_id != session_id: - logger.error(f'you cannot end session {session_id}, because this ' - f'session is {self._session.session_id}') - return StatusCode.TRITON_SESSION_INVALID_ARG - if self._session.status == 0: - logger.warning(f'session {session_id} has already been ended') - return StatusCode.TRITON_SESSION_CLOSED - - self._session.status = 0 - for status, _, _ in self._stream_infer(self._session, - prompt='', - request_output_len=0, - sequence_start=False, - sequence_end=True): - if status.value < 0: - break - - self.reset_session() - return status - - def cancel(self, session_id: int, *args, **kwargs): - """Cancel the session during generating tokens. - - Args: - session_id (int): the identical id of a session - - Returns: - int: 0: success, -1: session not found - """ - assert isinstance(session_id, int), \ - f'INT session id is required, but got {type(session_id)}' - logger = get_logger('service.ft', log_level=self.log_level) - logger.info(f'cancel session: {session_id}') - - if self._session is None: - logger.error( - f"session {session_id} doesn't exist. It cannot be cancelled") - return StatusCode.TRITON_SESSION_INVALID_ARG - if self._session.session_id != session_id: - logger.error( - f'you cannot cancel session {session_id}, because this ' - f'session is {self._session.session_id}') - return StatusCode.TRITON_SESSION_INVALID_ARG - if self._session.status == 0: - logger.error(f'session {session_id} has already been ended. ' - f'It cannot be cancelled') - return StatusCode.TRITON_SESSION_CLOSED - - prev_session = self._session - status, res = None, None - for status, res, _ in self._stream_infer(self._session, - prompt='', - request_output_len=0, - sequence_start=False, - sequence_end=False, - cancel=True): - if status.value < 0: - break - if status == StatusCode.TRITON_STREAM_END: - logger.info(f'cancel session {session_id} successfully') - if prev_session.histories: - logger.warning(f'TODO: start to recover session {session_id}') - else: - logger.info(f'cancel session {session_id} failed: {res}') - return status - - def resume(self, session_id: int, *args, **kwargs): - """Resume a session by sending the history conversations to triton - inference server. After resuming, users can continue chatting with - chatbot. - - Args: - session_id (int): the identical id of a session - - Returns: - int: 0: success, -1: session not found - """ - assert isinstance(session_id, int), \ - f'INT session id is required, but got {type(session_id)}' - - logger = get_logger('service.ft', log_level=self.log_level) - logger.info(f'resume session: {session_id}') - - if self._session is None: - logger.error( - f"session {session_id} doesn't exist. It cannot be recovered") - return StatusCode.TRITON_SESSION_INVALID_ARG - if self._session.session_id != session_id: - logger.error( - f'you cannot resume session {session_id}, because this ' - f'session is {self._session.session_id}') - return StatusCode.TRITON_SESSION_INVALID_ARG - - self._session.status = 1 - self._session.sequence_length = 0 - histories = self._session.histories - for status, _, _ in self._stream_infer(self._session, - prompt=histories, - request_output_len=0, - sequence_start=True, - sequence_end=False): - if status.value < 0: - break - - self._session.histories = histories - return status - - def infer(self, - session_id: int, - prompt: str, - request_id: str = '', - request_output_len: int = None, - sequence_start: bool = False, - sequence_end: bool = False, - skip_special_tokens: bool = True, - *args, - **kwargs): - """Start a new round conversion of a session. Return the chat - completions in non-stream mode. - - Args: - session_id (int): the identical id of a session - prompt (str): user's prompt in this round conversation - request_id (str): the identical id of this round conversation - request_output_len (int): the expected generated token numbers - sequence_start (bool): start flag of a session - sequence_end (bool): end flag of a session - skip_special_tokens (bool): Whether or not to remove special tokens - in the decoding. Default to be True. - Returns: - tuple(Status, str, int): status, text/chat completion, - generated token number - """ - assert isinstance(session_id, int), \ - f'INT session id is required, but got {type(session_id)}' - - logger = get_logger('service.ft', log_level=self.log_level) - logger.info(f'session {session_id}, request_id {request_id}, ' - f'request_output_len {request_output_len}') - - if self._session is None: - sequence_start = True - self._session = Session(session_id=session_id) - elif self._session.status == 0: - logger.error(f'session {session_id} has been ended. Please set ' - f'`sequence_start` be True if you want to restart it') - return StatusCode.TRITON_SESSION_CLOSED, '', 0 - - self._session.status = 1 - self._session.request_id = request_id - self._session.response = '' - - self._session.prompt = self._get_prompt(prompt, sequence_start) - status, res, tokens = None, '', 0 - for status, res, tokens in self._stream_infer( - self._session, - self._session.prompt, - request_output_len, - sequence_start, - sequence_end, - skip_special_tokens=skip_special_tokens): - if status.value < 0: - break - if status == StatusCode.TRITON_STREAM_END: # remove stop_words - res = filter_suffix(res, self.model.stop_words) - if status.value == 0: - self._session.histories = \ - self._session.histories + self._session.prompt + \ - self._session.response - return status, res, tokens - else: - return status, res, tokens - - def reset_session(self): - """reset session.""" - self._session = None - - @property - def session(self): - """get session.""" - return self._session - - @session.setter - def session(self, value): - """set session.""" - self._session = value - - def _get_model_name(self): - with grpcclient.InferenceServerClient( - self.tritonserver_addr) as client: - model_config = client.get_model_config(model_name='turbomind', - as_json=True) - return model_config['config']['parameters']['model_name'][ - 'string_value'] - - def _get_bos(self): - """return bos token id.""" - token_ids, _ = self.preprocess('') - return token_ids[0][0] - - def _get_eos(self): - """return eos token id.""" - token_ids, _ = self.preprocess('') - return token_ids[0][0] - - def _stop_words(self, stop_words: List[str]): - """return stop-words' token ids.""" - if stop_words is None: - return None - assert isinstance(stop_words, List) and \ - all(isinstance(elem, str) for elem in stop_words), \ - f'stop_words must be a list but got {type(stop_words)}' - # each id in stop_words represents a stop word - # refer to https://github.com/fauxpilot/fauxpilot/discussions/165 for - # detailed explanation about turbomind's stop_words - stop_words = [ - int(self.preprocess(stop_word)[0][0][-1]) - for stop_word in stop_words - ] - assert isinstance(stop_words, List) and \ - all(isinstance(elem, int) for elem in stop_words), \ - 'invalid stop_words' - stop_word_offsets = range(1, len(stop_words) + 1) - stop_words = np.array([[stop_words, - stop_word_offsets]]).astype(np.int32) - return stop_words - - def _get_prompt(self, prompt: str, sequence_start: bool): - """return the concatenated prompt according to the model's chat - template.""" - return self.model.get_prompt(prompt, sequence_start) - - def _stream_infer(self, - session: Session, - prompt: str, - request_output_len: int = 512, - sequence_start: bool = True, - sequence_end: bool = False, - skip_special_tokens: bool = True, - cancel: bool = False): - """communicate with inference server to chat, or cancel a session, or - end a session. - - Args: - session (Session): an instance of a session - prompt (str): the concatenated prompt - request_output_len (int): the max number of tokens to be generated - sequence_start (bool): indicator for starting a sequence - sequence_end (bool): indicator for ending a sequence - cancel (bool): indicator for cancelling the session - skip_special_tokens (bool): Whether or not to remove special tokens - in the decoding. Default to be True. - Yields: - tuple: status, text, generated token number - """ - logger = get_logger('service.ft', log_level=self.log_level) - logger.info(f'session {session.session_id}, ' - f'request id {session.request_id}, ' - f'request_output_len {request_output_len}, ' - f'start {sequence_start}, ' - f'end {sequence_end}, cancel {cancel}') - - assert request_output_len is None or \ - isinstance(request_output_len, int), \ - f'request_output_len is supposed to be None or int, ' \ - f'but got {type(request_output_len)}' - - if sequence_start: - logger.info(f'session {session.session_id}, clear history since ' - f'sequence_start is True') - session.histories = '' - session.sequence_length = 0 - - input_ids, input_lengths = self.preprocess(prompt) - # got input_ids with default add_bos == True - if not sequence_start and input_ids[0][0] == self.bos_id: - input_ids = input_ids[:, 1:] - input_lengths = input_lengths - 1 - # will crash if last_token_id == eos_id and send empty input_ids - if sequence_end and request_output_len == 0: - input_ids = np.array([[1]], dtype=np.uint32) - input_lengths = np.array([[1]], dtype=np.uint32) - input_tokens = input_lengths.squeeze() - - if request_output_len is None: - request_output_len = max( - 128, - self.cfg.session_len - session.sequence_length - input_tokens) - - if input_tokens + request_output_len + \ - session.sequence_length > self.cfg.session_len: - errmsg = f'session {session.session_id}, ' \ - f'out of max sequence length {self.cfg.session_len}, ' \ - f'#input tokens {input_tokens}, ' \ - f'history tokens {session.sequence_length}, ' \ - f'request length {request_output_len}' - logger.warning(errmsg) - yield StatusCode.TRITON_SESSION_OUT_OF_LIMIT, errmsg, 0 - return - - logger.info(f'session {session.session_id}, ' - f'max length: {self.cfg.session_len}, ' - f'input tokens: {input_tokens}, ' - f'request tokens: {request_output_len}, ' - f'history tokens: {session.sequence_length}') - - preseq_length = session.sequence_length - session.response = '' - session.status = StatusCode.TRITON_SESSION_READY - - que = queue.Queue() - producer = threading.Thread(target=self._stream_producer, - args=(self.tritonserver_addr, session, que, - self.cfg, input_ids, input_lengths, - request_output_len, sequence_start, - sequence_end, preseq_length, cancel)) - producer.start() - for status, res, n_token in self.stream_consumer( - self.postprocess, que, session, input_tokens, preseq_length, - cancel, logger, self.display, self.eos_id, - skip_special_tokens): - yield status, res, n_token - - producer.join() - self._session = que.get() - curseq_length = self._session.sequence_length - logger.info(f'session {session.session_id}, pre seq_len ' - f'{preseq_length}, cur seq_len {curseq_length}, ' - f'diff {curseq_length - preseq_length}') - - @staticmethod - def _stream_producer(tritonserver_addr, session, que, cfg, input_ids, - input_lengths, request_output_len, sequence_start, - sequence_end, preseq_length, cancel): - """Send a request to the triton inference server. - - Args: - tritonserver_addr (str): the communication address of the inference - server - session (Session): an instance of a session - que (multiprocessing.Queue): response queue - cfg (dict): parameters for sampling - input_ids (numpy.ndarray): token ids of input prompt - input_lengths (numpy.ndarray): length of input_ids - request_output_len (int): the max number of tokens to be generated - sequence_start (bool): indicator for starting a sequence - sequence_end (bool): indicator for ending a sequence - preseq_length (int): the history sequence length - cancel (bool): indicator for cancelling the session - """ - request_output_len = np.full(input_lengths.shape, - request_output_len).astype(np.uint32) - - callback = partial(stream_callback, que) - with grpcclient.InferenceServerClient(tritonserver_addr) as client: - inputs = [ - prepare_tensor('input_ids', input_ids), - prepare_tensor('input_lengths', input_lengths), - prepare_tensor('request_output_len', request_output_len), - prepare_tensor('runtime_top_p', - cfg.top_p * np.ones((1, 1), dtype=np.float32)), - prepare_tensor( - 'temperature', - cfg.temperature * np.ones((1, 1), dtype=np.float32)), - prepare_tensor( - 'repetition_penalty', - cfg.repetition_penalty * np.ones( - (1, 1), dtype=np.float32)), - prepare_tensor('step', - preseq_length * np.ones((1, 1), dtype=np.int32)) - ] - if cfg.top_k is not None: - inputs += prepare_tensor( - 'runtime_top_k', - cfg.top_k * np.ones((1, 1), dtype=np.uint32)), - if cfg.stop_words is not None: - inputs += [prepare_tensor('stop_words_list', cfg.stop_words)] - if cfg.bad_words is not None: - inputs += [prepare_tensor('bad_words_list', cfg.bad_words)] - - inputs += [ - prepare_tensor( - 'session_len', - cfg.session_len * - np.ones([input_ids.shape[0], 1], dtype=np.uint32)), - prepare_tensor('START', (1 if sequence_start else 0) * np.ones( - (1, 1), dtype=np.int32)), - prepare_tensor('END', (1 if sequence_end else 0) * np.ones( - (1, 1), dtype=np.int32)), - prepare_tensor( - 'CORRID', - session.session_id * np.ones((1, 1), dtype=np.uint64)), - prepare_tensor('STOP', (1 if cancel else 0) * np.ones( - (1, 1), dtype=np.int32)) - ] - if sequence_start: - random_seed = random.getrandbits(64) - inputs += [ - prepare_tensor( - 'random_seed', - random_seed * np.ones((1, 1), dtype=np.uint64)) - ] - client.start_stream(callback) - client.async_stream_infer('turbomind', - inputs, - sequence_id=session.session_id, - request_id=session.request_id, - sequence_start=sequence_start, - sequence_end=sequence_end) - que.put(None) - - @staticmethod - def stream_consumer(postprocess, res_queue, session, n_input_token, - preseq_length, cancel, logger, display, eos_id, - skip_special_tokens): - """Consume the response from the triton inference server. - - Args: - postprocess (callable): postprocess function for - the generated tokens - res_queue (multiprocessing.Queue): response queue - session (Session): an instance of a session - n_input_token (int): token number of input prompt - preseq_length (int): the history sequence length - cancel (bool): indicator for cancelling the session - logger (util.Logger): - display (bool): display the text in the consolo interface or not - eos_id (int): eos token id - skip_special_tokens (bool): Whether or not to remove special tokens - in the decoding. Default to be True. - - Yields: - tuple: status, text, generated token number - """ - status, res, n_token = None, '', 0 - output_ids = np.zeros((1, 1, 0), dtype=np.uint32) - text = '' - while True: - result_pack = res_queue.get() - if result_pack is None: - status = StatusCode.TRITON_STREAM_END - res = session.response - session.status = StatusCode.TRITON_STREAM_END - break - result, error = result_pack - if error is not None: - logger.error(f'got error from turbomind, code ' - f'{StatusCode.TRITON_SERVER_ERR}, {error}, ' - f'token {session.sequence_length}') - session.sequence_length = preseq_length - session.response = '' - status = StatusCode.TRITON_SERVER_ERR - res = f'{status}, {error}' - n_token = 0 - break - if cancel: - continue - try: - sequence_length = result.as_numpy('sequence_length') - output_ids = result.as_numpy('output_ids') - - session.sequence_length = sequence_length.squeeze() - output_ids = output_ids.reshape((1, 1, output_ids.shape[-1])) - output_ids = output_ids[:, :, n_input_token + - preseq_length:sequence_length.squeeze( - )] - last_token_id = None if output_ids.shape[ - -1] == 0 else output_ids[-1, -1, -1] - if last_token_id == eos_id: - session.sequence_length = session.sequence_length - 1 - output_ids = output_ids[:, :, :-1] - - output_str = postprocess( - output_ids, np.array([[n_token]], dtype=np.uint32), - np.array([[int(skip_special_tokens)]], dtype=np.int32)) - text = output_str[0].decode() - # utf-8 char at the end means it's a potential unfinished - # byte sequence, continue to concate it with the next - # sequence and decode them together - if text.endswith('�'): - continue - n_token = output_ids.shape[-1] - if display: - print(text, end='', flush=True) - session.response += text - yield (StatusCode.TRITON_STREAM_ING, session.response, - output_ids.shape[-1]) - except Exception as e: - logger.error(f'catch exception: {e}') - logger.error( - f'session {session.session_id}: prompt: {session.prompt}') - # `n_token` might be not updated since `if text.endswith('�')` - if n_token != output_ids.shape[-1]: - n_token = output_ids.shape[-1] - session.response += text - # put session back to queue so that `_stream_infer` can update it in - # `self.sessions` - while not res_queue.empty(): - res_queue.get() - res_queue.put(session) - if display: - print('\n') - yield status, res, n_token diff --git a/lmdeploy/serve/turbomind/service_docker_up.sh b/lmdeploy/serve/turbomind/service_docker_up.sh deleted file mode 100644 index d45345e61..000000000 --- a/lmdeploy/serve/turbomind/service_docker_up.sh +++ /dev/null @@ -1,87 +0,0 @@ -#!/bin/sh - -show_help() { - echo "Usage: $0 [-h] [--help] [-l] [--lib-dir]" - echo - echo "Options:" - echo " -h, --help Show this help message and exit" - echo " --lib-dir Specify the directory of turbomind libraries" -} - -# check if '-h' or '--help' in the arguments -for arg in "$@" -do - if [ "$arg" == "-h" ] || [ "$arg" == "--help" ]; then - show_help - exit 0 - fi -done - - -TP=1 -DEVICES="0" -for ((i = 1; i < ${TP}; ++i)); do - DEVICES="${DEVICES},$i" -done -DEVICES="\"device=${DEVICES}\"" - - -SCRIPT_DIR="$(dirname "$0")" -SCRIPT_ABS_DIR="$(realpath "$SCRIPT_DIR")" - - -if [ -z "$1" ]; then - docker run \ - --gpus $DEVICES \ - --rm \ - -v "${SCRIPT_ABS_DIR}":/workspace/models \ - --shm-size 16g \ - -p 33336:22 \ - -p 33337-33400:33337-33400 \ - --cap-add=SYS_PTRACE \ - --cap-add=SYS_ADMIN \ - --security-opt seccomp=unconfined \ - --name lmdeploy \ - -it --env NCCL_LAUNCH_MODE=GROUP openmmlab/lmdeploy:latest \ - tritonserver \ - --model-repository=/workspace/models/model_repository \ - --allow-http=0 \ - --allow-grpc=1 \ - --grpc-port=33337 \ - --log-verbose=0 \ - --allow-metrics=1 -fi - -for ((i = 1; i <= $#; i++)); do - arg=${!i} - case "$arg" in - --lib-dir) - if [ "$i" -eq "$#" ]; then - show_help - exit -1 - fi - LIB_PATH=${@:i+1:1} - docker run \ - --gpus $DEVICES \ - --rm \ - -v "${LIB_PATH}":/opt/tritonserver/backends/turbomind \ - -v ""${SCRIPT_ABS_DIR}"":/workspace/models \ - --shm-size 16g \ - -p 33336:22 \ - -p 33337-33400:33337-33400 \ - --cap-add=SYS_PTRACE \ - --cap-add=SYS_ADMIN \ - --security-opt seccomp=unconfined \ - --name lmdeploy \ - -it --env NCCL_LAUNCH_MODE=GROUP openmmlab/lmdeploy:latest \ - tritonserver \ - --model-repository=/workspace/models/model_repository \ - --allow-http=0 \ - --allow-grpc=1 \ - --grpc-port=33337 \ - --log-verbose=0 \ - --allow-metrics=1 - break - ;; - esac -done diff --git a/lmdeploy/serve/turbomind/triton_models/interactive/1/placeholder b/lmdeploy/serve/turbomind/triton_models/interactive/1/placeholder deleted file mode 100644 index e69de29bb..000000000 diff --git a/lmdeploy/serve/turbomind/triton_models/interactive/config.pbtxt b/lmdeploy/serve/turbomind/triton_models/interactive/config.pbtxt deleted file mode 100644 index 0b1e431ea..000000000 --- a/lmdeploy/serve/turbomind/triton_models/interactive/config.pbtxt +++ /dev/null @@ -1,281 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions -# are met: -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * Neither the name of NVIDIA CORPORATION nor the names of its -# contributors may be used to endorse or promote products derived -# from this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS AND ANY -# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR -# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR -# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY -# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -name: "turbomind" -backend: "turbomind" -default_model_filename: "weights" -max_batch_size: 1 - -model_transaction_policy { - decoupled: True -} - -instance_group [ - { - # max concurrent instances - count: 48 - kind: KIND_CPU - } -] - -input [ - { - name: "input_ids" - data_type: TYPE_UINT32 - dims: [ -1 ] - # allow_ragged_batch: true - }, - { - name: "input_lengths" - data_type: TYPE_UINT32 - dims: [ 1 ] - reshape: { shape: [ ] } - }, - { - name: "request_output_len" - data_type: TYPE_UINT32 - dims: [ -1 ] - }, - { - name: "input_embeddings" - data_type: TYPE_INT8 - dims: [ -1 ] - optional: true - }, - { - name: "input_embedding_ranges" - data_type: TYPE_UINT32 - dims: [ -1, 2 ] - optional: true - }, - { - name: "step" - data_type: TYPE_INT32 - dims: [ 1 ] - reshape: { shape: [ ] } - optional: true - }, - { - name: "session_len" - data_type: TYPE_UINT32 - dims: [ 1 ] - reshape: { shape: [ ] } - optional: true - }, - { - name: "runtime_top_k" - data_type: TYPE_UINT32 - dims: [ 1 ] - reshape: { shape: [ ] } - optional: true - }, - { - name: "runtime_top_p" - data_type: TYPE_FP32 - dims: [ 1 ] - reshape: { shape: [ ] } - optional: true - }, - { - name: "beam_search_diversity_rate" - data_type: TYPE_FP32 - dims: [ 1 ] - reshape: { shape: [ ] } - optional: true - }, - { - name: "temperature" - data_type: TYPE_FP32 - dims: [ 1 ] - reshape: { shape: [ ] } - optional: true - }, - { - name: "len_penalty" - data_type: TYPE_FP32 - dims: [ 1 ] - reshape: { shape: [ ] } - optional: true - }, - { - name: "repetition_penalty" - data_type: TYPE_FP32 - dims: [ 1 ] - reshape: { shape: [ ] } - optional: true - }, - { - name: "random_seed" - data_type: TYPE_UINT64 - dims: [ 1 ] - reshape: { shape: [ ] } - optional: true - }, - { - name: "is_return_log_probs" - data_type: TYPE_BOOL - dims: [ 1 ] - reshape: { shape: [ ] } - optional: true - }, - { - name: "beam_width" - data_type: TYPE_UINT32 - dims: [ 1 ] - reshape: { shape: [ ] } - optional: true - }, - { - name: "start_id" - data_type: TYPE_UINT32 - dims: [ 1 ] - reshape: { shape: [ ] } - optional: true - }, - { - name: "end_id" - data_type: TYPE_UINT32 - dims: [ 1 ] - reshape: { shape: [ ] } - optional: true - }, - { - name: "bad_words_list" - data_type: TYPE_INT32 - dims: [ 2, -1 ] - optional: true - }, - { - name: "stop_words_list" - data_type: TYPE_INT32 - dims: [ 2, -1 ] - optional: true - }, - { - name: "prompt_learning_task_name_ids" - data_type: TYPE_UINT32 - dims: [ 1 ] - reshape: { shape: [ ] } - optional: true - }, - { - name: "top_p_decay" - data_type: TYPE_FP32 - dims: [ 1 ] - reshape: { shape: [ ] } - optional: true - }, - { - name: "top_p_min" - data_type: TYPE_FP32 - dims: [ 1 ] - reshape: { shape: [ ] } - optional: true - }, - { - name: "top_p_reset_ids" - data_type: TYPE_UINT32 - dims: [ 1 ] - reshape: { shape: [ ] } - optional: true - }, - { - name: "START" - data_type: TYPE_INT32 - dims: [ 1 ] - reshape: { shape: [ ] } - optional: true - }, - { - name: "END" - data_type: TYPE_INT32 - dims: [ 1 ] - reshape: { shape: [ ] } - optional: true - }, - { - name: "STOP" - data_type: TYPE_INT32 - dims: [ 1 ] - reshape: { shape: [ ] } - optional: true - }, - { - name: "CORRID" - data_type: TYPE_UINT64 - dims: [ 1 ] - reshape: { shape: [ ] } - optional: true - } -] -output [ - { - name: "output_ids" - data_type: TYPE_UINT32 - dims: [ -1, -1 ] - }, - { - name: "sequence_length" - data_type: TYPE_UINT32 - dims: [ -1 ] - }, - { - name: "cum_log_probs" - data_type: TYPE_FP32 - dims: [ -1 ] - }, - { - name: "output_log_probs" - data_type: TYPE_FP32 - dims: [ -1, -1 ] - } -] - -parameters { - key: "pipeline_para_size" - value: { - string_value: "1" - } -} -parameters { - key: "data_type" - value: { - string_value: "fp16" - } -} -parameters { - key: "model_type" - value: { - string_value: "Llama" - } -} - -parameters { - key: "enable_custom_all_reduce" - value: { - string_value: "0" - } -} diff --git a/lmdeploy/serve/turbomind/triton_models/postprocessing/1/model.py b/lmdeploy/serve/turbomind/triton_models/postprocessing/1/model.py deleted file mode 100644 index 0aa6805a3..000000000 --- a/lmdeploy/serve/turbomind/triton_models/postprocessing/1/model.py +++ /dev/null @@ -1,136 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import json -import os.path as osp -from pathlib import Path - -import numpy as np -import triton_python_backend_utils as pb_utils - -# This tokenizer is `lmdeploy/turbomind/tokenizer.py`. When an LLM is served -# by triton inference server, it has to be converted first by running -# `python lmdeploy/serve/turbomind/deploy.py`. Then -# `lmdeploy/turbomind/tokenizer.py` will be copied to `tokenizer/tokenizer.py` -from .tokenizer.tokenizer import Tokenizer - - -class TritonPythonModel: - """Your Python model must use the same class name. - - Every Python model that is created must have "TritonPythonModel" as the - class name. - """ - - def initialize(self, args): - """`initialize` is called only once when the model is being loaded. - Implementing `initialize` function is optional. This function allows - the model to initialize any state associated with this model. - Parameters - ---------- - args : dict - Both keys and values are strings. The dictionary keys and values are: - * model_config: A JSON string containing the model configuration - * model_instance_kind: A string containing model instance kind - * model_instance_device_id: A string containing model instance device - ID - * model_repository: Model repository path - * model_version: Model version - * model_name: Model name - """ - # Parse model configs - self.model_config = model_config = json.loads(args['model_config']) - - # Parse model output configs - output_config = pb_utils.get_output_config_by_name( - model_config, 'OUTPUT') - - # Convert Triton types to numpy types - self.output_dtype = pb_utils.triton_string_to_numpy( - output_config['data_type']) - - cur_folder = Path(__file__).parent - - self.tokenizer = Tokenizer( - osp.join( - cur_folder, self.model_config['parameters']['tokenizer_path'] - ['string_value'])) - - def execute(self, requests): - """`execute` must be implemented in every Python model. `execute` - function receives a list of pb_utils.InferenceRequest as the only - argument. This function is called when an inference is requested - for this model. Depending on the batching configuration (e.g. Dynamic - Batching) used, `requests` may contain multiple requests. Every - Python model, must create one pb_utils.InferenceResponse for every - pb_utils.InferenceRequest in `requests`. If there is an error, you can - set the error argument when creating a pb_utils.InferenceResponse. - Parameters - ---------- - requests : list - A list of pb_utils.InferenceRequest - Returns - ------- - list - A list of pb_utils.InferenceResponse. The length of this list must - be the same as `requests` - """ - - responses = [] - - # Every Python backend must iterate over everyone of the requests - # and create a pb_utils.InferenceResponse for each of them. - for idx, request in enumerate(requests): - # Get input tensors - tokens_batch = pb_utils.get_input_tensor_by_name( - request, 'TOKENS_BATCH').as_numpy() - sequence_length = pb_utils.get_input_tensor_by_name( - request, 'sequence_length').as_numpy() - skip_special_tokens = pb_utils.get_input_tensor_by_name( - request, 'skip_special_tokens').as_numpy() - - # Postprocessing output data. - outputs = self._postprocessing(tokens_batch.tolist(), - sequence_length, - skip_special_tokens) - - # Create output tensors. You need pb_utils.Tensor - # objects to create pb_utils.InferenceResponse. - output_tensor = pb_utils.Tensor( - 'OUTPUT', - np.array(outputs).astype(self.output_dtype)) - - # Create InferenceResponse. You can set an error here in case - # there was a problem with handling this inference request. - # Below is an example of how you can set errors in inference - # response: - # - # pb_utils.InferenceResponse( - # output_tensors=..., TritonError("An error occurred")) - inference_response = pb_utils.InferenceResponse( - output_tensors=[output_tensor]) - responses.append(inference_response) - - # You should return a list of pb_utils.InferenceResponse. Length - # of this list must match the length of `requests` list. - return responses - - def finalize(self): - """`finalize` is called only once when the model is being unloaded. - - Implementing `finalize` function is optional. This function allows the - model to perform any necessary clean ups before exit. - """ - print('Cleaning up...') - - def _postprocessing(self, tokens_batch, sequence_length, - skip_special_tokens): - """decode token ids into texts.""" - outputs = [] - for beam_tokens, beam_len, beam_skip_special in zip( - tokens_batch, sequence_length, skip_special_tokens): - for tokens, _len, skip_special in zip(beam_tokens, beam_len, - beam_skip_special): - output = self.tokenizer.decode( - tokens, _len, skip_special_tokens=bool(skip_special)) - output = output.encode('utf8') - outputs.append(output) - return outputs diff --git a/lmdeploy/serve/turbomind/triton_models/postprocessing/config.pbtxt b/lmdeploy/serve/turbomind/triton_models/postprocessing/config.pbtxt deleted file mode 100644 index 7954a65f8..000000000 --- a/lmdeploy/serve/turbomind/triton_models/postprocessing/config.pbtxt +++ /dev/null @@ -1,41 +0,0 @@ -name: "postprocessing" -backend: "python" -max_batch_size: 1 -input [ - { - name: "TOKENS_BATCH" - data_type: TYPE_UINT32 - dims: [ -1, -1 ] - }, - { - name: "sequence_length" - data_type: TYPE_UINT32 - dims: [ -1 ] - }, - { - name: "skip_special_tokens" - data_type: TYPE_INT32 - dims: [ -1 ] - } -] -output [ - { - name: "OUTPUT" - data_type: TYPE_STRING - dims: [ -1, -1 ] - } -] - -instance_group [ - { - count: 16 - kind: KIND_CPU - } -] - -parameters { - key: "tokenizer_path" - value: { - string_value: "tokenizer/tokenizer.model" - } -} diff --git a/lmdeploy/serve/turbomind/triton_models/preprocessing/1/model.py b/lmdeploy/serve/turbomind/triton_models/preprocessing/1/model.py deleted file mode 100644 index 7e659fbae..000000000 --- a/lmdeploy/serve/turbomind/triton_models/preprocessing/1/model.py +++ /dev/null @@ -1,151 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import json -import os.path as osp -from pathlib import Path - -import numpy as np -import torch -import triton_python_backend_utils as pb_utils -from torch.nn.utils.rnn import pad_sequence - -# This tokenizer is `lmdeploy/turbomind/tokenizer.py`. When an LLM is served -# by triton inference server, it has to be converted first by running -# `python lmdeploy/serve/turbomind/deploy.py`. Then -# `lmdeploy/turbomind/tokenizer.py` will be copied to `tokenizer/tokenizer.py` -from .tokenizer.tokenizer import Tokenizer - - -class TritonPythonModel: - """Your Python model must use the same class name. - - Every Python model that is created must have "TritonPythonModel" as the - class name. - """ - - def initialize(self, args): - """`initialize` is called only once when the model is being loaded. - Implementing `initialize` function is optional. This function allows - the model to initialize any state associated with this model. - Parameters - ---------- - args : dict - Both keys and values are strings. The dictionary keys and values are: - * model_config: A JSON string containing the model configuration - * model_instance_kind: A string containing model instance kind - * model_instance_device_id: A string containing model instance device - ID - * model_repository: Model repository path - * model_version: Model version - * model_name: Model name - """ - # Parse model configs - self.model_config = model_config = json.loads(args['model_config']) - - # Parse model output configs and convert Triton types to numpy types - input_names = ['INPUT_ID', 'REQUEST_INPUT_LEN'] - for input_name in input_names: - setattr( - self, - input_name.lower() + '_dtype', - pb_utils.triton_string_to_numpy( - pb_utils.get_output_config_by_name( - model_config, input_name)['data_type'])) - - cur_folder = Path(__file__).parent - self.tokenizer = Tokenizer( - osp.join( - cur_folder, self.model_config['parameters']['tokenizer_path'] - ['string_value'])) - self.start_id = self.tokenizer.bos_token_id - self.end_id = self.tokenizer.eos_token_id - - def execute(self, requests): - """`execute` must be implemented in every Python model. `execute` - function receives a list of pb_utils.InferenceRequest as the only - argument. This function is called when an inference is requested - for this model. Depending on the batching configuration (e.g. Dynamic - Batching) used, `requests` may contain multiple requests. Every - Python model, must create one pb_utils.InferenceResponse for every - pb_utils.InferenceRequest in `requests`. If there is an error, you can - set the error argument when creating a pb_utils.InferenceResponse. - Parameters - ---------- - requests : list - A list of pb_utils.InferenceRequest - Returns - ------- - list - A list of pb_utils.InferenceResponse. The length of this list must - be the same as `requests` - """ - - responses = [] - - # Every Python backend must iterate over everyone of the requests - # and create a pb_utils.InferenceResponse for each of them. - for idx, request in enumerate(requests): - # Get input tensors - query = pb_utils.get_input_tensor_by_name(request, - 'QUERY').as_numpy() - - # Preprocessing input data. - input_id, request_input_len = self._create_request(query) - - # Create output tensors. You need pb_utils.Tensor - # objects to create pb_utils.InferenceResponse. - input_id_tensor = pb_utils.Tensor( - 'INPUT_ID', - np.array(input_id).astype(self.input_id_dtype)) - request_input_len_tensor = pb_utils.Tensor( - 'REQUEST_INPUT_LEN', - np.array(request_input_len).astype( - self.request_input_len_dtype)) - - # Create InferenceResponse. You can set an error here in case - # there was a problem with handling this inference request. - # Below is an example of how you can set errors in inference - # response: - # - # pb_utils.InferenceResponse( - # output_tensors=..., TritonError("An error occurred")) - inference_response = pb_utils.InferenceResponse( - output_tensors=[input_id_tensor, request_input_len_tensor]) - responses.append(inference_response) - - # You should return a list of pb_utils.InferenceResponse. Length - # of this list must match the length of `requests` list. - return responses - - def finalize(self): - """`finalize` is called only once when the model is being unloaded. - - Implementing `finalize` function is optional. This function allows the - model to perform any necessary clean ups before exit. - """ - print('Cleaning up...') - - def _create_request(self, query): - """Tokenize prompts and return the token ids and their length. - - Args: - query (List[str]): a list of prompt - Returns: - tuple: token ids and their length - """ - start_ids = [] - for s in query: - _s = s[0].decode() - if _s == '': - start_id = [self.start_id - ] if self.start_id is not None else [-1] - elif _s == '': - start_id = [self.end_id] if self.end_id is not None else [-1] - else: - start_id = self.tokenizer.encode(_s) - start_ids.append(torch.IntTensor(start_id)) - - start_lengths = torch.IntTensor([[len(ids)] for ids in start_ids]) - start_ids = pad_sequence(start_ids, - batch_first=True, - padding_value=self.end_id) - return start_ids, start_lengths diff --git a/lmdeploy/serve/turbomind/triton_models/preprocessing/config.pbtxt b/lmdeploy/serve/turbomind/triton_models/preprocessing/config.pbtxt deleted file mode 100644 index 997ba399b..000000000 --- a/lmdeploy/serve/turbomind/triton_models/preprocessing/config.pbtxt +++ /dev/null @@ -1,37 +0,0 @@ -name: "preprocessing" -backend: "python" -max_batch_size: 1 - -input [ - { - name: "QUERY" - data_type: TYPE_STRING - dims: [ -1 ] - } -] -output [ - { - name: "INPUT_ID" - data_type: TYPE_UINT32 - dims: [ -1 ] - }, - { - name: "REQUEST_INPUT_LEN" - data_type: TYPE_UINT32 - dims: [ 1 ] - } -] - -instance_group [ - { - count: 4 - kind: KIND_CPU - } -] - -parameters { - key: "tokenizer_path" - value: { - string_value: "tokenizer/tokenizer.model" - } -} diff --git a/lmdeploy/serve/turbomind/triton_models/tokenizer/placeholder b/lmdeploy/serve/turbomind/triton_models/tokenizer/placeholder deleted file mode 100644 index e69de29bb..000000000 diff --git a/lmdeploy/serve/turbomind/triton_models/weights/config.ini b/lmdeploy/serve/turbomind/triton_models/weights/config.ini deleted file mode 100644 index e69de29bb..000000000 diff --git a/lmdeploy/serve/turbomind/utils.py b/lmdeploy/serve/turbomind/utils.py deleted file mode 100644 index 0a3d8bcf0..000000000 --- a/lmdeploy/serve/turbomind/utils.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from typing import List, Union - -import numpy as np -import tritonclient.grpc as grpcclient -from tritonclient.utils import np_to_triton_dtype - - -def prepare_tensor(name, input_tensor): - """Create grpcclient's InferInput instance according to a given tensor.""" - t = grpcclient.InferInput(name, list(input_tensor.shape), - np_to_triton_dtype(input_tensor.dtype)) - t.set_data_from_numpy(input_tensor) - return t - - -class Preprocessor: - """Tokenize prompts. - - Args: - tritonserver_addr (str): the communication address of the inference - server - """ - - def __init__(self, tritonserver_addr: str): - self.tritonserver_addr = tritonserver_addr - self.model_name = 'preprocessing' - - def __call__(self, *args, **kwargs): - return self.infer(*args, **kwargs) - - def infer(self, prompts: Union[str, List[str]]) -> tuple: - """Tokenize the input prompts. - - Args: - prompts(str | List[str]): user's prompt, or a batch prompts - - Returns: - Tuple(numpy.ndarray, numpy.ndarray, numpy.ndarray): prompt's token - ids, ids' length and requested output length - """ - if isinstance(prompts, str): - input0 = [[prompts]] - elif isinstance(prompts, List): - input0 = [[prompt] for prompt in prompts] - else: - assert 0, f'str or List[str] prompts are expected but got ' \ - f'{type(prompts)}' - - input0_data = np.array(input0).astype(object) - inputs = [prepare_tensor('QUERY', input0_data)] - - with grpcclient.InferenceServerClient(self.tritonserver_addr) as \ - client: - result = client.infer(self.model_name, inputs) - output0 = result.as_numpy('INPUT_ID') - output1 = result.as_numpy('REQUEST_INPUT_LEN') - return output0, output1 - - -class Postprocessor: - """De-tokenize prompts. - - Args: - tritonserver_addr (str): the communication address of the inference - server - """ - - def __init__(self, tritonserver_addr: str): - self.tritonserver_addr = tritonserver_addr - - def __call__(self, *args, **kwargs): - return self.infer(*args, **kwargs) - - def infer(self, - output_ids: np.ndarray, - seqlen: np.ndarray, - skip_special_tokens: bool = True): - """De-tokenize tokens for text. - - Args: - output_ids(np.ndarray): tokens' id - seqlen(np.ndarray): sequence length - skip_special_tokens (bool): Whether or not to remove special tokens - in the decoding. Default to be True. - - Returns: - str: decoded tokens - """ - inputs = [ - prepare_tensor('TOKENS_BATCH', output_ids), - prepare_tensor('sequence_length', seqlen), - prepare_tensor('skip_special_tokens', skip_special_tokens) - ] - inputs[0].set_data_from_numpy(output_ids) - inputs[1].set_data_from_numpy(seqlen) - inputs[2].set_data_from_numpy(skip_special_tokens) - model_name = 'postprocessing' - with grpcclient.InferenceServerClient(self.tritonserver_addr) \ - as client: - result = client.infer(model_name, inputs) - output0 = result.as_numpy('OUTPUT') - return output0 diff --git a/lmdeploy/turbomind/chat.py b/lmdeploy/turbomind/chat.py index 11d6e8a41..3800cd87d 100644 --- a/lmdeploy/turbomind/chat.py +++ b/lmdeploy/turbomind/chat.py @@ -2,11 +2,12 @@ import os import random +from lmdeploy.archs import get_model_arch from lmdeploy.messages import EngineGenerationConfig, TurbomindEngineConfig from lmdeploy.model import MODELS, ChatTemplateConfig, best_match_model from lmdeploy.serve.async_engine import deduce_a_name from lmdeploy.tokenizer import DetokenizeState -from lmdeploy.utils import _stop_words +from lmdeploy.utils import _get_and_verify_max_len, _stop_words log_level = 'ERROR' if os.getenv('TM_LOG_LEVEL') is None: @@ -91,10 +92,10 @@ def main(model_path: str, print('chat_template_config:\n', chat_template_config, sep='', flush=True) model = chat_template_config.chat_template - # engine - if session_len is None: - session_len = model.session_len + _, model_config = get_model_arch(model_path) + session_len = _get_and_verify_max_len(model_config, None) + # engine engine_cfg = TurbomindEngineConfig( max_batch_size=max_batch_size, model_name=model_name, diff --git a/lmdeploy/turbomind/deploy/converter.py b/lmdeploy/turbomind/deploy/converter.py index 2371d008f..6c93d39c4 100644 --- a/lmdeploy/turbomind/deploy/converter.py +++ b/lmdeploy/turbomind/deploy/converter.py @@ -1,9 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import os import os.path as osp -import re import shutil -from pathlib import Path import fire import torch @@ -20,12 +18,6 @@ SUPPORTED_FORMATS = ['meta_llama', 'hf', 'awq', None] -def get_package_root_path(): - """Get lmdeploy root path.""" - import lmdeploy - return Path(lmdeploy.__file__).parent - - def get_input_model_registered_name(model_path: str, model_format: str): """Get the registered name of a model. The name will be used to access the INPUT_MODELS registry. @@ -52,68 +44,41 @@ def create_workspace(_path: str): print(f'remove workspace in directory {_path}') shutil.rmtree(_path) print(f'create workspace in directory {_path}') - os.makedirs(_path) - - -def copy_triton_model_templates(_path: str): - """copy triton model templates to the specified path. - - Args: - _path (str): the target path - Returns: - str: the path of the triton models - """ - - root = get_package_root_path() - dir_path = osp.join(root, 'serve', 'turbomind') - triton_models_path = osp.join(dir_path, 'triton_models') - dst_path = osp.join(_path, 'triton_models') - print(f'copy triton model templates from "{triton_models_path}" to ' - f'"{dst_path}"') - shutil.copytree(triton_models_path, dst_path, symlinks=True) - service_docker_up_file = osp.join(dir_path, 'service_docker_up.sh') - print(f'copy service_docker_up.sh from "{service_docker_up_file}" to ' - f'"{_path}"') - shutil.copy(osp.join(dir_path, 'service_docker_up.sh'), _path) - return dst_path + weight_path = osp.join(_path, 'triton_models', 'weights') + tokenizer_path = osp.join(_path, 'triton_models', 'tokenizer') + os.makedirs(weight_path) + os.makedirs(tokenizer_path) + return weight_path, tokenizer_path def copy_tokenizer(model_path: str, tokenizer_path: str, - triton_models_path: str, trust_remote_code: bool): + tm_tokenizer_path: str): """Copy tokenizer.""" + if tokenizer_path is not None: assert osp.exists(tokenizer_path), f'{tokenizer_path} does not exists.' - shutil.copy( - tokenizer_path, - osp.join(triton_models_path, - osp.join('tokenizer', osp.basename(tokenizer_path)))) + shutil.copy(tokenizer_path, + osp.join(tm_tokenizer_path, osp.basename(tokenizer_path))) else: from transformers import AutoTokenizer try: - _ = AutoTokenizer.from_pretrained( - model_path, trust_remote_code=trust_remote_code) - except Exception: - assert 0, ( - f'Failed to load tokenizer model from path {model_path}.' - 'please specify tokenizer path by --tokenizer-path') + _ = AutoTokenizer.from_pretrained(model_path, + trust_remote_code=True) + except Exception as e: + assert 0, f'{e}' # move tokenizer model to the target path candidate = ['tokenizer.model', 'qwen.tiktoken', 'merges.txt'] for name in candidate: tmp_path = osp.join(model_path, name) if osp.exists(tmp_path): - shutil.copy(tmp_path, - osp.join(triton_models_path, 'tokenizer', name)) - # move py/json files that are related to tokenizer to the target path + shutil.copy(tmp_path, osp.join(tm_tokenizer_path, name)) + # copy py/json files that are related to tokenizer to the target path for _file in os.listdir(model_path): if _file.endswith('.json') or _file.endswith('.py'): json_path = osp.join(model_path, _file) - shutil.copy(json_path, - osp.join(triton_models_path, 'tokenizer', _file)) - with get_package_root_path() as root_path: - shutil.copy(osp.join(root_path, 'tokenizer.py'), - osp.join(triton_models_path, 'tokenizer')) + shutil.copy(json_path, osp.join(tm_tokenizer_path, _file)) def get_output_model_registered_name_and_config(model_path: str, @@ -205,7 +170,6 @@ def main(model_name: str, tp: int = 1, quant_path: str = None, group_size: int = 0, - trust_remote_code: bool = False, revision: str = None, download_dir: str = None, **kwargs): @@ -226,8 +190,6 @@ def main(model_name: str, quant_path (str): Path of the quantized model, which can be None. group_size (int): a parameter used in AWQ to quantize fp16 weights to 4 bits - trust_remote_code (bool): Whether or not to allow for custom models - defined on the Hub in their own modeling files. Defaults to False revision (str): The specific model version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. @@ -279,45 +241,20 @@ def main(model_name: str, cfg.model_name = model_name cfg.tensor_para_size = tp - create_workspace(dst_path) + tm_weight_path, tm_tokenizer_path = create_workspace(dst_path) - triton_models_path = copy_triton_model_templates(dst_path) + copy_tokenizer(model_path, tokenizer_path, tm_tokenizer_path) - copy_tokenizer(model_path, tokenizer_path, triton_models_path, - trust_remote_code) - - weight_path = osp.join(triton_models_path, 'weights') input_model = INPUT_MODELS.get(input_model_name)( model_path=model_path, tokenizer_path=tokenizer_path, ckpt_path=quant_path) output_model = OUTPUT_MODELS.get(output_model_name)( - input_model=input_model, cfg=cfg, to_file=True, out_dir=weight_path) + input_model=input_model, cfg=cfg, to_file=True, out_dir=tm_weight_path) print(f'turbomind model config: {output_model.cfg}') output_model.export() - # update `tensor_para_size` in `triton_models/interactive/config.pbtxt` - with open(osp.join(triton_models_path, 'interactive', 'config.pbtxt'), - 'a') as f: - param = \ - 'parameters {\n key: "tensor_para_size"\n value: {\n ' \ - 'string_value: ' + f'"{tp}"\n' + ' }\n}\n' + \ - 'parameters {\n key: "model_name"\n value: {\n ' \ - 'string_value: ' + f'"{model_name}"\n' + ' }\n}\n' - f.write(param) - - # pack model repository for triton inference server - pack_model_repository(dst_path) - - # update the value of $TP in `service_docker_up.sh` - file_path = osp.join(dst_path, 'service_docker_up.sh') - with open(file_path, 'r') as f: - content = f.read() - content = re.sub('TP=1', f'TP={tp}', content) - with open(file_path, 'w') as f: - f.write(content) - if __name__ == '__main__': fire.Fire(main) diff --git a/lmdeploy/turbomind/deploy/source_model/__init__.py b/lmdeploy/turbomind/deploy/source_model/__init__.py index cbbb2ce41..5ba4badb5 100644 --- a/lmdeploy/turbomind/deploy/source_model/__init__.py +++ b/lmdeploy/turbomind/deploy/source_model/__init__.py @@ -4,6 +4,7 @@ from .deepseek_vl import DeepSeekVLModel # noqa: F401 from .deepseek_vl_awq import DeepSeekVLAwqModel # noqa: F401 from .glm4 import Glm4Model # noqa: F401 +from .glm4_awq import Glm4AwqModel # noqa: F401 from .internlm2 import InternLM2AwqModel, InternLM2Model # noqa: F401 from .internvl import InternVLModel # noqa: F401 from .llama import LlamaModel # noqa: F401 diff --git a/lmdeploy/turbomind/deploy/source_model/glm4_awq.py b/lmdeploy/turbomind/deploy/source_model/glm4_awq.py new file mode 100644 index 000000000..f27a5f6ce --- /dev/null +++ b/lmdeploy/turbomind/deploy/source_model/glm4_awq.py @@ -0,0 +1,93 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +import os.path as osp + +import torch + +from .base import INPUT_MODELS +from .glm4 import Glm4Model, Glm4Reader + + +class Glm4AwqReader(Glm4Reader): + """Glm4AwqReader.""" + + attn_layer_patten = r'transformer.encoder.layers.([0-9]+).' + tok_embeddings_key = 'transformer.embedding.word_embeddings.weight' + norm_weight_key = 'transformer.encoder.final_layernorm.weight' + output_weight_key = 'transformer.output_layer.weight' + + def __init__(self, new_params: dict, unused_params: dict, last_bin: bool, + model_cfg: dict): + super().__init__(new_params, unused_params, last_bin, model_cfg) + + def _attn(self, i: int, kind: str, size_dim: int, dim: int = 0): + """Get q, k, v, o kind for layer i.""" + qkv = self.params[f'transformer.encoder.layers.{i}' + f'.self_attention.query_key_value.{kind}'] + attn_head_num = self.model_cfg['num_attention_heads'] + kv_head_num = attn_head_num + if self.model_cfg.get('multi_query_attention', False): + kv_head_num = self.model_cfg['multi_query_group_num'] + HEAD_DIM = int(qkv.shape[size_dim] / (attn_head_num + kv_head_num * 2)) + q, k, v = torch.split(qkv, [ + attn_head_num * HEAD_DIM, kv_head_num * HEAD_DIM, + kv_head_num * HEAD_DIM + ], + dim=size_dim) + o = self.params.get( + f'transformer.encoder.layers.{i}.self_attention.dense.{kind}', + None) + if o is None: # handle the case when qkv has bias but o doesn't + o = torch.zeros_like(q) + return q, k, v, o + + def attn(self, i: int): + """Get q, k, v, o qweight for layer i.""" + return self._attn(i, 'qweight', -1, -1) + + def attn_zero(self, i: int): + """Get q, k, v, o qzeros for layer i.""" + return self._attn(i, 'qzeros', -1, -1) + + def attn_scale(self, i: int): + """Get q, k, v, o scales for layer i.""" + return self._attn(i, 'scales', -1, -1) + + def attn_bias(self, i: int): + """Get q, k, v, o bias for layer i.""" + return self._attn(i, 'bias', -1, 0) + + def _ffn(self, i: int, kind: str): + """Get ffn kind for layer i.""" + up_and_gate = self.params[ + f'transformer.encoder.layers.{i}.mlp.dense_h_to_4h.{kind}'] + up, gate = up_and_gate.chunk(2, dim=-1) + down = self.params[ + f'transformer.encoder.layers.{i}.mlp.dense_4h_to_h.{kind}'] + + return (up, down, gate) + + def ffn(self, i: int): + """Get ffn weight for layer i.""" + return self._ffn(i, 'qweight') + + def ffn_zero(self, i: int): + """Get ffn zero point for layer i.""" + return self._ffn(i, 'qzeros') + + def ffn_scale(self, i: int): + """Get ffn scale for layer i.""" + return self._ffn(i, 'scales') + + +@INPUT_MODELS.register_module(name='glm4-awq') +class Glm4AwqModel(Glm4Model): + """Glm2/3/4 model in hf format.""" + + Reader = Glm4AwqReader + + def __init__(self, model_path: str, tokenizer_path: str, **kwargs): + super().__init__(model_path, tokenizer_path, **kwargs) + config_path = osp.join(self.model_path, 'config.json') + with open(config_path) as f: + self.config = json.load(f) diff --git a/lmdeploy/turbomind/deploy/source_model/internvl.py b/lmdeploy/turbomind/deploy/source_model/internvl.py index 7f5a7c6ff..6b4dc2beb 100644 --- a/lmdeploy/turbomind/deploy/source_model/internvl.py +++ b/lmdeploy/turbomind/deploy/source_model/internvl.py @@ -94,6 +94,11 @@ class InternVLAwqReader(LlamaAwqReader): norm_weight_key = 'language_model.model.norm.weight' output_weight_key = 'language_model.lm_head.weight' + def __init__(self, new_params: dict, unused_params: dict, last_bin: bool, + model_cfg: dict): + model_cfg = model_cfg.get('llm_config') + super().__init__(new_params, unused_params, last_bin, model_cfg) + class InternVL2AwqReader(InternLM2AwqReader): """InternVLReader for InternLM2 model.""" diff --git a/lmdeploy/turbomind/deploy/source_model/llama.py b/lmdeploy/turbomind/deploy/source_model/llama.py index bf1307c20..ca1af9048 100644 --- a/lmdeploy/turbomind/deploy/source_model/llama.py +++ b/lmdeploy/turbomind/deploy/source_model/llama.py @@ -203,17 +203,39 @@ def model_info(self): rope_scaling = model_arg.get('rope_scaling', None) scaling_factor = 0.0 use_dynamic_ntk = 0 + scaling_type = '' + low_freq_factor = 1.0 + high_freq_factor = 1.0 + original_max_position_embeddings = 0 if isinstance(rope_scaling, dict): - scaling_type = model_arg['rope_scaling'].get('type', '') + llama2_scaling_type = model_arg['rope_scaling'].get('type', '') + llama3_scaling_type = model_arg['rope_scaling'].get( + 'rope_type', '') scaling_factor = model_arg['rope_scaling'].get('factor', '') + low_freq_factor = model_arg['rope_scaling'].get( + 'low_freq_factor', 1.0) + high_freq_factor = model_arg['rope_scaling'].get( + 'high_freq_factor', 1.0) + original_max_position_embeddings = model_arg[ + 'rope_scaling'].get('original_max_position_embeddings', 0) + if llama2_scaling_type and llama3_scaling_type: + raise ValueError( + f'Ambiguous rope_scaling in config: {model_arg}') + scaling_type = llama2_scaling_type if llama2_scaling_type \ + else llama3_scaling_type if scaling_type == 'dynamic': use_dynamic_ntk = 1 - return dict(num_layer=num_layer, - norm_eps=norm_eps, - attn_head_num=attn_head_num, - kv_head_num=kv_head_num, - rope_theta=rope_theta, - max_position_embeddings=max_position_embeddings, - use_dynamic_ntk=use_dynamic_ntk, - rope_scaling_factor=scaling_factor) + return dict( + num_layer=num_layer, + norm_eps=norm_eps, + attn_head_num=attn_head_num, + kv_head_num=kv_head_num, + rope_theta=rope_theta, + max_position_embeddings=max_position_embeddings, + original_max_position_embeddings=original_max_position_embeddings, + use_dynamic_ntk=use_dynamic_ntk, + rope_scaling_type=scaling_type, + rope_scaling_factor=scaling_factor, + low_freq_factor=low_freq_factor, + high_freq_factor=high_freq_factor) diff --git a/lmdeploy/turbomind/deploy/source_model/xcomposer2.py b/lmdeploy/turbomind/deploy/source_model/xcomposer2.py index d8c94b361..a57762302 100644 --- a/lmdeploy/turbomind/deploy/source_model/xcomposer2.py +++ b/lmdeploy/turbomind/deploy/source_model/xcomposer2.py @@ -1,7 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -import json -import os.path as osp - import torch from .base import INPUT_MODELS @@ -111,10 +108,9 @@ def _lora_cfg_4khd_7b(self, model_info: dict): def model_info(self): out = super().model_info() - params_path = osp.join(self.model_path, 'config.json') - with open(params_path) as f: - model_arg = json.load(f) - if model_arg['max_length'] == 16384: + from lmdeploy.vl.model.xcomposer2 import ModelType, get_xcomposer_type + model_type, _ = get_xcomposer_type(self.model_path) + if model_type == ModelType.XCOMPOSER2_4KHD: out.update(self._lora_cfg_4khd_7b(out)) else: out.update(self._lora_cfg_7b()) diff --git a/lmdeploy/turbomind/deploy/target_model/base.py b/lmdeploy/turbomind/deploy/target_model/base.py index 525cd2222..be0ad4a2d 100644 --- a/lmdeploy/turbomind/deploy/target_model/base.py +++ b/lmdeploy/turbomind/deploy/target_model/base.py @@ -66,8 +66,12 @@ class TurbomindModelConfig: use_context_fmha: int = 1 quant_policy: int = 0 max_position_embeddings: int = 0 + original_max_position_embeddings: int = 0 + rope_scaling_type: str = '' rope_scaling_factor: float = 0.0 use_dynamic_ntk: int = 0 + low_freq_factor: float = 1.0 + high_freq_factor: float = 1.0 use_logn_attn: int = 0 lora_policy: str = '' lora_r: int = 0 diff --git a/lmdeploy/turbomind/deploy/target_model/plora_w4.py b/lmdeploy/turbomind/deploy/target_model/plora_w4.py index 02d00df6a..55f8bf48d 100644 --- a/lmdeploy/turbomind/deploy/target_model/plora_w4.py +++ b/lmdeploy/turbomind/deploy/target_model/plora_w4.py @@ -82,6 +82,8 @@ def export_transformer_block(self, bin: BaseReader, i: int): w1_qw, w1_sz = convert_s4(w1_qw, w1_qz, w1_s, group_size) w3_qw, w3_sz = convert_s4(w3_qw, w3_qz, w3_s, group_size) + w1_qw = tp_m_s4(w1_qw, tp) + w3_qw = tp_m_s4(w3_qw, tp) self.save_split(w1_qw, f'layers.{i}.feed_forward.w1.qweight', -1) self.save_split(w1_sz, f'layers.{i}.feed_forward.w1.scales_zeros', -1) self.save_split(w3_qw, f'layers.{i}.feed_forward.w3.qweight', -1) diff --git a/lmdeploy/turbomind/deploy/target_model/w4.py b/lmdeploy/turbomind/deploy/target_model/w4.py index 282c7df60..5cb944a2c 100644 --- a/lmdeploy/turbomind/deploy/target_model/w4.py +++ b/lmdeploy/turbomind/deploy/target_model/w4.py @@ -110,12 +110,13 @@ def export_transformer_block(self, bin: BaseReader, i: int): q_qz, k_qz, v_qz, o_qz = get_cuda_tensor(bin.attn_zero(i)) q_s, k_s, v_s, o_s = get_cuda_tensor(bin.attn_scale(i)) - q_qw = transpose_qk_s4(q_qw, group_size) - k_qw = transpose_qk_s4(k_qw, group_size) - q_qz = transpose_qk_s4(q_qz, group_size) - k_qz = transpose_qk_s4(k_qz, group_size) - q_s = permute(q_s, size_per_head) - k_s = permute(k_s, size_per_head) + if self.permute_qk: + q_qw = transpose_qk_s4(q_qw, group_size) + k_qw = transpose_qk_s4(k_qw, group_size) + q_qz = transpose_qk_s4(q_qz, group_size) + k_qz = transpose_qk_s4(k_qz, group_size) + q_s = permute(q_s, size_per_head) + k_s = permute(k_s, size_per_head) qkv_qw = merge_qkv(q_qw, k_qw, v_qw, tp, dim=2) qkv_qz = merge_qkv(q_qz, k_qz, v_qz, tp, dim=2) @@ -132,8 +133,12 @@ def export_transformer_block(self, bin: BaseReader, i: int): q_b, k_b, v_b, o_b = get_cuda_tensor(bin.attn_bias(i)) if q_b is not None: - q_b = permute(q_b, size_per_head) - k_b = permute(k_b, size_per_head) + if self.permute_qk: + q_b = permute(q_b, size_per_head) + k_b = permute(k_b, size_per_head) + else: + q_b = q_b[None, :] + k_b = k_b[None, :] qkv_b = merge_qkv(q_b, k_b, v_b, tp, dim=1) self.save_split(qkv_b, f'layers.{i}.attention.w_qkv.bias', -1) self.save_split(o_b, f'layers.{i}.attention.wo.bias', copy=True) diff --git a/lmdeploy/turbomind/supported_models.py b/lmdeploy/turbomind/supported_models.py index 5f0d2b83a..8ebb93fdf 100644 --- a/lmdeploy/turbomind/supported_models.py +++ b/lmdeploy/turbomind/supported_models.py @@ -37,7 +37,8 @@ MGMLlamaForCausalLM='llama', MiniGeminiLlamaForCausalLM='llama', # chatglm2/3, glm4 - ChatGLMModel='glm4') + ChatGLMModel='glm4', + ChatGLMForConditionalGeneration='glm4') def is_supported(model_path: str): @@ -87,7 +88,7 @@ def _is_head_dim_128(cfg): # qwen2 0.5b size_per_head is 64, which hasn't been supported # by turbomind yet support_by_turbomind = _is_head_dim_128(cfg) - elif arch == 'ChatGLMModel': + elif arch in ('ChatGLMModel', 'ChatGLMForConditionalGeneration'): # chatglm1/2/3 is not working yet support_by_turbomind = cfg.num_layers == 40 if getattr(cfg, 'vision_config', None) is not None: diff --git a/lmdeploy/turbomind/turbomind.py b/lmdeploy/turbomind/turbomind.py index b11ef6c1c..4ae170cda 100644 --- a/lmdeploy/turbomind/turbomind.py +++ b/lmdeploy/turbomind/turbomind.py @@ -2,9 +2,9 @@ import asyncio import os.path as osp import sys +from concurrent.futures import ThreadPoolExecutor from configparser import ConfigParser from queue import LifoQueue, Queue -from threading import Thread from typing import Dict, Iterable, List, Optional, Union import numpy as np @@ -162,13 +162,12 @@ def _create_weight_func(device_id): rank = self.node_id * self.gpu_count + device_id model_comm.create_shared_weights(device_id, rank) - threads = [] - for device_id in range(self.gpu_count): - t = Thread(target=_create_weight_func, args=(device_id, )) - t.start() - threads.append(t) - for t in threads: - t.join() + with ThreadPoolExecutor(max_workers=self.gpu_count) as executor: + futures = [] + for device_id in range(self.gpu_count): + futures.append(executor.submit(_create_weight_func, device_id)) + for future in futures: + future.result() def _get_model_params(self, model_comm, tm_params): """Get turbomind model params when loading from hf.""" @@ -179,13 +178,12 @@ def _get_params(device_id, que): que.put(out) que = Queue() - threads = [] - for device_id in range(self.gpu_count): - t = Thread(target=_get_params, args=(device_id, que)) - t.start() - threads.append(t) - for t in threads: - t.join() + with ThreadPoolExecutor(max_workers=self.gpu_count) as executor: + futures = [] + for device_id in range(self.gpu_count): + futures.append(executor.submit(_get_params, device_id, que)) + for future in futures: + future.result() for _ in range(self.gpu_count): tensor_map = que.get() @@ -380,18 +378,19 @@ def __init__(self, tm_model: TurboMind, cuda_stream_id: int = 0): # create model instances model_insts = [None] * self.gpu_count - threads = [] - for device_id in range(self.gpu_count): - t = Thread(target=self._create_model_instance, - args=(device_id, model_insts)) - t.start() - threads.append(t) - for t in threads: - t.join() + with ThreadPoolExecutor(max_workers=self.gpu_count) as executor: + futures = [] + for device_id in range(self.gpu_count): + futures.append( + executor.submit(self._create_model_instance, device_id, + model_insts)) + for future in futures: + future.result() self.model_insts = model_insts self.que = Queue() - self.threads = [None] * self.gpu_count + self.executor: ThreadPoolExecutor = None + self.futures = [None] * self.gpu_count def _create_model_instance(self, device_id, model_insts): rank = self.node_id * self.gpu_count + device_id @@ -411,12 +410,10 @@ def _func(device_id, enque_output): if enque_output: self.que.put((True, output)) + self.executor = ThreadPoolExecutor(self.gpu_count) for device_id in range(self.gpu_count): - t = Thread(target=_func, - args=(device_id, device_id == 0), - daemon=True) - t.start() - self.threads[device_id] = t + f = self.executor.submit(_func, device_id, device_id == 0) + self.futures[device_id] = f def _async_forward_callback(self, result, ctx, que: LifoQueue): que.put((False, result)) @@ -430,12 +427,10 @@ def _func(device_id, enque_output): if enque_output: que.put((True, output)) + self.executor = ThreadPoolExecutor(self.gpu_count) for device_id in range(self.gpu_count): - t = Thread(target=_func, - args=(device_id, device_id == 0), - daemon=True) - t.start() - self.threads[device_id] = t + f = self.executor.submit(_func, device_id, device_id == 0) + self.futures[device_id] = f def _update_generation_config(self, config: EngineGenerationConfig, **kwargs: dict): @@ -772,8 +767,9 @@ async def async_stream_infer(self, yield outputs if finish: - for t in self.threads: - t.join() + for f in self.futures: + f.result() + self.executor.shutdown() break if stream_output and not stop: @@ -884,8 +880,9 @@ def stream_infer(self, yield outputs if finish: - for t in self.threads: - t.join() + for f in self.futures: + f.result() + self.executor.shutdown() while self.que.qsize() > 0: self.que.get() break diff --git a/lmdeploy/version.py b/lmdeploy/version.py index bcb5916b2..28e376f14 100644 --- a/lmdeploy/version.py +++ b/lmdeploy/version.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from typing import Tuple -__version__ = '0.5.0' +__version__ = '0.5.2.post1' short_version = __version__ diff --git a/lmdeploy/vl/model/xcomposer2.py b/lmdeploy/vl/model/xcomposer2.py index 67e49cffa..1d4393591 100644 --- a/lmdeploy/vl/model/xcomposer2.py +++ b/lmdeploy/vl/model/xcomposer2.py @@ -1,16 +1,43 @@ # Copyright (c) OpenMMLab. All rights reserved. +import enum import warnings from contextlib import contextmanager -from typing import List +from typing import Any, List, Tuple import torch from PIL.Image import Image from transformers import AutoConfig, AutoModelForCausalLM +from lmdeploy.utils import get_logger from lmdeploy.vl.model.base import VISION_MODELS, VisonModel -from lmdeploy.vl.model.utils import (add_device_hook, add_sys_path, - disable_logging, rewrite_ctx) +from lmdeploy.vl.model.utils import (add_device_hook, disable_logging, + rewrite_ctx) + +logger = get_logger('lmdeploy') + + +class ModelType(enum.Enum): + """Request type.""" + XCOMPOSER2 = enum.auto() + XCOMPOSER2_4KHD = enum.auto() + XCOMPOSER2D5 = enum.auto() + + +def get_xcomposer_type(model_path: str) -> Tuple[ModelType, Any]: + """get xcomposer type.""" + from transformers.dynamic_module_utils import get_class_from_dynamic_module + match_modules = { + 'ixc_utils.Image_transform': ModelType.XCOMPOSER2D5, + 'ixc_utils.HD_transform': ModelType.XCOMPOSER2_4KHD + } + for key, value in match_modules.items(): + try: + module = get_class_from_dynamic_module(key, model_path) + return value, module + except Exception: + pass + return ModelType.XCOMPOSER2, None def _CLIPVisionModel_from_pretrained(vision_tower_name): @@ -66,15 +93,16 @@ def build_model(self): self.vl_model = model # additional components. - with add_sys_path(self.model_path): - try: - # internlm-xcomposer2-4khd-7b - from ixc_utils import HD_transform - self.HD_transform = HD_transform - self._forward_func = self._forward_4khd_7b - except: # noqa - # internlm-xcomposer2-7b - self._forward_func = self._forward_7b + model_type, module = get_xcomposer_type(self.model_path) + logger.info(f'matching type of {model_type}') + if model_type == ModelType.XCOMPOSER2D5: + self.HD_transform = module + self._forward_func = self._forward_2d5 + elif model_type == ModelType.XCOMPOSER2_4KHD: + self.HD_transform = module + self._forward_func = self._forward_4khd_7b + else: + self._forward_func = self._forward_7b from accelerate.utils import get_balanced_memory, infer_auto_device_map max_memory = get_balanced_memory( @@ -108,6 +136,22 @@ def build_model(self): self.model = model.eval() + def _forward_2d5(self, images: List[Image]) -> List[torch.Tensor]: + """internlm-xcomposer2d5-7b vit forward.""" + outputs = [x.convert('RGB') for x in images] + hd_num = 6 if len(images) > 1 else 24 + outputs = [self.HD_transform(x, hd_num=hd_num) for x in outputs] + outputs = [ + self.model.vis_processor(x).unsqueeze(0).to(dtype=torch.half) + for x in outputs + ] + embeds, split = self.model.vit(outputs, self.model.plora_glb_GN, + self.model.plora_sub_GN) + embeds = self.model.vision_proj(embeds) + embeds = torch.split(embeds, split, dim=1) + embeds = [x.squeeze() for x in embeds] + return embeds + def _forward_7b(self, images: List[Image]) -> List[torch.Tensor]: """internlm-xcomposer2-7b vit forward.""" outputs = [x.convert('RGB') for x in images] diff --git a/lmdeploy/vl/templates.py b/lmdeploy/vl/templates.py index 1206d9285..a75938376 100644 --- a/lmdeploy/vl/templates.py +++ b/lmdeploy/vl/templates.py @@ -250,6 +250,8 @@ class InternLMXComposer2TemplateWrapper(VLChatTemplateWrapper): """InternLM-XComposer2 chat template.""" def append_image_token(self, prompt, num_images: int): + logger.warning(f'auto append {IMAGE_TOKEN} at the beginning, ' + 'the user can manually insert the token to prompt') return ' '.join([IMAGE_TOKEN] * num_images) + prompt diff --git a/lmdeploy/vl/tools/__init__.py b/lmdeploy/vl/tools/__init__.py new file mode 100644 index 000000000..ef101fec6 --- /dev/null +++ b/lmdeploy/vl/tools/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/lmdeploy/vl/tools/merge_xcomposer2d5_task.py b/lmdeploy/vl/tools/merge_xcomposer2d5_task.py new file mode 100644 index 000000000..34625ad47 --- /dev/null +++ b/lmdeploy/vl/tools/merge_xcomposer2d5_task.py @@ -0,0 +1,65 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import shutil + +import fire +import torch +from tqdm import tqdm + +from lmdeploy.vl.model.builder import vl_model_with_tokenizer + + +def main(src_path: str, dst_path: str, task: str): + """merge internlm-xcomposer2d5-7b LoRA model weights. + + Args: + src_path (str): the source model path of internlm-xcomposer2d5-7b + dst_path (str): the target model path of merged model + task (str): the task of source model, should choose from + ['web', 'write'] + """ + if os.path.exists(dst_path): + shutil.rmtree(dst_path) + + to_merged = dict(web=['lora_web'], write=['lora_sft', 'lora_dpo']) + keys = to_merged[task] + + # load model + model, _, tokenizer = vl_model_with_tokenizer(src_path) + + # merge lora weight to base model + @torch.inference_mode + def _merge(module: torch.nn.Module, lora_weights): + # merge lora weight first to reduce precision loss + mw = None + for wa, wb in lora_weights: + if mw is None: + mw = (wb.float() @ wa.float()) + else: + mw += (wb.float() @ wa.float()) + ow = module.weight + mw += ow.float() + module.weight.data = mw.half() + + def _extract_lora(module: torch.nn.Module, keys: str): + lora_weights = [] + for key in keys: + lora_a_key = f'{key}_A' + lora_b_key = f'{key}_B' + wa = getattr(module, lora_a_key).weight + wb = getattr(module, lora_b_key).weight + lora_weights.append((wa, wb)) + return lora_weights + + for _, module in tqdm(model.named_modules()): + if type(module).__name__ == 'PLoRA': + lora_weights = _extract_lora(module, keys) + _merge(module, lora_weights) + + # save model + model.save_pretrained(dst_path) + tokenizer.save_pretrained(dst_path) + + +if __name__ == '__main__': + fire.Fire(main) diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 339debf03..9564e7e89 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -13,8 +13,8 @@ safetensors sentencepiece shortuuid tiktoken -torch<=2.2.2,>=2.0.0 -torchvision<=0.17.2,>=0.15.0 +torch<=2.3.1,>=2.0.0 +torchvision<=0.18.1,>=0.15.0 transformers -triton>=2.1.0,<=2.2.0; sys_platform == "linux" +triton>=2.1.0,<=2.3.1; sys_platform == "linux" uvicorn diff --git a/src/turbomind/kernels/attention/array_ops.h b/src/turbomind/kernels/attention/array_ops.h index 005dab5a5..3c92dfc79 100644 --- a/src/turbomind/kernels/attention/array_ops.h +++ b/src/turbomind/kernels/attention/array_ops.h @@ -242,7 +242,14 @@ struct FastRoPE { Array inv_freq_; bool is_valid_; - __device__ FastRoPE(int idx, D dims, float base, float ti_scale, std::integral_constant) + __device__ FastRoPE(int idx, + D dims, + float base, + float ti_scale, + float llama3_inv_scaling_factor, + float llama3_alpha, + float llama3_beta, + std::integral_constant) { is_valid_ = idx < dims; /// TODO: Take this away from device code @@ -251,6 +258,25 @@ struct FastRoPE { for (int i = 0; i < N; i += 2) { inv_freq_[i / 2] = ti_scale * exp2f((idx + i) * scale_factor); } + // clang-format off + /* The [llama3 rope](https://github.com/huggingface/transformers/blob/5f4ee98a7ade33e1c54fdd6181d04ee7b426b392/src/transformers/modeling_rope_utils.py#L298) + * used by llama3.1 equals to the following equation, given the precommuted parameters as: + ```C++ + inv_scaling_factor = 1 / factor; + inv_diff_freq_factor = 1 / (high_freq_factor - low_freq_factor); + alpha = old_context_len / (2 * PI) * inv_diff_freq_factor; + beta = low_freq_factor * inv_diff_freq_factor + ``` + */ + // clang-format on + if (llama3_inv_scaling_factor) { + PRAGMA_UNROLL + for (int i = 0; i < N; i += 2) { + auto freq = inv_freq_[i / 2]; + auto smooth = fmaxf(0.f, fminf(1.f, llama3_alpha * freq - llama3_beta)); + inv_freq_[i / 2] = (1 - smooth) * freq * llama3_inv_scaling_factor + smooth * freq; + } + } } template diff --git a/src/turbomind/kernels/attention/attention_params.h b/src/turbomind/kernels/attention/attention_params.h index a17ae146e..e244de04e 100644 --- a/src/turbomind/kernels/attention/attention_params.h +++ b/src/turbomind/kernels/attention/attention_params.h @@ -60,6 +60,10 @@ struct AttentionParams { float rotary_embedding_base; int max_position_embeddings; float rope_ti_scale; // used for linear RoPE scaling + // the following 3 parameters are used by llama3 + float llama3_inv_scaling_factor; + float llama3_alpha; + float llama3_beta; // log(n) attention bool use_logn_attn; diff --git a/src/turbomind/kernels/attention/attention_universal.h b/src/turbomind/kernels/attention/attention_universal.h index 2b6e69556..1d6466ba8 100644 --- a/src/turbomind/kernels/attention/attention_universal.h +++ b/src/turbomind/kernels/attention/attention_universal.h @@ -232,6 +232,9 @@ struct AttentionUniversal { params.rotary_embedding_dim, rope_base, params.rope_ti_scale, + params.llama3_inv_scaling_factor, + params.llama3_alpha, + params.llama3_beta, std::integral_constant{}); PRAGMA_UNROLL for (int s = 0; s < ITER_S; ++s) { diff --git a/src/turbomind/kernels/attention/kv_cache_utils_v2.cu b/src/turbomind/kernels/attention/kv_cache_utils_v2.cu index 484f1e7d6..984a287ff 100644 --- a/src/turbomind/kernels/attention/kv_cache_utils_v2.cu +++ b/src/turbomind/kernels/attention/kv_cache_utils_v2.cu @@ -23,6 +23,9 @@ __global__ void __launch_bounds__(128) ProcessKV_v2(char** blocks, const float* rope_base, int rope_dim, float rope_ti_scale, + float llama3_inv_scaling_factor, + float llama3_alpha, + float llama3_beta, int64_t stride_b, int64_t stride_c, int64_t stride_h, @@ -121,7 +124,14 @@ __global__ void __launch_bounds__(128) ProcessKV_v2(char** blocks, PRAGMA_UNROLL for (int c = 0; c < ITER_C; ++c) { const int di = offset.x + c * Map::kDeltaC; - FastRoPE rope(di, rope_dim, base, rope_ti_scale, std::integral_constant{}); + FastRoPE rope(di, + rope_dim, + base, + rope_ti_scale, + llama3_inv_scaling_factor, + llama3_alpha, + llama3_beta, + std::integral_constant{}); PRAGMA_UNROLL for (int s = 0; s < ITER_S; ++s) { const int ti = history_len + offset.y + s * Map::kDeltaS + token_idx; // sequence local @@ -194,6 +204,9 @@ void invokeProcessKV_v2(char** blocks, const float* rope_base, int rope_dim, float rope_ti_scale, + float llama3_inv_scaling_factor, + float llama3_1_alpha, + float llama3_1_beta, int64_t stride_b, int64_t stride_c, int64_t stride_h, @@ -232,6 +245,9 @@ void invokeProcessKV_v2(char** blocks, rope_base, rope_dim, rope_ti_scale, + llama3_inv_scaling_factor, + llama3_1_alpha, + llama3_1_beta, stride_b, stride_c, stride_h, @@ -263,6 +279,9 @@ void invokeProcessKV_v2(char** blocks, const float* rope_base, \ int rope_dim, \ float rope_ti_scale, \ + float llama3_inv_scaling_factor, \ + float llama3_1_alpha, \ + float llama3_1_beta, \ int64_t stride_b, \ int64_t stride_c, \ int64_t stride_h, \ @@ -290,6 +309,9 @@ __global__ void __launch_bounds__(128) flattenKV_v2(T* k, const float* rope_base, int rope_dim, float rope_ti_scale, + float llama3_inv_scaling_factor, + float llama3_alpha, + float llama3_beta, int64_t stride_b, int64_t stride_c, int64_t stride_h, @@ -371,7 +393,14 @@ __global__ void __launch_bounds__(128) flattenKV_v2(T* k, PRAGMA_UNROLL for (int c = 0; c < ITER_C; ++c) { const int di = offset.x + c * Map::kDeltaC; - FastRoPE rope(di, rope_dim, base, rope_ti_scale, std::integral_constant{}); + FastRoPE rope(di, + rope_dim, + base, + rope_ti_scale, + llama3_inv_scaling_factor, + llama3_alpha, + llama3_beta, + std::integral_constant{}); PRAGMA_UNROLL for (int s = 0; s < ITER_S; ++s) { const int ti = offset.y + s * Map::kDeltaS + token_idx; // sequence local @@ -405,6 +434,9 @@ void invokeFlattenKV_v2(T* k, const float* rope_base, int rope_dim, float rope_ti_scale, + float llama3_inv_scaling_factor, + float llama3_alpha, + float llama3_beta, int64_t stride_b, int64_t stride_c, int64_t stride_h, @@ -440,6 +472,9 @@ void invokeFlattenKV_v2(T* k, rope_base, rope_dim, rope_ti_scale, + llama3_inv_scaling_factor, + llama3_alpha, + llama3_beta, stride_b, stride_c, stride_h, @@ -468,6 +503,9 @@ void invokeFlattenKV_v2(T* k, const float* rope_base, \ int rope_dim, \ float rope_ti_scale, \ + float llama3_inv_scaling_factor, \ + float llama3_alpha, \ + float llama3_beta, \ int64_t stride_b, \ int64_t stride_c, \ int64_t stride_h, \ diff --git a/src/turbomind/kernels/attention/kv_cache_utils_v2.h b/src/turbomind/kernels/attention/kv_cache_utils_v2.h index 695f88993..eced5d64e 100644 --- a/src/turbomind/kernels/attention/kv_cache_utils_v2.h +++ b/src/turbomind/kernels/attention/kv_cache_utils_v2.h @@ -19,6 +19,9 @@ void invokeProcessKV_v2(char** blocks, const float* rope_base, int rope_dim, float rope_ti_scale, + float llama3_inv_scaling_factor, + float llama3_alpha, + float llama3_beta, int64_t stride_b, int64_t stride_c, int64_t stride_h, @@ -46,6 +49,9 @@ void invokeProcessKV_v2_(const AttentionParams& params) params.rope_theta, params.rotary_embedding_dim, params.rope_ti_scale, + params.llama3_inv_scaling_factor, + params.llama3_alpha, + params.llama3_beta, 0, // stride b params.stride / params.size_per_head, // stride c 1, // stride h @@ -69,6 +75,9 @@ void invokeFlattenKV_v2(T* k, const float* rope_base, int rope_dim, float rope_ti_scale, + float llama3_inv_scaling_factor, + float llama3_alpha, + float llama3_beta, int64_t stride_b, int64_t stride_c, int64_t stride_h, @@ -95,6 +104,9 @@ void invokeFlattenKV_v2_(const AttentionParams& params, int sum_k_len) nullptr, // params.rope_theta, params.rotary_embedding_dim, params.rope_ti_scale, + params.llama3_inv_scaling_factor, + params.llama3_alpha, + params.llama3_beta, 0, 1, 2 * sum_k_len, diff --git a/src/turbomind/kernels/attention/test_attention.cu b/src/turbomind/kernels/attention/test_attention.cu index e8f519f69..4496b8b4a 100644 --- a/src/turbomind/kernels/attention/test_attention.cu +++ b/src/turbomind/kernels/attention/test_attention.cu @@ -149,6 +149,9 @@ void TestBlocks(const thrust::universal_vector& k_cache, // [B, H, S, nullptr, rope_dim, 1., + 0., + 1.0, + 1.0, 2 * head_num * seq_len, 0, seq_len, @@ -175,6 +178,9 @@ void TestBlocks(const thrust::universal_vector& k_cache, // [B, H, S, nullptr, rope_dim, 1., + 0., + 1.0, + 1.0, 2 * head_num * seq_len, 0, seq_len, @@ -531,6 +537,9 @@ int test_attention() nullptr, // DECODING ? nullptr : params.rope_theta, kRoPEDim, 1., + 0., + 1.0, + 1.0, KvHeadNum * kContextLen, 0, kContextLen, diff --git a/src/turbomind/models/llama/llama_params.h b/src/turbomind/models/llama/llama_params.h index 4c9b177de..2d1cdd107 100644 --- a/src/turbomind/models/llama/llama_params.h +++ b/src/turbomind/models/llama/llama_params.h @@ -10,12 +10,16 @@ namespace turbomind { struct LlamaAttentionParams { - int rotary_embedding_dim; - float rotary_embedding_base; - int max_position_embeddings; - float rope_scaling_factor; - bool use_dynamic_ntk; - bool use_logn_attn; + int rotary_embedding_dim; + float rotary_embedding_base; + int max_position_embeddings; + std::string rope_scaling_type; + int original_max_position_embeddings; + float rope_scaling_factor; + float low_freq_factor; + float high_freq_factor; + bool use_dynamic_ntk; + bool use_logn_attn; }; struct EngineParams { diff --git a/src/turbomind/models/llama/unified_attention_layer.cc b/src/turbomind/models/llama/unified_attention_layer.cc index 81b8e70d7..6a068f26f 100644 --- a/src/turbomind/models/llama/unified_attention_layer.cc +++ b/src/turbomind/models/llama/unified_attention_layer.cc @@ -239,9 +239,16 @@ inline void UnifiedAttentionLayer::forward(TensorMap* outputs, const TensorMa params.rotary_embedding_base = params_.rotary_embedding_base; params.max_position_embeddings = params_.max_position_embeddings; params.rope_ti_scale = 1.f; - if (!params_.use_dynamic_ntk && params_.rope_scaling_factor) { + if (params_.rope_scaling_type == "linear") { params.rope_ti_scale /= params_.rope_scaling_factor; } + if (params_.rope_scaling_type == "llama3") { + const double PI = 3.14159265358979323846; + float inv_diff_freq_factor = 1.0 / (params_.high_freq_factor - params_.low_freq_factor); + params.llama3_inv_scaling_factor = 1.0 / params_.rope_scaling_factor; + params.llama3_alpha = params_.original_max_position_embeddings / (2 * PI) * inv_diff_freq_factor; + params.llama3_beta = params_.low_freq_factor * inv_diff_freq_factor; + } params.use_logn_attn = params_.use_logn_attn; diff --git a/src/turbomind/triton_backend/CMakeLists.txt b/src/turbomind/triton_backend/CMakeLists.txt index 7aa96e6ca..4311d9d9b 100644 --- a/src/turbomind/triton_backend/CMakeLists.txt +++ b/src/turbomind/triton_backend/CMakeLists.txt @@ -34,250 +34,3 @@ set_property(TARGET TransformerTritonBackend PROPERTY POSITION_INDEPENDENT_CODE install(TARGETS TransformerTritonBackend DESTINATION ${CMAKE_INSTALL_LIBDIR}) add_subdirectory(llama) - -# Needn't build triton backend on windows -if (MSVC) - return () -endif() - -# -# Options -# -option(TRITON_ENABLE_GPU "Enable GPU support in backend" ON) -option(TRITON_ENABLE_STATS "Include statistics collections in backend" ON) - -set(TRITON_PYTORCH_INCLUDE_PATHS "" CACHE PATH "Paths to Torch includes") -set(TRITON_PYTORCH_LIB_PATHS "" CACHE PATH "Paths to Torch libraries") - -set(TRITON_BACKEND_REPO_TAG "r22.12" CACHE STRING "Tag for triton-inference-server/backend repo") -set(TRITON_CORE_REPO_TAG "r22.12" CACHE STRING "Tag for triton-inference-server/core repo") -set(TRITON_COMMON_REPO_TAG "r22.12" CACHE STRING "Tag for triton-inference-server/common repo") - -if(NOT CMAKE_BUILD_TYPE) - set(CMAKE_BUILD_TYPE Release) -endif() - -set(USE_TRITONSERVER_DATATYPE "ON") -message("-- Enable USE_TRITONSERVER_DATATYPE") - -# -# Dependencies -# -# FetchContent's composability isn't very good. We must include the -# transitive closure of all repos so that we can override the tag. -# -include(FetchContent) - -FetchContent_Declare( - repo-common - GIT_REPOSITORY https://github.com/triton-inference-server/common.git - GIT_TAG ${TRITON_COMMON_REPO_TAG} - GIT_SHALLOW ON -) -FetchContent_Declare( - repo-core - GIT_REPOSITORY https://github.com/triton-inference-server/core.git - GIT_TAG ${TRITON_CORE_REPO_TAG} - GIT_SHALLOW ON -) -FetchContent_Declare( - repo-backend - GIT_REPOSITORY https://github.com/triton-inference-server/backend.git - GIT_TAG ${TRITON_BACKEND_REPO_TAG} - GIT_SHALLOW ON -) -FetchContent_MakeAvailable(repo-common repo-core repo-backend) - -# -# CUDA -# -if(${TRITON_ENABLE_GPU}) - find_package(CUDAToolkit REQUIRED) -endif() # TRITON_ENABLE_GPU - -# -# Shared library implementing the Triton Backend API -# -configure_file(libtriton_fastertransformer.ldscript libtriton_fastertransformer.ldscript COPYONLY) - -add_library( - triton-turbomind-backend SHARED - libfastertransformer.cc -) - -add_library( - TritonTurboMindBackend::triton-turbomind-backend ALIAS triton-turbomind-backend -) - -find_package(CUDAToolkit REQUIRED) -find_package(CUDA 10.1 REQUIRED) -if (${CUDA_VERSION} GREATER_EQUAL 11.0) - message(STATUS "Add DCUDA11_MODE") - add_definitions("-DCUDA11_MODE") -endif() - -set(CUDA_PATH ${CUDA_TOOLKIT_ROOT_DIR}) - -target_compile_definitions(triton-turbomind-backend PUBLIC - USE_TRITONSERVER_DATATYPE) - -if (BUILD_MULTI_GPU) - target_compile_definitions(triton-turbomind-backend PUBLIC - BUILD_MULTI_GPU) -endif () - -target_include_directories( - triton-turbomind-backend - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/src - ${TRITON_PYTORCH_INCLUDE_PATHS} - ${Python3_INCLUDE_DIRS} - ${repo-ft_SOURCE_DIR} - ${repo-ft_SOURCE_DIR}/3rdparty/cutlass/include - ${repo-core_SOURCE_DIR}/include - ) - -target_link_directories( - triton-turbomind-backend - PRIVATE - ${CUDA_PATH}/lib64 - ) - -target_compile_features(triton-turbomind-backend PRIVATE cxx_std_14) - -target_compile_options( - triton-turbomind-backend PRIVATE - $<$,$,$>: - -Wall -Wextra -Wno-unused-parameter -Wno-type-limits >#-Werror> -) - -if(${TRITON_ENABLE_GPU}) - target_compile_definitions( - triton-turbomind-backend - PRIVATE TRITON_ENABLE_GPU=1 - ) -endif() # TRITON_ENABLE_GPU - -set_target_properties( - triton-turbomind-backend - PROPERTIES - POSITION_INDEPENDENT_CODE ON - OUTPUT_NAME triton_turbomind - SKIP_BUILD_RPATH TRUE - BUILD_WITH_INSTALL_RPATH TRUE - INSTALL_RPATH_USE_LINK_PATH FALSE - INSTALL_RPATH "$\{ORIGIN\}" - LINK_DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/libtriton_fastertransformer.ldscript - LINK_FLAGS "-Wl,--no-as-needed,--version-script ${CMAKE_CURRENT_BINARY_DIR}/libtriton_fastertransformer.ldscript" -) - -# Need to turn off unused-but-set-variable due to Torchvision -# Need to turn off unknown-pragmas due to ATen OpenMP -set_target_properties( - triton-turbomind-backend - PROPERTIES COMPILE_FLAGS - "-Wno-unknown-pragmas -Wno-unused-but-set-variable" -) - -set(TRITON_PYTORCH_LDFLAGS "") -FOREACH(p ${TRITON_PYTORCH_LIB_PATHS}) - set(TRITON_PYTORCH_LDFLAGS ${TRITON_PYTORCH_LDFLAGS} "-L${p}") -ENDFOREACH(p) - -target_link_libraries( - triton-turbomind-backend - PRIVATE - triton-core-serverapi # from repo-core - triton-core-backendapi # from repo-core - triton-core-serverstub # from repo-core - triton-backend-utils # from repo-backend - transformer-shared # from repo-ft - ${TRITON_PYTORCH_LDFLAGS} - -lcublas - -lcublasLt - -lcudart - -lcurand -) - -if (BUILD_MULTI_GPU) - target_compile_definitions( - triton-turbomind-backend - PUBLIC - BUILD_MULTI_GPU - ) - target_include_directories( - triton-turbomind-backend - PRIVATE - ${MPI_INCLUDE_PATH} - ) - target_link_directories( - triton-turbomind-backend - PRIVATE - ${MPI_Libraries} - /usr/local/mpi/lib - ) - target_link_libraries( - triton-turbomind-backend - PRIVATE - ${NCCL_LIBRARIES} - ${MPI_LIBRARIES} - ) -endif() - -if(${TRITON_ENABLE_GPU}) - target_link_libraries( - triton-turbomind-backend - PRIVATE - CUDA::cudart - ) -endif() # TRITON_ENABLE_GPU - -# -# Install -# -include(GNUInstallDirs) -set(INSTALL_CONFIGDIR ${CMAKE_INSTALL_LIBDIR}/cmake/TurboMindBackend) - -install( - TARGETS - triton-turbomind-backend - EXPORT - triton-turbomind-backend-targets - LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/turbomind - ARCHIVE DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/turbomind -) - -install( - EXPORT - triton-turbomind-backend-targets - FILE - TritonTurboMindBackendTargets.cmake - NAMESPACE - TritonTurboMindBackend:: - DESTINATION - ${INSTALL_CONFIGDIR} -) - -include(CMakePackageConfigHelpers) -configure_package_config_file( - ${CMAKE_SOURCE_DIR}/cmake/TritonTurboMindBackendConfig.cmake.in - ${CMAKE_CURRENT_BINARY_DIR}/TritonTurboMindBackendConfig.cmake - INSTALL_DESTINATION ${INSTALL_CONFIGDIR} -) - -install( - FILES - ${CMAKE_CURRENT_BINARY_DIR}/TritonTurboMindBackendConfig.cmake - DESTINATION ${INSTALL_CONFIGDIR} -) - -# -# Export from build tree -# -export( - EXPORT triton-turbomind-backend-targets - FILE ${CMAKE_CURRENT_BINARY_DIR}/TritonTurboMindBackendTargets.cmake - NAMESPACE TritonTurboMindBackend:: -) - -export(PACKAGE TritonTurboMindBackend) diff --git a/src/turbomind/triton_backend/libfastertransformer.cc b/src/turbomind/triton_backend/libfastertransformer.cc deleted file mode 100644 index dcdb598ec..000000000 --- a/src/turbomind/triton_backend/libfastertransformer.cc +++ /dev/null @@ -1,1914 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. -// Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. -// -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions -// are met: -// * Redistributions of source code must retain the above copyright -// notice, this list of conditions and the following disclaimer. -// * Redistributions in binary form must reproduce the above copyright -// notice, this list of conditions and the following disclaimer in the -// documentation and/or other materials provided with the distribution. -// * Neither the name of NVIDIA CORPORATION nor the names of its -// contributors may be used to endorse or promote products derived -// from this software without specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY -// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR -// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR -// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY -// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -// Modified from -// https://github.com/triton-inference-server/fastertransformer_backend/blob/main/src/libfastertransformer.cc - -#include - -#include -#include -#include -#include - -#pragma GCC diagnostic push -// #pragma GCC diagnostic ignored "-Wsign-compare" -#pragma GCC diagnostic ignored "-Wcast-function-type" -#pragma warning(push, 0) -#pragma warning(pop) -#pragma GCC diagnostic pop - -// must include triton libraries first -#include "triton/backend/backend_common.h" -#include "triton/backend/backend_input_collector.h" -#include "triton/backend/backend_memory.h" -#include "triton/backend/backend_model.h" -#include "triton/backend/backend_model_instance.h" -#include "triton/backend/backend_output_responder.h" -#include "triton/core/tritonbackend.h" - -// FT's libraries have dependency with triton's lib -#include "src/turbomind/macro.h" -#include "src/turbomind/triton_backend/llama/LlamaTritonModel.h" -#include "src/turbomind/triton_backend/llama/LlamaTritonModelInstance.h" -#include "src/turbomind/triton_backend/transformer_triton_backend.hpp" -#include "src/turbomind/utils/Tensor.h" -#include "src/turbomind/utils/cuda_bf16_wrapper.h" -#include "src/turbomind/utils/instance_comm.h" -#include "src/turbomind/utils/mpi_utils.h" -#include "src/turbomind/utils/nccl_utils.h" - -std::exception_ptr ptr[8]; - -namespace ft = turbomind; - -namespace triton { -namespace backend { -namespace turbomind_backend { - -#define RESPOND_ALL_AND_RETURN_IF_ERROR(RESPONSES, RESPONSES_COUNT, X) \ - do { \ - TRITONSERVER_Error* raarie_err__ = (X); \ - if (raarie_err__ != nullptr) { \ - SendErrorForResponses(RESPONSES, RESPONSES_COUNT, raarie_err__); \ - return; \ - } \ - } while (false) - -// Cuda Error handling -TRITONSERVER_Error* -ConvertCUDAStatusToTritonError(cudaError_t cuda_error, TRITONSERVER_Error_Code code, const char* msg) -{ - if (cuda_error != cudaSuccess) { - return TRITONSERVER_ErrorNew(code, cudaGetErrorString(cuda_error)); - } - return nullptr; // success -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Ragged Baching - -struct RaggedBatchingParams { - bool is_input_ragged = false; - int32_t max_seq_length = 0; - int32_t max_elements_per_seq = 0; - const int32_t* batch_input_ptr = nullptr; - size_t batch_intput_size = 0; - size_t total_input_elements = 0; -}; - -using RaggedBatchingParam_Map = std::unordered_map; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// -// ModelState -// -// State associated with a model that is using this backend. An object -// of this class is created and associated with each -// TRITONBACKEND_Model. -// -class ModelState: public BackendModel { -public: - static TRITONSERVER_Error* Create(TRITONBACKEND_Model* triton_model, ModelState** state); - virtual ~ModelState() = default; - - TRITONSERVER_Error* LoadModel(const std::string& artifact_name, - const int32_t node_id, - const int32_t device_id, - const int32_t device_id_start, - const int32_t stream_id, - std::pair, std::vector>& nccl_params, - std::shared_ptr custom_all_reduce_comms, - std::string* model_path, - std::unique_ptr* ft_model_instance); - - int GetGpuSize() - { - return gpu_size; - }; - int GetWorldSize() - { - return world_size; - }; - int GetParallelSize() - { - return tp_pp_size; - }; - int GetInstanceId() - { - return current_model_instance_id++; - }; - int GetInstanceGroupCount() - { - return instance_group_count; - }; - bool SequenceBatchingEnabled() - { - return sequence_batching_enabled; - }; - bool DynamicBatchingEnabled() - { - return dynamic_batching_enabled; - }; - std::shared_ptr GetFtModel() - { - return ft_model; - }; - -private: - ModelState(TRITONBACKEND_Model* triton_model); - TRITONSERVER_Error* AutoCompleteConfig(); - std::string GetParameter(const char* parameter); - int current_model_instance_id = 0; - bool sequence_batching_enabled = false; - bool dynamic_batching_enabled = false; - int instance_group_count = 1; - std::shared_ptr ft_model; - int node_id, gpu_size, world_size, tp_pp_size; - std::vector streams_; - - std::shared_ptr ModelFactory(common::TritonJson::Value& param, - const std::string& model_filename); -}; - -TRITONSERVER_Error* ModelState::Create(TRITONBACKEND_Model* triton_model, ModelState** state) -{ - try { - *state = new ModelState(triton_model); - } - catch (const BackendModelException& ex) { - RETURN_ERROR_IF_TRUE(ex.err_ == nullptr, - TRITONSERVER_ERROR_INTERNAL, - std::string("unexpected nullptr in BackendModelException")); - RETURN_IF_ERROR(ex.err_); - } - - // Auto-complete the configuration if requested, or T5-Encoder - bool auto_complete_config = false; - RETURN_IF_ERROR(TRITONBACKEND_ModelAutoCompleteConfig(triton_model, &auto_complete_config)); - auto_complete_config |= - (*state)->GetParameter("model_type") == "T5-Encoder" || (*state)->GetParameter("model_type") == "bert"; - if (auto_complete_config) { - RETURN_IF_ERROR((*state)->AutoCompleteConfig()); - - triton::common::TritonJson::WriteBuffer json_buffer; - (*state)->ModelConfig().Write(&json_buffer); - - TRITONSERVER_Message* message; - RETURN_IF_ERROR(TRITONSERVER_MessageNewFromSerializedJson(&message, json_buffer.Base(), json_buffer.Size())); - RETURN_IF_ERROR(TRITONBACKEND_ModelSetConfig(triton_model, 1 /* config_version */, message)); - } - - return nullptr; // success -} - -std::string param_get(common::TritonJson::Value& param, const char* field, const std::string& fallback = "") -{ - common::TritonJson::Value key; - std::string value = fallback; - param.MemberAsObject(field, &key); - key.MemberAsString("string_value", &value); - return value; -} - -int param_get_int(common::TritonJson::Value& param, const char* field, int fallback = 0) -{ - int ret = fallback; - try { - ret = std::stoi(param_get(param, field)); - } - catch (std::invalid_argument& ia) { - LOG_MESSAGE(TRITONSERVER_LOG_ERROR, - (std::string("Invalid configuration argument '") + field + "': " + ia.what()).c_str()); - } - return ret; -} - -float param_get_float(common::TritonJson::Value& param, const char* field, float fallback = 0.0) -{ - float ret = fallback; - try { - ret = std::stof(param_get(param, field)); - } - catch (std::invalid_argument& ia) { - LOG_MESSAGE(TRITONSERVER_LOG_ERROR, - (std::string("Invalid configuration argument '") + field + "': " + ia.what()).c_str()); - } - return ret; -} - -bool param_get_bool(common::TritonJson::Value& param, const char* field, bool fallback = false) -{ - return static_cast(param_get_int(param, field, static_cast(fallback))); -} - -std::shared_ptr ModelState::ModelFactory(common::TritonJson::Value& param, - const std::string& model_filename) -{ - std::shared_ptr ft_model; - - const std::string model_dir = param_get( - param, "model_checkpoint_path", JoinPath({RepositoryPath(), std::to_string(Version()), model_filename})); - const std::string model_type = param_get(param, "model_type", "GPT"); - const std::string data_type = param_get(param, "data_type"); - const int tp = param_get_int(param, "tensor_para_size"); - const int pp = param_get_int(param, "pipeline_para_size"); - const int custom_ar = param_get_int(param, "enable_custom_all_reduce"); - - const std::string dt_message = std::string("Invalid configuration argument 'data_type': ") + data_type; - - if (model_type == "Llama") { - if (data_type == "fp16") { - ft_model = std::make_shared>(tp, pp, custom_ar, model_dir); - } - else { -#if ENABLE_FP32 - ft_model = std::make_shared>(tp, pp, custom_ar, model_dir); -#else - THROW_IF_BACKEND_MODEL_ERROR( - TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_UNSUPPORTED, "turbomind is not built with FP32 support")); -#endif - } - } - else { - THROW_IF_BACKEND_MODEL_ERROR( - TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_UNSUPPORTED, ("Unknown model \"" + model_type + "\"").c_str())); - } - - return ft_model; -} - -ModelState::ModelState(TRITONBACKEND_Model* triton_model): BackendModel(triton_model, true) -{ - node_id = ft::mpi::getCommWorldRank(); - int num_nodes = ft::mpi::getCommWorldSize(); - - triton::common::TritonJson::WriteBuffer buffer; - ModelConfig().PrettyWrite(&buffer); - LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, (std::string("model configuration:\n") + buffer.Contents()).c_str()); - - common::TritonJson::Value param; - model_config_.MemberAsObject("parameters", ¶m); - - // instance groups - triton::common::TritonJson::Value instance_group, instance_obj, instance_group_count_val, instance_group_kind; - if (!ModelConfig().Find("instance_group", &instance_group) || instance_group.ArraySize() > 1) { - THROW_IF_BACKEND_MODEL_ERROR( - TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_UNSUPPORTED, "Only supports one instance group !")); - } - instance_group.IndexAsObject(0, &instance_obj); - instance_obj.Find("count", &instance_group_count_val); - instance_obj.Find("kind", &instance_group_kind); - std::string instance_group_kind_str; - int64_t instance_group_count_int64 = 1; - instance_group_kind.AsString(&instance_group_kind_str); - instance_group_count_val.AsInt(&instance_group_count_int64); - instance_group_count = (int)instance_group_count_int64; - LOG_MESSAGE( - TRITONSERVER_LOG_INFO, - ("Instance group type: " + instance_group_kind_str + " count: " + std::to_string(instance_group_count_int64)) - .c_str()); - if (instance_group_kind_str != "KIND_CPU") { - THROW_IF_BACKEND_MODEL_ERROR( - TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_UNSUPPORTED, "Instance Group: only KIND_CPU supports!")); - } - - // instance group validation - bool multi_node_enabled = num_nodes > 1; - tp_pp_size = param_get_int(param, "tensor_para_size") * param_get_int(param, "pipeline_para_size"); - gpu_size = ft::getDeviceCount(); - world_size = gpu_size * num_nodes; - int model_instance_size = num_nodes > 1 ? gpu_size : tp_pp_size; - bool multi_model_instance_valid = (multi_node_enabled && tp_pp_size == world_size && instance_group_count == 1) - || (!multi_node_enabled && gpu_size % tp_pp_size == 0 - && model_instance_size * instance_group_count >= gpu_size); - - printf("num_nodes=%d\n", num_nodes); - printf("tp_pp_size=%d\n", tp_pp_size); - printf("gpu_size=%d\n", gpu_size); - printf("world_size=%d\n", world_size); - printf("model_instance_size=%d\n", model_instance_size); - if (!multi_model_instance_valid) { - THROW_IF_BACKEND_MODEL_ERROR( - TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_UNSUPPORTED, - "1. Number of visible GPUs must be evenly divisble by TP * PP \n" - "2. Number of visible GPUs must be <= instance count * TP * PP \n" - "3. Multi-Node Inference only support one model instance \n")); - } - - int64_t max_batch_size = 0; - model_config_.MemberAsInt("max_batch_size", &max_batch_size); - - // sequence batching - triton::common::TritonJson::Value sequence_batching; - sequence_batching_enabled = ModelConfig().Find("sequence_batching", &sequence_batching); - std::string sequence_batching_log = sequence_batching_enabled ? "enabled" : "disabled"; - LOG_MESSAGE(TRITONSERVER_LOG_INFO, (std::string("Sequence Batching: ") + sequence_batching_log).c_str()); - // if (sequence_batching_enabled && max_batch_size != 1) { - // THROW_IF_BACKEND_MODEL_ERROR(TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_UNSUPPORTED, - // "Sequence Batching for interactive text generation: only supports max - // batch size = 1 currently !")); - // } - - // dynamic batching - triton::common::TritonJson::Value dynamic_batching; - dynamic_batching_enabled = ModelConfig().Find("dynamic_batching", &dynamic_batching); - std::string dynamic_batching_log = dynamic_batching_enabled ? "enabled" : "disabled"; - LOG_MESSAGE(TRITONSERVER_LOG_INFO, (std::string("Dynamic Batching: ") + dynamic_batching_log).c_str()); - if (dynamic_batching_enabled && sequence_batching_enabled) { - THROW_IF_BACKEND_MODEL_ERROR(TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_UNSUPPORTED, - "Sequence Batching cannot work with dynamic " - "batching at the same time !")); - } - - std::string model_filename; - model_config_.MemberAsString("default_model_filename", &model_filename); - - if (model_filename == "") { - model_filename = std::to_string(param_get_int(param, "tensor_para_size")) + "-gpu"; - } - - ft_model = ModelFactory(param, model_filename); - - std::cout << ft_model->toString(); - - int total_weight_gpu_size = (instance_group_count * model_instance_size) >= gpu_size ? - gpu_size : - (instance_group_count * model_instance_size); - streams_.resize(instance_group_count * model_instance_size); - - /* create shared weights - assume 8 gpus, 8 model instances, Tensor Para Size 2 - then we will distribute model instances to [0, 1], [2, 3], [4, 5], [6, 7], - [0, 1], [2, 3], [4, 5], [6, 7] GPUs; - two instance instances on GPUs [0, 1] will share the same weights - */ - std::vector threads; - LOG_MESSAGE(TRITONSERVER_LOG_INFO, (std::string("Before Loading Weights:")).c_str()); - ft::print_mem_usage(); - for (int gid = 0; gid < total_weight_gpu_size; gid++) { - int rank = node_id * gpu_size + gid % tp_pp_size; - threads.push_back(std::thread(&AbstractTransformerModel::createSharedWeights, ft_model, gid, rank)); - } - for (auto& t : threads) { - t.join(); - } - LOG_MESSAGE(TRITONSERVER_LOG_INFO, (std::string("After Loading Weights:")).c_str()); - ft::print_mem_usage(); -} - -TRITONSERVER_Error* -ModelState::LoadModel(const std::string& artifact_name, - const int32_t node_id, - const int32_t device_id, - const int32_t device_id_start, - const int32_t stream_id, - std::pair, std::vector>& nccl_params_instance, - std::shared_ptr custom_all_reduce_comms, - std::string* model_path, - std::unique_ptr* ft_model_instance) -{ - LOG_IF_ERROR(ConvertCUDAStatusToTritonError( - cudaSetDevice(device_id), TRITONSERVER_ERROR_INTERNAL, "Failed to set cuda device"), - "Failed to set cuda device"); - - std::string cc_model_filename = artifact_name; - if (cc_model_filename.empty()) { - cc_model_filename = "gpt3-model"; - } - - if (!node_id && !device_id) { - LOG_MESSAGE(TRITONSERVER_LOG_INFO, (std::string("Before Loading Model:")).c_str()); - } - ft::print_mem_usage(); - - LOG_IF_ERROR(ConvertCUDAStatusToTritonError(cudaStreamCreate(&streams_[stream_id]), - TRITONSERVER_ERROR_INTERNAL, - "Failed to create the stream"), - "Failed to create the stream"); - - const int rank = node_id * GetGpuSize() + device_id - device_id_start; - - auto model_instance = ft_model->createModelInstance( - device_id, rank, streams_[stream_id], nccl_params_instance, custom_all_reduce_comms); - ft_model_instance->reset(model_instance.release()); - - if (!node_id && !device_id) { - LOG_MESSAGE(TRITONSERVER_LOG_INFO, (std::string("After Loading Model:")).c_str()); - } - ft::print_mem_usage(); - - return nullptr; // success -} - -TRITONSERVER_Error* ModelState::AutoCompleteConfig() -{ - if (GetParameter("model_type") == "T5-Encoder") { - const std::string data_type = GetParameter("data_type"); - auto& config = ModelConfig(); - common::TritonJson::Value outputs, output, dtype_object; - std::string name; - config.MemberAsArray("output", &outputs); - - std::unordered_map return_type_map{ - {"fp16", "TYPE_FP16"}, {"fp32", "TYPE_FP32"}, {"bf16", "TYPE_BF16"}}; - - std::set outputs_to_modify = {"output_hidden_state", "output_attentions"}; - for (size_t idx = 0; idx < outputs.ArraySize(); idx++) { - outputs.IndexAsObject(idx, &output); - output.MemberAsString("name", &name); - if (outputs_to_modify.find(name) == outputs_to_modify.end()) { - continue; - } - output.Find("data_type", &dtype_object); - dtype_object.SetString(return_type_map[data_type]); - LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, - ("Automatically setting return data_type for \"" + name + "\" to \"" - + return_type_map[data_type] + "\"") - .c_str()); - } - } - else if (GetParameter("model_type") == "bert") { - const std::string data_type = GetParameter("data_type"); - auto& config = ModelConfig(); - common::TritonJson::Value inputs, input, dtype_object; - common::TritonJson::Value outputs, output; - std::string name; - config.MemberAsArray("input", &inputs); - config.MemberAsArray("output", &outputs); - - std::unordered_map return_type_map{ - {"fp16", "TYPE_FP16"}, {"fp32", "TYPE_FP32"}, {"bf16", "TYPE_BF16"}}; - - for (size_t idx = 0; idx < inputs.ArraySize(); idx++) { - inputs.IndexAsObject(idx, &input); - input.MemberAsString("name", &name); - if (name != "input_hidden_state") { - continue; - } - input.Find("data_type", &dtype_object); - dtype_object.SetString(return_type_map[data_type]); - LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, - ("Automatically setting return data_type for " - "\"input_hidden_state\" to \"" - + return_type_map[data_type] + "\"") - .c_str()); - } - - for (size_t idx = 0; idx < outputs.ArraySize(); idx++) { - outputs.IndexAsObject(idx, &output); - output.MemberAsString("name", &name); - if (name != "output_hidden_state") { - continue; - } - output.Find("data_type", &dtype_object); - dtype_object.SetString(return_type_map[data_type]); - LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, - ("Automatically setting return data_type for " - "\"output_hidden_state\" to \"" - + return_type_map[data_type] + "\"") - .c_str()); - } - } - else { - // Auto-complete configuration is not supported since turbomind does - // not store/capture sufficient model metadata so just log error instead. - LOG_MESSAGE(TRITONSERVER_LOG_WARN, - (std::string("skipping model configuration auto-complete for '") + Name() - + "': not supported for turbomind backend") - .c_str()); - } - - return nullptr; // success -} - -std::string ModelState::GetParameter(const char* parameter) -{ - auto& config = ModelConfig(); - common::TritonJson::Value parameters, model_type_obj; - std::string model_type; - config.MemberAsObject("parameters", ¶meters); - parameters.MemberAsObject(parameter, &model_type_obj); - model_type_obj.MemberAsString("string_value", &model_type); - return model_type; -} - -struct stream_callback_ctx_t { - size_t total_batch_size; - TRITONBACKEND_Request** requests; - uint32_t request_count; - std::vector* responses; - std::vector* factories; - BackendModelInstance* model; -}; - -void generate_response_placeholders(std::vector* responses, - std::vector* factories) -{ - TRITONSERVER_Error* err = nullptr; - for (auto factory : *factories) { - TRITONBACKEND_Response* response; - err = TRITONBACKEND_ResponseNewFromFactory(&response, factory); - if (err) { - LOG_MESSAGE(TRITONSERVER_LOG_ERROR, "Fail to create response from factory"); - TRITONSERVER_ErrorDelete(err); - } - responses->push_back(response); - } -} - -// -// ModelInstanceState -// -// State associated with a model instance. An object of this class is -// created and associated with each TRITONBACKEND_ModelInstance. -// -class ModelInstanceState: public BackendModelInstance { -public: - static TRITONSERVER_Error* - Create(ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance, ModelInstanceState** state); - virtual ~ModelInstanceState(); - - // Get the state of the model that corresponds to this instance. - ModelState* StateForModel() const - { - return model_state_; - } - - // Execute... - void ProcessRequests(TRITONBACKEND_Request** requests, const uint32_t request_count); - - std::shared_ptr> - Execute(std::vector* responses, - stream_callback_ctx_t* context, - const uint32_t response_count, - std::shared_ptr> input_tensors); - - void ReadOutputTensors(size_t total_batch_size, - std::shared_ptr> output_tensors, - TRITONBACKEND_Request** requests, - const uint32_t request_count, - std::vector* responses); - - int GetModelInstanceCount() - { - return model_instance_count_; - }; - int GetModelInstanceId() - { - return model_instance_id_; - }; - -private: - ModelInstanceState(ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance); - TRITONSERVER_Error* ValidateInputs(); - TRITONSERVER_Error* ValidateOutputs(); - - void SetInputTensors(size_t total_batch_size, - TRITONBACKEND_Request** requests, - const uint32_t request_count, - std::vector* responses, - BackendInputCollector* collector, - std::vector* input_names, - std::shared_ptr>* input_tensors, - std::vector* input_memories, - bool* cuda_copy); - - void BroadcastInputTensors(std::shared_ptr>* input_tensors); - - ModelState* model_state_; - - // model instance id - int model_instance_count_ = 1; - int model_instance_id_ = 0; - int model_instance_gpu_size_ = 1; - int model_instance_device_id_start_ = 0; - - // output tensor stream - cudaStream_t output_stream_; - - // tensor parallel + pipeline parallel - int gpu_size_ = 1; - int world_size_ = 1; - int tp_pp_size_ = 1; - - // Should we use the streaming API? - bool is_decoupled_ = false; - - // The full path to the FT model file. - std::string model_path_; - - std::vector> ft_model_instance_; - - std::unique_ptr instance_comm_; - - // inter-node broadcast buffer - std::vector bcast_buffers; - - // Map from configuration name for an input to the index of - // that input in the model. - std::unordered_map input_index_map_; - - // Map from configuration name for an output to the index of - // that output in the model. - std::unordered_map output_dtype_map_; - - std::pair, std::vector> nccl_params_; - - // custom all reduce comms - std::vector> custom_all_reduce_comms_; -}; - -TRITONSERVER_Error* ModelInstanceState::Create(ModelState* model_state, - TRITONBACKEND_ModelInstance* triton_model_instance, - ModelInstanceState** state) -{ - try { - *state = new ModelInstanceState(model_state, triton_model_instance); - } - catch (const BackendModelInstanceException& ex) { - RETURN_ERROR_IF_TRUE(ex.err_ == nullptr, - TRITONSERVER_ERROR_INTERNAL, - std::string("unexpected nullptr in BackendModelInstanceException")); - RETURN_IF_ERROR(ex.err_); - } - - return nullptr; // success -} - -int ThreadLoadModel(ModelState* model_state, - const std::string& artifact_name, - const int32_t node_id, - const int32_t device_id, - const int32_t device_id_start, - const int32_t stream_id, - std::pair, std::vector> nccl_params, - std::shared_ptr custom_all_reduce_comms, - std::string* model_path, - std::unique_ptr* ft_model_instance) -{ - THROW_IF_BACKEND_INSTANCE_ERROR(model_state->LoadModel(artifact_name, - node_id, - device_id, - device_id_start, - stream_id, - nccl_params, - custom_all_reduce_comms, - model_path, - ft_model_instance)); - return 0; -} - -ModelInstanceState::ModelInstanceState(ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance): - BackendModelInstance(model_state, triton_model_instance), model_state_(model_state) -{ - int node_id = ft::mpi::getCommWorldRank(); - int num_nodes = ft::mpi::getCommWorldSize(); - - LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, (std::string("Model name ") + ArtifactFilename()).c_str()); - - triton::common::TritonJson::Value transaction_policy; - is_decoupled_ = false; - model_state_->ModelConfig().MemberAsObject("model_transaction_policy", &transaction_policy); - transaction_policy.MemberAsBool("decoupled", &is_decoupled_); - - LOG_MESSAGE( - TRITONSERVER_LOG_VERBOSE, - (std::string("Use ") + (is_decoupled_ ? "DECOUPLED (streaming)" : "COUPLED (classic)") + " API.").c_str()); - - THROW_IF_BACKEND_INSTANCE_ERROR(ValidateInputs()); - THROW_IF_BACKEND_INSTANCE_ERROR(ValidateOutputs()); - - // NOTE: model instance params - model_instance_id_ = model_state->GetInstanceId(); - model_instance_count_ = model_state->GetInstanceGroupCount(); - tp_pp_size_ = model_state->GetParallelSize(); - gpu_size_ = model_state->GetGpuSize(); - world_size_ = model_state->GetWorldSize(); - - model_instance_gpu_size_ = num_nodes > 1 ? gpu_size_ : tp_pp_size_; - ft_model_instance_.resize(model_instance_gpu_size_); - std::vector threads; - - std::shared_ptr shared_ft_model = model_state->GetFtModel(); - - // NOTE: CPU_KIND only, the backend fully controls how to distribute models to - // GPUs - model_instance_device_id_start_ = (model_instance_id_ * model_instance_gpu_size_) % gpu_size_; - // create output tensor stream - LOG_IF_ERROR(ConvertCUDAStatusToTritonError(cudaSetDevice(model_instance_device_id_start_), - TRITONSERVER_ERROR_INTERNAL, - "Failed to set cuda device"), - "Failed to set cuda device"); - LOG_IF_ERROR(ConvertCUDAStatusToTritonError( - cudaStreamCreate(&output_stream_), TRITONSERVER_ERROR_INTERNAL, "Failed to create the stream"), - "Failed to create the stream"); - - // create nccl params - nccl_params_ = shared_ft_model->createNcclParams(node_id, model_instance_device_id_start_, num_nodes > 1); - - shared_ft_model->createCustomComms(&custom_all_reduce_comms_, world_size_); - std::string model_instance_gpu_ids = "[ "; - for (int gid = model_instance_device_id_start_; gid < model_instance_device_id_start_ + model_instance_gpu_size_; - gid++) { - model_instance_gpu_ids += (std::to_string(gid) + " "); - threads.push_back(std::thread(ThreadLoadModel, - model_state, - ArtifactFilename(), - node_id, - gid, - model_instance_device_id_start_, - model_instance_id_ * model_instance_gpu_size_ + gid, - nccl_params_, - custom_all_reduce_comms_[gid - model_instance_device_id_start_], - &model_path_, - &ft_model_instance_[gid - model_instance_device_id_start_])); - } - model_instance_gpu_ids += "]"; - - for (auto& t : threads) { - t.join(); - } - - instance_comm_ = shared_ft_model->createInstanceComm(tp_pp_size_); - - LOG_MESSAGE(TRITONSERVER_LOG_INFO, - (std::string("Model instance is created on GPU ") + model_instance_gpu_ids).c_str()); -} - -ModelInstanceState::~ModelInstanceState() -{ -#ifdef TRITON_ENABLE_GPU -#endif // TRITON_ENABLE_GPU - for (auto bcast_buffer : bcast_buffers) { - free(bcast_buffer); - } -} - -TRITONSERVER_Error* ModelInstanceState::ValidateInputs() -{ - triton::common::TritonJson::Value ios, bios; - // input - std::string name, data_type; - triton::common::TritonJson::Value jshape; - // batch input - std::string kind, target_name, source_input; - triton::common::TritonJson::Value target_name_array, source_input_array; - model_state_->ModelConfig().MemberAsArray("input", &ios); - model_state_->ModelConfig().MemberAsArray("batch_input", &bios); - - std::vector valid_batch_input; - - // batch input - for (size_t size = 0; size < bios.ArraySize(); size++) { - triton::common::TritonJson::Value batch_input; - bios.IndexAsObject(size, &batch_input); - batch_input.MemberAsString("kind", &kind); - batch_input.MemberAsArray("target_name", &target_name_array); - batch_input.MemberAsString("data_type", &data_type); - batch_input.MemberAsArray("source_input", &source_input_array); - target_name_array.IndexAsString(0, &target_name); - source_input_array.IndexAsString(0, &source_input); - - LOG_MESSAGE(TRITONSERVER_LOG_INFO, - (std::string("Get batch input kind: " + kind + ", target_name: " + target_name - + ", data_type: " + data_type + ", source_input: " + source_input) - .c_str())); - - if (kind == "BATCH_ITEM_SHAPE" && data_type == "TYPE_INT32" && source_input + "_item_shape" == target_name) { - valid_batch_input.emplace_back(std::move(source_input)); - } - } - - // input - for (size_t size = 0; size < ios.ArraySize(); size++) { - triton::common::TritonJson::Value input; - ios.IndexAsObject(size, &input); - input.MemberAsString("name", &name); - input.MemberAsString("data_type", &data_type); - input.MemberAsArray("dims", &jshape); - - triton::common::TritonJson::Value allow_ragged_batch_json; - bool allow_ragged_batch = false; - if (input.Find("allow_ragged_batch", &allow_ragged_batch_json)) { - RETURN_IF_ERROR(allow_ragged_batch_json.AsBool(&allow_ragged_batch)); - } - - if (allow_ragged_batch - && std::find(valid_batch_input.begin(), valid_batch_input.end(), name) == valid_batch_input.end()) { - return TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INVALID_ARG, - std::string("Ragged Batch [ " + name + " ] needs the corresponding batch_input item shape !").c_str()); - } - - std::vector shape; - for (size_t size = 0; size < jshape.ArraySize(); size++) { - int64_t value = 0; - jshape.IndexAsInt(size, &value); - shape.push_back(value); - } - - std::string str_shape = "["; - for (uint i = 0; i < shape.size(); i++) { - str_shape = str_shape + std::to_string(shape[i]); - if (i != shape.size() - 1) { - str_shape = str_shape + ", "; - } - else { - str_shape = str_shape + "]"; - } - } - - std::string allow_ragged_batch_str = allow_ragged_batch ? "true" : "false"; - - LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, - (std::string("Get input name: " + name + ", type: " + data_type + ", shape: " + str_shape - + ", allow_ragged_batch: " + allow_ragged_batch_str) - .c_str())); - } - return nullptr; // success -} - -TRITONSERVER_Error* ModelInstanceState::ValidateOutputs() -{ - triton::common::TritonJson::Value ios; - RETURN_IF_ERROR(model_state_->ModelConfig().MemberAsArray("output", &ios)); - - std::string name, data_type; - triton::common::TritonJson::Value jshape; - model_state_->ModelConfig().MemberAsArray("output", &ios); - for (size_t size = 0; size < ios.ArraySize(); size++) { - triton::common::TritonJson::Value input; - ios.IndexAsObject(size, &input); - input.MemberAsString("name", &name); - input.MemberAsString("data_type", &data_type); - input.MemberAsArray("dims", &jshape); - - std::vector shape; - for (size_t size = 0; size < jshape.ArraySize(); size++) { - int64_t value = 0; - jshape.IndexAsInt(size, &value); - shape.push_back(value); - } - - std::string str_shape = "["; - for (uint i = 0; i < shape.size(); i++) { - str_shape = str_shape + std::to_string(shape[i]); - if (i != shape.size() - 1) { - str_shape = str_shape + ", "; - } - else { - str_shape = str_shape + "]"; - } - } - - LOG_MESSAGE( - TRITONSERVER_LOG_VERBOSE, - (std::string("Get output name: " + name + ", type: " + data_type + ", shape: " + str_shape).c_str())); - } - - return nullptr; // success -} - -void ModelInstanceState::ProcessRequests(TRITONBACKEND_Request** requests, const uint32_t request_count) -{ - LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, - (std::string("TRITONBACKEND_ModelExecute: Running ") + Name() + " with " + std::to_string(request_count) - + " requests") - .c_str()); - uint64_t exec_start_ns = 0; - SET_TIMESTAMP(exec_start_ns); - - const int max_batch_size = model_state_->MaxBatchSize(); - - // For each request collect the total batch size for this inference - // execution. The batch-size, number of inputs, and size of each - // input has already been checked so don't need to do that here. - size_t total_batch_size = 0; - - // bool sequence_batching_enabled = model_state_->SequenceBatchingEnabled(); - // size_t real_batch_dim = (int) sequence_batching_enabled; - constexpr size_t real_batch_dim = 0; - - // only one batch slot per model instance when sequence_batching enabled - for (size_t i = 0; i < request_count; i++) { - // If we get a nullptr request then something is badly wrong. Fail - // and release all requests. - if (requests[i] == nullptr) { - RequestsRespondWithError( - requests, - request_count, - TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INTERNAL, - std::string("null request given to TurboMind backend for '" + Name() + "'").c_str())); - return; - } - - if (max_batch_size > 0) { - // Retrieve the batch size from one of the inputs, if the model - // supports batching, the first dimension size is batch size - int index = 0; - while (true) { - TRITONBACKEND_Input* input; - TRITONSERVER_Error* err_0 = TRITONBACKEND_RequestInputByIndex(requests[i], index, &input); - if (err_0 == nullptr) { - const char* input_name; - const int64_t* shape; - TRITONSERVER_Error* err_1 = - TRITONBACKEND_InputProperties(input, &input_name, nullptr, &shape, nullptr, nullptr, nullptr); - std::string input_name_str = std::string(input_name); - if (err_1 == nullptr) { - if (input_name_str != "START" && input_name_str != "END" && input_name_str != "READY") { - total_batch_size += shape[real_batch_dim]; - break; - } - index++; - } - else { - RequestsRespondWithError(requests, request_count, err_1); - return; - } - } - else { - RequestsRespondWithError(requests, request_count, err_0); - return; - } - } - } - else { - total_batch_size += 1; - } - } - - // If there are no valid payloads then no need to run the inference. - if (total_batch_size == 0) { - return; - } - - LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, - (std::string("get total batch_size = ") + std::to_string(total_batch_size)).c_str()); - - // Make sure the maximum batch size is not exceeded. The - // total_batch_size must be 1 for models that don't support batching - // (i.e. max_batch_size == 0). If max_batch_size is exceeded then - // scheduler has done something badly wrong so fail and release all - // requests. - if ((total_batch_size != 1) && (total_batch_size > (size_t)max_batch_size)) { - RequestsRespondWithError( - requests, - request_count, - TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL, - std::string("batch size " + std::to_string(total_batch_size) + " for '" + Name() - + "', max allowed is " + std::to_string(max_batch_size)) - .c_str())); - return; - } - - // At this point we are committed to running inference with all - // 'requests'. Create a response for each request. During input - // processing if there is an error with any request that error will - // be sent immediately with the corresponding response (and the - // response unique_ptr will then be nullptr). The request object - // itself will not be released until after all inferencing is done - // (below) as we may need to access the request object when - // determine how to process outputs (for example, even if we don't - // need the outputs for a request that has an error, we do need to - // know the size of those outputs associated with the request so we - // can skip them in the output tensors). - // - // When operating in the decoupled mode, responses should be created - // from factories. Here, we instantiate a factory for each request and - // generate the first response. At each new result from the model the - // generated response is filled, sent, and another response is created - // from the factory. The last response is send just like in the - // non-decoupled mode. - std::vector responses; - responses.reserve(request_count); - std::vector factories; - - for (size_t i = 0; i < request_count; i++) { - if (is_decoupled_) { - TRITONBACKEND_ResponseFactory* factory; - auto err = TRITONBACKEND_ResponseFactoryNew(&factory, requests[i]); - if (err == nullptr) { - factories.emplace_back(factory); - } - else { - factories.emplace_back(nullptr); - LOG_MESSAGE(TRITONSERVER_LOG_ERROR, "Fail to create response factory"); - TRITONSERVER_ErrorDelete(err); - } - } - else { - TRITONBACKEND_Response* response; - auto err = TRITONBACKEND_ResponseNew(&response, requests[i]); - if (err == nullptr) { - responses.emplace_back(response); - } - else { - responses.emplace_back(nullptr); - LOG_MESSAGE(TRITONSERVER_LOG_ERROR, "Fail to create response"); - TRITONSERVER_ErrorDelete(err); - } - } - } - - std::vector input_names; - std::shared_ptr> input_tensors = - std::make_shared>(); - std::vector input_memories; - bool cuda_copy = false; - if (is_decoupled_) { - generate_response_placeholders(&responses, &factories); - } - BackendInputCollector collector(requests, - request_count, - &responses, - model_state_->TritonMemoryManager(), - model_state_->EnablePinnedInput(), - CudaStream()); - SetInputTensors(total_batch_size, - requests, - request_count, - &responses, - &collector, - &input_names, - &input_tensors, - &input_memories, - &cuda_copy); - - // Wait for any in-flight input tensor copies to complete. -#ifdef TRITON_ENABLE_GPU - if (cuda_copy) { - cudaStreamSynchronize(CudaStream()); - } -#endif - - uint64_t compute_start_ns = 0; - SET_TIMESTAMP(compute_start_ns); - - stream_callback_ctx_t context = {total_batch_size, requests, request_count, &responses, &factories, this}; - - auto output_tensors = Execute(&responses, &context, request_count, input_tensors); - - uint64_t compute_end_ns = 0; - SET_TIMESTAMP(compute_end_ns); - - // Free BackendMemory used for inputs - for (BackendMemory* mem : input_memories) { - delete mem; - } - input_memories.clear(); - - ReadOutputTensors(total_batch_size, output_tensors, requests, request_count, &responses); - - uint64_t exec_end_ns = 0; - SET_TIMESTAMP(exec_end_ns); - - LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, - (std::string("get response size = ") + std::to_string(responses.size())).c_str()); - - // Send all the responses that haven't already been sent because of - // an earlier error. Note that the responses are not set to nullptr - // here as we need that indication below to determine if the request - // we successful or not. - for (auto& response : responses) { - if (response != nullptr) { - LOG_IF_ERROR(TRITONBACKEND_ResponseSend(response, TRITONSERVER_RESPONSE_COMPLETE_FINAL, nullptr), - "failed to send TurboMind backend response"); - LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, (std::string("response is sent")).c_str()); - } - else { - LOG_MESSAGE(TRITONSERVER_LOG_WARN, (std::string("response is nullptr")).c_str()); - } - } - - // Report statistics for each request. - for (uint32_t r = 0; r < request_count; ++r) { - auto& request = requests[r]; - LOG_IF_ERROR(TRITONBACKEND_ModelInstanceReportStatistics(TritonModelInstance(), - request, - (responses[r] != nullptr) /* success */, - exec_start_ns, - compute_start_ns, - compute_end_ns, - exec_end_ns), - "failed reporting request statistics"); - - LOG_IF_ERROR(TRITONBACKEND_RequestRelease(request, TRITONSERVER_REQUEST_RELEASE_ALL), - "failed releasing request"); - } - - // Report the entire batch statistics. - LOG_IF_ERROR( - TRITONBACKEND_ModelInstanceReportBatchStatistics( - TritonModelInstance(), total_batch_size, exec_start_ns, compute_start_ns, compute_end_ns, exec_end_ns), - "failed reporting batch request statistics"); -} - -void streaming_callback(std::shared_ptr> output_tensors, void* ctx) -{ - stream_callback_ctx_t* context = reinterpret_cast(ctx); - ModelInstanceState* model = reinterpret_cast(context->model); - - std::vector* responses = context->responses; - - model->ReadOutputTensors( - context->total_batch_size, output_tensors, context->requests, context->request_count, responses); - - for (auto& response : *responses) { - if (response != nullptr) { - LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, (std::string("start to send streaming response")).c_str()); - LOG_IF_ERROR(TRITONBACKEND_ResponseSend(response, 0, nullptr), "failed to send TurboMind backend response"); - LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, (std::string("streaming response is sent")).c_str()); - } - else { - LOG_MESSAGE(TRITONSERVER_LOG_WARN, (std::string("streaming response is nullptr")).c_str()); - } - } - responses->clear(); - generate_response_placeholders(responses, context->factories); -} - -int ThreadForward(std::unique_ptr* ft_model_instance, - std::shared_ptr>* input_tensors, - std::shared_ptr>* output_tensors, - ft::AbstractInstanceComm* instance_comm, - std::exception_ptr* exception_ptr, - const int device_id, - const int use_stream_cb, - stream_callback_ctx_t* context) -{ - LOG_IF_ERROR(ConvertCUDAStatusToTritonError( - cudaSetDevice(device_id), TRITONSERVER_ERROR_INTERNAL, "Failed to set cuda device"), - "Failed to set cuda device"); - - LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, (std::string("Start to forward")).c_str()); - if (use_stream_cb) { - (*ft_model_instance)->registerCallback(streaming_callback, (void*)context); - } - *output_tensors = (*ft_model_instance)->forward(*input_tensors, instance_comm); - if (use_stream_cb) { - (*ft_model_instance)->unRegisterCallback(); - } - LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, (std::string("Stop to forward")).c_str()); - - if ((*output_tensors)->count("error_message")) { - *exception_ptr = *((std::exception_ptr*)((*output_tensors)->at("error_message").data)); - } - return 0; -} - -void triton_check_inputs(std::shared_ptr> output_tensors, const char* filename) -{ - auto& output = output_tensors->at("output_ids"); - auto shape = output.shape; - assert(shape.size() == 3); - assert(output.type == TYPE_UINT32); - auto batch_size = shape[0]; - auto length = shape[2]; - std::string fName = filename; - auto file = std::ofstream(fName, std::ios::out); - if (!file.is_open()) {} - else { - for (size_t i = 0; i < batch_size; i++) { - for (size_t j = 0; j < length; j++) { - file << ((uint32_t*)output.data)[i * length + j] << " "; - } - file << std::endl; - } - } -} - -void ModelInstanceState::BroadcastInputTensors(std::shared_ptr>* input_tensors) -{ - int node_id = ft::mpi::getCommWorldRank(); - - uint32_t input_count = node_id ? 0 : (*input_tensors)->size(); - ft::mpi::bcast(&input_count, 1, ft::mpi::MPI_TYPE_UINT32_T, 0, ft::mpi::COMM_WORLD); - if (input_count > bcast_buffers.size()) { - bcast_buffers.resize(input_count); - } - - if (node_id) { - for (uint input_index = 0; input_index < input_count; input_index++) { - std::vector batchn_shape; - int64_t shape_size = 0; - int64_t buffer_size = 1; - ft::mpi::bcast(&shape_size, 1, ft::mpi::MPI_TYPE_INT64_T, 0, ft::mpi::COMM_WORLD); - for (int s_id = 0; s_id < shape_size; s_id++) { - int64_t val; - ft::mpi::bcast(&val, 1, ft::mpi::MPI_TYPE_INT64_T, 0, ft::mpi::COMM_WORLD); - batchn_shape.push_back(val); - buffer_size *= val; - } - int64_t data_type_size = 1; - ft::mpi::bcast(&data_type_size, 1, ft::mpi::MPI_TYPE_INT64_T, 0, ft::mpi::COMM_WORLD); - buffer_size *= data_type_size; - bcast_buffers[input_index] = (char*)realloc(bcast_buffers[input_index], buffer_size); - char* input_buffer = bcast_buffers[input_index]; - ft::mpi::bcast(input_buffer, buffer_size, ft::mpi::MPI_TYPE_BYTE, 0, ft::mpi::COMM_WORLD); - - int64_t name_size = 0; - ft::mpi::bcast(&name_size, 1, ft::mpi::MPI_TYPE_INT64_T, 0, ft::mpi::COMM_WORLD); - char char_name[1024] = {0}; - ft::mpi::bcast(char_name, name_size, ft::mpi::MPI_TYPE_CHAR, 0, ft::mpi::COMM_WORLD); - uint32_t data_type_num = 0; - ft::mpi::bcast(&data_type_num, 1, ft::mpi::MPI_TYPE_UINT32_T, 0, ft::mpi::COMM_WORLD); - TRITONSERVER_DataType triton_data_type = TRITONSERVER_DataType(data_type_num); - - (*input_tensors) - ->insert({std::string(char_name), - Tensor{TRITONSERVER_MEMORY_CPU, triton_data_type, batchn_shape, input_buffer}}); - } - } - else { - int input_index = 0; - for (auto it = (*input_tensors)->begin(); it != (*input_tensors)->end(); ++it) { - std::vector batchn_shape = it->second.shape; - int64_t shape_size = batchn_shape.size(); - int64_t buffer_size = 1; - ft::mpi::bcast(&shape_size, 1, ft::mpi::MPI_TYPE_INT64_T, 0, ft::mpi::COMM_WORLD); - for (int s_id = 0; s_id < shape_size; s_id++) { - int64_t val = batchn_shape[s_id]; - ft::mpi::bcast(&val, 1, ft::mpi::MPI_TYPE_INT64_T, 0, ft::mpi::COMM_WORLD); - buffer_size *= val; - } - - ft::Tensor tmp{ - ft::MEMORY_CPU, ft::TYPE_BYTES, {1}, nullptr}; // TODO change the getDataTypeByteNum function to static - int64_t data_type_size = tmp.getTypeSize(triton::Tensor::convertTritonTypeToFt(it->second.type)); - ft::mpi::bcast(&data_type_size, 1, ft::mpi::MPI_TYPE_INT64_T, 0, ft::mpi::COMM_WORLD); - buffer_size *= data_type_size; - - ft::mpi::bcast( - const_cast(it->second.data), buffer_size, ft::mpi::MPI_TYPE_BYTE, 0, ft::mpi::COMM_WORLD); - - std::string name = it->first; - int64_t name_size = name.size(); - ft::mpi::bcast(&name_size, 1, ft::mpi::MPI_TYPE_INT64_T, 0, ft::mpi::COMM_WORLD); - bcast_buffers[input_index] = (char*)realloc(bcast_buffers[input_index], name_size); - char* char_name = bcast_buffers[input_index]; - int64_t length = (int64_t)name.copy(char_name, name_size); - ft::FT_CHECK(length == name_size); - ft::mpi::bcast(char_name, name_size, ft::mpi::MPI_TYPE_CHAR, 0, ft::mpi::COMM_WORLD); - - uint32_t data_type_num = (uint32_t)(it->second.type); - ft::mpi::bcast(&data_type_num, 1, ft::mpi::MPI_TYPE_UINT32_T, 0, ft::mpi::COMM_WORLD); - input_index++; - } - } -} - -std::shared_ptr> -ModelInstanceState::Execute(std::vector* responses, - stream_callback_ctx_t* context, - const uint32_t response_count, - std::shared_ptr> input_tensors) -{ - int node_id = ft::mpi::getCommWorldRank(); - - if (node_id == 0) { - // Debug: input array - // triton_check_inputs(input_tensors, "triton_in"); - } - if (node_id) { - input_tensors = std::make_shared>(); - } - - ft::mpi::barrier(); - - BroadcastInputTensors(&input_tensors); - std::vector threads; - std::shared_ptr> output_tensors_list[model_instance_gpu_size_]; - std::exception_ptr exception_ptr[model_instance_gpu_size_]; - for (int gid = model_instance_device_id_start_; gid < model_instance_device_id_start_ + model_instance_gpu_size_; - gid++) { - int instance_local_id = gid - model_instance_device_id_start_; - LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, (std::string("before ThreadForward " + std::to_string(gid))).c_str()); - threads.push_back(std::thread(ThreadForward, - &ft_model_instance_[instance_local_id], - &input_tensors, - &output_tensors_list[instance_local_id], - instance_comm_.get(), - &exception_ptr[instance_local_id], - gid, - is_decoupled_ && gid == model_instance_device_id_start_, - context)); - LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, (std::string("after ThreadForward " + std::to_string(gid))).c_str()); - } - - for (auto& t : threads) { - t.join(); - } - - try { - for (int gid = model_instance_device_id_start_; - gid < model_instance_device_id_start_ + model_instance_gpu_size_; - gid++) { - int instance_local_id = gid - model_instance_device_id_start_; - if (exception_ptr[instance_local_id]) { - std::rethrow_exception(exception_ptr[instance_local_id]); - } - } - } - catch (std::exception& ex) { - SendErrorForResponses(responses, - response_count, - TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL, - ("TurboMind execute failure: " + std::string(ex.what())).c_str())); - } - auto output_tensors = output_tensors_list[0]; - return output_tensors; -} - -void ModelInstanceState::SetInputTensors( - size_t total_batch_size, - TRITONBACKEND_Request** requests, - const uint32_t request_count, - std::vector* responses, - BackendInputCollector* collector, - std::vector* input_names, - std::shared_ptr>* input_tensors, - std::vector* input_memories, - bool* cuda_copy) -{ - const int max_batch_size = model_state_->MaxBatchSize(); - // bool sequence_batching_enabled = model_state_->SequenceBatchingEnabled(); - bool dynamic_batching_enabled = model_state_->DynamicBatchingEnabled() || model_state_->SequenceBatchingEnabled(); - - // All requests must have equally-sized input tensors so use any - // request as the representative for the input tensors. - uint32_t input_count; - RESPOND_ALL_AND_RETURN_IF_ERROR( - responses, request_count, TRITONBACKEND_RequestInputCount(requests[0], &input_count)); - - LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, (std::string("get input count = ") + std::to_string(input_count)).c_str()); - - // Process batch input if any - RaggedBatchingParam_Map batch_input_param_map; - - if (dynamic_batching_enabled) { - // Handle batch inputs for ragged batching - for (const auto& batch_input : model_state_->BatchInputs()) { - std::vector shape; - collector->BatchInputShape(batch_input, &shape); - - auto batch_input_kind = batch_input.BatchInputKind(); - auto batch_input_name = batch_input.TargetNames()[0]; - - // we only take care of the ragged input_ids - // Assume the first dimension (length) are different and others are the - // same BATCH_ITEM_SHAPE [num_requests (batches), num_dims (excluding - // batch dimension)] - if (batch_input_kind == BatchInput::Kind::BATCH_ITEM_SHAPE - && (batch_input_name == "input_ids_item_shape" - || batch_input_name == "request_prompt_embedding_item_shape")) { - RaggedBatchingParams param{}; - - size_t num_feature_dimensions = (size_t)shape[1]; - const char* dst_buffer = nullptr; - size_t dst_buffer_byte_size; - TRITONSERVER_MemoryType dst_memory_type; - int64_t dst_memory_type_id; - - // Batch inputs are always created on CPU - RESPOND_ALL_AND_SET_NULL_IF_ERROR((*responses), - responses->size(), - collector->ProcessBatchInput(batch_input, - nullptr, - 0, - {{TRITONSERVER_MEMORY_CPU, 0}}, - &dst_buffer, - &dst_buffer_byte_size, - &dst_memory_type, - &dst_memory_type_id)); - - param.batch_input_ptr = reinterpret_cast(dst_buffer); - - // concat all feature dimensions - param.batch_intput_size = (dst_buffer_byte_size / sizeof(int32_t)) / num_feature_dimensions; - if (num_feature_dimensions > 1) { - BackendMemory* batch_item_shape_memory; - RESPOND_ALL_AND_RETURN_IF_ERROR(responses, - request_count, - BackendMemory::Create(model_state_->TritonMemoryManager(), - {BackendMemory::AllocationType::CPU}, - 0, - dst_buffer_byte_size / num_feature_dimensions, - &batch_item_shape_memory)); - int32_t* batch_item_shape_memory_ptr = - reinterpret_cast(batch_item_shape_memory->MemoryPtr()); - for (size_t idx = 0; idx < param.batch_intput_size; idx++) { - int32_t concat_dimensions = 1; - for (size_t dim_idx = 0; dim_idx < num_feature_dimensions; dim_idx++) { - concat_dimensions *= param.batch_input_ptr[idx * num_feature_dimensions + dim_idx]; - // dim0 is seq length dimension - if (dim_idx == 0) { - param.max_seq_length = - std::max(param.max_seq_length, param.batch_input_ptr[idx * num_feature_dimensions]); - } - } - batch_item_shape_memory_ptr[idx] = concat_dimensions; - } - param.batch_input_ptr = reinterpret_cast(batch_item_shape_memory_ptr); - } - else { - param.max_seq_length = - *std::max_element(param.batch_input_ptr, param.batch_input_ptr + param.batch_intput_size); - } - - // check if padding is needed - param.is_input_ragged = std::any_of(param.batch_input_ptr, - param.batch_input_ptr + param.batch_intput_size, - [&](int x) { return x != param.batch_input_ptr[0]; }); - - // calculate statistics of elements - if (param.is_input_ragged) { - param.max_elements_per_seq = - *std::max_element(param.batch_input_ptr, param.batch_input_ptr + param.batch_intput_size); - param.total_input_elements = - std::accumulate(param.batch_input_ptr, param.batch_input_ptr + param.batch_intput_size, 0); - batch_input_param_map.insert({batch_input_name, param}); - // verbose logging for debugging - if (TRITONSERVER_LogIsEnabled(TRITONSERVER_LOG_VERBOSE)) { - std::string value_str = "[ "; - for (size_t i = 0; i < param.batch_intput_size; i++) { - value_str += std::to_string(param.batch_input_ptr[i]) + " "; - } - value_str += "]"; - - LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, - (std::string("collect batch input name: ") + batch_input_name + "\n size: " - + std::to_string(dst_buffer_byte_size) + " bytes\n value: " + value_str - + "\n max sequence length: " + std::to_string(param.max_seq_length) - + "\n max elements per sequence: " + std::to_string(param.max_elements_per_seq)) - .c_str()); - } - } - } - } - } - - // Process user-defined inputs - for (uint32_t input_idx = 0; input_idx < input_count; input_idx++) { - TRITONBACKEND_Input* input; - RESPOND_ALL_AND_RETURN_IF_ERROR( - responses, request_count, TRITONBACKEND_RequestInputByIndex(requests[0], input_idx, &input)); - - const char* input_name; - TRITONSERVER_DataType input_datatype; - const int64_t* input_shape; - uint32_t input_dims_count; - RESPOND_ALL_AND_RETURN_IF_ERROR( - responses, - request_count, - TRITONBACKEND_InputProperties( - input, &input_name, &input_datatype, &input_shape, &input_dims_count, nullptr, nullptr)); - - input_names->emplace_back(input_name); - - std::string input_name_str = std::string(input_name); - - // Pad input ids from different requests - RaggedBatchingParams param = batch_input_param_map[input_name_str + "_item_shape"]; - if (batch_input_param_map.find(input_name_str + "_item_shape") != batch_input_param_map.end() - && batch_input_param_map[input_name_str + "_item_shape"].is_input_ragged) { - RaggedBatchingParams param = batch_input_param_map[input_name_str + "_item_shape"]; - - const int64_t total_batch_size_int64 = (int64_t)total_batch_size; - const int64_t max_elements_per_seq_int64 = (int64_t)param.max_elements_per_seq; - const size_t padded_input_ids_buffer_size = - GetByteSize(input_datatype, std::vector{total_batch_size_int64, max_elements_per_seq_int64}); - // Always host memory - BackendMemory* padded_input_memory; - BackendMemory* request_input_memory; - RESPOND_ALL_AND_RETURN_IF_ERROR(responses, - request_count, - BackendMemory::Create(model_state_->TritonMemoryManager(), - {BackendMemory::AllocationType::CPU}, - 0, - padded_input_ids_buffer_size, - &padded_input_memory)); - RESPOND_ALL_AND_RETURN_IF_ERROR(responses, - request_count, - BackendMemory::Create(model_state_->TritonMemoryManager(), - {BackendMemory::AllocationType::CPU}, - 0, - padded_input_ids_buffer_size, - &request_input_memory)); - - memset(padded_input_memory->MemoryPtr(), 0, padded_input_ids_buffer_size); - - collector->ProcessTensor( - input_name, - request_input_memory->MemoryPtr(), - GetByteSize(input_datatype, std::vector{(int64_t)param.total_input_elements}), - request_input_memory->MemoryType(), - request_input_memory->MemoryTypeId()); - - int64_t accumulated_elements_offset = 0; - - char* padded_input_ids_ptr = padded_input_memory->MemoryPtr(); - char* base_input_ids = request_input_memory->MemoryPtr(); - - // copy each request buffer to padded buffer - for (int64_t single_batch_idx = 0; single_batch_idx < total_batch_size_int64; single_batch_idx++) { - int32_t sequence_elements = param.batch_input_ptr[single_batch_idx]; - std::memcpy(padded_input_ids_ptr - + GetByteSize(input_datatype, - std::vector{single_batch_idx, max_elements_per_seq_int64}), - base_input_ids - + GetByteSize(input_datatype, std::vector{accumulated_elements_offset}), - GetByteSize(input_datatype, std::vector{sequence_elements})); - - accumulated_elements_offset += sequence_elements; - } - - // modify batch dimension shape, and sequence length dimension shape after - // padding - std::vector batchn_shape(input_shape, input_shape + input_dims_count); - if (max_batch_size != 0) { - batchn_shape[0] = total_batch_size; - batchn_shape[1] = (size_t)param.max_seq_length; - // assume all non-seq-length dimensions have the same shape - if (input_dims_count > 2) { - batchn_shape[2] = (size_t)(param.max_elements_per_seq / param.max_seq_length); - } - } - (*input_tensors) - ->insert({std::string(input_name), - triton::Tensor{TRITONSERVER_MEMORY_CPU, input_datatype, batchn_shape, padded_input_ids_ptr}}); - - continue; - } - - // bool start_end_ready_flag = (input_name_str == "START" || input_name_str - // == "END" - // || input_name_str == "READY"); - - // int shape_dims_start = (int) (sequence_batching_enabled && - // !start_end_ready_flag); - - // The shape for the entire input patch, [total_batch_size, ...] - std::vector batchn_shape(input_shape, input_shape + input_dims_count); - if (max_batch_size != 0) { - batchn_shape[0] = total_batch_size; - } - - std::vector batchn_shape_2(input_shape, input_shape + input_dims_count); - if (max_batch_size != 0) { - batchn_shape_2[0] = total_batch_size; - } - - // std::vector batchn_shape( - // input_shape + shape_dims_start, input_shape + input_dims_count); - // if (max_batch_size != 0 && !start_end_ready_flag) { - // batchn_shape[0] = total_batch_size; - // } - - // std::vector batchn_shape_2( - // input_shape + shape_dims_start, input_shape + input_dims_count); - // if (max_batch_size != 0 && !start_end_ready_flag) { - // batchn_shape_2[0] = total_batch_size; - // } - - // The input must be in contiguous CPU/GPU memory. - const int64_t batchn_byte_size = GetByteSize(input_datatype, batchn_shape); - - // Always host memory - BackendMemory* input_memory; - RESPOND_ALL_AND_RETURN_IF_ERROR(responses, - request_count, - BackendMemory::Create(model_state_->TritonMemoryManager(), - {BackendMemory::AllocationType::CPU}, - 0, - batchn_byte_size, - &input_memory)); - input_memories->push_back(input_memory); - - TRITONSERVER_MemoryType memory_type = input_memory->MemoryType(); - int64_t memory_type_id = input_memory->MemoryTypeId(); - char* input_buffer = input_memory->MemoryPtr(); - - collector->ProcessTensor(input_name, input_buffer, batchn_byte_size, memory_type, memory_type_id); - - LOG_MESSAGE( - TRITONSERVER_LOG_VERBOSE, - (std::string("collect name: ") + input_name + " size: " + std::to_string(batchn_byte_size) + " bytes") - .c_str()); - (*input_tensors) - ->insert({std::string(input_name), - triton::Tensor{TRITONSERVER_MEMORY_CPU, input_datatype, batchn_shape_2, input_buffer}}); - } - - LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, - (std::string("the data is in ") + (*cuda_copy ? std::string("GPU") : std::string("CPU"))).c_str()); - // Finalize... - *cuda_copy |= collector->Finalize(); - LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, - (std::string("the data is in ") + (*cuda_copy ? std::string("GPU") : std::string("CPU"))).c_str()); -} - -void ModelInstanceState::ReadOutputTensors(size_t total_batch_size, - std::shared_ptr> output_tensors, - TRITONBACKEND_Request** requests, - const uint32_t request_count, - std::vector* responses) -{ - BackendOutputResponder responder(requests, - request_count, - responses, - model_state_->MaxBatchSize(), - model_state_->TritonMemoryManager(), - model_state_->EnablePinnedInput(), - output_stream_); - - bool cuda_copy = false; - // bool sequence_batching_enabled = model_state_->SequenceBatchingEnabled(); - std::vector> string_buffers; - - int idx = 0; - for (auto it = output_tensors->begin(); it != output_tensors->end(); ++it) { - LOG_MESSAGE( - TRITONSERVER_LOG_VERBOSE, - (std::string("Get output_tensors ") + std::to_string(idx) + std::string(": ") + std::string(it->first)) - .c_str()); - idx++; - auto& output = it->second; - - // Verify output datatype matches datatype from model config - TRITONSERVER_DataType output_dtype = output.type; - LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, - (std::string(" output_type: ") + TRITONSERVER_DataTypeString(output_dtype)).c_str()); - - const char* output_buffer = static_cast(output.data); - - // Set output shape - // std::vector batchn_shape = sequence_batching_enabled ? - // std::vector{1} : - // std::vector{}; - std::vector batchn_shape; - if (TRITONSERVER_LogIsEnabled(TRITONSERVER_LOG_VERBOSE)) { - // std::string batch_shape_str = sequence_batching_enabled ? " output - // shape: [1, " : - // " output shape: ["; - std::string batch_shape_str = " output shape: ["; - for (uint i = 0; i < output.shape.size(); i++) { - batchn_shape.push_back(output.shape[i]); - batch_shape_str = batch_shape_str + std::to_string(output.shape[i]); - if (i != output.shape.size() - 1) { - batch_shape_str = batch_shape_str + ", "; - } - else { - batch_shape_str = batch_shape_str + "]"; - } - } - LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, batch_shape_str.c_str()); - } - else { - batchn_shape.insert(batchn_shape.end(), output.shape.begin(), output.shape.end()); - } - - responder.ProcessTensor(it->first, - output_dtype, - batchn_shape, - output_buffer, - TRITONSERVER_MEMORY_CPU, - model_instance_device_id_start_); - } - - // Finalize and wait for any pending buffer copies. - cuda_copy |= responder.Finalize(); - -#ifdef TRITON_ENABLE_GPU - if (cuda_copy) { - cudaStreamSynchronize(output_stream_); - } -#endif // TRITON_ENABLE_GPU - - LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, - (std::string("PERFORMED GPU copy: ") + (cuda_copy ? std::string("YES") : std::string("NO"))).c_str()); -} - -///////////// - -extern "C" { - -TRITONSERVER_Error* TRITONBACKEND_Initialize(TRITONBACKEND_Backend* backend) -{ - int provided; - ft::mpi::initThread(nullptr, nullptr, ft::mpi::THREAD_MULTIPLE, &provided); - const char* cname; - RETURN_IF_ERROR(TRITONBACKEND_BackendName(backend, &cname)); - std::string name(cname); - - LOG_MESSAGE(TRITONSERVER_LOG_INFO, (std::string("TRITONBACKEND_Initialize: ") + name).c_str()); - - // Check the backend API version that Triton supports vs. what this - // backend was compiled against. - uint32_t api_version_major, api_version_minor; - RETURN_IF_ERROR(TRITONBACKEND_ApiVersion(&api_version_major, &api_version_minor)); - - LOG_MESSAGE(TRITONSERVER_LOG_INFO, - (std::string("Triton TRITONBACKEND API version: ") + std::to_string(api_version_major) + "." - + std::to_string(api_version_minor)) - .c_str()); - LOG_MESSAGE(TRITONSERVER_LOG_INFO, - (std::string("'") + name - + "' TRITONBACKEND API version: " + std::to_string(TRITONBACKEND_API_VERSION_MAJOR) + "." - + std::to_string(TRITONBACKEND_API_VERSION_MINOR)) - .c_str()); - - if ((api_version_major != TRITONBACKEND_API_VERSION_MAJOR) - || (api_version_minor < TRITONBACKEND_API_VERSION_MINOR)) { - return TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_UNSUPPORTED, - (std::string("Triton TRITONBACKEND API version: ") + std::to_string(api_version_major) + "." - + std::to_string(api_version_minor) + " does not support '" + name + "' TRITONBACKEND API version: " - + std::to_string(TRITONBACKEND_API_VERSION_MAJOR) + "." + std::to_string(TRITONBACKEND_API_VERSION_MINOR)) - .c_str()); - } - return nullptr; // success -} - -TRITONSERVER_Error* TRITONBACKEND_ModelInitialize(TRITONBACKEND_Model* model) -{ - const char* cname; - RETURN_IF_ERROR(TRITONBACKEND_ModelName(model, &cname)); - std::string name(cname); - - uint64_t version; - RETURN_IF_ERROR(TRITONBACKEND_ModelVersion(model, &version)); - - LOG_MESSAGE( - TRITONSERVER_LOG_INFO, - (std::string("TRITONBACKEND_ModelInitialize: ") + name + " (version " + std::to_string(version) + ")").c_str()); - - // Create a ModelState object and associate it with the - // TRITONBACKEND_Model. - ModelState* model_state; - RETURN_IF_ERROR(ModelState::Create(model, &model_state)); - RETURN_IF_ERROR(TRITONBACKEND_ModelSetState(model, reinterpret_cast(model_state))); - - return nullptr; // success -} - -TRITONSERVER_Error* TRITONBACKEND_ModelFinalize(TRITONBACKEND_Model* model) -{ - void* vstate; - RETURN_IF_ERROR(TRITONBACKEND_ModelState(model, &vstate)); - ModelState* model_state = reinterpret_cast(vstate); - - LOG_MESSAGE(TRITONSERVER_LOG_INFO, "TRITONBACKEND_ModelFinalize: delete model state"); - - delete model_state; - - LOG_MESSAGE(TRITONSERVER_LOG_INFO, "TRITONBACKEND_ModelFinalize: MPI Finalize"); - - ft::mpi::finalize(); - - return nullptr; // success -} - -TRITONSERVER_Error* TRITONBACKEND_ModelInstanceInitialize(TRITONBACKEND_ModelInstance* instance) -{ - int node_id = ft::mpi::getCommWorldRank(); - - const char* cname; - RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceName(instance, &cname)); - std::string name(cname); - - // Get the model state associated with this instance's model. - TRITONBACKEND_Model* model; - RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceModel(instance, &model)); - - void* vmodelstate; - RETURN_IF_ERROR(TRITONBACKEND_ModelState(model, &vmodelstate)); - ModelState* model_state = reinterpret_cast(vmodelstate); - - // Create a ModelInstanceState object and associate it with the - // TRITONBACKEND_ModelInstance. - ModelInstanceState* instance_state; - RETURN_IF_ERROR(ModelInstanceState::Create(model_state, instance, &instance_state)); - RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceSetState(instance, reinterpret_cast(instance_state))); - - int model_instance_id = instance_state->GetModelInstanceId(); - int model_instance_count = instance_state->GetModelInstanceCount(); - - LOG_MESSAGE(TRITONSERVER_LOG_INFO, - (std::string("TRITONBACKEND_ModelInstanceInitialize: ") + name + " (count " - + std::to_string(model_instance_count) + ")" + " (instance_id " + std::to_string(model_instance_id) - + ")") - .c_str()); - - if (node_id) { - while (true) { - instance_state->Execute( - nullptr, nullptr, 0, std::shared_ptr>(nullptr)); - } - } - - return nullptr; // success -} - -TRITONSERVER_Error* TRITONBACKEND_ModelInstanceFinalize(TRITONBACKEND_ModelInstance* instance) -{ - void* vstate; - RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceState(instance, &vstate)); - ModelInstanceState* instance_state = reinterpret_cast(vstate); - - LOG_MESSAGE(TRITONSERVER_LOG_INFO, "TRITONBACKEND_ModelInstanceFinalize: delete instance state"); - - delete instance_state; - - return nullptr; // success -} - -TRITONSERVER_Error* TRITONBACKEND_ModelInstanceExecute(TRITONBACKEND_ModelInstance* instance, - TRITONBACKEND_Request** requests, - const uint32_t request_count) -{ - // Triton will not call this function simultaneously for the same - // 'instance'. But since this backend could be used by multiple - // instances from multiple models the implementation needs to handle - // multiple calls to this function at the same time (with different - // 'instance' objects). Suggested practice for this is to use only - // function-local and model-instance-specific state (obtained from - // 'instance'), which is what we do here. - ModelInstanceState* instance_state; - RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceState(instance, reinterpret_cast(&instance_state))); - ModelState* model_state = instance_state->StateForModel(); - - // This backend specifies BLOCKING execution policy. That means that - // we should not return from this function until execution is - // complete. Triton will automatically release 'instance' on return - // from this function so that it is again available to be used for - // another call to TRITONBACKEND_ModelInstanceExecute. - - LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, - (std::string("model ") + model_state->Name() + ", instance " + instance_state->Name() + ", executing " - + std::to_string(request_count) + " requests") - .c_str()); - - // At this point we accept ownership of 'requests', which means that - // even if something goes wrong we must still return success from - // this function. If something does go wrong in processing a - // particular request then we send an error response just for the - // specific request. - instance_state->ProcessRequests(requests, request_count); - - return nullptr; // success -} - -} // extern "C" - -} // namespace turbomind_backend -} // namespace backend -} // namespace triton diff --git a/src/turbomind/triton_backend/libtriton_fastertransformer.ldscript b/src/turbomind/triton_backend/libtriton_fastertransformer.ldscript deleted file mode 100644 index 26d2fbb33..000000000 --- a/src/turbomind/triton_backend/libtriton_fastertransformer.ldscript +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions -# are met: -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * Neither the name of NVIDIA CORPORATION nor the names of its -# contributors may be used to endorse or promote products derived -# from this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY -# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR -# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR -# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY -# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -{ - global: - TRITONBACKEND_*; - local: *; -}; diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc index 2b27addb8..c5b0bf884 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc @@ -208,11 +208,16 @@ LlamaTritonModel::LlamaTritonModel(size_t tensor_para_size, // rotary embedding parameters attn_params_.rotary_embedding_dim = reader.GetInteger("llama", "rotary_embedding"); attn_params_.rotary_embedding_base = reader.GetFloat("llama", "rope_theta", 10000.0f); + attn_params_.rope_scaling_type = reader.Get("llama", "rope_scaling_type", ""); attn_params_.rope_scaling_factor = reader.GetFloat("llama", "rope_scaling_factor", 0.f); + attn_params_.low_freq_factor = reader.GetFloat("llama", "low_freq_factor", 1.0); + attn_params_.high_freq_factor = reader.GetFloat("llama", "high_freq_factor", 1.0); attn_params_.max_position_embeddings = reader.GetInteger("llama", "max_position_embeddings", 0); attn_params_.use_dynamic_ntk = reader.GetInteger("llama", "use_dynamic_ntk", 0); attn_params_.use_logn_attn = reader.GetInteger("llama", "use_logn_attn", 0); + attn_params_.original_max_position_embeddings = reader.GetInteger("llama", "original_max_position_embeddings", 0); + engine_params_.max_batch_size = reader.GetInteger("llama", "max_batch_size", 0); engine_params_.max_context_token_num = reader.GetInteger("llama", "max_context_token_num", 0); engine_params_.session_len = reader.GetInteger("llama", "session_len", 0); diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModelInstance.cc b/src/turbomind/triton_backend/llama/LlamaTritonModelInstance.cc index f9381a03f..e3ce79826 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModelInstance.cc +++ b/src/turbomind/triton_backend/llama/LlamaTritonModelInstance.cc @@ -21,7 +21,6 @@ #include "src/turbomind/triton_backend/llama/LlamaTritonModelInstance.h" #include "src/turbomind/macro.h" #include "src/turbomind/triton_backend/transformer_triton_backend.hpp" -#include "src/turbomind/triton_backend/triton_utils.hpp" #include "src/turbomind/utils/Tensor.h" #include "src/turbomind/utils/constant.h" #include "src/turbomind/utils/cuda_utils.h" diff --git a/src/turbomind/triton_backend/triton_utils.hpp b/src/turbomind/triton_backend/triton_utils.hpp deleted file mode 100644 index a87dd7d6f..000000000 --- a/src/turbomind/triton_backend/triton_utils.hpp +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "src/turbomind/triton_backend/transformer_triton_backend.hpp" -#include "src/turbomind/utils/Tensor.h" - -namespace ft = turbomind; - -template -void move_tensor_H2D(const triton::Tensor& tensor, - T*& d_ptr, - const std::unique_ptr>* allocator) -{ - if (tensor.where == triton::MEMORY_GPU) { - return; - } - - size_t tensor_size = 1; - for (auto t : tensor.shape) { - tensor_size *= t; - } - - cudaStream_t stream = (*allocator)->returnStream(); - - d_ptr = (T*)((*allocator)->reMalloc(d_ptr, sizeof(T) * tensor_size, false)); - ft::check_cuda_error(cudaMemcpyAsync(d_ptr, (T*)tensor.data, sizeof(T) * tensor_size, cudaMemcpyDefault, stream)); -} - -template -ft::Tensor as_GPU_tensor(const triton::Tensor& tensor, T* d_ptr) -{ - return ft::Tensor{ft::MEMORY_GPU, - triton::Tensor::convertTritonTypeToFt(tensor.type), - tensor.shape, - tensor.where == triton::MEMORY_CPU ? d_ptr : tensor.data}; -} - -inline ft::Tensor as_CPU_tensor(const triton::Tensor& tensor) -{ - ft::FT_CHECK(tensor.where == triton::MEMORY_CPU); - return ft::Tensor{ft::MEMORY_CPU, triton::Tensor::convertTritonTypeToFt(tensor.type), tensor.shape, tensor.data}; -} diff --git a/src/turbomind/utils/allocator.h b/src/turbomind/utils/allocator.h index 2a5d01cd0..d995e2a9b 100644 --- a/src/turbomind/utils/allocator.h +++ b/src/turbomind/utils/allocator.h @@ -167,6 +167,7 @@ class Allocator: public IAllocator { check_cuda_error(cudaGetDeviceCount(&device_count)); cudaMemPool_t mempool; check_cuda_error(cudaDeviceGetDefaultMemPool(&mempool, device_id)); +#if TM_ENABLE_CUSTOM_ALL_REDUCE cudaMemAccessDesc desc = {}; int peer_access_available = 0; for (int i = 0; i < device_count; i++) { @@ -184,6 +185,7 @@ class Allocator: public IAllocator { desc.flags = cudaMemAccessFlagsProtReadWrite; check_cuda_error(cudaMemPoolSetAccess(mempool, &desc, 1)); } +#endif // set memory pool threshold to avoid shrinking the pool uint64_t setVal = UINT64_MAX; check_cuda_error(cudaMemPoolSetAttribute(mempool, cudaMemPoolAttrReleaseThreshold, &setVal)); diff --git a/tests/test_lmdeploy/test_model.py b/tests/test_lmdeploy/test_model.py index eef633423..3cc1c1efb 100644 --- a/tests/test_lmdeploy/test_model.py +++ b/tests/test_lmdeploy/test_model.py @@ -20,6 +20,8 @@ ('deepseek-ai/deepseek-coder-6.7b-instruct', ['deepseek-coder']), ('deepseek-ai/deepseek-vl-7b-chat', ['deepseek-vl']), ('deepseek-ai/deepseek-moe-16b-chat', ['deepseek']), + ('internlm/internlm-xcomposer2-4khd-7b', ['internlm-xcomposer2']), + ('internlm/internlm-xcomposer2d5-7b', ['internlm-xcomposer2d5']), ('tiiuae/falcon-7b', ['falcon']), ('workspace', ['base'])]) @pytest.mark.parametrize('suffix', ['', '-w4', '-4bit', '-16bit']) def test_best_match_model(model_path_and_name, suffix): @@ -82,31 +84,27 @@ def test_vicuna(): def test_internlm_chat(): prompt = 'hello, can u introduce yourself' - model = MODELS.get('internlm-chat-7b')(capability='completion') + model = MODELS.get('internlm')(capability='completion') assert model.get_prompt(prompt, sequence_start=True) == prompt assert model.get_prompt(prompt, sequence_start=False) == prompt assert model.stop_words is not None assert model.system == '<|System|>:' - assert model.session_len == 2048 - model = MODELS.get('internlm-chat-7b')(capability='chat', - system='Provide answers in Python') + model = MODELS.get('internlm')(capability='chat', + system='Provide answers in Python') assert model.get_prompt(prompt, sequence_start=True) != prompt assert model.get_prompt(prompt, sequence_start=False) != prompt assert model.system == 'Provide answers in Python' - model = MODELS.get('internlm-chat-7b')(capability='voice') + model = MODELS.get('internlm')(capability='voice') _prompt = None with pytest.raises(AssertionError): _prompt = model.get_prompt(prompt, sequence_start=True) assert _prompt is None - model = MODELS.get('internlm-chat-7b-8k')() - assert model.session_len == 8192 - def test_messages2prompt4internlm2_chat(): - model = MODELS.get('internlm2-chat-7b')() + model = MODELS.get('internlm2')() # Test with a single message messages = [ { @@ -165,16 +163,38 @@ def test_messages2prompt4internlm2_chat(): assert actual_prompt == expected_prompt +def test_llama3_1(): + model = MODELS.get('llama3_1')() + messages = [ + dict(role='user', + content='Can you check the top 5 trending songs on spotify?') + ] + tools = [{ + 'name': 'spotify_trending_songs', + 'description': 'Get top trending songs on Spotify', + 'parameters': { + 'n': { + 'param_type': 'int', + 'description': 'Number of trending songs to get', + 'required': True + } + }, + }] + actual_prompt = model.messages2prompt(messages, tools=tools) + expected_prompt = '<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 23 Jul 2024\n\n# Tool Instructions\n- Always execute python code in messages that you share.\n- When looking for real time information use relevant functions if available else fallback to brave_search\n\n\n\nYou have access to the following functions:\n\nUse the function \'spotify_trending_songs\' to: Get top trending songs on Spotify\n{"name": "spotify_trending_songs", "description": "Get top trending songs on Spotify", "parameters": {"n": {"param_type": "int", "description": "Number of trending songs to get", "required": true}}}\n\n\nIf a you choose to call a function ONLY reply in the following format:\n<{start_tag}={function_name}>{parameters}{end_tag}\nwhere\n\nstart_tag => ` a JSON dict with the function argument name as key and function argument value as value.\nend_tag => ``\n\nHere is an example,\n{"example_name": "example_value"}\n\nReminder:\n- Function calls MUST follow the specified format\n- Required parameters MUST be specified\n- Only call one function at a time\n- Put the entire function call reply on one line"\n- Always add your sources when using search results to answer the user query\n\nYou are a helpful assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nCan you check the top 5 trending songs on spotify?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n' # noqa + assert actual_prompt == expected_prompt + + def test_baichuan(): prompt = 'hello, can u introduce yourself' - model = MODELS.get('baichuan-7b')(capability='completion') + model = MODELS.get('baichuan2')(capability='completion') assert model.get_prompt(prompt, sequence_start=True) == prompt assert model.get_prompt(prompt, sequence_start=False) == prompt assert model.stop_words is None - model = MODELS.get('baichuan-7b')(capability='chat') + model = MODELS.get('baichuan2')(capability='chat') _prompt = model.get_prompt(prompt, sequence_start=True) - assert _prompt == prompt + assert _prompt == '' + prompt + '' def test_llama2(): @@ -209,16 +229,16 @@ def test_llama3(): def test_qwen(): prompt = 'hello, can u introduce yourself' - model = MODELS.get('qwen-7b')(capability='completion') + model = MODELS.get('qwen')(capability='completion') assert model.get_prompt(prompt, sequence_start=True) == prompt assert model.get_prompt(prompt, sequence_start=False) == prompt assert model.stop_words is not None - model = MODELS.get('qwen-7b')(capability='chat') + model = MODELS.get('qwen')(capability='chat') assert model.get_prompt(prompt, sequence_start=True) != prompt assert model.get_prompt(prompt, sequence_start=False) != prompt - model = MODELS.get('qwen-7b')(capability='voice') + model = MODELS.get('qwen')(capability='voice') _prompt = None with pytest.raises(AssertionError): _prompt = model.get_prompt(prompt, sequence_start=True) @@ -327,7 +347,13 @@ def test_chatglm3(): def test_glm4(): - model = MODELS.get('glm4')() + model_path_and_name = 'THUDM/glm-4-9b-chat' + deduced_name = best_match_model(model_path_and_name) + assert deduced_name == 'glm4' + + model = MODELS.get(deduced_name)() + # check stop words + assert model.stop_words == ['<|user|>', '<|endoftext|>', '<|observation|>'] messages = [{ 'role': 'system', 'content': 'you are a helpful assistant' @@ -342,7 +368,7 @@ def test_glm4(): 'content': 'AGI is?' }] from transformers import AutoTokenizer - tokenizer = AutoTokenizer.from_pretrained('THUDM/glm-4-9b-chat', + tokenizer = AutoTokenizer.from_pretrained(model_path_and_name, trust_remote_code=True) ref = tokenizer.apply_chat_template(messages, tokenize=False) res = model.messages2prompt(messages) @@ -403,3 +429,29 @@ def test_internvl2(): 'assistant\nI am an AI<|im_end|>\n<|im_start|>assistant\n' res = model.messages2prompt(messages) assert res == expected + + +def test_codegeex4(): + model_path_and_name = 'THUDM/codegeex4-all-9b' + deduced_name = best_match_model(model_path_and_name) + assert deduced_name == 'codegeex4' + model = MODELS.get(deduced_name)() + messages = [{ + 'role': 'system', + 'content': 'you are a helpful assistant' + }, { + 'role': 'user', + 'content': 'who are you' + }, { + 'role': 'assistant', + 'content': 'I am an AI' + }, { + 'role': 'user', + 'content': 'AGI is?' + }] + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(model_path_and_name, + trust_remote_code=True) + ref = tokenizer.apply_chat_template(messages, tokenize=False) + res = model.messages2prompt(messages) + assert res.startswith(ref) From a9ec3fac0a1c46c2a1e65fb91dd7eab6433dab33 Mon Sep 17 00:00:00 2001 From: grimoire Date: Tue, 30 Jul 2024 14:28:27 +0800 Subject: [PATCH 17/89] support llava for llama --- .../pytorch/backends/cuda/graph_runner.py | 3 + lmdeploy/pytorch/models/llama.py | 10 +- lmdeploy/pytorch/models/llava.py | 96 ++++++++++--------- lmdeploy/pytorch/models/module_map.py | 2 +- lmdeploy/pytorch/models/patch.py | 11 +++ 5 files changed, 73 insertions(+), 49 deletions(-) diff --git a/lmdeploy/pytorch/backends/cuda/graph_runner.py b/lmdeploy/pytorch/backends/cuda/graph_runner.py index 8da9a83cb..9fa057b88 100644 --- a/lmdeploy/pytorch/backends/cuda/graph_runner.py +++ b/lmdeploy/pytorch/backends/cuda/graph_runner.py @@ -213,6 +213,9 @@ def check_enable_graph(self): if self.backend_config.eager_mode: return False + if getattr(self.model, 'support_cuda_graph', False): + return False + # TODO: should we enable cudagraph awq now? # hf_config = self.model_config.hf_config # quantization_config = getattr(hf_config, diff --git a/lmdeploy/pytorch/models/llama.py b/lmdeploy/pytorch/models/llama.py index 0a8ea55d5..b113c8b1b 100644 --- a/lmdeploy/pytorch/models/llama.py +++ b/lmdeploy/pytorch/models/llama.py @@ -289,6 +289,10 @@ def forward( return hidden_states + def get_input_embeddings(self): + """get input embeddings.""" + return self.embed_tokens + class LlamaForCausalLM(nn.Module): @@ -321,6 +325,10 @@ def forward( logits = logits.float() return logits + def get_input_embeddings(self): + """get input embeddings.""" + return self.model.get_input_embeddings() + def prepare_inputs_for_generation( self, past_key_values: List[List[torch.Tensor]], @@ -337,7 +345,7 @@ def prepare_inputs_for_generation( if vision_embeddings is not None and len(vision_embeddings) > 0: if inputs_embeds is None: - inputs_embeds = self.model.embed_tokens(input_ids) + inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to( inputs_embeds) diff --git a/lmdeploy/pytorch/models/llava.py b/lmdeploy/pytorch/models/llava.py index 8d5384592..ff6658cac 100644 --- a/lmdeploy/pytorch/models/llava.py +++ b/lmdeploy/pytorch/models/llava.py @@ -1,67 +1,69 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import List, Optional, Tuple, Union +from typing import Any, List import torch from torch import nn -from transformers.modeling_outputs import CausalLMOutputWithPast +from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -class PatchedLlavaLlamaForCausalLM(nn.Module): +from .patch import get_rewrite_cls - def forward(self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - return_dict: Optional[bool] = None, - **kwargs) -> Union[Tuple, CausalLMOutputWithPast]: - context = self.context.context - # get inputs from context - vision_embeddings = context.input_embeddings - vision_embedding_indexing = context.input_embedding_indexing - if inputs_embeds is None: - inputs_embeds = self.model.embed_tokens(input_ids) +class PatchedLlavaForConditionalGeneration(nn.Module): - if vision_embeddings is not None and len(vision_embeddings) > 0: - inputs_embeds[:, - vision_embedding_indexing, :] = vision_embeddings.to( - inputs_embeds) - outputs = self.model.forward(input_ids=None, - inputs_embeds=inputs_embeds, - past_key_values=past_key_values, - position_ids=position_ids, - return_dict=return_dict, - attention_mask=attention_mask) - hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - logits = logits.float() - return CausalLMOutputWithPast(logits=logits) + support_cuda_graph = True + def __init__(self, origin: nn.Module, ctx_mgr: StepContextManager): + super().__init__() + self.ctx_mgr = ctx_mgr + language_model_cls = get_rewrite_cls(origin.language_model) + self.language_model = language_model_cls(origin.language_model, + ctx_mgr) -class PatchedLlavaForConditionalGeneration(nn.Module): + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: List[List[torch.Tensor]], + attn_metadata: Any = None, + inputs_embeds: torch.Tensor = None, + **kwargs, + ): + return self.language_model.forward(input_ids=input_ids, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + position_ids=position_ids, + attn_metadata=attn_metadata) + + def get_input_embeddings(self): + """get input embeddings.""" + return self.language_model.get_input_embeddings() - def forward(self, - input_ids: torch.LongTensor = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - **kwargs) -> Union[Tuple, CausalLMOutputWithPast]: - """llava hf.""" - context = self.context.context + def prepare_inputs_for_generation( + self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: torch.Tensor = None, + context: StepContext = None, + ): + """prepare input.""" + input_ids = context.input_ids + position_ids = context.position_ids + attn_metadata = context.attn_metadata # get inputs from context vision_embeddings = context.input_embeddings vision_embedding_indexing = context.input_embedding_indexing - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(input_ids) - if vision_embeddings is not None and len(vision_embeddings) > 0: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to( inputs_embeds) - return self.language_model.forward(input_ids=None, - inputs_embeds=inputs_embeds, - past_key_values=past_key_values, - position_ids=position_ids) + + return dict( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + ) diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index 91ea7071e..fb3cde714 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -321,7 +321,7 @@ # llava MODULE_MAP.update({ 'llava.model.language_model.llava_llama.LlavaLlamaForCausalLM': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llava.PatchedLlavaLlamaForCausalLM', + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaForCausalLM', 'llava.model.language_model.llava_llama.LlavaLlamaModel': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaModel', 'llava.model.language_model.llava_mistral.LlavaMistralForCausalLM': diff --git a/lmdeploy/pytorch/models/patch.py b/lmdeploy/pytorch/models/patch.py index 79c22c6eb..08f3198a9 100644 --- a/lmdeploy/pytorch/models/patch.py +++ b/lmdeploy/pytorch/models/patch.py @@ -111,6 +111,17 @@ def get_origin_mod(self): return model +def get_rewrite_cls(model: torch.nn.Module, module_map: Dict[str, str] = None): + """get rewrite cls.""" + if module_map is None: + module_map = _get_module_map() + rewrite_qualname = _find_rewrite_module_qualname(model, + module_map=module_map) + if rewrite_qualname is None: + return None + return _class_from_qualname(rewrite_qualname) + + def _patch(model: torch.nn.Module, module_map: Dict[str, str]) -> torch.nn.Module: """patch the model with rewrite module. From 5158b96fc16c5dcf19a43ce6c8f81f323d586139 Mon Sep 17 00:00:00 2001 From: grimoire Date: Wed, 31 Jul 2024 11:05:50 +0800 Subject: [PATCH 18/89] fix support cudagraph flag --- lmdeploy/pytorch/backends/cuda/graph_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lmdeploy/pytorch/backends/cuda/graph_runner.py b/lmdeploy/pytorch/backends/cuda/graph_runner.py index 9fa057b88..04e6e1690 100644 --- a/lmdeploy/pytorch/backends/cuda/graph_runner.py +++ b/lmdeploy/pytorch/backends/cuda/graph_runner.py @@ -213,7 +213,7 @@ def check_enable_graph(self): if self.backend_config.eager_mode: return False - if getattr(self.model, 'support_cuda_graph', False): + if not getattr(self.model, 'support_cuda_graph', False): return False # TODO: should we enable cudagraph awq now? From 580cdd0a995d6fd31a1d522cc530534539e9b5d9 Mon Sep 17 00:00:00 2001 From: grimoire Date: Wed, 31 Jul 2024 18:48:06 +0800 Subject: [PATCH 19/89] support lora cudagraph --- lmdeploy/pytorch/adapter/adapter.py | 64 ++++++++- .../pytorch/backends/cuda/graph_runner.py | 17 ++- lmdeploy/pytorch/backends/cuda/slora.py | 123 +++++++++--------- lmdeploy/pytorch/backends/slora.py | 14 +- lmdeploy/pytorch/engine/engine.py | 10 +- lmdeploy/pytorch/engine/model_agent.py | 50 ++++--- lmdeploy/pytorch/layers/linear.py | 23 +++- lmdeploy/pytorch/model_inputs.py | 49 +------ 8 files changed, 202 insertions(+), 148 deletions(-) diff --git a/lmdeploy/pytorch/adapter/adapter.py b/lmdeploy/pytorch/adapter/adapter.py index 97749be9a..815412d65 100644 --- a/lmdeploy/pytorch/adapter/adapter.py +++ b/lmdeploy/pytorch/adapter/adapter.py @@ -77,6 +77,7 @@ def get_indexed_lora_linears(model: torch.nn.Module): def update_lora_linears(lora_linears: Dict, weight_maps: List['AdapterWeightMap'], + gpu_caches: List[List], device: str = 'cuda'): """update lora linears.""" @@ -87,14 +88,43 @@ def __update_linear(linear, idx, target_name, adapter_names): linear.adapter_info.lora_A = None linear.adapter_info.lora_B = None + num_adapters = len(weight_maps) + target_modules = weight_maps[0].target_modules + num_offs = len(weight_maps[0].rank_offset) + max_rank = weight_maps[0].max_rank + target_map = dict((name, idx) for idx, name in enumerate(target_modules)) + all_ranks = [weight_maps[idx].rank for idx in range(num_adapters)] + all_ranks = torch.tensor(all_ranks, device=device).t().contiguous() + all_scaling = [weight_maps[idx].scaling for idx in range(num_adapters)] + all_scaling = torch.tensor(all_scaling, device=device).t().contiguous() + all_boffs = torch.zeros(num_adapters, + num_offs, + dtype=torch.int64, + device=device) + all_boffs = all_boffs.unflatten(-1, (-1, max_rank)) + adapter_names = [weight_map.adapter_name for weight_map in weight_maps] for idx, lora_linear in lora_linears.items(): + a_cache, b_cache = gpu_caches[idx] for target, linear in lora_linear.items(): + target_id = target_map[target] + ranks = all_ranks[target_id] + scalings = all_scaling[target_id] + boffs = all_boffs[:, target_id] + linear.post_init( + ranks, + scalings, + boffs, + a_cache, + b_cache, + max_rank=max_rank, + ) __update_linear(linear, idx, target_name=target, adapter_names=adapter_names) + return all_boffs.flatten(1, 2) @dataclass @@ -102,7 +132,7 @@ class LoRALinearInfo: """lora linear info.""" ranks: Dict[str, int] scalings: Dict[str, int] - target_names: List[str] + adapter_names: List[str] in_features: int out_features: int rank_stride: int = field(default=0, init=False) @@ -122,7 +152,7 @@ def from_loralinear(cls, linear: torch.nn.Module): return cls( ranks=ranks, scalings=scalings, - target_names=list(ranks.keys()), + adapter_names=list(ranks.keys()), in_features=in_features, out_features=out_features, ) @@ -182,8 +212,10 @@ def get_loralinear_info(model: torch.nn.Module): @dataclass class AdapterWeightMap: + adapter_id: int adapter_name: str rank: List[int] + scaling: List[int] rank_offset: np.ndarray max_rank: int target_modules: List[str] @@ -234,6 +266,7 @@ def cache_adapter(self, lora_linears: Dict, caches: List[List[Tensor]]): class SchedulerAdapter: """lora adapter.""" + adapter_id: int adapter_name: str rank: List[int] scaling: List[int] @@ -248,8 +281,9 @@ class SchedulerAdapter: _active: bool = field(default=False, init=False) @classmethod - def new(cls, adapter_name: str, linear_infos: Dict[str, LoRALinearInfo], - block_numel: int, max_rank: int): + def new(cls, adapter_id: int, adapter_name: str, + linear_infos: Dict[str, LoRALinearInfo], block_numel: int, + max_rank: int): """new.""" target_modules = list(linear_infos.keys()) @@ -281,6 +315,7 @@ def new(cls, adapter_name: str, linear_infos: Dict[str, LoRALinearInfo], block_idx_per_rank = np.concatenate(block_idx_per_rank) ret = cls( + adapter_id=adapter_id, adapter_name=adapter_name, rank=rank, scaling=scaling, @@ -322,14 +357,24 @@ def build_weight_map(self): """build weight map.""" assert self.rank_offset is not None return AdapterWeightMap( + adapter_id=self.adapter_id, adapter_name=self.name, rank=self.rank, + scaling=self.scaling, rank_offset=self.rank_offset, max_rank=self.max_rank, target_modules=self.target_modules, ) +def _get_adapter_names(linear_infos: Dict[str, LoRALinearInfo]): + """get adapter names.""" + adapter_names = set() + for info in linear_infos.values(): + adapter_names.update(info.adapter_names) + return list(adapter_names) + + class AdapterManager: """adapter manager.""" @@ -339,6 +384,11 @@ def __init__(self, linear_infos: Dict[str, LoRALinearInfo], self.block_numel = block_numel self._adapters: Dict[str, SchedulerAdapter] = dict() + adapter_names = _get_adapter_names(linear_infos) + self.adapter_id_map = dict( + (name, idx + 1) for idx, name in enumerate(adapter_names)) + self.adapter_id_map[None] = 0 + self.max_rank = self._get_max_rank() self._add_non_adapter() @@ -353,7 +403,9 @@ def _get_max_rank(self): def _add_non_adapter(self): """add non adapter.""" - self.add_adapter(None) + adapter = self.add_adapter(None) + rank_offset = adapter.inblock_offset.copy() + adapter.update_rank_offset(rank_offset) def _register_adapter(self, adapter: SchedulerAdapter): """register adapter.""" @@ -371,7 +423,9 @@ def num_adapters(self): def add_adapter(self, adapter_name: str): """add adapter.""" + adapter_id = self.adapter_id_map[adapter_name] adapter = SchedulerAdapter.new( + adapter_id, adapter_name, self.linear_infos, self.block_numel, diff --git a/lmdeploy/pytorch/backends/cuda/graph_runner.py b/lmdeploy/pytorch/backends/cuda/graph_runner.py index 04e6e1690..4f4f47b91 100644 --- a/lmdeploy/pytorch/backends/cuda/graph_runner.py +++ b/lmdeploy/pytorch/backends/cuda/graph_runner.py @@ -89,7 +89,21 @@ def make_buffers(self): dtype=torch.int64, device=device) self.input_buffers['local_adapter_ids'] = torch.zeros( - max_batches, dtype=torch.int64, device=device) - 1 + max_batches, dtype=torch.int64, device=device) + + def _fill_context(self): + """fill context.""" + context = self.ctx_mgr.current_context() + local_adapter_ids = context.local_adapter_ids + if local_adapter_ids is not None: + batch_size = local_adapter_ids.size(0) + self.input_buffers['local_adapter_ids'].fill_(0) + self.input_buffers[ + 'local_adapter_ids'][:batch_size] = local_adapter_ids + context.local_adapter_ids = self.input_buffers['local_adapter_ids'] + context.q_seqlens = self.input_buffers['q_seqlens'] + context.kv_seqlens = self.input_buffers['kv_seqlens'] + context.q_start_loc = self.input_buffers['q_start_loc'] def _fill_inputs(self, input_ids: torch.Tensor, position_ids: torch.Tensor, past_key_values: List, attn_metadata: Any, @@ -154,6 +168,7 @@ def _fill_inputs(self, input_ids: torch.Tensor, position_ids: torch.Tensor, 'inputs_embeds'] new_inputs.update(kwargs) + self._fill_context() return new_inputs def capture(self, **kwargs): diff --git a/lmdeploy/pytorch/backends/cuda/slora.py b/lmdeploy/pytorch/backends/cuda/slora.py index d313d6cbd..c57bca29f 100644 --- a/lmdeploy/pytorch/backends/cuda/slora.py +++ b/lmdeploy/pytorch/backends/cuda/slora.py @@ -16,16 +16,10 @@ class PackedLoRAInput: """packed lora input.""" x: torch.Tensor - a_cache: torch.Tensor - b_cache: torch.Tensor q_start_loc: torch.Tensor q_seqlens: torch.Tensor adapter_ids: torch.Tensor - scaling: torch.Tensor - rank_offset: torch.Tensor - ranks: torch.Tensor max_seq_len: int - max_rank: int is_decoding: bool @@ -36,37 +30,40 @@ def __init__(self, adapter_info: AdapterInfo, ctx_mgr: StepContextManager, colwise: bool = True): + super().__init__() self.base_slice = adapter_info.base_slice self.ctx_mgr = ctx_mgr self.colwise = colwise - def _make_packed_lora_input(self, x, target_name: str, layer_idx: int): + def post_init( + self, + ranks: torch.Tensor, + scalings: torch.Tensor, + rank_offsets: torch.Tensor, + a_cache: torch.Tensor, + b_cache: torch.Tensor, + max_rank: int, + ): + """post init.""" + self.ranks = ranks + self.scalings = scalings + self.rank_offsets = rank_offsets + self.a_cache = a_cache + self.b_cache = b_cache + self.max_rank = max_rank + + def _make_packed_lora_input(self, x): """make PackedLoRAInput.""" context = self.ctx_mgr.current_context() - adapter_param = context.adapter_params[target_name] # adapter cache - ranks = adapter_param.ranks - scaling = adapter_param.scalings - rank_offset = adapter_param.rank_offsets - max_rank = adapter_param.max_rank - k_cache, v_cache = context.kv_caches[layer_idx] - cache_len = k_cache.size(0) - a_cache = k_cache.view(cache_len, -1) - b_cache = v_cache.view(cache_len, -1) max_q_seq_length = x.numel() // x.size(-1) return PackedLoRAInput(x=x.flatten(0, -2).contiguous(), - a_cache=a_cache, - b_cache=b_cache, q_start_loc=context.q_start_loc, q_seqlens=context.q_seqlens, adapter_ids=context.local_adapter_ids, - scaling=scaling, - rank_offset=rank_offset, - ranks=ranks, max_seq_len=max_q_seq_length, - max_rank=max_rank, is_decoding=context.is_decoding) def _forward_rowwise(self, @@ -82,39 +79,39 @@ def _forward_rowwise(self, out_size //= world_size if not lora_input.is_decoding: xa = mbgmm_a(lora_input.x, - lora_input.a_cache, + self.a_cache, q_start_loc=lora_input.q_start_loc, q_seqlens=lora_input.q_seqlens, adapter_ids=lora_input.adapter_ids, - rank_offset=lora_input.rank_offset, - ranks=lora_input.ranks, + rank_offset=self.rank_offsets, + ranks=self.ranks, max_seq_len=lora_input.max_seq_len, - max_rank=lora_input.max_rank) + max_rank=self.max_rank) lora_out = mbgmm_b(xa, - lora_input.b_cache, + self.b_cache, q_start_loc=lora_input.q_start_loc, q_seqlens=lora_input.q_seqlens, adapter_ids=lora_input.adapter_ids, - scaling=lora_input.scaling, - rank_offset=lora_input.rank_offset, - ranks=lora_input.ranks, + scaling=self.scalings, + rank_offset=self.rank_offsets, + ranks=self.ranks, max_seq_len=lora_input.max_seq_len, - max_rank=lora_input.max_rank, + max_rank=self.max_rank, out_size=out_size) else: xa = mbgmv_a(lora_input.x, - lora_input.a_cache, + self.a_cache, adapter_ids=lora_input.adapter_ids, - rank_offset=lora_input.rank_offset, - ranks=lora_input.ranks, - max_rank=lora_input.max_rank) + rank_offset=self.rank_offsets, + ranks=self.ranks, + max_rank=self.max_rank) lora_out = mbgmv_b(xa, - lora_input.b_cache, + self.b_cache, adapter_ids=lora_input.adapter_ids, - scaling=lora_input.scaling, - rank_offset=lora_input.rank_offset, - ranks=lora_input.ranks, - max_rank=lora_input.max_rank, + scaling=self.scalings, + rank_offset=self.rank_offsets, + ranks=self.ranks, + max_rank=self.max_rank, out_size=out_size) if is_tp: @@ -151,63 +148,63 @@ def __gather_xa(xa): if not lora_input.is_decoding: xa = mbgmm_a(lora_input.x, - lora_input.a_cache, + self.a_cache, q_start_loc=lora_input.q_start_loc, q_seqlens=lora_input.q_seqlens, adapter_ids=lora_input.adapter_ids, - rank_offset=lora_input.rank_offset, - ranks=lora_input.ranks, + rank_offset=self.rank_offsets, + ranks=self.ranks, max_seq_len=lora_input.max_seq_len, - max_rank=lora_input.max_rank, + max_rank=self.max_rank, rank_step=world_size) gathered_xa = __gather_xa(xa) - if len(lora_input.ranks) > 1: + if len(self.ranks) > 1: gathered_xa = rearange_all_gather( gathered_xa, b_start_loc=lora_input.q_start_loc, b_seq_lens=lora_input.q_seqlens, adapter_ids=lora_input.adapter_ids, - ranks=lora_input.ranks, + ranks=self.ranks, world_size=world_size, max_seq_len=lora_input.max_seq_len, output=gathered_xa) lora_out = mbgmm_b(gathered_xa, - lora_input.b_cache, + self.b_cache, q_start_loc=lora_input.q_start_loc, q_seqlens=lora_input.q_seqlens, adapter_ids=lora_input.adapter_ids, - scaling=lora_input.scaling, - rank_offset=lora_input.rank_offset, - ranks=lora_input.ranks, + scaling=self.scalings, + rank_offset=self.rank_offsets, + ranks=self.ranks, max_seq_len=lora_input.max_seq_len, - max_rank=lora_input.max_rank, + max_rank=self.max_rank, out_size=out_size) else: xa = mbgmv_a(lora_input.x, - lora_input.a_cache, + self.a_cache, adapter_ids=lora_input.adapter_ids, - rank_offset=lora_input.rank_offset, - ranks=lora_input.ranks, - max_rank=lora_input.max_rank, + rank_offset=self.rank_offsets, + ranks=self.ranks, + max_rank=self.max_rank, rank_step=world_size) gathered_xa = __gather_xa(xa) - if len(lora_input.ranks) > 1: + if len(self.ranks) > 1: gathered_xa = rearange_all_gather( gathered_xa, b_start_loc=lora_input.q_start_loc, b_seq_lens=lora_input.q_seqlens, adapter_ids=lora_input.adapter_ids, - ranks=lora_input.ranks, + ranks=self.ranks, world_size=world_size, max_seq_len=lora_input.max_seq_len, output=gathered_xa) lora_out = mbgmv_b(gathered_xa, - lora_input.b_cache, + self.b_cache, adapter_ids=lora_input.adapter_ids, - scaling=lora_input.scaling, - rank_offset=lora_input.rank_offset, - ranks=lora_input.ranks, - max_rank=lora_input.max_rank, + scaling=self.scalings, + rank_offset=self.rank_offsets, + ranks=self.ranks, + max_rank=self.max_rank, out_size=out_size) lora_out = lora_out.reshape(sliced_base.shape) @@ -222,7 +219,7 @@ def forward(self, layer_idx: int, is_tp: bool = True): """forward.""" - lora_input = self._make_packed_lora_input(x, target_name, layer_idx) + lora_input = self._make_packed_lora_input(x) if self.colwise and is_tp: return self._forward_colwise(lora_input, base_output) else: diff --git a/lmdeploy/pytorch/backends/slora.py b/lmdeploy/pytorch/backends/slora.py index fc7681fd5..18f8c9106 100644 --- a/lmdeploy/pytorch/backends/slora.py +++ b/lmdeploy/pytorch/backends/slora.py @@ -34,9 +34,21 @@ def from_lora_linear(mod: nn.Module, base_slice: slice = None): ) -class SLoRAImpl(ABC): +class SLoRAImpl(ABC, nn.Module): """slora implementation api.""" + def post_init( + self, + ranks: torch.Tensor, + scalings: torch.Tensor, + rank_offsets: torch.Tensor, + a_cache: torch.Tensor, + b_cache: torch.Tensor, + max_rank: int, + ): + """post init.""" + raise NotImplementedError + @abstractmethod def forward(self, x: torch.Tensor, diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 68b31b73c..78c81f471 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -63,6 +63,9 @@ def _paging_adapters(adapters: dict, model_agent: AutoModelAgent, scheduler: Scheduler): adapters = adapters or dict() weight_maps = [] + adapter_manager = scheduler.adapter_manager + non_adapter = adapter_manager.get_adapter(None) + weight_maps.append(non_adapter.build_weight_map()) for name in adapters: weight_map = scheduler.add_adapter(name) weight_map.rank_offset = torch.tensor(weight_map.rank_offset) @@ -80,8 +83,7 @@ def _tensorlize_block_offsets(block_offsets): def _get_adapter_ids(seqs: SeqList, adapters: AdapterList): """get adapter ids.""" - adapter_names_map = dict( - (ada.name, idx) for idx, ada in enumerate(adapters)) + adapter_names_map = dict((ada.name, ada.adapter_id) for ada in adapters) adapter_ids = [adapter_names_map[seq.adapter_name] for seq in seqs] return adapter_ids @@ -143,10 +145,6 @@ def __init__(self, eager_mode=engine_config.eager_mode, device_type=engine_config.device_type, ) - if (engine_config.adapters is not None - and not backend_config.eager_mode): - logger.warning('LoRA adapter require eager_mode=True') - backend_config.eager_mode = True with get_device_manager().context(self.device_context): self.model_agent = AutoModelAgent.from_pretrained( diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index cc9d29e93..41a0f47c2 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -148,6 +148,8 @@ def model_forward( world_size=world_size, kv_caches=cache_engine.gpu_cache, ) + if inputs.adapter_info is not None: + inputs.adapter_info.update_offsets(model.rank_offsets) with ctx_mgr.context(context): input_dict = model.prepare_inputs_for_generation( past_key_values=cache_engine.gpu_cache, @@ -211,29 +213,6 @@ def _remove_unused_modules(hf_model: torch.nn.Module, model_cfg: ModelConfig): return hf_model -def _unparam_lora_weight(model: torch.nn.Module): - """unparam lora weight. - - We don't want to move weight of lora to gpu. - """ - from peft.tuners.lora import Linear as LoRALinear - - def _tensorize_weight(linear): - """tensorize weight.""" - w = linear.weight - del linear.weight - linear.weight = w.data - - for _, mod in model.named_modules(): - if isinstance(mod, LoRALinear): - lora_A = mod.lora_A - lora_B = mod.lora_B - for linear in lora_A.values(): - _tensorize_weight(linear) - for linear in lora_B.values(): - _tensorize_weight(linear) - - SwapMap = Dict[int, int] @@ -364,9 +343,6 @@ def _build_model(self, patched_model = update_model(hf_model) - if adapters: - _unparam_lora_weight(patched_model) - return patched_model def get_loralinear_info(self): @@ -387,9 +363,18 @@ def paging_adapters(self, weight_maps: List[AdapterWeightMap]): cpu_caches = [(kcache.view(num_blocks, -1), vcache.view(num_blocks, -1)) for kcache, vcache in cpu_caches] + gpu_caches = self.cache_engine.gpu_cache + num_gpu_blocks = self.cache_engine.num_gpu_blocks + gpu_caches = [(kcache.view(num_gpu_blocks, + -1), vcache.view(num_gpu_blocks, -1)) + for kcache, vcache in gpu_caches] for weight_map in weight_maps: weight_map.cache_adapter(lora_linears, cpu_caches) - update_lora_linears(lora_linears, weight_maps, device='cuda') + rank_offsets = update_lora_linears(lora_linears, + weight_maps, + gpu_caches, + device='cuda') + self.patched_model.rank_offsets = rank_offsets def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap): @@ -601,9 +586,18 @@ def __paging(weight_maps): cpu_caches = [(kcache.view(num_blocks, -1), vcache.view(num_blocks, -1)) for kcache, vcache in cpu_caches] + gpu_caches = cache_engine.gpu_cache + num_gpu_blocks = cache_engine.num_gpu_blocks + gpu_caches = [(kcache.view(num_gpu_blocks, + -1), vcache.view(num_gpu_blocks, -1)) + for kcache, vcache in gpu_caches] for weight_map in weight_maps: weight_map.cache_adapter(lora_linears, cpu_caches) - update_lora_linears(lora_linears, weight_maps, device='cuda') + rank_offsets = update_lora_linears(lora_linears, + weight_maps, + gpu_caches, + device='cuda') + patched_model.rank_offsets = rank_offsets weight_maps = __get_weight_map() diff --git a/lmdeploy/pytorch/layers/linear.py b/lmdeploy/pytorch/layers/linear.py index c7a35c4a4..093ead71d 100644 --- a/lmdeploy/pytorch/layers/linear.py +++ b/lmdeploy/pytorch/layers/linear.py @@ -68,6 +68,25 @@ def __init__(self, self.layer_idx = None self.is_tp = is_tp + def post_init( + self, + ranks: torch.Tensor, + scalings: torch.Tensor, + rank_offsets: torch.Tensor, + a_cache: torch.Tensor, + b_cache: torch.Tensor, + max_rank: int, + ): + """post init.""" + self.impl.post_init( + ranks, + scalings, + rank_offsets, + a_cache, + b_cache, + max_rank, + ) + def forward(self, x, base_output=None): """forward of loraA@loraB.""" return self.impl.forward(x, base_output, self.target_name, @@ -99,15 +118,15 @@ def __init__(self, def forward(self, x): """w4a16 forward.""" + is_tp = False if self.colwise else self.is_tp if self.lora_adapters is None: - is_tp = False if self.colwise else self.is_tp return self.impl.forward(x, is_tp) out = self.impl.forward(x, False) if self.lora_adapters is not None: for lora_adapter in self.lora_adapters: out = lora_adapter(x, out) - if self.is_tp: + if is_tp: dist.all_reduce(out) return out diff --git a/lmdeploy/pytorch/model_inputs.py b/lmdeploy/pytorch/model_inputs.py index 63829f790..dcd805412 100644 --- a/lmdeploy/pytorch/model_inputs.py +++ b/lmdeploy/pytorch/model_inputs.py @@ -12,57 +12,27 @@ @dataclass class AdapterInfo: - ranks: torch.LongTensor - scalings: torch.Tensor + adapter_ids: torch.LongTensor rank_offsets: torch.LongTensor - target_modules: List[str] - max_rank_per_target: List[int] - max_rank: int @classmethod def from_adapters(cls, adapters: List[SchedulerAdapter]): """from adapters.""" if len(adapters) == 0: return None - target_modules = adapters[0].target_modules - max_rank = adapters[0].max_rank - ranks = [ada.rank for ada in adapters] - scalings = [ada.scaling for ada in adapters] + adapter_ids = [ada.adapter_id for ada in adapters] + adapter_ids = torch.tensor(adapter_ids) rank_offsets = [torch.from_numpy(ada.rank_offset) for ada in adapters] - ranks = torch.tensor(ranks) - scalings = torch.tensor(scalings) rank_offsets = torch.stack(rank_offsets) - max_rank_per_target = ranks.max(0)[0].tolist() return cls( - ranks=ranks, - scalings=scalings, + adapter_ids=adapter_ids, rank_offsets=rank_offsets, - target_modules=target_modules, - max_rank=max_rank, - max_rank_per_target=max_rank_per_target, ) - def split_by_targets(self): - """split by targets.""" - ret = dict() - max_rank = self.max_rank - for idx, target in enumerate(self.target_modules): - r = self.ranks[:, idx] - scaling = self.scalings[:, idx] - r_off_start = idx * max_rank - r_off_end = r_off_start + max_rank - rank_offset = self.rank_offsets[:, r_off_start:r_off_end] - max_rank_per_target = [self.max_rank_per_target[idx]] - ret[target] = AdapterInfo( - r, - scaling, - rank_offset, - target_modules=[target], - max_rank=max_rank_per_target[0], - max_rank_per_target=max_rank_per_target, - ) - return ret + def update_offsets(self, rank_offsets: torch.LongTensor): + """update rank offsets.""" + rank_offsets[self.adapter_ids] = self.rank_offsets def to_device(self, device: str): """to device.""" @@ -285,10 +255,6 @@ def new( kv_seqlens = q_seqlens + history_seqlens kv_seqlens -= inputs.num_ignored_history - adapter_params = None - if inputs.adapter_info is not None: - adapter_params = inputs.adapter_info.split_by_targets() - ret = StepContext( input_ids=inputs.input_ids, block_offsets=inputs.block_offsets, @@ -303,7 +269,6 @@ def new( is_decoding=inputs.is_decoding, world_size=world_size, local_adapter_ids=inputs.local_adapter_ids, - adapter_params=adapter_params, ) ret = get_backend().update_step_context(ret) From 449f9474866da9d7df9d732cd83ee41184c787bd Mon Sep 17 00:00:00 2001 From: grimoire Date: Thu, 1 Aug 2024 14:31:28 +0800 Subject: [PATCH 20/89] support logit softcapping --- lmdeploy/pytorch/backends/attention.py | 4 ++-- lmdeploy/pytorch/backends/cuda/attention.py | 4 +++- lmdeploy/pytorch/layers/attention.py | 4 ++-- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/lmdeploy/pytorch/backends/attention.py b/lmdeploy/pytorch/backends/attention.py index d7adf9a09..e07bed5b9 100644 --- a/lmdeploy/pytorch/backends/attention.py +++ b/lmdeploy/pytorch/backends/attention.py @@ -31,7 +31,7 @@ def __init__( v_head_size: int = None, alibi_scale: float = None, sliding_window: int = None, - logical_softcapping: float = None, + logit_softcapping: float = None, **kwargs, ) -> None: if scale is None: @@ -50,7 +50,7 @@ def __init__( self.v_head_size = v_head_size self.alibi_scale = alibi_scale self.sliding_window = sliding_window - self.logical_softcapping = logical_softcapping + self.logit_softcapping = logit_softcapping @abstractmethod def forward( diff --git a/lmdeploy/pytorch/backends/cuda/attention.py b/lmdeploy/pytorch/backends/cuda/attention.py index 459e5459f..75f2d3b05 100644 --- a/lmdeploy/pytorch/backends/cuda/attention.py +++ b/lmdeploy/pytorch/backends/cuda/attention.py @@ -21,7 +21,7 @@ def __init__( v_head_size: int = None, alibi_scale: float = None, sliding_window: int = None, - logical_softcapping: float = None, + logit_softcapping: float = None, **kwargs, ): super().__init__( @@ -32,6 +32,7 @@ def __init__( v_head_size, alibi_scale, sliding_window, + logit_softcapping, **kwargs, ) @@ -90,6 +91,7 @@ def forward( max_seqlen=max_q_seqlen, window_size=self.sliding_window, sm_scale=self.scale, + logit_softcapping=self.logit_softcapping, ) return attn_output diff --git a/lmdeploy/pytorch/layers/attention.py b/lmdeploy/pytorch/layers/attention.py index e14f5ac8f..a7a7326bd 100644 --- a/lmdeploy/pytorch/layers/attention.py +++ b/lmdeploy/pytorch/layers/attention.py @@ -18,7 +18,7 @@ def __init__( v_head_size: int = None, alibi_scale: float = None, sliding_window: int = None, - logical_softcapping: float = None, + logit_softcapping: float = None, **kwargs, ): super().__init__() @@ -34,7 +34,7 @@ def __init__( v_head_size, alibi_scale, sliding_window, - logical_softcapping, + logit_softcapping, **kwargs, ) From 0e16e697bf092e22b30b79e340e5531fcc825a9a Mon Sep 17 00:00:00 2001 From: grimoire Date: Thu, 1 Aug 2024 18:25:41 +0800 Subject: [PATCH 21/89] support transformers 4.43 --- .../pytorch/kernels/cuda/apply_rotary_pos_emb.py | 2 +- lmdeploy/pytorch/models/llama.py | 14 +++++++++++--- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/lmdeploy/pytorch/kernels/cuda/apply_rotary_pos_emb.py b/lmdeploy/pytorch/kernels/cuda/apply_rotary_pos_emb.py index 0330aa94f..e8141a45f 100644 --- a/lmdeploy/pytorch/kernels/cuda/apply_rotary_pos_emb.py +++ b/lmdeploy/pytorch/kernels/cuda/apply_rotary_pos_emb.py @@ -34,7 +34,7 @@ BLOCK_KH=torch.int32, BLOCK_N=torch.int32, )) -@triton.jit +@triton.jit(do_not_specialize=('seq_len', )) def apply_rotary_pos_emb_qk_kernel( Q, K, diff --git a/lmdeploy/pytorch/models/llama.py b/lmdeploy/pytorch/models/llama.py index b113c8b1b..484292dc4 100644 --- a/lmdeploy/pytorch/models/llama.py +++ b/lmdeploy/pytorch/models/llama.py @@ -249,10 +249,18 @@ def __init__(self, origin: nn.Module, ctx_mgr: StepContextManager): elif rotary_name == 'LlamaDynamicNTKScalingRotaryEmbedding': emb_type = EmbeddingType.DynamicNTKScaling scaling_factor = getattr(rotary_emb, 'scaling_factor', 1.0) + config = origin.config + rope_dim = config.hidden_size // config.num_attention_heads + rope_max_pos_emb = config.max_position_embeddings + rope_base = config.rope_theta + scaling_factor = 1.0 + if config.rope_scaling is not None: + scaling_factor = config.rope_scaling.get('scaling_factor', + scaling_factor) self.rotary_emb = build_rotary_embedding( - rotary_emb.dim, - rotary_emb.max_position_embeddings, - rotary_emb.base, + rope_dim, + rope_max_pos_emb, + rope_base, scaling_factor, emb_type, ) From e6a3048b236c51c16e3c3042df89a3c6451ec6cb Mon Sep 17 00:00:00 2001 From: grimoire Date: Fri, 2 Aug 2024 14:36:03 +0800 Subject: [PATCH 22/89] fix ut --- tests/pytorch/paging/test_block_trie.py | 3 ++- tests/pytorch/paging/test_scheduler.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/paging/test_block_trie.py b/tests/pytorch/paging/test_block_trie.py index 557ee27d6..06829f4c7 100644 --- a/tests/pytorch/paging/test_block_trie.py +++ b/tests/pytorch/paging/test_block_trie.py @@ -23,7 +23,8 @@ def num_gpu_blocks(self): @pytest.fixture def cache_config(self, block_size, num_cpu_blocks, num_gpu_blocks): - yield CacheConfig(block_size=block_size, + yield CacheConfig(max_batches=256, + block_size=block_size, num_cpu_blocks=num_cpu_blocks, num_gpu_blocks=num_gpu_blocks, enable_prefix_caching=True) diff --git a/tests/pytorch/paging/test_scheduler.py b/tests/pytorch/paging/test_scheduler.py index adc77775e..2e7c1e1a0 100644 --- a/tests/pytorch/paging/test_scheduler.py +++ b/tests/pytorch/paging/test_scheduler.py @@ -22,7 +22,8 @@ def num_gpu_blocks(self): @pytest.fixture def cache_config(self, block_size, num_cpu_blocks, num_gpu_blocks): - yield CacheConfig(block_size=block_size, + yield CacheConfig(max_batches=256, + block_size=block_size, num_cpu_blocks=num_cpu_blocks, num_gpu_blocks=num_gpu_blocks) From 93d374683dbdf63bf06bf8a1c98da089605e955c Mon Sep 17 00:00:00 2001 From: grimoire Date: Mon, 5 Aug 2024 11:37:15 +0800 Subject: [PATCH 23/89] fix dynamic ntk cudagraph --- .../backends/default/rotary_embedding.py | 85 ++++++++++++------- 1 file changed, 55 insertions(+), 30 deletions(-) diff --git a/lmdeploy/pytorch/backends/default/rotary_embedding.py b/lmdeploy/pytorch/backends/default/rotary_embedding.py index eb8a52f11..dd4278ef8 100644 --- a/lmdeploy/pytorch/backends/default/rotary_embedding.py +++ b/lmdeploy/pytorch/backends/default/rotary_embedding.py @@ -6,6 +6,28 @@ RotaryEmbeddingImpl) +def _rotary_embedding_fwd(position_ids: torch.Tensor, inv_freq: torch.Tensor, + scaling_factor: float, dtype: torch.dtype, + device_type: torch.device): + """rotary embedding forward.""" + position_ids = position_ids.float() / scaling_factor + inv_freq_expanded = inv_freq[None, :, + None].float().expand(position_ids.shape[0], + -1, 1) + position_ids_expanded = position_ids[:, None, :] + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = device_type if isinstance( + device_type, str) and device_type != 'mps' else 'cpu' + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() + @ position_ids_expanded.float()).transpose(1, 2) + emb = freqs.repeat(1, 1, 2) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=dtype), sin.to(dtype=dtype) + + class RotaryEmbeddingImpl(RotaryEmbeddingImpl, nn.Module): """base rotary embedding.""" @@ -21,27 +43,17 @@ def __init__(self, 0, self.dim, 2, dtype=torch.int64).float() / self.dim)) self.register_buffer('inv_freq', inv_freq, persistent=False) - def forward(self, x, position_ids): + def forward(self, x: torch.Tensor, position_ids: torch.Tensor): """forward.""" - # x: [bs, num_attention_heads, seq_len, head_size] + device_type = x.device.type + dtype = x.dtype if self.inv_freq.device != x.device: self.inv_freq = self.inv_freq.to(x.device) - position_ids = position_ids.float() / self.scaling_factor - inv_freq_expanded = self.inv_freq[None, :, None].float().expand( - position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :] - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 - device_type = x.device.type - device_type = device_type if isinstance( - device_type, str) and device_type != 'mps' else 'cpu' - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() - @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + return _rotary_embedding_fwd(position_ids, + self.inv_freq, + scaling_factor=self.scaling_factor, + dtype=dtype, + device_type=device_type) class LlamaDynamicNTKScalingRotaryEmbedding(RotaryEmbeddingImpl): @@ -58,20 +70,33 @@ def __init__(self, super().__init__(dim, base, scaling_factor) self.max_position_embeddings = max_position_embeddings - def forward(self, x, position_ids): + def _ntk_inv_freq(self, seq_len: torch.Tensor): + """ntk_inv_freq.""" + device = seq_len.device + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - + (self.scaling_factor - 1))**(self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base**(torch.arange( + 0, self.dim, 2, dtype=torch.int64, device=device).float() / + self.dim)) + return inv_freq + + def forward(self, x: torch.Tensor, position_ids: torch.Tensor): """forward.""" + device_type = x.device.type + dtype = x.dtype seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_position_embeddings: - base = self.base * ((self.scaling_factor * seq_len / - self.max_position_embeddings) - - (self.scaling_factor - 1))**(self.dim / - (self.dim - 2)) - inv_freq = 1.0 / (base**(torch.arange( - 0, self.dim, 2, dtype=torch.int64).float().to(x.device) / - self.dim)) - self.register_buffer('inv_freq', inv_freq, persistent=False) - - cos, sin = super().forward(x, position_ids) + ntk_inv_freq = self._ntk_inv_freq(seq_len) + if self.inv_freq.device != x.device: + self.inv_freq = self.inv_freq.to(x.device) + inv_freq = torch.where(seq_len > self.max_position_embeddings, + ntk_inv_freq, self.inv_freq) + + cos, sin = _rotary_embedding_fwd(position_ids, + inv_freq, + scaling_factor=self.scaling_factor, + dtype=dtype, + device_type=device_type) return cos, sin From 2dfcc6fe4684b758c946b56724027ee7747180bc Mon Sep 17 00:00:00 2001 From: grimoire Date: Mon, 5 Aug 2024 17:12:07 +0800 Subject: [PATCH 24/89] add moe support --- lmdeploy/pytorch/backends/base.py | 2 + .../pytorch/backends/cuda/layer_backend.py | 3 + lmdeploy/pytorch/backends/cuda/moe.py | 94 +++ .../pytorch/backends/default/layer_backend.py | 3 + lmdeploy/pytorch/backends/default/moe.py | 31 + lmdeploy/pytorch/backends/moe.py | 49 ++ lmdeploy/pytorch/layers/moe.py | 38 ++ lmdeploy/pytorch/models/llama.py | 3 +- lmdeploy/pytorch/models/mixtral.py | 569 +++++++++--------- lmdeploy/pytorch/models/module_map.py | 18 +- 10 files changed, 513 insertions(+), 297 deletions(-) create mode 100644 lmdeploy/pytorch/backends/cuda/moe.py create mode 100644 lmdeploy/pytorch/backends/default/moe.py create mode 100644 lmdeploy/pytorch/backends/moe.py create mode 100644 lmdeploy/pytorch/layers/moe.py diff --git a/lmdeploy/pytorch/backends/base.py b/lmdeploy/pytorch/backends/base.py index 66b27821d..9e813782d 100644 --- a/lmdeploy/pytorch/backends/base.py +++ b/lmdeploy/pytorch/backends/base.py @@ -23,6 +23,8 @@ class LayerType(Enum): RMSNormW8A8 = auto() MultinomialSampling = auto() LinearW4A16 = auto() + SoftmaxTopK = auto() + FusedMoE = auto() class LayersBackend(ABC): diff --git a/lmdeploy/pytorch/backends/cuda/layer_backend.py b/lmdeploy/pytorch/backends/cuda/layer_backend.py index f7cbe5fd8..c14b9ffa1 100644 --- a/lmdeploy/pytorch/backends/cuda/layer_backend.py +++ b/lmdeploy/pytorch/backends/cuda/layer_backend.py @@ -53,6 +53,9 @@ def get_layer_impl_builder(cls, layer_type: LayerType): logger.debug( f'Op {layer_type} fallback to default implementation.') return super().get_layer_impl_builder(layer_type) + elif layer_type == LayerType.FusedMoE: + from .moe import TritonFusedMoEBuilder + return TritonFusedMoEBuilder else: logger.debug( f'Op {layer_type} fallback to default implementation.') diff --git a/lmdeploy/pytorch/backends/cuda/moe.py b/lmdeploy/pytorch/backends/cuda/moe.py new file mode 100644 index 000000000..e0b7771da --- /dev/null +++ b/lmdeploy/pytorch/backends/cuda/moe.py @@ -0,0 +1,94 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import torch + +from lmdeploy.pytorch.kernels.cuda import fused_moe + +from ..moe import FusedMoEBuilder, FusedMoEImpl + + +class TritonFusedMoEImpl(FusedMoEImpl): + """triton fused moe implementation.""" + + def __init__(self, + gate_up_weights: torch.Tensor, + down_weights: torch.Tensor, + top_k: int, + renormalize: bool = False): + super().__init__() + self.top_k = top_k + self.renormalize = renormalize + self.register_buffer('gate_up_weights', gate_up_weights) + self.register_buffer('down_weights', down_weights) + + def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.LongTensor): + """forward.""" + return fused_moe(hidden_states, + self.gate_up_weights, + self.down_weights, + topk_weights=topk_weights, + topk_ids=topk_ids, + topk=self.top_k, + renormalize=self.renormalize) + + +def _merge_mlp(gates: List[torch.Tensor], ups: List[torch.Tensor], + downs: List[torch.Tensor]): + """merge experts.""" + num_experts = len(gates) + + def __get_meta(): + gate = gates[0] + down = downs[0] + ffn_dim = gate.weight.size(0) + hidden_dim = down.weight.size(0) + dtype = gate.weight.dtype + device = gate.weight.device + return ffn_dim, hidden_dim, dtype, device + + def __copy_assign_param(param, weight): + """copy assign.""" + weight.copy_(param.data) + param.data = weight + + ffn_dim, hidden_dim, dtype, device = __get_meta() + + gate_up_weights = torch.empty(num_experts, + ffn_dim * 2, + hidden_dim, + device=device, + dtype=dtype) + down_weights = torch.empty(num_experts, + hidden_dim, + ffn_dim, + device=device, + dtype=dtype) + for exp_id in range(num_experts): + gate = gates[exp_id] + up = ups[exp_id] + down = downs[exp_id] + __copy_assign_param(gate.weight, gate_up_weights[exp_id, :ffn_dim]) + __copy_assign_param(up.weight, gate_up_weights[exp_id, ffn_dim:]) + __copy_assign_param(down.weight, down_weights[exp_id]) + + torch.cuda.empty_cache() + return gate_up_weights, down_weights + + +class TritonFusedMoEBuilder(FusedMoEBuilder): + """triton fused moe builder.""" + + @staticmethod + def build_from_mlp(gates: List[torch.Tensor], + ups: List[torch.Tensor], + downs: List[torch.Tensor], + top_k: int, + renormalize: bool = False): + """build from mlp.""" + gate_up_weights, down_weights = _merge_mlp(gates, ups, downs) + return TritonFusedMoEImpl(gate_up_weights, + down_weights, + top_k=top_k, + renormalize=renormalize) diff --git a/lmdeploy/pytorch/backends/default/layer_backend.py b/lmdeploy/pytorch/backends/default/layer_backend.py index 8b445db27..587c2be32 100644 --- a/lmdeploy/pytorch/backends/default/layer_backend.py +++ b/lmdeploy/pytorch/backends/default/layer_backend.py @@ -36,6 +36,9 @@ def get_layer_impl_builder(cls, layer_type: LayerType): elif layer_type == LayerType.LinearW4A16: from .awq_modules import DefaultLinearW4A16Builder return DefaultLinearW4A16Builder + elif layer_type == LayerType.SoftmaxTopK: + from .moe import DefaultSoftmaxTopKBuilder + return DefaultSoftmaxTopKBuilder else: raise RuntimeError(f'{layer_type} not supported.') diff --git a/lmdeploy/pytorch/backends/default/moe.py b/lmdeploy/pytorch/backends/default/moe.py new file mode 100644 index 000000000..d67626006 --- /dev/null +++ b/lmdeploy/pytorch/backends/default/moe.py @@ -0,0 +1,31 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch import nn + +from ..moe import SoftmaxTopKBuilder, SoftmaxTopKImpl + + +class DefaultSoftmaxTopKImpl(SoftmaxTopKImpl, nn.Module): + """RMS norm implementation api.""" + + def __init__(self, top_k: int, dim: int = -1): + super().__init__() + self.top_k = top_k + self.dim = dim + + def forward(self, x: torch.Tensor): + """forward.""" + routing_weights = torch.softmax(x, dim=self.dim, dtype=torch.float32) + topk_weights, topk_ids = torch.topk(routing_weights, + self.top_k, + dim=self.dim) + return topk_weights, topk_ids + + +class DefaultSoftmaxTopKBuilder(SoftmaxTopKBuilder): + """RMS norm implementation builder.""" + + @staticmethod + def build(top_k: int, dim: int = -1): + """build.""" + return DefaultSoftmaxTopKImpl(top_k, dim) diff --git a/lmdeploy/pytorch/backends/moe.py b/lmdeploy/pytorch/backends/moe.py new file mode 100644 index 000000000..269a3810d --- /dev/null +++ b/lmdeploy/pytorch/backends/moe.py @@ -0,0 +1,49 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABC, abstractmethod +from typing import List + +import torch +from torch import nn + + +class SoftmaxTopKImpl(ABC): + """Softmax topk implementation api.""" + + @abstractmethod + def forward(self, x: torch.Tensor): + """forward.""" + raise NotImplementedError + + +class SoftmaxTopKBuilder(ABC): + """Softmax topk implementation builder.""" + + @staticmethod + @abstractmethod + def build(top_k: int, dim: int = -1): + """build.""" + raise NotImplementedError + + +class FusedMoEImpl(ABC, nn.Module): + """fused moe implementation.""" + + @abstractmethod + def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.LongTensor): + """forward.""" + raise NotImplementedError + + +class FusedMoEBuilder(ABC): + """fused moe builder.""" + + @staticmethod + @abstractmethod + def build_from_mlp(gates: List[torch.Tensor], + ups: List[torch.Tensor], + downs: List[torch.Tensor], + top_k: int, + renormalize: bool = False): + """build from mlp.""" + raise NotImplementedError diff --git a/lmdeploy/pytorch/layers/moe.py b/lmdeploy/pytorch/layers/moe.py new file mode 100644 index 000000000..dfdf49d4d --- /dev/null +++ b/lmdeploy/pytorch/layers/moe.py @@ -0,0 +1,38 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import torch +from torch import nn + +from ..backends import LayerType, get_backend + + +class SoftmaxTopK(nn.Module): + """softmax topk.""" + + def __init__(self, top_k: int, dim: int = -1): + super().__init__() + self.top_k = top_k + impl_builder = get_backend().get_layer_impl_builder( + LayerType.SoftmaxTopK) + self.impl = impl_builder.build(top_k, dim) + + def forward(self, x: torch.Tensor): + """forward.""" + return self.impl.forward(x) + + +def build_moe_from_mlp( + gates: List[nn.Linear], + ups: List[nn.Linear], + downs: List[nn.Linear], + top_k: int, + renormalize: bool = False, +): + """build moe from mlp.""" + impl_builder = get_backend().get_layer_impl_builder(LayerType.FusedMoE) + return impl_builder.build_from_mlp(gates, + ups, + downs, + top_k=top_k, + renormalize=renormalize) diff --git a/lmdeploy/pytorch/models/llama.py b/lmdeploy/pytorch/models/llama.py index 484292dc4..cc33d5255 100644 --- a/lmdeploy/pytorch/models/llama.py +++ b/lmdeploy/pytorch/models/llama.py @@ -248,7 +248,6 @@ def __init__(self, origin: nn.Module, ctx_mgr: StepContextManager): emb_type = EmbeddingType.LinearScaling elif rotary_name == 'LlamaDynamicNTKScalingRotaryEmbedding': emb_type = EmbeddingType.DynamicNTKScaling - scaling_factor = getattr(rotary_emb, 'scaling_factor', 1.0) config = origin.config rope_dim = config.hidden_size // config.num_attention_heads rope_max_pos_emb = config.max_position_embeddings @@ -340,7 +339,7 @@ def get_input_embeddings(self): def prepare_inputs_for_generation( self, past_key_values: List[List[torch.Tensor]], - inputs_embeds: torch.Tensor = None, + inputs_embeds: Optional[torch.Tensor] = None, context: StepContext = None, ): """prepare input.""" diff --git a/lmdeploy/pytorch/models/mixtral.py b/lmdeploy/pytorch/models/mixtral.py index b1be9b81d..84807d543 100644 --- a/lmdeploy/pytorch/models/mixtral.py +++ b/lmdeploy/pytorch/models/mixtral.py @@ -1,162 +1,128 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union import torch import torch.distributed as dist from torch import nn from transformers.modeling_outputs import BaseModelOutputWithPast -from ..kernels import apply_rotary_pos_emb, fill_kv_cache, paged_attention_fwd -from ..kernels.fused_moe import fused_moe +from lmdeploy.pytorch.layers import (ApplyRotaryEmb, Attention, EmbeddingType, + RMSNorm, build_rotary_embedding) +from lmdeploy.pytorch.layers.linear import (build_colwise_linear, + build_merged_colwise_linear, + build_rowwise_linear) +from lmdeploy.pytorch.layers.moe import SoftmaxTopK, build_moe_from_mlp +from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager + from ..weight_loader.dist_utils import (colwise_parallelize_linear, rowwise_parallelize_linear) -class PatchedMixtralAttention(nn.Module): - """Rewrite module of MixtralAttention.""" +class MixtralAttention(nn.Module): + """mixtral attention.""" - def _load_weights(self, loader, rank: int, world_size: int, + def __init__(self, origin: nn.Module, ctx_mgr: StepContextManager): + super().__init__() + world_size = 1 + if dist.is_initialized(): + world_size = dist.get_world_size() + is_tp = world_size > 1 + self.ctx_mgr = ctx_mgr + self.num_heads = origin.num_heads // world_size + self.num_kv_heads = origin.num_key_value_heads // world_size + self.head_dim = origin.head_dim + + # qkv + self.qkv_proj = build_merged_colwise_linear( + origin.q_proj, + origin.k_proj, + origin.v_proj, + ctx_mgr=ctx_mgr, + is_tp=is_tp, + ) + del origin.q_proj, origin.k_proj, origin.v_proj + + self.apply_rotary_pos_emb = ApplyRotaryEmb() + + # attention + self.window_size = origin.config.sliding_window or -1 + self.attn_fwd = Attention( + self.num_heads, + self.head_dim, + num_kv_heads=self.num_kv_heads, + v_head_size=self.head_dim, + sliding_window=self.window_size, + ) + + self.o_proj = build_rowwise_linear( + origin.o_proj, + ctx_mgr=ctx_mgr, + is_tp=is_tp, + ) + + @staticmethod + def _load_weights(mod, loader, rank: int, world_size: int, device: torch.device): """load weights.""" for mod_name in ['q_proj', 'k_proj', 'v_proj']: - colwise_parallelize_linear(getattr(self, mod_name), + colwise_parallelize_linear(getattr(mod, mod_name), loader, rank=rank, world_size=world_size, prefix=mod_name) - rowwise_parallelize_linear(self.o_proj, + rowwise_parallelize_linear(mod.o_proj, loader, rank=rank, world_size=world_size, prefix='o_proj') - @classmethod - def _distribute_output_fn(cls, outputs, **kwargs): - """Distribution output hook.""" - dist.all_reduce(outputs[0]) - return outputs - - def _contiguous_batching_forward_impl( + def forward( self, hidden_states: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - attention_mask: Optional[torch.Tensor] = None, - world_size: int = 1, + attn_metadata: Any = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """default rewrite.""" - - context = self.context.context - kv_seq_length = context.kv_seq_length - q_seq_length = context.q_seq_length - q_start_loc = context.q_start_loc - block_offsets = context.block_offsets - max_q_seq_length = context.max_q_seq_length - max_kv_seq_length = context.max_kv_seq_length - - num_heads = self.num_heads // world_size - num_kv_heads = self.num_key_value_heads // world_size - hidden_size = num_heads * self.head_dim - - def __qkv_proj(hidden_states): - """qkv proj.""" - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - return query_states, key_states, value_states - - def __rotary_emb_fn(query_states, key_states, value_states): - if hasattr(self, 'rotary_emb'): - if not hasattr(context, '_cos'): - cos, sin = self.rotary_emb(value_states, - seq_len=max_kv_seq_length) - context._cos = cos - context._sin = sin - else: - cos = context._cos - sin = context._sin - query_states, key_states = apply_rotary_pos_emb( - query_states, - key_states, - cos, - sin, - position_ids, - context.position_ids_1d, - q_embed=query_states, - k_embed=key_states) - return query_states, key_states, value_states - - query_states, key_states, value_states = __qkv_proj(hidden_states) - - query_states = query_states.view(-1, num_heads, self.head_dim) - key_states = key_states.view(-1, num_kv_heads, self.head_dim) - value_states = value_states.view(-1, num_kv_heads, self.head_dim) - - query_states, key_states, value_states = __rotary_emb_fn( - query_states, key_states, value_states) - # fill kv cache - fill_kv_cache( + """Rewrite of LlamaAttention.forward.""" + qkv_states = self.qkv_proj(hidden_states) + # (-1, heads, head_dim) + qkv_states = qkv_states.flatten(0, -2) + qkv_states = qkv_states.unflatten(-1, (-1, self.head_dim)) + query_states, key_states, value_states = qkv_states.split( + ( + self.num_heads, + self.num_kv_heads, + self.num_kv_heads, + ), + dim=1, + ) + + cos, sin = rotary_pos_emb + query_states, key_states = self.apply_rotary_pos_emb( + query_states, key_states, - value_states, - past_key_value[0], - past_key_value[1], - q_start_loc, - q_seq_length, - kv_seq_length=kv_seq_length, - max_q_seq_length=max_q_seq_length, - block_offsets=block_offsets, + cos, + sin, + inplace=True, ) - # page attention - attn_output = query_states - window_size = self.config.sliding_window or -1 - paged_attention_fwd( + attn_output = self.attn_fwd( query_states, + key_states, + value_states, past_key_value[0], past_key_value[1], - attn_output, - block_offsets, - q_start_loc=q_start_loc, - q_seqlens=q_seq_length, - kv_seqlens=kv_seq_length, - max_seqlen=max_q_seq_length, - window_size=window_size, + attn_metadata, + inplace=True, ) - - attn_output = attn_output.reshape(*hidden_states.shape[:-1], - hidden_size) + attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1) attn_output = self.o_proj(attn_output) - return attn_output, None, past_key_value - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - """Rewrite of MistralAttention.forward.""" - world_size = 1 - if dist.is_initialized(): - world_size = dist.get_world_size() - return self._contiguous_batching_forward_impl( - hidden_states, - position_ids, - past_key_value, - output_attentions, - attention_mask=attention_mask, - world_size=world_size, - ) + return attn_output -class PatchedMixtralBLockSparseTop2MLP(nn.Module): +class MixtralBLockSparseTop2MLP(nn.Module): def _load_weights(self, loader, rank: int, world_size: int, device: torch.device): @@ -173,198 +139,233 @@ def _load_weights(self, loader, rank: int, world_size: int, world_size=world_size, prefix='w2') - @classmethod - def _distribute_output_fn(cls, outputs, **kwargs): - """Distribution output hook.""" - dist.all_reduce(outputs) - return outputs - - -class PatchedMixtralSparseMoeBlock(nn.Module): - - def _update_model_fn(self): - """update model.""" - num_experts = self.num_experts - - def __get_meta(): - exp = self.experts[0] - ffn_dim = exp.w1.weight.size(0) - hidden_dim = exp.w2.weight.size(0) - dtype = exp.w1.weight.dtype - device = exp.w1.weight.device - return ffn_dim, hidden_dim, dtype, device - - def __copy_assign_param(param, weight): - """copy assign.""" - weight.copy_(param.data) - param.data = weight - - ffn_dim, hidden_dim, dtype, device = __get_meta() - - gate_up_weights = torch.empty(num_experts, - ffn_dim * 2, - hidden_dim, - device=device, - dtype=dtype) - down_weights = torch.empty(num_experts, - hidden_dim, - ffn_dim, - device=device, - dtype=dtype) - for exp_id, exp in enumerate(self.experts): - __copy_assign_param(exp.w1.weight, - gate_up_weights[exp_id, :ffn_dim]) - __copy_assign_param(exp.w3.weight, gate_up_weights[exp_id, - ffn_dim:]) - __copy_assign_param(exp.w2.weight, down_weights[exp_id]) - - torch.cuda.empty_cache() - - self.register_buffer('gate_up_weights', gate_up_weights) - self.register_buffer('down_weights', down_weights) - - @classmethod - def _distribute_output_fn(cls, outputs, **kwargs): - """Distribution output hook.""" - dist.all_reduce(outputs[0]) - return outputs - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - """rewrite moe forward.""" +class MixtralSparseMoeBlock(nn.Module): + """mixtral sparse moe block.""" + def __init__(self, origin: nn.Module, ctx_mgr: StepContextManager): + super().__init__() + world_size = 1 + if dist.is_initialized(): + world_size = dist.get_world_size() + is_tp = world_size > 1 + self.is_tp = is_tp + self.top_k = origin.top_k + self.gate = build_colwise_linear( + origin.gate, + ctx_mgr=ctx_mgr, + is_tp=is_tp, + ) + self.softmax_topk = SoftmaxTopK(self.top_k) + + gates = [exp.w1 for exp in origin.experts] + ups = [exp.w3 for exp in origin.experts] + downs = [exp.w2 for exp in origin.experts] + self.fused_moe = build_moe_from_mlp(gates, + ups, + downs, + top_k=self.top_k, + renormalize=True) + + def forward(self, hidden_states: torch.Tensor): + """forward.""" batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) router_logits = self.gate(hidden_states) - routing_weights = torch.softmax(router_logits, - dim=-1, - dtype=torch.float32) - topk_weights, topk_ids = torch.topk(routing_weights, - self.top_k, - dim=-1) - del routing_weights - out_states = fused_moe(hidden_states, - self.gate_up_weights, - self.down_weights, - topk_weights, - topk_ids, - topk=self.top_k, - renormalize=True) + topk_weights, topk_ids = self.softmax_topk(router_logits) + out_states = self.fused_moe( + hidden_states, + topk_weights, + topk_ids, + ) out_states = out_states.reshape(batch_size, sequence_length, -1) + + if self.is_tp: + dist.all_reduce(out_states) return out_states, router_logits -class PatchedMixtralModel(nn.Module): +class MixtralDecoderLayer(nn.Module): + """mixtral decoder layer.""" + + def __init__(self, origin: nn.Module, layer_idx: int, + ctx_mgr: StepContextManager): + super().__init__() + self.layer_idx = layer_idx + self.self_attn = MixtralAttention(origin.self_attn, ctx_mgr) + self.block_sparse_moe = MixtralSparseMoeBlock(origin.block_sparse_moe, + ctx_mgr) + + # norm + input_layernorm = origin.input_layernorm + is_w8a8 = hasattr(input_layernorm, 'from_float') + self.input_layernorm = RMSNorm( + input_layernorm.weight, + input_layernorm.variance_epsilon, + is_w8a8=is_w8a8, + ) + post_attention_layernorm = origin.post_attention_layernorm + is_w8a8 = hasattr(post_attention_layernorm, 'from_float') + self.post_attention_layernorm = RMSNorm( + post_attention_layernorm.weight, + post_attention_layernorm.variance_epsilon, + is_w8a8=is_w8a8, + ) - def _continuous_batching_forward( + def forward( + self, + hidden_states: torch.Tensor, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], + past_key_value: Optional[List[torch.FloatTensor]], + residual: Optional[torch.Tensor] = None, + attn_metadata: Any = None, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + # Self Attention + hidden_states = self.self_attn( + hidden_states=hidden_states, + rotary_pos_emb=rotary_pos_emb, + past_key_value=past_key_value, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states, _ = self.block_sparse_moe(hidden_states) + + outputs = (hidden_states, residual) + return outputs + + +class MixtralModel(nn.Module): + """mixtral model.""" + + def __init__(self, origin: nn.Module, ctx_mgr: StepContextManager): + super().__init__() + self.ctx_mgr = ctx_mgr + self.embed_tokens = origin.embed_tokens + self.layers = nn.ModuleList([ + MixtralDecoderLayer(layer, idx, ctx_mgr) + for idx, layer in enumerate(origin.layers) + ]) + norm = origin.norm + is_w8a8 = hasattr(norm, 'from_float') + self.norm = RMSNorm(norm.weight, + norm.variance_epsilon, + is_w8a8=is_w8a8) + + emb_type = EmbeddingType.LinearScaling + config = origin.config + rope_dim = config.hidden_size // config.num_attention_heads + rope_max_pos_emb = config.max_position_embeddings + rope_base = config.rope_theta + scaling_factor = 1.0 + self.rotary_emb = build_rotary_embedding( + rope_dim, + rope_max_pos_emb, + rope_base, + scaling_factor, + emb_type, + ) + + def forward( self, input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + attn_metadata: Any = None, inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: - """Rewrite implementation of LlamaModel.forward.""" + """Rewrite of LlamaModel.forward.""" + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) - from transformers.modeling_outputs import MoeModelOutputWithPast + hidden_states = inputs_embeds + residual = None + cos, sin = self.rotary_emb(hidden_states, position_ids) + cos, sin = cos[0], sin[0] + rotary_pos_emb = (cos, sin) + for idx, decoder_layer in enumerate(self.layers): - output_attentions = (output_attentions if output_attentions is not None - else self.config.output_attentions) - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else - self.config.output_hidden_states) + past_key_value = past_key_values[idx] + hidden_states, residual = decoder_layer( + hidden_states, + rotary_pos_emb=rotary_pos_emb, + past_key_value=past_key_value, + residual=residual, + attn_metadata=attn_metadata, + ) - if use_cache is None: - use_cache = self.config.use_cache + hidden_states, _ = self.norm(hidden_states, residual) - return_dict = (return_dict if return_dict is not None else - self.config.use_return_dict) + return hidden_states - assert ( - position_ids is not None - ), 'position_ids can not be none when using continuous batching mode.' - assert position_ids.dim() == 2 + def get_input_embeddings(self): + """get input embeddings.""" + return self.embed_tokens - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - # Attention mask is not necessary in continuous batching - attention_mask = None +class MixtralForCausalLM(nn.Module): + """mixture model for causalLM.""" - hidden_states = inputs_embeds + support_cuda_graph = True - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None + def __init__(self, origin: nn.Module, ctx_mgr: StepContextManager): + super().__init__() + self.ctx_mgr = ctx_mgr + self.model = MixtralModel(origin.model, ctx_mgr) + self.lm_head = build_rowwise_linear(origin.lm_head) - for idx, decoder_layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states, ) + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: List[List[torch.Tensor]], + attn_metadata: Any = None, + inputs_embeds: torch.Tensor = None, + **kwargs, + ): + hidden_states = self.model( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + ) - past_key_value = (past_key_values[idx] - if past_key_values is not None else None) - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += ( - layer_outputs[2 if output_attentions else 1], ) - - if output_attentions: - all_self_attns += (layer_outputs[1], ) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states, ) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple( - v for v in - [hidden_states, next_cache, all_hidden_states, all_self_attns] - if v is not None) - - return MoeModelOutputWithPast(last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - router_logits='') - - def forward(self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - **kwargs) -> Union[Tuple, BaseModelOutputWithPast]: - """Rewrite of LlamaModel.forward.""" - return self._continuous_batching_forward( - input_ids, - attention_mask, - position_ids, - past_key_values, - inputs_embeds, - use_cache, - output_attentions, - output_hidden_states, - return_dict, + logits = self.lm_head(hidden_states) + logits = logits.float() + return logits + + def get_input_embeddings(self): + """get input embeddings.""" + return self.model.get_input_embeddings() + + def prepare_inputs_for_generation( + self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + context: StepContext = None, + ): + """prepare input.""" + input_ids = context.input_ids + position_ids = context.position_ids + attn_metadata = context.attn_metadata + + return dict( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, ) diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index fb3cde714..09c5b6369 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -225,21 +225,17 @@ # mixtral MODULE_MAP.update({ 'transformers.models.mixtral.modeling_mixtral.MixtralAttention': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mixtral.PatchedMixtralAttention', + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mixtral.MixtralAttention', 'transformers.models.mixtral.modeling_mixtral.MixtralFlashAttention2': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mixtral.PatchedMixtralAttention', + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mixtral.MixtralAttention', 'transformers.models.mixtral.modeling_mixtral.MixtralSdpaAttention': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mixtral.PatchedMixtralAttention', - 'transformers.models.mixtral.modeling_mixtral.MixtralModel': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mixtral.PatchedMixtralModel', + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mixtral.MixtralAttention', 'transformers.models.mixtral.modeling_mixtral.MixtralBLockSparseTop2MLP': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mixtral.PatchedMixtralBLockSparseTop2MLP', + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mixtral.MixtralBLockSparseTop2MLP', 'transformers.models.mixtral.modeling_mixtral.MixtralBlockSparseTop2MLP': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mixtral.PatchedMixtralBLockSparseTop2MLP', - 'transformers.models.mixtral.modeling_mixtral.MixtralRMSNorm': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaRMSNorm', - 'transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mixtral.PatchedMixtralSparseMoeBlock', + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mixtral.MixtralBLockSparseTop2MLP', + 'transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mixtral.MixtralForCausalLM', }) # dbrx From 93c64ee06bb57053bbd23799ca5b3c8bcf83f80c Mon Sep 17 00:00:00 2001 From: grimoire Date: Tue, 6 Aug 2024 15:40:16 +0800 Subject: [PATCH 25/89] add custom module support --- lmdeploy/messages.py | 1 + lmdeploy/pytorch/config.py | 1 + lmdeploy/pytorch/engine/engine.py | 7 ++-- lmdeploy/pytorch/engine/model_agent.py | 42 +++++++---------------- lmdeploy/pytorch/models/module_map.py | 2 ++ lmdeploy/pytorch/models/patch.py | 46 +++++++++++++++++++++++++- 6 files changed, 65 insertions(+), 34 deletions(-) diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index a4054449d..2fdcc35dd 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -213,6 +213,7 @@ class PytorchEngineConfig: enable_prefix_caching: bool = False device_type: str = 'cuda' eager_mode: bool = False + custom_module_map: str = None download_dir: str = None revision: str = None diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index d0d7ea368..652d93864 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -102,6 +102,7 @@ class ModelConfig: unused_modules: List[str] = None auto_model_cls: Any = AutoModelForCausalLM cogvlm_style: bool = False + custom_module_map: str = None def get_head_size(self): """get head size.""" diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 78c81f471..7889a4bc0 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -20,7 +20,7 @@ from ..model_inputs import AdapterInfo, ModelInputs, VisionModelInputs from ..paging import Scheduler from .logits_process import FusedLogitsProcessor, SamplingInputs -from .model_agent import AutoModelAgent +from .model_agent import AutoModelAgent, build_model_agent from .request import Request, RequestManager, RequestType, Response logger = get_logger('lmdeploy') @@ -147,13 +147,14 @@ def __init__(self, ) with get_device_manager().context(self.device_context): - self.model_agent = AutoModelAgent.from_pretrained( + self.model_agent = build_model_agent( model_path, cache_config=cache_config, backend_config=backend_config, trust_remote_code=trust_remote_code, adapters=adapters, - tp=tp) + tp=tp, + custom_module_map=engine_config.custom_module_map) cache_config = self.model_agent.cache_config self.adapter_manager = self._build_adapter_manager(adapters) diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 41a0f47c2..a18a72483 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -19,7 +19,7 @@ from ..config import BackendConfig, CacheConfig, ModelConfig from ..devices import DeviceContext, get_device_manager from ..model_inputs import ModelInputs -from ..models.patch import patch, update_model +from ..models.patch import patch, update_custom_module_map, update_model from ..utils import get_gpu_memory from ..weight_loader.model_weight_loader import load_model_weights from .cache_engine import CacheEngine @@ -257,22 +257,6 @@ def forward(self, inputs: ModelInputs, swap_in_map: SwapMap, """ raise NotImplementedError('Not implemented.') - @classmethod - def from_pretrained(cls, - pretrained_model_name_or_path: str, - cache_config: CacheConfig, - backend_config: BackendConfig, - trust_remote_code: bool, - adapters: Dict[str, str] = None, - tp: int = 1): - """from pretrained.""" - return build_model_agent(pretrained_model_name_or_path, - cache_config=cache_config, - backend_config=backend_config, - trust_remote_code=trust_remote_code, - adapters=adapters, - tp=tp) - class BaseModelAgent(AutoModelAgent): """Base model agent. @@ -341,6 +325,9 @@ def _build_model(self, if adapters: _load_adapters(hf_model, adapters) + custom_module_map = self.model_config.custom_module_map + if custom_module_map is not None: + update_custom_module_map(custom_module_map) patched_model = update_model(hf_model) return patched_model @@ -422,16 +409,6 @@ async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, return output -def _get_model_memory_usage(model: torch.nn.Module) -> int: - """get model memory usage.""" - size = 0 - for _, param in model.named_parameters(): - size += param.element_size() * param.numel() - for _, buf in model.named_buffers(): - size += buf.element_size() * param.numel() - return size - - def _create_device_map(model: torch.nn.Module, world_size: int, device_map: dict = None): @@ -460,7 +437,7 @@ def _tp_build_model( backend_config: BackendConfig, adapters: Dict[str, str], world_size: int, - trust_remote_code=True, + trust_remote_code: bool = True, ): """build tensor parallel model.""" from accelerate import init_empty_weights @@ -513,6 +490,9 @@ def _broadcast_config(cache_config): model.eval() model.config.use_cache = True + custom_module_map = model_config.custom_module_map + if custom_module_map is not None: + update_custom_module_map(custom_module_map) patched_model = patch(model) load_model_weights(patched_model, model_path, @@ -615,7 +595,7 @@ def _tp_model_loop( backend_config: BackendConfig, adapters: Dict[str, str], world_size: int, - trust_remote_code=True, + trust_remote_code: bool = True, ): """Start model loops for tensor parallel model inference. @@ -960,10 +940,12 @@ def build_model_agent(model_path: str, backend_config: BackendConfig, trust_remote_code: bool, adapters: Dict[str, str] = None, - tp: int = 1): + tp: int = 1, + custom_module_map: str = None): """create model agent.""" model_config = ModelConfig.from_pretrained( model_path, trust_remote_code=trust_remote_code) + model_config.custom_module_map = custom_module_map if tp == 1: model_agent = BaseModelAgent(model_path, model_config=model_config, diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index 09c5b6369..ef1f85420 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -383,3 +383,5 @@ 'modeling_internlm2.InternLM2FlashAttention2': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.internlm2.PatchedInternLM2AttentionAscend', }) + +CUSTOM_MODULE_MAP = dict() diff --git a/lmdeploy/pytorch/models/patch.py b/lmdeploy/pytorch/models/patch.py index 08f3198a9..ed0284c66 100644 --- a/lmdeploy/pytorch/models/patch.py +++ b/lmdeploy/pytorch/models/patch.py @@ -1,7 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. import importlib import inspect +import os.path as osp import re +import sys from copy import copy from typing import Any, Dict @@ -10,7 +12,8 @@ from lmdeploy.utils import get_logger from ..devices import get_device_manager -from .module_map import DEVICE_SPECIAL_MODULE_MAP, MODULE_MAP +from .module_map import (CUSTOM_MODULE_MAP, DEVICE_SPECIAL_MODULE_MAP, + MODULE_MAP) logger = get_logger('lmdeploy') @@ -157,6 +160,8 @@ def _get_module_map(): if device_type != 'cuda': device_map = DEVICE_SPECIAL_MODULE_MAP.get(device_type, dict()) module_map.update(device_map) + # add custom module map + module_map.update(CUSTOM_MODULE_MAP) return module_map @@ -191,3 +196,42 @@ def update_model(model: torch.nn.Module): model_cls = _class_from_qualname(rewrite_qualname) return model_cls(model, ctx_mgr) + + +def update_custom_module_map(module_map_path: str): + """moad custom module map from file.""" + from importlib.machinery import SourceFileLoader + + from lmdeploy.pytorch.models.module_map import LMDEPLOY_PYTORCH_MODEL_PATH + assert osp.exists(module_map_path), ( + f'custom module map path: "{module_map_path}" not exists.') + + module_map_path = osp.abspath(module_map_path) + folder = osp.split(module_map_path)[0] + sys.path.append(folder) + custom_mod = SourceFileLoader('map_mod', module_map_path).load_module() + sys.modules[f'{LMDEPLOY_PYTORCH_MODEL_PATH}._custom_mod'] = custom_mod + + new_mod_map = dict() + has_map = False + if hasattr(custom_mod, 'MODULE_MAP'): + has_map = True + mod_map = custom_mod.MODULE_MAP + assert isinstance(mod_map, Dict) + new_mod_map.update(mod_map) + + if hasattr(custom_mod, 'CUSTOM_MODULE_MAP'): + has_map = True + mod_map = custom_mod.CUSTOM_MODULE_MAP + assert isinstance(mod_map, Dict) + new_mod_map.update(mod_map) + + if not has_map: + raise RuntimeError(f'Found no map in "{module_map_path}".') + + for k, v in new_mod_map.items(): + if '.' not in v: + v = f'{LMDEPLOY_PYTORCH_MODEL_PATH}._custom_mod.{v}' + new_mod_map[k] = v + + CUSTOM_MODULE_MAP.update(new_mod_map) From 362263578d754b806806f3fba407d490864c87dc Mon Sep 17 00:00:00 2001 From: grimoire Date: Wed, 14 Aug 2024 20:55:22 +0800 Subject: [PATCH 26/89] optimize awq kernel --- lmdeploy/pytorch/backends/cuda/awq_modules.py | 10 +- lmdeploy/pytorch/kernels/cuda/awq_kernels.py | 454 ++++++++++++++---- lmdeploy/pytorch/kernels/cuda/fused_moe.py | 17 +- 3 files changed, 373 insertions(+), 108 deletions(-) diff --git a/lmdeploy/pytorch/backends/cuda/awq_modules.py b/lmdeploy/pytorch/backends/cuda/awq_modules.py index f46dce6e9..3198590a2 100644 --- a/lmdeploy/pytorch/backends/cuda/awq_modules.py +++ b/lmdeploy/pytorch/backends/cuda/awq_modules.py @@ -21,20 +21,18 @@ def wq_gemm_forward( """wq gemm forward.""" from awq.modules.linear.gemm import awq_ext - from lmdeploy.pytorch.kernels.cuda.awq_kernels import \ - awq_dequantize_weights + from lmdeploy.pytorch.kernels.cuda.awq_kernels import awq_linear out_shape = x.shape[:-1] + (out_features, ) input_dtype = x.dtype if input_dtype != torch.float16: x = x.half() - FP16_MATMUL_HEURISTIC_CONDITION = x.size(0) * x.size(1) >= 1024 + FP16_MATMUL_HEURISTIC_CONDITION = x.size(0) * x.size(1) >= 64 + x = x.flatten(0, -2) if FP16_MATMUL_HEURISTIC_CONDITION: - out = awq_dequantize_weights(qweight, scales, qzeros) - out = torch.matmul(x, out) + out = awq_linear(x, qweight, scales, qzeros) else: - x = x.flatten(0, -2) if not x.is_contiguous(): x = x.contiguous() out = awq_ext.gemm_forward_cuda(x, qweight, scales, qzeros, 8) diff --git a/lmdeploy/pytorch/kernels/cuda/awq_kernels.py b/lmdeploy/pytorch/kernels/cuda/awq_kernels.py index 9475e93c6..13b9841e9 100644 --- a/lmdeploy/pytorch/kernels/cuda/awq_kernels.py +++ b/lmdeploy/pytorch/kernels/cuda/awq_kernels.py @@ -1,120 +1,384 @@ # Copyright (c) OpenMMLab. All rights reserved. import triton -import triton.language as tl -from torch import Tensor +from triton import language as tl from .triton_utils import get_kernel_meta, wrap_jit_func +def get_cuda_autotune_config(): + return [ + # most used + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 64, + 'GROUP_SIZE_M': 8 + }, + num_stages=4, + num_warps=4), + # # other + # triton.Config( + # { + # 'BLOCK_SIZE_M': 128, + # 'BLOCK_SIZE_N': 256, + # 'BLOCK_SIZE_K': 64, + # 'GROUP_SIZE_M': 8 + # }, + # num_stages=3, + # num_warps=8), + # triton.Config( + # { + # 'BLOCK_SIZE_M': 64, + # 'BLOCK_SIZE_N': 256, + # 'BLOCK_SIZE_K': 32, + # 'GROUP_SIZE_M': 8 + # }, + # num_stages=4, + # num_warps=4), + # triton.Config( + # { + # 'BLOCK_SIZE_M': 128, + # 'BLOCK_SIZE_N': 128, + # 'BLOCK_SIZE_K': 32, + # 'GROUP_SIZE_M': 8 + # }, + # num_stages=4, + # num_warps=4), + # triton.Config( + # { + # 'BLOCK_SIZE_M': 64, + # 'BLOCK_SIZE_N': 128, + # 'BLOCK_SIZE_K': 32, + # 'GROUP_SIZE_M': 8 + # }, + # num_stages=4, + # num_warps=4), + # triton.Config( + # { + # 'BLOCK_SIZE_M': 128, + # 'BLOCK_SIZE_N': 32, + # 'BLOCK_SIZE_K': 32, + # 'GROUP_SIZE_M': 8 + # }, + # num_stages=4, + # num_warps=4), + # triton.Config( + # { + # 'BLOCK_SIZE_M': 64, + # 'BLOCK_SIZE_N': 32, + # 'BLOCK_SIZE_K': 32, + # 'GROUP_SIZE_M': 8 + # }, + # num_stages=5, + # num_warps=2), + # triton.Config( + # { + # 'BLOCK_SIZE_M': 32, + # 'BLOCK_SIZE_N': 64, + # 'BLOCK_SIZE_K': 32, + # 'GROUP_SIZE_M': 8 + # }, + # num_stages=5, + # num_warps=2), + # # Good config for fp8 inputs. + # triton.Config( + # { + # 'BLOCK_SIZE_M': 128, + # 'BLOCK_SIZE_N': 256, + # 'BLOCK_SIZE_K': 128, + # 'GROUP_SIZE_M': 8 + # }, + # num_stages=3, + # num_warps=8), + # triton.Config( + # { + # 'BLOCK_SIZE_M': 256, + # 'BLOCK_SIZE_N': 128, + # 'BLOCK_SIZE_K': 128, + # 'GROUP_SIZE_M': 8 + # }, + # num_stages=3, + # num_warps=8), + # triton.Config( + # { + # 'BLOCK_SIZE_M': 256, + # 'BLOCK_SIZE_N': 64, + # 'BLOCK_SIZE_K': 128, + # 'GROUP_SIZE_M': 8 + # }, + # num_stages=4, + # num_warps=4), + # triton.Config( + # { + # 'BLOCK_SIZE_M': 64, + # 'BLOCK_SIZE_N': 256, + # 'BLOCK_SIZE_K': 128, + # 'GROUP_SIZE_M': 8 + # }, + # num_stages=4, + # num_warps=4), + # triton.Config( + # { + # 'BLOCK_SIZE_M': 128, + # 'BLOCK_SIZE_N': 128, + # 'BLOCK_SIZE_K': 128, + # 'GROUP_SIZE_M': 8 + # }, + # num_stages=4, + # num_warps=4), + # triton.Config( + # { + # 'BLOCK_SIZE_M': 128, + # 'BLOCK_SIZE_N': 64, + # 'BLOCK_SIZE_K': 64, + # 'GROUP_SIZE_M': 8 + # }, + # num_stages=4, + # num_warps=4), + # triton.Config( + # { + # 'BLOCK_SIZE_M': 128, + # 'BLOCK_SIZE_N': 32, + # 'BLOCK_SIZE_K': 64, + # 'GROUP_SIZE_M': 8 + # }, + # num_stages=4, + # num_warps=4), + ] + + @triton.jit -def _get_unpacked_order(offs_n, elem_per_int): +def _get_unpacked_order(offs_n, elem_per_int: tl.constexpr): """get unpacked order.""" origin_order = offs_n % elem_per_int - unpacked_order = (origin_order % 2) * 4 + origin_order // 2 + unpacked_order = (origin_order & 1) * 4 + origin_order // 2 return unpacked_order +@triton.jit +def _broadcast_pack(weight, width: tl.constexpr): + """broadcast pack.""" + broadcast_tmp = tl.arange(0, width) + BLOCK_SIZE_K: tl.constexpr = weight.shape[0] + BLOCK_SIZE_QN: tl.constexpr = weight.shape[1] + BLOCK_SIZE_N: tl.constexpr = BLOCK_SIZE_QN * width + weight = tl.broadcast(weight[:, :, None], broadcast_tmp[None, None, :]) + weight = tl.reshape(weight, (BLOCK_SIZE_K, BLOCK_SIZE_N)) + return weight + + @triton.jit def _unpack_weight(weight, order): """unpack weight.""" + weight = _broadcast_pack(weight, 8) weight = weight >> (order * 4) - weight = weight & 0xf + # cast to float16 + immLut = (0xf0 & 0xcc) | 0xaa + BOTTOM_MASK = 0xf + I4s_TO_F16s_MAGIC_NUM = 0x6400 + FP16_TOP_MAGIC_NUM = 0x6400 + weight = tl.inline_asm_elementwise( + """lop3.b32 $1, $1, $2, $3, $4; + sub.f16x2 $1, $1, $5; + mov.b32 {$0, _}, $1;""", + '=h, r, n, n, n, r', [ + weight, BOTTOM_MASK, I4s_TO_F16s_MAGIC_NUM, immLut, + FP16_TOP_MAGIC_NUM + ], + dtype=tl.float16, + is_pure=False, + pack=1) return weight +@triton.autotune( + configs=get_cuda_autotune_config(), + key=['M_NEXT_P2', 'N', 'K'], +) @wrap_jit_func @triton.jit -def _dequantize_weights_kernel( - QWeight, - Scales, - Qzeros, - Out, - in_size: tl.constexpr, - out_size: tl.constexpr, - group_size: tl.constexpr, - elem_per_int: tl.constexpr, - stride_wk: tl.constexpr, - stride_wn: tl.constexpr, - stride_sk: tl.constexpr, - stride_sn: tl.constexpr, - stride_zk: tl.constexpr, - stride_zn: tl.constexpr, - stride_ok: tl.constexpr, - stride_on: tl.constexpr, - BLOCK_K: tl.constexpr, - BLOCK_N: tl.constexpr, +def awq_linear_kernel( + a_ptr, + qw_ptr, + s_ptr, + qz_ptr, + c_ptr, + M, + N: tl.constexpr, + K: tl.constexpr, + stride_am, + stride_ak: tl.constexpr, # + stride_wk: tl.constexpr, + stride_wn: tl.constexpr, # + stride_sk: tl.constexpr, + stride_sn: tl.constexpr, # + stride_zk: tl.constexpr, + stride_zn: tl.constexpr, # + stride_cm, + stride_ck: tl.constexpr, + stride_cn: tl.constexpr, + # Meta-parameters + M_NEXT_P2: tl.constexpr, + Q_GROUP_SIZE: tl.constexpr, + SPLIT_K_ITERS: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # ): - """dequantize weight kernel.""" - k_block_id = tl.program_id(0) - n_block_id = tl.program_id(1) - - offs_k = tl.arange(0, BLOCK_K) + k_block_id * BLOCK_K - offs_n = tl.arange(0, BLOCK_N) + n_block_id * BLOCK_N - offs_wk = offs_k - offs_wn = offs_n // elem_per_int - offs_sk = offs_k // group_size - offs_sn = offs_n - offs_zk = offs_sk - offs_zn = offs_wn - mask_k = offs_k < in_size - mask_n = offs_n < out_size - mask = mask_k[:, None] and mask_n[None, :] - unpacked_order = _get_unpacked_order(offs_n, elem_per_int) - - qw_ptr = QWeight + offs_wk[:, - None] * stride_wk + offs_wn[None, :] * stride_wn - s_ptr = Scales + offs_sk[:, - None] * stride_sk + offs_sn[None, :] * stride_sn - qz_ptr = Qzeros + offs_zk[:, - None] * stride_zk + offs_zn[None, :] * stride_zn - - out_dtype = Out.dtype.element_ty - qw = tl.load(qw_ptr, mask=mask) - s = tl.load(s_ptr, mask=mask).to(out_dtype) - qz = tl.load(qz_ptr, mask=mask) - - # unpack w, z - w = _unpack_weight(qw, unpacked_order) - z = _unpack_weight(qz, unpacked_order) - - o = (w - z).to(s.dtype) * s - o_ptr = Out + offs_k[:, None] * stride_ok + offs_n[None, :] * stride_on - tl.store(o_ptr, o, mask=mask) - - -def awq_dequantize_weights(qweight: Tensor, scales: Tensor, qzeros: Tensor): - """dequantize weights of awq.""" - in_size = qweight.size(0) - out_size = scales.size(1) - group_size = in_size // qzeros.size(0) - elem_per_int = out_size // qzeros.size(1) - - output = scales.new_empty(in_size, out_size, dtype=scales.dtype) - - BLOCK_K = 128 - BLOCK_N = 64 - kernel_meta = get_kernel_meta(qweight) - grid = (triton.cdiv(in_size, BLOCK_K), triton.cdiv(out_size, BLOCK_N)) - _dequantize_weights_kernel[grid]( + """Kernel for computing the matmul C = A x B. + A has shape (M, K), B has shape (K, N) and C has shape (M, N) + """ + ELEM_PER_INT = 8 + if Q_GROUP_SIZE > BLOCK_SIZE_K: + GROUP_SIZE_K: tl.constexpr = BLOCK_SIZE_K + else: + GROUP_SIZE_K: tl.constexpr = Q_GROUP_SIZE + K_PER_GROUP: tl.constexpr = Q_GROUP_SIZE // GROUP_SIZE_K + + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + # See above `L2 Cache Optimizations` section for details. + pid = tl.program_id(axis=0) + split_kid = tl.program_id(axis=1) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + BLOCK_SIZE_QN: tl.constexpr = BLOCK_SIZE_N // 8 + offs_wn = pid_n * BLOCK_SIZE_QN + tl.arange(0, BLOCK_SIZE_QN) + offs_k = tl.arange(0, GROUP_SIZE_K) + unpacked_order = _get_unpacked_order(offs_bn, ELEM_PER_INT) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + + offs_k[None, :] * stride_ak) + qw_ptrs = qw_ptr + (offs_k[:, None] * stride_wk + + offs_wn[None, :] * stride_wn) + s_ptrs = s_ptr + offs_bn * stride_sn + qz_ptrs = qz_ptr + offs_wn * stride_zn + + # split k + NUM_K_BLOCKS = K // GROUP_SIZE_K + K_PER_SPLIT = tl.cdiv(NUM_K_BLOCKS, SPLIT_K_ITERS) + k_start = split_kid * K_PER_SPLIT + k_last = min(k_start + K_PER_SPLIT, NUM_K_BLOCKS) + a_ptrs += k_start * GROUP_SIZE_K * stride_ak + qw_ptrs += k_start * GROUP_SIZE_K * stride_wk + qg_id = k_start // K_PER_GROUP + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + s = tl.zeros((1, BLOCK_SIZE_N), dtype=s_ptrs.dtype.element_ty) + zs = tl.zeros((1, BLOCK_SIZE_N), dtype=s_ptrs.dtype.element_ty) + + # prefetch + next_qw = tl.load(qw_ptrs) + qw_ptrs += GROUP_SIZE_K * stride_wk + + for k in range(k_start, k_last): + a = tl.load(a_ptrs) + qw = next_qw + if k + 1 < k_last: + next_qw = tl.load(qw_ptrs) + w = _unpack_weight(qw, unpacked_order) + + if k == k_start or k % K_PER_GROUP == 0: + s = tl.load(s_ptrs + qg_id * stride_sk)[None, :] + qz = tl.load(qz_ptrs + qg_id * stride_zk)[None, :] + qg_id += 1 + z = _unpack_weight(qz, unpacked_order) + zs = -z * s + b = w * s + zs + + # We accumulate along the K dimension. + accumulator += tl.dot(a, b) + + # Advance the ptrs to the next K block. + a_ptrs += GROUP_SIZE_K * stride_ak + qw_ptrs += GROUP_SIZE_K * stride_wk + + c = accumulator.to(tl.float16) + + # ----------------------------------------------------------- + # Write back the block of the output matrix C with masks. + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, + None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + if stride_ck > 0: + c_ptrs += split_kid * stride_ck + tl.store(c_ptrs, c, mask=c_mask) + else: + tl.atomic_add(c_ptrs, c, mask=c_mask) + + +def awq_linear(x, qweight, scales, qzeros): + """awq linear.""" + M = x.size(0) + K = qweight.size(0) + N = scales.size(1) + SPLIT_K_ITERS = 4 + group_size = K // scales.size(0) + + def grid(META): + """grid.""" + return (triton.cdiv(M, META['BLOCK_SIZE_M']) * + triton.cdiv(N, META['BLOCK_SIZE_N']), SPLIT_K_ITERS) + + out = scales.new_empty(M, SPLIT_K_ITERS, N) + M_NEXT_P2 = triton.next_power_of_2(M) + + kernel_meta = get_kernel_meta(x) + awq_linear_kernel[grid]( + # Pointers to matrices + x, qweight, scales, qzeros, - output, - in_size, - out_size, - group_size, - elem_per_int, - qweight.stride(0), - qweight.stride(1), - scales.stride(0), - scales.stride(1), - qzeros.stride(0), - qzeros.stride(1), - output.stride(0), - output.stride(1), - BLOCK_K=BLOCK_K, - BLOCK_N=BLOCK_N, - **kernel_meta, - ) - - return output + out, + # Matrix dimensions + M, + N, + K, + stride_am=x.stride(0), + stride_ak=x.stride(1), # + stride_wk=qweight.stride(0), + stride_wn=qweight.stride(1), # + stride_sk=scales.stride(0), + stride_sn=scales.stride(1), # + stride_zk=qzeros.stride(0), + stride_zn=qzeros.stride(1), # + stride_cm=out.stride(0), + stride_ck=out.stride(1), + stride_cn=out.stride(2), + # Meta-parameters + M_NEXT_P2=M_NEXT_P2, + Q_GROUP_SIZE=group_size, + SPLIT_K_ITERS=SPLIT_K_ITERS, + **kernel_meta) + + return out.sum(1) diff --git a/lmdeploy/pytorch/kernels/cuda/fused_moe.py b/lmdeploy/pytorch/kernels/cuda/fused_moe.py index 1e6d3ba96..37a4ddca9 100644 --- a/lmdeploy/pytorch/kernels/cuda/fused_moe.py +++ b/lmdeploy/pytorch/kernels/cuda/fused_moe.py @@ -79,7 +79,7 @@ def get_cuda_autotune_config(): @triton.autotune( configs=get_cuda_autotune_config(), - key=['N', 'K'], + key=['N', 'K', 'M_NP2'], ) @wrap_jit_func(type_hint=dict( A=torch.Tensor, @@ -130,6 +130,7 @@ def fused_moe_kernel( BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, + M_NP2: tl.constexpr, ENABLE_WEIGHTS: tl.constexpr, top_k: tl.constexpr, expert_offset: tl.constexpr, @@ -169,9 +170,11 @@ def fused_moe_kernel( offs_am = offs_sid a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), + BLOCK_SIZE_N) # deepseek has 160 experts, exp index would overflow int32 - exp_off = tl.full((1, ), stride_be, dtype=tl.int64) * exp_id + exp_off = stride_be * exp_id.to(tl.int64) b_ptrs = B + exp_off + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) @@ -198,11 +201,9 @@ def fused_moe_kernel( if reindex_c: offs_cm = sid else: - offs_cm = exp_start + pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] - c_mask = mask_sid[:, None] & (offs_cn[None, :] < N) - tl.store(c_ptrs, c, mask=c_mask) + offs_cm = offs_sid + c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_bn[None, :] + tl.store(c_ptrs, c, mask=mask_sid[:, None]) def fused_moe_kernel_launcher( @@ -224,6 +225,7 @@ def fused_moe_kernel_launcher( if num_tokens is None: num_tokens = A.size(0) + M_NP2 = triton.next_power_of_2(num_tokens) E, N, K = B.shape def _grid_fn(META): @@ -263,6 +265,7 @@ def _grid_fn(META): reindex_a=reindex_a, reindex_c=reindex_c, GROUP_SIZE_M=GROUP_SIZE_M, + M_NP2=M_NP2, **kernel_meta, ) From 973d22248a8f29404d10ac9dd14f3dbc8e430900 Mon Sep 17 00:00:00 2001 From: grimoire Date: Fri, 16 Aug 2024 19:57:03 +0800 Subject: [PATCH 27/89] optimize attention --- .../pytorch/backends/cuda/graph_runner.py | 2 + lmdeploy/pytorch/engine/engine.py | 2 +- .../pytorch/kernels/cuda/pagedattention.py | 488 +++++------------- 3 files changed, 124 insertions(+), 368 deletions(-) diff --git a/lmdeploy/pytorch/backends/cuda/graph_runner.py b/lmdeploy/pytorch/backends/cuda/graph_runner.py index 4f4f47b91..4c41115e0 100644 --- a/lmdeploy/pytorch/backends/cuda/graph_runner.py +++ b/lmdeploy/pytorch/backends/cuda/graph_runner.py @@ -123,7 +123,9 @@ def _fill_inputs(self, input_ids: torch.Tensor, position_ids: torch.Tensor, self.input_buffers['position_ids'][:, :num_tokens] = position_ids self.input_buffers[ 'block_offsets'][:batch_size, :num_blocks] = block_offsets + self.input_buffers['q_seqlens'].zero_() self.input_buffers['q_seqlens'][:batch_size] = q_seqlens + self.input_buffers['kv_seqlens'].zero_() self.input_buffers['kv_seqlens'][:batch_size] = kv_seqlens self.input_buffers['q_start_loc'][:batch_size] = q_start_loc if inputs_embeds is not None: diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 7889a4bc0..4982ec7a8 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -122,7 +122,7 @@ def __init__(self, max_batches=engine_config.max_batch_size, max_session_len=engine_config.session_len, eviction_type=engine_config.eviction_type, - prefill_interval=engine_config.prefill_interval) + prefill_interval=16) # block_size = 1 to enable unified paging adapters = engine_config.adapters diff --git a/lmdeploy/pytorch/kernels/cuda/pagedattention.py b/lmdeploy/pytorch/kernels/cuda/pagedattention.py index 722eae6fc..16a83965d 100644 --- a/lmdeploy/pytorch/kernels/cuda/pagedattention.py +++ b/lmdeploy/pytorch/kernels/cuda/pagedattention.py @@ -17,204 +17,6 @@ assert TRITON_VERSION >= version.parse('2.1.0') -@triton.jit -def _load_block_offsets(offset_ptr, block_id, BLOCK: tl.constexpr): - """load block offsets.""" - offs_n = tl.arange(0, BLOCK) - return tl.load(offset_ptr + block_id) * BLOCK + offs_n - - -@triton.autotune(configs=[ - triton.Config({}, num_stages=1, num_warps=16), - triton.Config({}, num_stages=1, num_warps=8), - triton.Config({}, num_stages=1, num_warps=4), -], - key=['BLOCK_N', 'BLOCK_DMODEL', 'BLOCK_DV']) -@wrap_jit_func(type_hint=dict( - Q=torch.Tensor, - K=torch.Tensor, - V=torch.Tensor, - sm_scale=float, - KV_seqlens=torch.Tensor, - Block_offsets=torch.Tensor, - Acc_out=torch.Tensor, - stride_qbs=int, - stride_qh=int, - stride_qd=int, - stride_kbs=int, - stride_kh=int, - stride_kd=int, - stride_vbs=int, - stride_vh=int, - stride_vd=int, - stride_ok=int, - stride_obs=int, - stride_oh=int, - stride_od=int, - stride_boffb=int, - kv_group_num=torch.int32, - block_per_cta=torch.int32, - window_size=torch.int32, - head_size=torch.int32, - head_size_v=torch.int32, - shared_kv=bool, - BLOCK_DMODEL=torch.int32, - BLOCK_DV=torch.int32, - BLOCK_N=torch.int32, -)) -@triton.jit -def _fwd_split_kernel( - Q, - K, - V, - sm_scale, - KV_seqlens, - Block_offsets, - Acc_out, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_ok, - stride_obs, - stride_oh, - stride_od, - stride_boffb, - kv_group_num, - block_per_cta, - window_size: tl.constexpr, - head_size: tl.constexpr, - head_size_v: tl.constexpr, - shared_kv: tl.constexpr, - logit_softcapping: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_DV: tl.constexpr, - BLOCK_N: tl.constexpr, -): - """first step kernel of split k attention.""" - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - split_k_id = tl.program_id(2) - - cur_kv_head = cur_head // kv_group_num - - q_seqlen = 1 - kv_seqlen = tl.load(KV_seqlens + cur_batch) - history_len = kv_seqlen - q_seqlen - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - mask_d = offs_d < head_size - offs_dv = tl.arange(0, BLOCK_DV) - mask_dv = offs_dv < head_size_v - off_q = (cur_batch * stride_qbs + cur_head * stride_qh + - offs_d * stride_qd) - off_k = (cur_kv_head * stride_kh + offs_d[None, :] * stride_kd) - off_v = (cur_kv_head * stride_vh + offs_dv[None, :] * stride_vd) - - q = tl.load(Q + off_q, mask=mask_d, other=0).to(tl.float32) - - k_ptrs = K + off_k - v_ptrs = V + off_v - - block_offset_ptrs = Block_offsets + cur_batch * stride_boffb - - # initialize pointer to m and l - m_i = -float('inf') - l_i = float(0) - acc = tl.zeros([BLOCK_DV], dtype=tl.float32) - - kv_len_per_prog = block_per_cta * BLOCK_N - loop_start = kv_len_per_prog * split_k_id - loop_end = tl.minimum(loop_start + kv_len_per_prog, kv_seqlen) - - # load block offset - # dirty - start_block_id = loop_start // BLOCK_N - if window_size > 0: - start_block_id = tl.maximum(history_len - window_size, - loop_start) // BLOCK_N - kv_min_loc = tl.maximum(history_len - window_size, 0) - b_offset = _load_block_offsets(block_offset_ptrs, start_block_id, BLOCK_N) - - loop_start = start_block_id * BLOCK_N - for start_n in range(loop_start, loop_end, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - - mask = (start_n + offs_n[:, None]) < kv_seqlen - - # -- compute qk ---- - k = tl.load( - k_ptrs + b_offset[:, None] * stride_kbs, - mask=mask & mask_d[None, :], - other=0.0, - ) - - if shared_kv: - v = k - else: - v = tl.load( - v_ptrs + b_offset[:, None] * stride_vbs, - mask=mask & mask_dv[None, :], - other=0.0, - ) - - # prefetch b_offset - if start_n + BLOCK_N < loop_end: - start_block_id += 1 - b_offset = _load_block_offsets(block_offset_ptrs, start_block_id, - BLOCK_N) - - qk = tl.sum(q[None, :] * k, 1) - qk *= sm_scale - if logit_softcapping > 0.0: - qk = qk / logit_softcapping - qk = tl.math.tanh(qk) - qk = qk * logit_softcapping - # NOTE: inf - inf = nan, and nan will leads to error - qk_mask = history_len >= (start_n + offs_n) - if window_size > 0: - qk_mask = qk_mask and ((start_n + offs_n) >= kv_min_loc) - qk = tl.where( - qk_mask, - qk, - -float('inf'), - ) - - # -- compute p, m_i and l_i - m_i_new = tl.maximum(m_i, tl.max(qk, 0)) - p = tl.exp(qk - m_i_new) - alpha = tl.exp(m_i - m_i_new) - l_i_new = alpha * l_i + tl.sum(p, 0) - - # -- update output accumulator -- - # scale acc - acc = acc * alpha - - # update acc - p_new = p.to(v.dtype) - acc += tl.sum(p_new[:, None] * v, 0) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - - # initialize pointers to output - off_acc = (cur_batch * stride_obs + split_k_id * stride_ok + - cur_head * stride_oh + offs_dv * stride_od) - tl.store(Acc_out + off_acc, acc, mask=mask_dv) - - off_meta = (cur_batch * stride_obs + split_k_id * stride_ok + - cur_head * stride_oh + head_size_v) - tl.store(Acc_out + off_meta + tl.arange(0, 1), m_i) - tl.store(Acc_out + off_meta + 1 + tl.arange(0, 1), l_i) - - @triton.autotune(configs=[ triton.Config({}, num_stages=1, num_warps=16), triton.Config({}, num_stages=1, num_warps=8), @@ -264,28 +66,30 @@ def _fwd_grouped_split_kernel( KV_seqlens, Block_offsets, Acc_out, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_ok, - stride_obs, - stride_oh, - stride_od, + stride_qbs: tl.constexpr, + stride_qh: tl.constexpr, + stride_qd: tl.constexpr, + stride_kp: tl.constexpr, + stride_kbs: tl.constexpr, + stride_kh: tl.constexpr, + stride_kd: tl.constexpr, + stride_vp: tl.constexpr, + stride_vbs: tl.constexpr, + stride_vh: tl.constexpr, + stride_vd: tl.constexpr, + stride_ok: tl.constexpr, + stride_obs: tl.constexpr, + stride_oh: tl.constexpr, + stride_od: tl.constexpr, stride_boffb, kv_group_num: tl.constexpr, - block_per_cta, window_size: tl.constexpr, head_size: tl.constexpr, head_size_v: tl.constexpr, num_heads_q: tl.constexpr, shared_kv: tl.constexpr, logit_softcapping: tl.constexpr, + SPLIT_K: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_DV: tl.constexpr, BLOCK_N: tl.constexpr, @@ -297,9 +101,12 @@ def _fwd_grouped_split_kernel( cur_kv_head = tl.program_id(1) split_k_id = tl.program_id(2) - heads_per_cta = min(BLOCK_H, kv_group_num) - cur_head = cur_kv_head * heads_per_cta + tl.arange(0, BLOCK_H) - mask_h = cur_head < cur_kv_head * heads_per_cta + heads_per_cta + if BLOCK_H < kv_group_num: + HEAD_PER_CTA: tl.constexpr = BLOCK_H + else: + HEAD_PER_CTA: tl.constexpr = kv_group_num + cur_head = cur_kv_head * HEAD_PER_CTA + tl.arange(0, BLOCK_H) + mask_h = cur_head < cur_kv_head * HEAD_PER_CTA + HEAD_PER_CTA mask_h = mask_h & (cur_head < num_heads_q) q_seqlen = 1 @@ -310,10 +117,14 @@ def _fwd_grouped_split_kernel( offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_DMODEL) mask_d = offs_d < head_size + offs_d = offs_d % head_size offs_dv = tl.arange(0, BLOCK_DV) mask_dv = offs_dv < head_size_v - off_k = (cur_kv_head * stride_kh + offs_d[:, None] * stride_kd) - off_v = (cur_kv_head * stride_vh + offs_dv[None, :] * stride_vd) + offs_dv = offs_dv % head_size_v + off_k = (cur_kv_head * stride_kh + offs_d[:, None] * stride_kd + + offs_n[None, :] * stride_kbs) + off_v = (cur_kv_head * stride_vh + offs_dv[None, :] * stride_vd + + offs_n[:, None] * stride_vbs) off_q = (cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :] * stride_qd) @@ -325,12 +136,14 @@ def _fwd_grouped_split_kernel( if BLOCK_DMODEL1 != 0: offs_d1 = BLOCK_DMODEL + tl.arange(0, BLOCK_DMODEL1) mask_d1 = offs_d1 < head_size + offs_d1 = offs_d1 % head_size off_q1 = (cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d1[None, :] * stride_qd) q1 = tl.load(Q + off_q1, mask=mask_h[:, None] & mask_d1[None, :], other=0) - off_k1 = (cur_kv_head * stride_kh + offs_d1[:, None] * stride_kd) + off_k1 = (cur_kv_head * stride_kh + offs_d1[:, None] * stride_kd + + offs_n[None, :] * stride_kbs) k1_ptrs = K + off_k1 block_offset_ptrs = Block_offsets + cur_batch * stride_boffb @@ -340,7 +153,9 @@ def _fwd_grouped_split_kernel( l_i = tl.zeros([BLOCK_H], dtype=tl.float32) acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32) - kv_len_per_prog = block_per_cta * BLOCK_N + num_total_blocks = tl.cdiv(kv_seqlen, BLOCK_N) + BLOCK_PER_CTA = tl.cdiv(num_total_blocks, SPLIT_K) + kv_len_per_prog = BLOCK_PER_CTA * BLOCK_N loop_start = kv_len_per_prog * split_k_id loop_end = tl.minimum(loop_start + kv_len_per_prog, kv_seqlen) @@ -351,41 +166,26 @@ def _fwd_grouped_split_kernel( start_block_id = tl.maximum(history_len - window_size, loop_start) // BLOCK_N kv_min_loc = tl.maximum(history_len - window_size, 0) - b_offset = _load_block_offsets(block_offset_ptrs, start_block_id, BLOCK_N) + b_offset = tl.load(block_offset_ptrs + start_block_id) loop_start = start_block_id * BLOCK_N for start_n in range(loop_start, loop_end, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) - mask = (start_n + offs_n) < kv_seqlen - # -- compute qk ---- - k = tl.load( - k_ptrs + b_offset[None, :] * stride_kbs, - mask=mask[None, :] & mask_d[:, None], - other=0.0, - ) + k = tl.load(k_ptrs + b_offset * stride_kp) if BLOCK_DMODEL1 != 0: - k1 = tl.load( - k1_ptrs + b_offset[None, :] * stride_kbs, - mask=mask[None, :] & mask_d1[:, None], - other=0.0, - ) + k1 = tl.load(k1_ptrs + b_offset * stride_kp) if shared_kv: v = tl.trans(k) else: - v = tl.load( - v_ptrs + b_offset[:, None] * stride_vbs, - mask=mask[:, None] & mask_dv[None, :], - other=0.0, - ) + v = tl.load(v_ptrs + b_offset * stride_vp) # prefetch b_offset if start_n + BLOCK_N < loop_end: start_block_id += 1 - b_offset = _load_block_offsets(block_offset_ptrs, start_block_id, - BLOCK_N) + b_offset = tl.load(block_offset_ptrs + start_block_id) qk = tl.zeros([BLOCK_H, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) @@ -408,8 +208,8 @@ def _fwd_grouped_split_kernel( # -- compute p, m_i and l_i m_i_new = tl.maximum(m_i, tl.max(qk, 1)) - p = tl.exp(qk - m_i_new[:, None]) - alpha = tl.exp(m_i - m_i_new) + p = tl.math.fast_expf(qk - m_i_new[:, None]) + alpha = tl.math.fast_expf(m_i - m_i_new) l_i_new = alpha * l_i + tl.sum(p, 1) # -- update output accumulator -- @@ -482,7 +282,7 @@ def _reduce_split_kernel( l_k = tl.load(Acc + offs_mi + 1) m_max = tl.max(m_k, 0) - alpha = tl.exp(m_k - m_max) + alpha = tl.math.fast_expf(m_k - m_max) acc_k = acc_k * alpha[:, None] l_k = l_k * alpha @@ -537,20 +337,22 @@ def _fwd_kernel( KV_seqlens, Block_offsets, Out, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, + stride_qbs: tl.constexpr, + stride_qh: tl.constexpr, + stride_qd: tl.constexpr, + stride_kp: tl.constexpr, + stride_kbs: tl.constexpr, + stride_kh: tl.constexpr, + stride_kd: tl.constexpr, + stride_vp: tl.constexpr, + stride_vbs: tl.constexpr, + stride_vh: tl.constexpr, + stride_vd: tl.constexpr, + stride_obs: tl.constexpr, + stride_oh: tl.constexpr, + stride_od: tl.constexpr, stride_boffb, - kv_group_num, + kv_group_num: tl.constexpr, window_size: tl.constexpr, head_size: tl.constexpr, head_size_v: tl.constexpr, @@ -575,18 +377,24 @@ def _fwd_kernel( history_len = kv_seqlen - q_seqlen block_start_loc = BLOCK_M * start_m + if block_start_loc >= q_seqlen: + return # initialize offsets offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_DMODEL) offs_dv = tl.arange(0, BLOCK_DV) mask_d = offs_d < head_size + offs_d = offs_d % head_size mask_dv = offs_dv < head_size_v + offs_d = offs_d % head_size_v offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) off_q = ((q_start_loc + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd) - off_k = (cur_kv_head * stride_kh + offs_d[:, None] * stride_kd) - off_v = (cur_kv_head * stride_vh + offs_dv[None, :] * stride_vd) + off_k = (cur_kv_head * stride_kh + offs_d[:, None] * stride_kd + + offs_n[None, :] * stride_kbs) + off_v = (cur_kv_head * stride_vh + offs_dv[None, :] * stride_vd + + offs_n[:, None] * stride_vbs) q = tl.load(Q + off_q, mask=(offs_m[:, None] < q_seqlen) & mask_d[None, :], @@ -598,10 +406,12 @@ def _fwd_kernel( if BLOCK_DMODEL1 != 0: offs_d1 = BLOCK_DMODEL + tl.arange(0, BLOCK_DMODEL1) mask_d1 = offs_d1 < head_size + offs_d1 = offs_d1 % head_size off_q1 = ((q_start_loc + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d1[None, :] * stride_qd) q1 = tl.load(Q + off_q1, mask=(offs_m[:, None] < q_seqlen) & mask_d1) - off_k1 = (cur_kv_head * stride_kh + offs_d1[:, None] * stride_kd) + off_k1 = (cur_kv_head * stride_kh + offs_d1[:, None] * stride_kd + + offs_n[None, :] * stride_kbs) k1_ptrs = K + off_k1 block_offset_ptrs = Block_offsets + cur_batch * stride_boffb @@ -611,45 +421,28 @@ def _fwd_kernel( l_i = tl.zeros([BLOCK_M], dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32) - block_mask = tl.where(block_start_loc < q_seqlen, 1, 0) - # this is dirty start_block_id = kv_seqlen - kv_seqlen if window_size > 0: start_block_id = tl.maximum(history_len - window_size, 0) // BLOCK_N kv_min_loc = tl.maximum(history_len + offs_m - window_size, 0) - b_offset = _load_block_offsets(block_offset_ptrs, start_block_id, BLOCK_N) + b_offset = tl.load(block_offset_ptrs + start_block_id) kv_start_loc = start_block_id * BLOCK_N - for start_n in range(kv_start_loc, block_mask * kv_seqlen, BLOCK_N): + for start_n in range(kv_start_loc, kv_seqlen, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- - k = tl.load( - k_ptrs + b_offset[None, :] * stride_kbs, - mask=(start_n + offs_n[None, :] < kv_seqlen) & mask_d[:, None], - other=0.0, - ) + k = tl.load(k_ptrs + b_offset * stride_kp) if BLOCK_DMODEL1 != 0: - k1 = tl.load( - k1_ptrs + b_offset[None, :] * stride_kbs, - mask=(start_n + offs_n[None, :] < kv_seqlen) - & mask_d1[:, None], - other=0.0, - ) + k1 = tl.load(k1_ptrs + b_offset * stride_kp) if shared_kv: v = tl.trans(k) else: - v = tl.load( - v_ptrs + b_offset[:, None] * stride_vbs, - mask=(start_n + offs_n[:, None] < kv_seqlen) - & mask_dv[None, :], - other=0.0, - ) + v = tl.load(v_ptrs + b_offset * stride_vp) if start_n + BLOCK_N < kv_seqlen: start_block_id = start_n // BLOCK_N + 1 - b_offset = _load_block_offsets(block_offset_ptrs, start_block_id, - BLOCK_N) + b_offset = tl.load(block_offset_ptrs + start_block_id) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) @@ -674,8 +467,8 @@ def _fwd_kernel( # -- compute p, m_i and l_i m_i_new = tl.maximum(m_i, tl.max(qk, 1)) - p = tl.exp(qk - m_i_new[:, None]) - alpha = tl.exp(m_i - m_i_new) + p = tl.math.fast_expf(qk - m_i_new[:, None]) + alpha = tl.math.fast_expf(m_i - m_i_new) l_i_new = alpha * l_i + tl.sum(p, 1) # -- update output accumulator -- # scale acc @@ -688,7 +481,7 @@ def _fwd_kernel( l_i = l_i_new m_i = m_i_new - acc = acc / l_i[:, None] + acc = tl.math.fast_dividef(acc, l_i[:, None]) # initialize pointers to output off_o = ((q_start_loc + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_dv[None, :] * stride_od) @@ -787,9 +580,11 @@ def _get_block_d(Lk): stride_qbs=q.stride(-3), stride_qh=q.stride(-2), stride_qd=q.stride(-1), + stride_kp=k.stride(-4), stride_kbs=k.stride(-3), stride_kh=k.stride(-2), stride_kd=k.stride(-1), + stride_vp=v.stride(-4), stride_vbs=v.stride(-3), stride_vh=v.stride(-2), stride_vd=v.stride(-1), @@ -813,90 +608,49 @@ def _get_block_d(Lk): **kernel_meta) else: SPLIT_K = 4 - block_per_cta = triton.cdiv(block_offsets.size(-1), SPLIT_K) acc = q.new_empty(batch, head, SPLIT_K, Lv + 2, dtype=torch.float32) - if False: - # TODO: make sure if we need to keep this kernel - BLOCK_DMODEL = triton.next_power_of_2(Lk) - if shared_kv: - BLOCK_DV = BLOCK_DMODEL - else: - BLOCK_DV = triton.next_power_of_2(Lv) - grid = (batch, head, SPLIT_K) - _fwd_split_kernel[grid](q, - k, - v, - sm_scale, - kv_seqlens, - block_offsets, - acc, - stride_qbs=q.stride(-3), - stride_qh=q.stride(-2), - stride_qd=q.stride(-1), - stride_kbs=k.stride(-3), - stride_kh=k.stride(-2), - stride_kd=k.stride(-1), - stride_vbs=v.stride(-3), - stride_vh=v.stride(-2), - stride_vd=v.stride(-1), - stride_ok=acc.stride(-2), - stride_obs=acc.stride(-4), - stride_oh=acc.stride(-3), - stride_od=acc.stride(-1), - stride_boffb=block_offsets.stride(0), - kv_group_num=kv_group_num, - block_per_cta=block_per_cta, - window_size=window_size, - head_size=Lk, - head_size_v=Lv, - shared_kv=shared_kv, - logit_softcapping=logit_softcapping, - BLOCK_DMODEL=BLOCK_DMODEL, - BLOCK_DV=BLOCK_DV, - BLOCK_N=BLOCK, - **kernel_meta) - else: - BLOCK_DMODEL, BLOCK_DMODEL1, BLOCK_DV = _get_block_d(Lk) - p2_kv_group_num = triton.next_power_of_2(kv_group_num) - BLOCK_H = max(16, min(BLOCK, p2_kv_group_num)) - grid_1 = triton.cdiv(head, min(BLOCK_H, kv_group_num)) - grid = (batch, grid_1, SPLIT_K) - _fwd_grouped_split_kernel[grid]( - q, - k, - v, - sm_scale, - kv_seqlens, - block_offsets, - acc, - stride_qbs=q.stride(-3), - stride_qh=q.stride(-2), - stride_qd=q.stride(-1), - stride_kbs=k.stride(-3), - stride_kh=k.stride(-2), - stride_kd=k.stride(-1), - stride_vbs=v.stride(-3), - stride_vh=v.stride(-2), - stride_vd=v.stride(-1), - stride_ok=acc.stride(-2), - stride_obs=acc.stride(-4), - stride_oh=acc.stride(-3), - stride_od=acc.stride(-1), - stride_boffb=block_offsets.stride(0), - kv_group_num=kv_group_num, - block_per_cta=block_per_cta, - window_size=window_size, - head_size=Lk, - head_size_v=Lv, - num_heads_q=head, - shared_kv=shared_kv, - logit_softcapping=logit_softcapping, - BLOCK_DMODEL=BLOCK_DMODEL, - BLOCK_DV=BLOCK_DV, - BLOCK_N=BLOCK, - BLOCK_H=BLOCK_H, - BLOCK_DMODEL1=BLOCK_DMODEL1, - **kernel_meta) + BLOCK_DMODEL, BLOCK_DMODEL1, BLOCK_DV = _get_block_d(Lk) + p2_kv_group_num = triton.next_power_of_2(kv_group_num) + BLOCK_H = max(16, min(BLOCK, p2_kv_group_num)) + grid_1 = triton.cdiv(head, min(BLOCK_H, kv_group_num)) + grid = (batch, grid_1, SPLIT_K) + _fwd_grouped_split_kernel[grid](q, + k, + v, + sm_scale, + kv_seqlens, + block_offsets, + acc, + stride_qbs=q.stride(-3), + stride_qh=q.stride(-2), + stride_qd=q.stride(-1), + stride_kp=k.stride(-4), + stride_kbs=k.stride(-3), + stride_kh=k.stride(-2), + stride_kd=k.stride(-1), + stride_vp=v.stride(-4), + stride_vbs=v.stride(-3), + stride_vh=v.stride(-2), + stride_vd=v.stride(-1), + stride_ok=acc.stride(-2), + stride_obs=acc.stride(-4), + stride_oh=acc.stride(-3), + stride_od=acc.stride(-1), + stride_boffb=block_offsets.stride(0), + kv_group_num=kv_group_num, + window_size=window_size, + head_size=Lk, + head_size_v=Lv, + num_heads_q=head, + shared_kv=shared_kv, + logit_softcapping=logit_softcapping, + SPLIT_K=SPLIT_K, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_DV=BLOCK_DV, + BLOCK_N=BLOCK, + BLOCK_H=BLOCK_H, + BLOCK_DMODEL1=BLOCK_DMODEL1, + **kernel_meta) num_warps = 4 grid = (batch, head) From 871b788d96125b043e7e29a8f7f55585d7a77a8f Mon Sep 17 00:00:00 2001 From: grimoire Date: Mon, 19 Aug 2024 15:27:55 +0800 Subject: [PATCH 28/89] fix graph runner --- .../pytorch/backends/cuda/graph_runner.py | 9 ++-- .../pytorch/kernels/cuda/pagedattention.py | 42 ++++++++++--------- 2 files changed, 29 insertions(+), 22 deletions(-) diff --git a/lmdeploy/pytorch/backends/cuda/graph_runner.py b/lmdeploy/pytorch/backends/cuda/graph_runner.py index 4c41115e0..80215148b 100644 --- a/lmdeploy/pytorch/backends/cuda/graph_runner.py +++ b/lmdeploy/pytorch/backends/cuda/graph_runner.py @@ -123,9 +123,12 @@ def _fill_inputs(self, input_ids: torch.Tensor, position_ids: torch.Tensor, self.input_buffers['position_ids'][:, :num_tokens] = position_ids self.input_buffers[ 'block_offsets'][:batch_size, :num_blocks] = block_offsets - self.input_buffers['q_seqlens'].zero_() + if q_seqlens.data_ptr() != self.input_buffers['q_seqlens'].data_ptr(): + self.input_buffers['q_seqlens'].zero_() self.input_buffers['q_seqlens'][:batch_size] = q_seqlens - self.input_buffers['kv_seqlens'].zero_() + if kv_seqlens.data_ptr() != self.input_buffers['kv_seqlens'].data_ptr( + ): + self.input_buffers['kv_seqlens'].zero_() self.input_buffers['kv_seqlens'][:batch_size] = kv_seqlens self.input_buffers['q_start_loc'][:batch_size] = q_start_loc if inputs_embeds is not None: @@ -188,6 +191,7 @@ def capture(self, **kwargs): output = self.model(**padded_kwargs) self.output_buffers['logits'] = output + return output def forward(self, **kwargs): """forward.""" @@ -275,7 +279,6 @@ def __call__(self, **kwargs): self._runner_map[graph_key] = runner else: runner = self._runner_map[graph_key] - output = runner.forward(**kwargs) return output diff --git a/lmdeploy/pytorch/kernels/cuda/pagedattention.py b/lmdeploy/pytorch/kernels/cuda/pagedattention.py index 16a83965d..5355adab7 100644 --- a/lmdeploy/pytorch/kernels/cuda/pagedattention.py +++ b/lmdeploy/pytorch/kernels/cuda/pagedattention.py @@ -111,6 +111,8 @@ def _fwd_grouped_split_kernel( q_seqlen = 1 kv_seqlen = tl.load(KV_seqlens + cur_batch) + if kv_seqlen <= 0: + return history_len = kv_seqlen - q_seqlen # initialize offsets @@ -197,14 +199,15 @@ def _fwd_grouped_split_kernel( qk = tl.math.tanh(qk) qk = qk * logit_softcapping # NOTE: inf - inf = nan, and nan will leads to error - qk_mask = history_len >= (start_n + offs_n) - if window_size > 0: - qk_mask = qk_mask and ((start_n + offs_n) >= kv_min_loc) - qk = tl.where( - qk_mask[None, :], - qk, - -float('inf'), - ) + if start_n + BLOCK_N > history_len or window_size > 0: + qk_mask = history_len >= (start_n + offs_n) + if window_size > 0: + qk_mask = qk_mask and ((start_n + offs_n) >= kv_min_loc) + qk = tl.where( + qk_mask[None, :], + qk, + -float('inf'), + ) # -- compute p, m_i and l_i m_i_new = tl.maximum(m_i, tl.max(qk, 1)) @@ -387,7 +390,7 @@ def _fwd_kernel( mask_d = offs_d < head_size offs_d = offs_d % head_size mask_dv = offs_dv < head_size_v - offs_d = offs_d % head_size_v + offs_dv = offs_dv % head_size_v offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) off_q = ((q_start_loc + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd) @@ -454,16 +457,17 @@ def _fwd_kernel( qk = tl.math.tanh(qk) qk = qk * logit_softcapping # NOTE: inf - inf = nan, and nan will leads to error - qk_mask = (history_len + offs_m[:, None]) >= (start_n + - offs_n[None, :]) - if window_size > 0: - qk_mask = qk_mask and ( - (start_n + offs_n[None, :]) >= kv_min_loc[:, None]) - qk = tl.where( - qk_mask, - qk, - float(-1e30), - ) + if start_n + BLOCK_N > history_len or window_size > 0: + qk_mask = (history_len + offs_m[:, None]) >= (start_n + + offs_n[None, :]) + if window_size > 0: + qk_mask = qk_mask and ( + (start_n + offs_n[None, :]) >= kv_min_loc[:, None]) + qk = tl.where( + qk_mask, + qk, + float(-1e30), + ) # -- compute p, m_i and l_i m_i_new = tl.maximum(m_i, tl.max(qk, 1)) From 09149ac53bfc918ad913de072a8b38b45ffac300 Mon Sep 17 00:00:00 2001 From: grimoire Date: Mon, 19 Aug 2024 21:46:41 +0800 Subject: [PATCH 29/89] optimize prefill --- lmdeploy/messages.py | 2 +- .../pytorch/backends/cuda/graph_runner.py | 36 +++++++------------ lmdeploy/pytorch/engine/engine.py | 12 +++++-- lmdeploy/pytorch/engine/engine_instance.py | 18 +++++----- 4 files changed, 31 insertions(+), 37 deletions(-) diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index 2fdcc35dd..c84e659c6 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -203,7 +203,7 @@ class PytorchEngineConfig: max_batch_size: int = 128 cache_max_entry_count: float = 0.8 eviction_type: str = 'recompute' - prefill_interval: int = 16 + prefill_interval: int = 8 block_size: int = 64 num_cpu_blocks: int = 0 num_gpu_blocks: int = 0 diff --git a/lmdeploy/pytorch/backends/cuda/graph_runner.py b/lmdeploy/pytorch/backends/cuda/graph_runner.py index 80215148b..131494d35 100644 --- a/lmdeploy/pytorch/backends/cuda/graph_runner.py +++ b/lmdeploy/pytorch/backends/cuda/graph_runner.py @@ -58,24 +58,14 @@ def make_buffers(self): max_tokens = self.max_tokens num_blocks = self.num_blocks device = self.device - is_decoding = self.is_decoding - if is_decoding: - self.input_buffers['input_ids'] = torch.zeros(1, - max_batches, - dtype=torch.int64, - device=device) - self.input_buffers['position_ids'] = torch.zeros((1, max_batches), - dtype=torch.int64, - device=device) - else: - self.input_buffers['input_ids'] = torch.zeros(1, - max_tokens, - dtype=torch.int64, - device=device) - self.input_buffers['position_ids'] = torch.zeros((1, max_tokens), - dtype=torch.int64, - device=device) + self.input_buffers['input_ids'] = torch.zeros(1, + max_tokens, + dtype=torch.int64, + device=device) + self.input_buffers['position_ids'] = torch.zeros((1, max_tokens), + dtype=torch.int64, + device=device) self.input_buffers['block_offsets'] = torch.zeros( (max_batches, num_blocks), dtype=torch.int64, device=device) @@ -254,10 +244,9 @@ def get_graph_key(self, input_ids: torch.Tensor, """get graph key.""" context = self.ctx_mgr.current_context() is_decoding = context.is_decoding - block_offsets = attn_metadata.block_offsets - batch_size, _ = block_offsets.size() - new_batch_size = next_power_of_2(batch_size) - return (new_batch_size, is_decoding) + num_tokens = input_ids.numel() + new_num_tokens = next_power_of_2(num_tokens) + return (new_num_tokens, is_decoding) def __call__(self, **kwargs): """call.""" @@ -265,12 +254,13 @@ def __call__(self, **kwargs): return self.model(**kwargs) graph_key = self.get_graph_key(**kwargs) - max_batches = graph_key[0] + max_tokens = graph_key[0] is_decoding = graph_key[1] if graph_key not in self._runner_map: + max_batches = max_tokens if is_decoding else self.max_batches runner = CUDASingleGraphRunner(self.model, max_batches=max_batches, - max_tokens=self.max_tokens, + max_tokens=max_tokens, num_blocks=self.num_blocks, is_decoding=is_decoding, pool=self.graph_pool_handle, diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 4982ec7a8..04173322c 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -259,31 +259,37 @@ def _on_add_session(self, reqs: Request, **kwargs): """on add session callback.""" for req in reqs: session_id = req.data['session_id'] + resp = req.data.get('response', True) resp_type = ResponseType.SESSION_REPEAT if session_id not in self.scheduler.sessions: self.scheduler.add_session(session_id) resp_type = ResponseType.SUCCESS - self._response(resp_type, req.sender_id, req.req_id) + if resp: + self._response(resp_type, req.sender_id, req.req_id) def _on_stop_session(self, reqs: Request, **kwargs): """on stop session callback.""" for req in reqs: session_id = req.data['session_id'] + resp = req.data.get('response', True) resp_type = ResponseType.SESSION_NOT_EXIST if session_id in self.scheduler.sessions: self.scheduler.stop_session(session_id) resp_type = ResponseType.SUCCESS - self._response(resp_type, req.sender_id, req.req_id) + if resp: + self._response(resp_type, req.sender_id, req.req_id) def _on_end_session(self, reqs: Request, **kwargs): """on end session callback.""" for req in reqs: session_id = req.data['session_id'] + resp = req.data.get('response', True) resp_type = ResponseType.SESSION_NOT_EXIST if session_id in self.scheduler.sessions: self.scheduler.end_session(session_id) resp_type = ResponseType.SUCCESS - self._response(resp_type, req.sender_id, req.req_id) + if resp: + self._response(resp_type, req.sender_id, req.req_id) def _on_add_message(self, reqs: Request, **kwargs): """on add message callback.""" diff --git a/lmdeploy/pytorch/engine/engine_instance.py b/lmdeploy/pytorch/engine/engine_instance.py index 4aa930ee8..8cd0caf39 100644 --- a/lmdeploy/pytorch/engine/engine_instance.py +++ b/lmdeploy/pytorch/engine/engine_instance.py @@ -42,10 +42,8 @@ async def async_try_add_session(req_sender: RequestSender, session_id: int): async def async_end(req_sender: RequestSender, session_id: int): """End the given session.""" - resp = await req_sender.async_send(RequestType.END_SESSION, - dict(session_id=session_id)) - _check_resp_success(resp, (f'Failed to end session: {session_id}. ' - f'Error: {resp.type}.')) + await req_sender.async_send_async(RequestType.END_SESSION, + dict(session_id=session_id, resp=False)) async def async_cancel(req_sender: RequestSender, session_id: int): @@ -71,10 +69,8 @@ def try_add_session(req_sender: RequestSender, session_id: int): def end(req_sender: RequestSender, session_id: int): """End the given session.""" - resp = req_sender.send(RequestType.END_SESSION, - dict(session_id=session_id)) - _check_resp_success(resp, (f'Failed to end session: {session_id}. ' - f'Error: {resp.type}.')) + req_sender.send_async(RequestType.END_SESSION, + dict(session_id=session_id, resp=False)) def cancel(req_sender: RequestSender, session_id: int): @@ -156,7 +152,8 @@ async def async_stream_infer( return gen_config = gen_config or EngineGenerationConfig() sampling_param = SamplingParam.from_gen_config(gen_config=gen_config) - await async_try_add_session(self.req_sender, session_id) + await self.req_sender.async_send_async( + RequestType.ADD_SESSION, dict(session_id=session_id, resp=False)) input_embeddings_new: List[InputEmbeddings] = None if input_embeddings is not None and len(input_embeddings) > 0: assert len(input_embeddings) == len(input_embedding_ranges) @@ -272,7 +269,8 @@ def __call_async(): gen_config = gen_config or EngineGenerationConfig() sampling_param = SamplingParam.from_gen_config(gen_config=gen_config) - try_add_session(self.req_sender, session_id) + self.req_sender.send_async(RequestType.ADD_SESSION, + dict(session_id=session_id, resp=False)) input_embeddings_new: List[InputEmbeddings] = None if input_embeddings is not None and len(input_embeddings) > 0: assert len(input_embeddings) == len(input_embedding_ranges) From c3638329f116df130eb9ffc5fd7ea3256145b417 Mon Sep 17 00:00:00 2001 From: grimoire Date: Tue, 20 Aug 2024 20:25:55 +0800 Subject: [PATCH 30/89] fix response --- lmdeploy/pytorch/engine/engine_instance.py | 11 +++++----- lmdeploy/pytorch/kernels/cuda/fused_moe.py | 25 ++++++++++++---------- lmdeploy/pytorch/tools/make_inputs.py | 7 ++---- 3 files changed, 22 insertions(+), 21 deletions(-) diff --git a/lmdeploy/pytorch/engine/engine_instance.py b/lmdeploy/pytorch/engine/engine_instance.py index 8cd0caf39..9d9ebf919 100644 --- a/lmdeploy/pytorch/engine/engine_instance.py +++ b/lmdeploy/pytorch/engine/engine_instance.py @@ -42,8 +42,8 @@ async def async_try_add_session(req_sender: RequestSender, session_id: int): async def async_end(req_sender: RequestSender, session_id: int): """End the given session.""" - await req_sender.async_send_async(RequestType.END_SESSION, - dict(session_id=session_id, resp=False)) + await req_sender.async_send_async( + RequestType.END_SESSION, dict(session_id=session_id, response=False)) async def async_cancel(req_sender: RequestSender, session_id: int): @@ -70,7 +70,7 @@ def try_add_session(req_sender: RequestSender, session_id: int): def end(req_sender: RequestSender, session_id: int): """End the given session.""" req_sender.send_async(RequestType.END_SESSION, - dict(session_id=session_id, resp=False)) + dict(session_id=session_id, response=False)) def cancel(req_sender: RequestSender, session_id: int): @@ -153,7 +153,8 @@ async def async_stream_infer( gen_config = gen_config or EngineGenerationConfig() sampling_param = SamplingParam.from_gen_config(gen_config=gen_config) await self.req_sender.async_send_async( - RequestType.ADD_SESSION, dict(session_id=session_id, resp=False)) + RequestType.ADD_SESSION, dict(session_id=session_id, + response=False)) input_embeddings_new: List[InputEmbeddings] = None if input_embeddings is not None and len(input_embeddings) > 0: assert len(input_embeddings) == len(input_embedding_ranges) @@ -270,7 +271,7 @@ def __call_async(): gen_config = gen_config or EngineGenerationConfig() sampling_param = SamplingParam.from_gen_config(gen_config=gen_config) self.req_sender.send_async(RequestType.ADD_SESSION, - dict(session_id=session_id, resp=False)) + dict(session_id=session_id, response=False)) input_embeddings_new: List[InputEmbeddings] = None if input_embeddings is not None and len(input_embeddings) > 0: assert len(input_embeddings) == len(input_embedding_ranges) diff --git a/lmdeploy/pytorch/kernels/cuda/fused_moe.py b/lmdeploy/pytorch/kernels/cuda/fused_moe.py index 37a4ddca9..852799d2c 100644 --- a/lmdeploy/pytorch/kernels/cuda/fused_moe.py +++ b/lmdeploy/pytorch/kernels/cuda/fused_moe.py @@ -15,6 +15,7 @@ def get_cuda_autotune_config(): 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, + 'GROUP_SIZE_M': 1, }, num_stages=3, num_warps=8), @@ -23,6 +24,7 @@ def get_cuda_autotune_config(): 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8, }, num_stages=4, num_warps=4), @@ -31,6 +33,7 @@ def get_cuda_autotune_config(): 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8, }, num_stages=4, num_warps=4), @@ -39,6 +42,7 @@ def get_cuda_autotune_config(): 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8, }, num_stages=4, num_warps=4), @@ -47,6 +51,7 @@ def get_cuda_autotune_config(): 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8, }, num_stages=4, num_warps=4), @@ -55,6 +60,7 @@ def get_cuda_autotune_config(): 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8, }, num_stages=4, num_warps=4), @@ -63,6 +69,7 @@ def get_cuda_autotune_config(): 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8, }, num_stages=5, num_warps=2), @@ -71,6 +78,7 @@ def get_cuda_autotune_config(): 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8, }, num_stages=5, num_warps=2), @@ -119,12 +127,12 @@ def fused_moe_kernel( Weights, N: tl.constexpr, K: tl.constexpr, - stride_am: int, + stride_am: tl.constexpr, stride_ak: tl.constexpr, stride_be: tl.constexpr, stride_bn: tl.constexpr, stride_bk: tl.constexpr, - stride_cm: int, + stride_cm: tl.constexpr, stride_cn: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, @@ -138,8 +146,8 @@ def fused_moe_kernel( reindex_c: tl.constexpr, ): """fused moe kernel.""" - exp_id = tl.program_id(0) - pid = tl.program_id(1) + exp_id = tl.program_id(1) + pid = tl.program_id(0) exp_start = tl.load(ExpStart + exp_id + expert_offset) exp_end = tl.load(ExpEnd + exp_id + expert_offset) @@ -229,14 +237,10 @@ def fused_moe_kernel_launcher( E, N, K = B.shape def _grid_fn(META): - grid = ( - E, - triton.cdiv(num_tokens, META['BLOCK_SIZE_M']) * - triton.cdiv(N, META['BLOCK_SIZE_N']), - ) + grid = (triton.cdiv(num_tokens, META['BLOCK_SIZE_M']) * + triton.cdiv(N, META['BLOCK_SIZE_N']), E) return grid - GROUP_SIZE_M = 1 A = A.flatten(0, -2) C = C.flatten(0, -2) @@ -264,7 +268,6 @@ def _grid_fn(META): expert_offset=expert_offset, reindex_a=reindex_a, reindex_c=reindex_c, - GROUP_SIZE_M=GROUP_SIZE_M, M_NP2=M_NP2, **kernel_meta, ) diff --git a/lmdeploy/pytorch/tools/make_inputs.py b/lmdeploy/pytorch/tools/make_inputs.py index 82fa3b735..f2d23830b 100644 --- a/lmdeploy/pytorch/tools/make_inputs.py +++ b/lmdeploy/pytorch/tools/make_inputs.py @@ -24,16 +24,12 @@ def make_model_inputs(input_ids: torch.Tensor, else: assert len(history_length) == len(seq_length) is_decoding = max_seq_len == 1 - max_q_seq_length = seq_length.max().item() - max_history_length = history_length.max().item() num_ignored_history = torch.zeros_like(seq_length) return ModelInputs(input_ids=input_ids, seq_length=seq_length, history_lengths=history_length, block_offsets=block_offsets, - max_q_seq_length=max_q_seq_length, - max_history_length=max_history_length, is_decoding=is_decoding, num_ignored_history=num_ignored_history) @@ -53,7 +49,7 @@ def make_step_context( from torch.nn.utils.rnn import pad_sequence from lmdeploy.pytorch.engine.cache_engine import CacheEngine - from lmdeploy.pytorch.engine.model_agent import StepContext + from lmdeploy.pytorch.model_inputs import StepContext if model_config is None: model_config = ModelConfig(hidden_size=4096, @@ -82,6 +78,7 @@ def make_step_context( num_blocks_per_seq = (total_length + block_size - 1) // block_size num_blocks = sum(num_blocks_per_seq).item() cache_config = CacheConfig( + max_batches=128, block_size=block_size, num_cpu_blocks=0, num_gpu_blocks=num_blocks, From b5bb49f41f855a94d7465a77123d129f01e4e235 Mon Sep 17 00:00:00 2001 From: grimoire Date: Wed, 21 Aug 2024 14:07:12 +0800 Subject: [PATCH 31/89] optmize prefill --- lmdeploy/pytorch/engine/engine.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index c9e8b4c12..bd9155a58 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -88,20 +88,19 @@ def _get_adapter_ids(seqs: SeqList, adapters: AdapterList): return adapter_ids -def _dynamic_prefill_interval(scheduler: Scheduler): +def _check_finish(scheduler: Scheduler, current_iter: int): """dynamic prefill interval.""" + if not scheduler.has_waiting(): + return False scheduler_config = scheduler.scheduler_config max_prefill_interval = scheduler_config.prefill_interval max_batches = scheduler_config.max_batches num_batches = len(scheduler.running) - num_waiting = len(scheduler.waiting) - if num_waiting == 0: - return max_prefill_interval ratio = num_batches / max_batches - prefill_interval = ratio * (max_prefill_interval + 1) - prefill_interval = np.clip(prefill_interval, 2, max_prefill_interval) - prefill_interval = int(prefill_interval) - return prefill_interval + min_iter = max_prefill_interval * ratio + if current_iter >= min_iter: + return True + return False class Engine: @@ -723,6 +722,7 @@ def __update_inputs(next_token_ids): # send output stopped = stopped.cpu() finish = stopped.all().item() or (idx == loop_count - 1) + finish = finish or _check_finish(self.scheduler, idx) output = (next_token_ids.cpu(), logits, stopped) output_que.put_nowait((finish, output)) @@ -776,7 +776,7 @@ def __get_num_ignore_eos(seqs: SeqList): while True: is_prefill = await in_que.get() try: - prefill_interval = _dynamic_prefill_interval(self.scheduler) + prefill_interval = self.scheduler_config.prefill_interval schedule_output = self.scheduler.schedule( is_prefill=is_prefill, prealloc_size=prefill_interval) running: SeqList = schedule_output.running @@ -845,6 +845,8 @@ async def __step(prefill: bool): in_que.put_nowait(prefill) finish = False while not finish: + if self.req_manager.has_requests(): + self.req_manager.step() finish, out = await out_que.get() try: if isinstance(out, Exception): From 8cf2ab2427bea952d769b6d66df6a36a4ff43842 Mon Sep 17 00:00:00 2001 From: grimoire Date: Wed, 21 Aug 2024 14:34:46 +0800 Subject: [PATCH 32/89] adjust grid of paged attention --- lmdeploy/pytorch/kernels/cuda/pagedattention.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lmdeploy/pytorch/kernels/cuda/pagedattention.py b/lmdeploy/pytorch/kernels/cuda/pagedattention.py index 5355adab7..1ea59c9d2 100644 --- a/lmdeploy/pytorch/kernels/cuda/pagedattention.py +++ b/lmdeploy/pytorch/kernels/cuda/pagedattention.py @@ -368,9 +368,9 @@ def _fwd_kernel( BLOCK_DMODEL1: tl.constexpr, ): """paged attention kernel.""" - cur_batch = tl.program_id(0) + cur_batch = tl.program_id(2) cur_head = tl.program_id(1) - start_m = tl.program_id(2) + start_m = tl.program_id(0) cur_kv_head = cur_head // kv_group_num @@ -571,7 +571,7 @@ def _get_block_d(Lk): BLOCK_M = max(16, min(BLOCK, 16384 // BLOCK_DMODEL)) num_warps = 4 num_stages = 1 - grid = (batch, head, triton.cdiv(max_seqlen, BLOCK_M)) + grid = (triton.cdiv(max_seqlen, BLOCK_M), head, batch) _fwd_kernel[grid](q, k, v, From 3ce4e2df7be58053767189a117d6fe1759602fff Mon Sep 17 00:00:00 2001 From: grimoire Date: Wed, 21 Aug 2024 18:21:48 +0800 Subject: [PATCH 33/89] add attention stages --- .../pytorch/kernels/cuda/pagedattention.py | 20 ++++++------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/lmdeploy/pytorch/kernels/cuda/pagedattention.py b/lmdeploy/pytorch/kernels/cuda/pagedattention.py index 1ea59c9d2..ebfb65cf7 100644 --- a/lmdeploy/pytorch/kernels/cuda/pagedattention.py +++ b/lmdeploy/pytorch/kernels/cuda/pagedattention.py @@ -18,9 +18,9 @@ @triton.autotune(configs=[ - triton.Config({}, num_stages=1, num_warps=16), - triton.Config({}, num_stages=1, num_warps=8), - triton.Config({}, num_stages=1, num_warps=4), + triton.Config({}, num_stages=2, num_warps=16), + triton.Config({}, num_stages=2, num_warps=8), + triton.Config({}, num_stages=2, num_warps=4), ], key=['BLOCK_H', 'BLOCK_N', 'BLOCK_DMODEL', 'BLOCK_DV']) @wrap_jit_func(type_hint=dict( @@ -168,11 +168,11 @@ def _fwd_grouped_split_kernel( start_block_id = tl.maximum(history_len - window_size, loop_start) // BLOCK_N kv_min_loc = tl.maximum(history_len - window_size, 0) - b_offset = tl.load(block_offset_ptrs + start_block_id) loop_start = start_block_id * BLOCK_N for start_n in range(loop_start, loop_end, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) + b_offset = tl.load(block_offset_ptrs + start_n // BLOCK_N) # -- compute qk ---- k = tl.load(k_ptrs + b_offset * stride_kp) @@ -184,11 +184,6 @@ def _fwd_grouped_split_kernel( else: v = tl.load(v_ptrs + b_offset * stride_vp) - # prefetch b_offset - if start_n + BLOCK_N < loop_end: - start_block_id += 1 - b_offset = tl.load(block_offset_ptrs + start_block_id) - qk = tl.zeros([BLOCK_H, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) if BLOCK_DMODEL1 != 0: @@ -429,10 +424,10 @@ def _fwd_kernel( if window_size > 0: start_block_id = tl.maximum(history_len - window_size, 0) // BLOCK_N kv_min_loc = tl.maximum(history_len + offs_m - window_size, 0) - b_offset = tl.load(block_offset_ptrs + start_block_id) kv_start_loc = start_block_id * BLOCK_N for start_n in range(kv_start_loc, kv_seqlen, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) + b_offset = tl.load(block_offset_ptrs + start_n // BLOCK_N) # -- compute qk ---- k = tl.load(k_ptrs + b_offset * stride_kp) @@ -443,9 +438,6 @@ def _fwd_kernel( v = tl.trans(k) else: v = tl.load(v_ptrs + b_offset * stride_vp) - if start_n + BLOCK_N < kv_seqlen: - start_block_id = start_n // BLOCK_N + 1 - b_offset = tl.load(block_offset_ptrs + start_block_id) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) @@ -570,7 +562,7 @@ def _get_block_d(Lk): BLOCK_DMODEL, BLOCK_DMODEL1, BLOCK_DV = _get_block_d(Lk) BLOCK_M = max(16, min(BLOCK, 16384 // BLOCK_DMODEL)) num_warps = 4 - num_stages = 1 + num_stages = 2 grid = (triton.cdiv(max_seqlen, BLOCK_M), head, batch) _fwd_kernel[grid](q, k, From 30c20663ef890b6cd048c04a7c704e713718f153 Mon Sep 17 00:00:00 2001 From: grimoire Date: Wed, 21 Aug 2024 19:57:23 +0800 Subject: [PATCH 34/89] support llama3 --- .../backends/default/rotary_embedding.py | 45 ++++++++++++ lmdeploy/pytorch/backends/rotary_embedding.py | 3 + lmdeploy/pytorch/layers/rotary_embedding.py | 11 ++- lmdeploy/pytorch/models/llama.py | 70 ++++++++++++++++--- 4 files changed, 116 insertions(+), 13 deletions(-) diff --git a/lmdeploy/pytorch/backends/default/rotary_embedding.py b/lmdeploy/pytorch/backends/default/rotary_embedding.py index dd4278ef8..6625ff2b8 100644 --- a/lmdeploy/pytorch/backends/default/rotary_embedding.py +++ b/lmdeploy/pytorch/backends/default/rotary_embedding.py @@ -1,4 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +import math + import torch from torch import nn @@ -100,6 +102,43 @@ def forward(self, x: torch.Tensor, position_ids: torch.Tensor): return cos, sin +class Llama3RotaryEmbeddingImpl(RotaryEmbeddingImpl): + """llama3 rotary embedding implementation.""" + + def __init__( + self, + dim: int, + base: int = 10000, + scaling_factor: float = 1.0, + low_freq_factor: float = 1.0, + high_freq_factor: float = 4.0, + original_max_position_embeddings: int = 8194, + ): + super().__init__(dim, base, scaling_factor) + old_context_len = original_max_position_embeddings + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + + inv_freq = self.inv_freq + factor = self.scaling_factor + wavelen = 2 * math.pi / inv_freq + # wavelen < high_freq_wavelen: do nothing + # wavelen > low_freq_wavelen: divide by factor + inv_freq_llama = torch.where(wavelen > low_freq_wavelen, + inv_freq / factor, inv_freq) + # otherwise: interpolate between the two, using a smooth factor + smooth_factor = (old_context_len / wavelen - low_freq_factor) / ( + high_freq_factor - low_freq_factor) + smoothed_inv_freq = ( + 1 - smooth_factor + ) * inv_freq_llama / factor + smooth_factor * inv_freq_llama + is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > + low_freq_wavelen) + inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, + inv_freq_llama) + self.register_buffer('inv_freq_llama', inv_freq_llama) + + class DefaultRotaryEmbeddingBuilder(RotaryEmbeddingBuilder): """rotary embedding builder.""" @@ -109,6 +148,8 @@ def build( max_position_embeddings: int = 2048, base: int = 10000, scaling_factor: float = 1.0, + low_freq_factor: float = 1.0, + high_freq_factor: float = 4.0, emb_type: EmbeddingType = EmbeddingType.Default, ): """build.""" @@ -117,6 +158,10 @@ def build( elif emb_type == EmbeddingType.DynamicNTKScaling: return LlamaDynamicNTKScalingRotaryEmbedding( dim, base, scaling_factor, max_position_embeddings) + elif emb_type == EmbeddingType.Llama3: + return Llama3RotaryEmbeddingImpl(dim, base, scaling_factor, + low_freq_factor, high_freq_factor, + max_position_embeddings) else: raise NotImplementedError( f'Unsupported embedding type: {emb_type}') diff --git a/lmdeploy/pytorch/backends/rotary_embedding.py b/lmdeploy/pytorch/backends/rotary_embedding.py index 46b095685..1a2cc919f 100644 --- a/lmdeploy/pytorch/backends/rotary_embedding.py +++ b/lmdeploy/pytorch/backends/rotary_embedding.py @@ -8,6 +8,7 @@ class EmbeddingType(Enum): Default = auto() LinearScaling = auto() DynamicNTKScaling = auto() + Llama3 = auto() class RotaryEmbeddingImpl(ABC): @@ -29,6 +30,8 @@ def build( max_position_embeddings: int = 2048, base: int = 10000, scaling_factor: float = 1.0, + low_freq_factor: float = 1.0, + high_freq_factor: float = 4.0, emb_type: EmbeddingType = EmbeddingType.Default, ): """build.""" diff --git a/lmdeploy/pytorch/layers/rotary_embedding.py b/lmdeploy/pytorch/layers/rotary_embedding.py index d3614e6a5..66126cb99 100644 --- a/lmdeploy/pytorch/layers/rotary_embedding.py +++ b/lmdeploy/pytorch/layers/rotary_embedding.py @@ -10,13 +10,20 @@ def build_rotary_embedding( max_position_embeddings: int = 2048, base: int = 10000, scaling_factor: float = 1.0, + low_freq_factor: float = 1.0, + high_freq_factor: float = 4.0, emb_type: EmbeddingType = EmbeddingType.Default) -> nn.Module: """build rotary embedding op.""" backend = get_backend() builder = backend.get_layer_impl_builder(LayerType.RotaryEmbedding) - return builder.build(dim, max_position_embeddings, base, scaling_factor, - emb_type) + return builder.build(dim, + max_position_embeddings, + base, + scaling_factor, + low_freq_factor=low_freq_factor, + high_freq_factor=high_freq_factor, + emb_type=emb_type) class ApplyRotaryEmb(nn.Module): diff --git a/lmdeploy/pytorch/models/llama.py b/lmdeploy/pytorch/models/llama.py index cc33d5255..4531aacd0 100644 --- a/lmdeploy/pytorch/models/llama.py +++ b/lmdeploy/pytorch/models/llama.py @@ -31,7 +31,7 @@ def __init__(self, origin: nn.Module, ctx_mgr: StepContextManager): self.num_kv_heads = origin.num_key_value_heads // world_size self.head_dim = origin.head_dim - # qkv + # packed qkv self.qkv_proj = build_merged_colwise_linear( origin.q_proj, origin.k_proj, @@ -39,8 +39,10 @@ def __init__(self, origin: nn.Module, ctx_mgr: StepContextManager): ctx_mgr=ctx_mgr, is_tp=is_tp, ) + # free old weight del origin.q_proj, origin.k_proj, origin.v_proj + # rotary embedding self.apply_rotary_pos_emb = ApplyRotaryEmb() # attention @@ -51,6 +53,7 @@ def __init__(self, origin: nn.Module, ctx_mgr: StepContextManager): v_head_size=self.head_dim, ) + # o_proj self.o_proj = build_rowwise_linear( origin.o_proj, ctx_mgr=ctx_mgr, @@ -60,13 +63,16 @@ def __init__(self, origin: nn.Module, ctx_mgr: StepContextManager): @staticmethod def _load_weights(mod, loader, rank: int, world_size: int, device: torch.device): - """load weights.""" + """load weights, support TP.""" + # split weight of qkv proj. for mod_name in ['q_proj', 'k_proj', 'v_proj']: colwise_parallelize_linear(getattr(mod, mod_name), loader, rank=rank, world_size=world_size, prefix=mod_name) + + # split weight of o_proj rowwise_parallelize_linear(mod.o_proj, loader, rank=rank, @@ -82,6 +88,7 @@ def forward( ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Rewrite of LlamaAttention.forward.""" + # qkv proj qkv_states = self.qkv_proj(hidden_states) # (-1, heads, head_dim) qkv_states = qkv_states.flatten(0, -2) @@ -95,6 +102,7 @@ def forward( dim=1, ) + # apply rotary embedding cos, sin = rotary_pos_emb query_states, key_states = self.apply_rotary_pos_emb( query_states, @@ -103,6 +111,8 @@ def forward( sin, inplace=True, ) + + # attention attn_output = self.attn_fwd( query_states, key_states, @@ -114,6 +124,7 @@ def forward( ) attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1) + # o proj attn_output = self.o_proj(attn_output) return attn_output @@ -134,6 +145,7 @@ def __init__(self, origin: nn.Module, ctx_mgr: StepContextManager): ctx_mgr=ctx_mgr, is_tp=is_tp, ) + # free old weight del origin.gate_proj, origin.up_proj # silu and mul @@ -147,13 +159,17 @@ def __init__(self, origin: nn.Module, ctx_mgr: StepContextManager): @staticmethod def _load_weights(mod: nn.Module, loader, rank: int, world_size: int, device: torch.device): - """load weights.""" + """load weights, support TP.""" + + # split weight of gate_proj and up_proj for mod_name in ['gate_proj', 'up_proj']: colwise_parallelize_linear(getattr(mod, mod_name), loader, rank=rank, world_size=world_size, prefix=mod_name) + + # split weight of down_proj rowwise_parallelize_linear(mod.down_proj, loader, rank=rank, @@ -161,6 +177,7 @@ def _load_weights(mod: nn.Module, loader, rank: int, world_size: int, prefix='down_proj') def forward(self, x): + """forward.""" gate_up = self.gate_up_proj(x) act = self.act_fn(gate_up) return self.down_proj(act) @@ -172,10 +189,14 @@ def __init__(self, origin: nn.Module, layer_idx: int, ctx_mgr: StepContextManager): super().__init__() self.layer_idx = layer_idx + + # build attention layer self.self_attn = LlamaAttention(origin.self_attn, ctx_mgr) + + # builf MLP self.mlp = LlamaMLP(origin.mlp, ctx_mgr) - # norm + # build input layer norm input_layernorm = origin.input_layernorm is_w8a8 = hasattr(input_layernorm, 'from_float') self.input_layernorm = RMSNorm( @@ -183,6 +204,8 @@ def __init__(self, origin: nn.Module, layer_idx: int, input_layernorm.variance_epsilon, is_w8a8=is_w8a8, ) + + # build attention layer norm post_attention_layernorm = origin.post_attention_layernorm is_w8a8 = hasattr(post_attention_layernorm, 'from_float') self.post_attention_layernorm = RMSNorm( @@ -230,16 +253,21 @@ def __init__(self, origin: nn.Module, ctx_mgr: StepContextManager): super().__init__() self.ctx_mgr = ctx_mgr self.embed_tokens = origin.embed_tokens + + # build all decode layers self.layers = nn.ModuleList([ LlamaDecoderLayer(layer, idx, ctx_mgr) for idx, layer in enumerate(origin.layers) ]) + + # build norm norm = origin.norm is_w8a8 = hasattr(norm, 'from_float') self.norm = RMSNorm(norm.weight, norm.variance_epsilon, is_w8a8=is_w8a8) + # build rotary embedding in LlamaModel rotary_emb = origin.layers[0].self_attn.rotary_emb rotary_name = type(rotary_emb).__name__ if rotary_name in [ @@ -253,15 +281,23 @@ def __init__(self, origin: nn.Module, ctx_mgr: StepContextManager): rope_max_pos_emb = config.max_position_embeddings rope_base = config.rope_theta scaling_factor = 1.0 + low_freq_factor = 1.0 + high_freq_factor = 4.0 if config.rope_scaling is not None: - scaling_factor = config.rope_scaling.get('scaling_factor', - scaling_factor) + rope_scaling = config.rope_scaling + scaling_factor = rope_scaling.get('scaling_factor', scaling_factor) + if rope_scaling['rope_type'] == 'llama3': + emb_type = EmbeddingType.Llama3 + low_freq_factor = rope_scaling.get('low_freq_factor', 1.0) + high_freq_factor = rope_scaling.get('high_freq_factor', 1.0) self.rotary_emb = build_rotary_embedding( rope_dim, rope_max_pos_emb, rope_base, scaling_factor, - emb_type, + low_freq_factor=low_freq_factor, + high_freq_factor=high_freq_factor, + emb_type=emb_type, ) def forward( @@ -273,16 +309,21 @@ def forward( inputs_embeds: Optional[torch.FloatTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: """Rewrite of LlamaModel.forward.""" + + # token embedding if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) hidden_states = inputs_embeds - residual = None + + # rotary embedding cos, sin = self.rotary_emb(hidden_states, position_ids) cos, sin = cos[0], sin[0] rotary_pos_emb = (cos, sin) - for idx, decoder_layer in enumerate(self.layers): + # decoding + residual = None + for idx, decoder_layer in enumerate(self.layers): past_key_value = past_key_values[idx] hidden_states, residual = decoder_layer( hidden_states, @@ -292,6 +333,7 @@ def forward( attn_metadata=attn_metadata, ) + # norm hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -302,13 +344,16 @@ def get_input_embeddings(self): class LlamaForCausalLM(nn.Module): + """rewrote model of LlamaForCausalLM.""" support_cuda_graph = True def __init__(self, origin: nn.Module, ctx_mgr: StepContextManager): super().__init__() self.ctx_mgr = ctx_mgr + # build LLamaModel self.model = LlamaModel(origin.model, ctx_mgr) + # build lm_head self.lm_head = build_rowwise_linear(origin.lm_head) def forward( @@ -320,6 +365,7 @@ def forward( inputs_embeds: torch.Tensor = None, **kwargs, ): + """model forward, return logits.""" hidden_states = self.model( input_ids=input_ids, position_ids=position_ids, @@ -343,13 +389,14 @@ def prepare_inputs_for_generation( context: StepContext = None, ): """prepare input.""" + # get input_ids, position_ids and attention metadatas input_ids = context.input_ids position_ids = context.position_ids attn_metadata = context.attn_metadata - # get inputs from context + + # process vision embeddings vision_embeddings = context.input_embeddings vision_embedding_indexing = context.input_embedding_indexing - if vision_embeddings is not None and len(vision_embeddings) > 0: if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) @@ -357,6 +404,7 @@ def prepare_inputs_for_generation( vision_embedding_indexing, :] = vision_embeddings.to( inputs_embeds) + # inputs of forward return dict( input_ids=input_ids, position_ids=position_ids, From 047e58e4de00fd342c411015c5782caaff242e45 Mon Sep 17 00:00:00 2001 From: grimoire Date: Thu, 22 Aug 2024 13:23:54 +0800 Subject: [PATCH 35/89] optimize apply rotary --- .../kernels/cuda/apply_rotary_pos_emb.py | 81 ++++++++++--------- 1 file changed, 41 insertions(+), 40 deletions(-) diff --git a/lmdeploy/pytorch/kernels/cuda/apply_rotary_pos_emb.py b/lmdeploy/pytorch/kernels/cuda/apply_rotary_pos_emb.py index e8141a45f..2cb9953b4 100644 --- a/lmdeploy/pytorch/kernels/cuda/apply_rotary_pos_emb.py +++ b/lmdeploy/pytorch/kernels/cuda/apply_rotary_pos_emb.py @@ -65,61 +65,62 @@ def apply_rotary_pos_emb_qk_kernel( seq_block_id = tl.program_id(0) pos_offset = seq_block_id * BLOCK + tl.arange(0, BLOCK) - # pos_ids = tl.load(POS + pos_offset, pos_offset < seq_len, other=-1) - pos_ids = pos_offset + pos_mask = pos_offset < seq_len + pos_offset = tl.max_contiguous(tl.multiple_of(pos_offset % seq_len, BLOCK), + BLOCK) feat_size = half_size * 2 feat_offset_l = tl.arange(0, BLOCK_N) + feat_mask = feat_offset_l < half_size + feat_offset_l = feat_offset_l % half_size feat_offset_h = half_size + feat_offset_l - seq_mask = (pos_offset < seq_len)[:, None] & (feat_offset_l < - half_size)[None, :] - cs_offset_l = pos_ids[:, None] * feat_size + feat_offset_l[None, :] - cs_offset_h = pos_ids[:, None] * feat_size + feat_offset_h[None, :] - # pos_ids_mask = pos_ids[:, None] >= 0 - pos_ids_mask = pos_ids[:, None] < seq_len + seq_mask = pos_mask[:, None] and feat_mask[None, :] + cs_offset_l = pos_offset[:, None] * feat_size + feat_offset_l[None, :] + cs_offset_h = pos_offset[:, None] * feat_size + feat_offset_h[None, :] q_elem_type = Q.dtype.element_ty - cos_l = tl.load(COS + cs_offset_l, mask=pos_ids_mask).to(q_elem_type) - cos_h = tl.load(COS + cs_offset_h, mask=pos_ids_mask).to(q_elem_type) - sin_l = tl.load(SIN + cs_offset_l, mask=pos_ids_mask).to(q_elem_type) - sin_h = tl.load(SIN + cs_offset_h, mask=pos_ids_mask).to(q_elem_type) + cos_l = tl.load(COS + cs_offset_l).to(q_elem_type) + cos_h = tl.load(COS + cs_offset_h).to(q_elem_type) + sin_l = tl.load(SIN + cs_offset_l).to(q_elem_type) + sin_h = tl.load(SIN + cs_offset_h).to(q_elem_type) q_ptr = Q + pos_offset * stride_qs qe_ptr = Q_EMB + pos_offset * stride_qes - for hidx in range(BLOCK_QH): - qh_ptr = q_ptr[:, None] + hidx * stride_qh - q_l = tl.load(qh_ptr + feat_offset_l[None, :] * stride_qd, - mask=seq_mask) - q_h = tl.load(qh_ptr + feat_offset_h[None, :] * stride_qd, - mask=seq_mask) + ql_ptrs = q_ptr[:, None] + feat_offset_l[None, :] * stride_qd + qh_ptrs = q_ptr[:, None] + feat_offset_h[None, :] * stride_qd + qel_ptrs = qe_ptr[:, None] + feat_offset_l[None, :] * stride_qed + qeh_ptrs = qe_ptr[:, None] + feat_offset_h[None, :] * stride_qed + for _ in range(BLOCK_QH): + q_l = tl.load(ql_ptrs) + q_h = tl.load(qh_ptrs) qe_l = q_l * cos_l - q_h * sin_l qe_h = q_h * cos_h + q_l * sin_h - qeh_ptr = qe_ptr[:, None] + hidx * stride_qeh - tl.store(qeh_ptr + feat_offset_l[None, :] * stride_qed, - qe_l, - mask=seq_mask) - tl.store(qeh_ptr + feat_offset_h[None, :] * stride_qed, - qe_h, - mask=seq_mask) + tl.store(qel_ptrs, qe_l, mask=seq_mask) + tl.store(qeh_ptrs, qe_h, mask=seq_mask) + + ql_ptrs += stride_qh + qh_ptrs += stride_qh + qel_ptrs += stride_qeh + qeh_ptrs += stride_qeh k_ptr = K + pos_offset * stride_ks ke_ptr = K_EMB + pos_offset * stride_kes - for hidx in range(BLOCK_KH): - kh_ptr = k_ptr[:, None] + hidx * stride_kh - k_l = tl.load(kh_ptr + feat_offset_l[None, :] * stride_kd, - mask=seq_mask) - k_h = tl.load(kh_ptr + feat_offset_h[None, :] * stride_kd, - mask=seq_mask) + kl_ptrs = k_ptr[:, None] + feat_offset_l[None, :] * stride_kd + kh_ptrs = k_ptr[:, None] + feat_offset_h[None, :] * stride_kd + kel_ptrs = ke_ptr[:, None] + feat_offset_l[None, :] * stride_ked + keh_ptrs = ke_ptr[:, None] + feat_offset_h[None, :] * stride_ked + for _ in range(BLOCK_KH): + k_l = tl.load(kl_ptrs) + k_h = tl.load(kh_ptrs) ke_l = k_l * cos_l - k_h * sin_l ke_h = k_h * cos_h + k_l * sin_h - keh_ptr = ke_ptr[:, None] + hidx * stride_keh - tl.store(keh_ptr + feat_offset_l[None, :] * stride_ked, - ke_l, - mask=seq_mask) - tl.store(keh_ptr + feat_offset_h[None, :] * stride_ked, - ke_h, - mask=seq_mask) + tl.store(kel_ptrs, ke_l, mask=seq_mask) + tl.store(keh_ptrs, ke_h, mask=seq_mask) + kl_ptrs += stride_kh + kh_ptrs += stride_kh + kel_ptrs += stride_keh + keh_ptrs += stride_keh def apply_rotary_pos_emb(q: Tensor, @@ -152,13 +153,13 @@ def apply_rotary_pos_emb(q: Tensor, k_embed = torch.empty_like(k) seq_len = cos.numel() // cos.size(-1) - BLOCK = 32 + BLOCK = 16 half_size = q.size(-1) // 2 BLOCK_N = triton.next_power_of_2(half_size) num_heads_q = q.size(-2) num_heads_k = k.size(-2) num_warps = 4 - num_stages = 2 + num_stages = 4 kernel_meta = get_kernel_meta(q) grid = [triton.cdiv(seq_len, BLOCK)] From 6ef049d04972a4b328b3ff43a570037994ba5499 Mon Sep 17 00:00:00 2001 From: grimoire Date: Thu, 22 Aug 2024 14:04:49 +0800 Subject: [PATCH 36/89] rename --- lmdeploy/pytorch/models/llama.py | 9 ++++----- lmdeploy/pytorch/models/mixtral.py | 12 ++++++------ lmdeploy/pytorch/{layers => nn}/__init__.py | 0 lmdeploy/pytorch/{layers => nn}/activation.py | 0 lmdeploy/pytorch/{layers => nn}/attention.py | 0 lmdeploy/pytorch/{layers => nn}/linear.py | 0 lmdeploy/pytorch/{layers => nn}/moe.py | 0 .../pytorch/{layers => nn}/multinomial_sampling.py | 0 lmdeploy/pytorch/{layers => nn}/norm.py | 0 lmdeploy/pytorch/{layers => nn}/rotary_embedding.py | 0 10 files changed, 10 insertions(+), 11 deletions(-) rename lmdeploy/pytorch/{layers => nn}/__init__.py (100%) rename lmdeploy/pytorch/{layers => nn}/activation.py (100%) rename lmdeploy/pytorch/{layers => nn}/attention.py (100%) rename lmdeploy/pytorch/{layers => nn}/linear.py (100%) rename lmdeploy/pytorch/{layers => nn}/moe.py (100%) rename lmdeploy/pytorch/{layers => nn}/multinomial_sampling.py (100%) rename lmdeploy/pytorch/{layers => nn}/norm.py (100%) rename lmdeploy/pytorch/{layers => nn}/rotary_embedding.py (100%) diff --git a/lmdeploy/pytorch/models/llama.py b/lmdeploy/pytorch/models/llama.py index 4531aacd0..730fd6348 100644 --- a/lmdeploy/pytorch/models/llama.py +++ b/lmdeploy/pytorch/models/llama.py @@ -6,12 +6,11 @@ from torch import nn from transformers.modeling_outputs import BaseModelOutputWithPast -from lmdeploy.pytorch.layers import (ApplyRotaryEmb, Attention, EmbeddingType, - RMSNorm, SiluAndMul, - build_rotary_embedding) -from lmdeploy.pytorch.layers.linear import (build_merged_colwise_linear, - build_rowwise_linear) from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, EmbeddingType, + RMSNorm, SiluAndMul, build_rotary_embedding) +from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear, + build_rowwise_linear) from ..weight_loader.dist_utils import (colwise_parallelize_linear, rowwise_parallelize_linear) diff --git a/lmdeploy/pytorch/models/mixtral.py b/lmdeploy/pytorch/models/mixtral.py index 84807d543..32e55f789 100644 --- a/lmdeploy/pytorch/models/mixtral.py +++ b/lmdeploy/pytorch/models/mixtral.py @@ -6,13 +6,13 @@ from torch import nn from transformers.modeling_outputs import BaseModelOutputWithPast -from lmdeploy.pytorch.layers import (ApplyRotaryEmb, Attention, EmbeddingType, - RMSNorm, build_rotary_embedding) -from lmdeploy.pytorch.layers.linear import (build_colwise_linear, - build_merged_colwise_linear, - build_rowwise_linear) -from lmdeploy.pytorch.layers.moe import SoftmaxTopK, build_moe_from_mlp from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, EmbeddingType, + RMSNorm, build_rotary_embedding) +from lmdeploy.pytorch.nn.linear import (build_colwise_linear, + build_merged_colwise_linear, + build_rowwise_linear) +from lmdeploy.pytorch.nn.moe import SoftmaxTopK, build_moe_from_mlp from ..weight_loader.dist_utils import (colwise_parallelize_linear, rowwise_parallelize_linear) diff --git a/lmdeploy/pytorch/layers/__init__.py b/lmdeploy/pytorch/nn/__init__.py similarity index 100% rename from lmdeploy/pytorch/layers/__init__.py rename to lmdeploy/pytorch/nn/__init__.py diff --git a/lmdeploy/pytorch/layers/activation.py b/lmdeploy/pytorch/nn/activation.py similarity index 100% rename from lmdeploy/pytorch/layers/activation.py rename to lmdeploy/pytorch/nn/activation.py diff --git a/lmdeploy/pytorch/layers/attention.py b/lmdeploy/pytorch/nn/attention.py similarity index 100% rename from lmdeploy/pytorch/layers/attention.py rename to lmdeploy/pytorch/nn/attention.py diff --git a/lmdeploy/pytorch/layers/linear.py b/lmdeploy/pytorch/nn/linear.py similarity index 100% rename from lmdeploy/pytorch/layers/linear.py rename to lmdeploy/pytorch/nn/linear.py diff --git a/lmdeploy/pytorch/layers/moe.py b/lmdeploy/pytorch/nn/moe.py similarity index 100% rename from lmdeploy/pytorch/layers/moe.py rename to lmdeploy/pytorch/nn/moe.py diff --git a/lmdeploy/pytorch/layers/multinomial_sampling.py b/lmdeploy/pytorch/nn/multinomial_sampling.py similarity index 100% rename from lmdeploy/pytorch/layers/multinomial_sampling.py rename to lmdeploy/pytorch/nn/multinomial_sampling.py diff --git a/lmdeploy/pytorch/layers/norm.py b/lmdeploy/pytorch/nn/norm.py similarity index 100% rename from lmdeploy/pytorch/layers/norm.py rename to lmdeploy/pytorch/nn/norm.py diff --git a/lmdeploy/pytorch/layers/rotary_embedding.py b/lmdeploy/pytorch/nn/rotary_embedding.py similarity index 100% rename from lmdeploy/pytorch/layers/rotary_embedding.py rename to lmdeploy/pytorch/nn/rotary_embedding.py From 7b75a657ec9e05ecafe6679567f9677716515973 Mon Sep 17 00:00:00 2001 From: grimoire Date: Thu, 22 Aug 2024 14:25:25 +0800 Subject: [PATCH 37/89] fix sampling --- lmdeploy/pytorch/engine/logits_process.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lmdeploy/pytorch/engine/logits_process.py b/lmdeploy/pytorch/engine/logits_process.py index fb948bd8a..658249bae 100644 --- a/lmdeploy/pytorch/engine/logits_process.py +++ b/lmdeploy/pytorch/engine/logits_process.py @@ -89,8 +89,7 @@ def _multinomial_sampling(scores: torch.Tensor, offsets: torch.LongTensor, indices: torch.LongTensor = None): """sampling.""" - from lmdeploy.pytorch.layers.multinomial_sampling import \ - multinomial_sampling + from lmdeploy.pytorch.nn.multinomial_sampling import multinomial_sampling return multinomial_sampling(scores, seeds, offsets, indices) From b8f7f54b602b4eb7073f66f031830e467c61a880 Mon Sep 17 00:00:00 2001 From: grimoire Date: Thu, 22 Aug 2024 15:22:55 +0800 Subject: [PATCH 38/89] remove print --- lmdeploy/pytorch/models/q_modules.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/lmdeploy/pytorch/models/q_modules.py b/lmdeploy/pytorch/models/q_modules.py index ea4bd1eba..001fab7a6 100644 --- a/lmdeploy/pytorch/models/q_modules.py +++ b/lmdeploy/pytorch/models/q_modules.py @@ -138,9 +138,6 @@ def forward(self, input): if isinstance(input, torch.Tensor): input_quant, input_scale = per_token_quant_int8(input, 1e-7) - print(input_quant) - print(input_scale) - exit() else: assert isinstance(input, QTensor) input_quant, input_scale = input.tensor, input.scale From 40fc417ff957d1ceef6126bf129ad971a65491ef Mon Sep 17 00:00:00 2001 From: grimoire Date: Sat, 24 Aug 2024 14:47:56 +0800 Subject: [PATCH 39/89] prepare for new weight loader --- lmdeploy/pytorch/backends/ascend/__init__.py | 2 +- lmdeploy/pytorch/backends/awq_modules.py | 28 +++++-- lmdeploy/pytorch/backends/cuda/awq_modules.py | 47 ++++++----- lmdeploy/pytorch/backends/cuda/moe.py | 72 ++-------------- lmdeploy/pytorch/backends/cuda/norm.py | 19 +++-- lmdeploy/pytorch/backends/cuda/qmodules.py | 48 ++++++----- .../pytorch/backends/default/activation.py | 3 +- .../pytorch/backends/default/awq_modules.py | 48 ++++++----- lmdeploy/pytorch/backends/default/linear.py | 26 +++--- lmdeploy/pytorch/backends/default/moe.py | 4 +- lmdeploy/pytorch/backends/default/norm.py | 19 +++-- lmdeploy/pytorch/backends/linear.py | 24 ++++-- lmdeploy/pytorch/backends/moe.py | 18 ++-- lmdeploy/pytorch/backends/norm.py | 7 +- lmdeploy/pytorch/backends/qmodules.py | 49 +++++++++-- lmdeploy/pytorch/models/llama.py | 17 +++- lmdeploy/pytorch/models/mixtral.py | 17 +++- lmdeploy/pytorch/nn/linear.py | 61 ++++++++++++-- lmdeploy/pytorch/nn/moe.py | 82 +++++++++++++++++-- lmdeploy/pytorch/nn/norm.py | 25 +++++- .../weight_loader/model_weight_loader.py | 22 +++++ 21 files changed, 422 insertions(+), 216 deletions(-) diff --git a/lmdeploy/pytorch/backends/ascend/__init__.py b/lmdeploy/pytorch/backends/ascend/__init__.py index 72e28d247..73444e906 100644 --- a/lmdeploy/pytorch/backends/ascend/__init__.py +++ b/lmdeploy/pytorch/backends/ascend/__init__.py @@ -1,2 +1,2 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .layer_backend import AscendLayerBackend # noqa: F401 +from .layer_backend import AscendLayersBackend # noqa: F401 diff --git a/lmdeploy/pytorch/backends/awq_modules.py b/lmdeploy/pytorch/backends/awq_modules.py index c62a72d20..51cb6cc00 100644 --- a/lmdeploy/pytorch/backends/awq_modules.py +++ b/lmdeploy/pytorch/backends/awq_modules.py @@ -1,16 +1,27 @@ # Copyright (c) OpenMMLab. All rights reserved. from abc import ABC, abstractmethod +from typing import Optional -from torch import nn +import torch -from lmdeploy.pytorch.model_inputs import StepContextManager - -class LinearW4A16Impl(ABC, nn.Module): +class LinearW4A16Impl(ABC): """w4a16 linear implementation.""" + def update_weights(self, + qweight: torch.Tensor, + scales: torch.Tensor, + qzeros: torch.Tensor, + bias: Optional[torch.Tensor] = None): + """update weights.""" + return qweight, scales, qzeros, bias + @abstractmethod - def forward(self, x, all_reduce: bool = False): + def forward(self, + x, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + all_reduce: bool = False): """forward.""" raise NotImplementedError @@ -20,6 +31,11 @@ class LinearW4A16Builder(ABC): @staticmethod @abstractmethod - def build(mod: nn.Module, ctx_mgr: StepContextManager = None): + def build(in_features: int, + out_features: int, + w_bit: int, + group_size: int, + bias: bool = False, + dtype: torch.dtype = None): """build.""" raise NotImplementedError diff --git a/lmdeploy/pytorch/backends/cuda/awq_modules.py b/lmdeploy/pytorch/backends/cuda/awq_modules.py index 3198590a2..f3cbf8bee 100644 --- a/lmdeploy/pytorch/backends/cuda/awq_modules.py +++ b/lmdeploy/pytorch/backends/cuda/awq_modules.py @@ -1,9 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + import torch from torch import distributed as dist -from torch import nn - -from lmdeploy.pytorch.model_inputs import StepContextManager from ..awq_modules import LinearW4A16Builder, LinearW4A16Impl @@ -52,25 +51,26 @@ def wq_gemm_forward( class AwqLinearW4A16Impl(LinearW4A16Impl): """awq kernel linear.""" - def __init__(self, mod: nn.Module): - super().__init__() + def __init__(self, in_features: int, out_features: int, w_bit: int, + group_size: int): from awq.modules.linear.gemm import AWQ_INSTALLED assert AWQ_INSTALLED - self.qweight = mod.qweight - self.qzeros = mod.qzeros - self.scales = mod.scales - self.w_bit = mod.w_bit - self.group_size = mod.group_size - self.bias = mod.bias - self.in_features = mod.in_features - self.out_features = mod.out_features - - def forward(self, x, all_reduce: bool = False): + self.in_features = in_features + self.out_features = out_features + self.w_bit = w_bit + self.group_size = group_size + + def forward(self, + x, + qweight: torch.Tensor, + scales: torch.Tensor, + qzeros: torch.Tensor, + bias: Optional[torch.Tensor] = None, + all_reduce: bool = False): """forward.""" - out_features = self.scales.size(1) - out = wq_gemm_forward(x, self.qweight, self.qzeros, self.scales, - self.w_bit, self.group_size, self.bias, - out_features) + out_features = scales.size(1) + out = wq_gemm_forward(x, qweight, qzeros, scales, self.w_bit, + self.group_size, bias, out_features) if all_reduce: dist.all_reduce(out) return out @@ -80,6 +80,11 @@ class AwqLinearW4A16Builder(LinearW4A16Builder): """awq linear builder.""" @staticmethod - def build(mod: nn.Module, ctx_mgr: StepContextManager = None): + def build(in_features: int, + out_features: int, + w_bit: int, + group_size: int, + bias: bool = False, + dtype: torch.dtype = None): """build.""" - return AwqLinearW4A16Impl(mod) + return AwqLinearW4A16Impl(in_features, out_features, w_bit, group_size) diff --git a/lmdeploy/pytorch/backends/cuda/moe.py b/lmdeploy/pytorch/backends/cuda/moe.py index e0b7771da..da06e6725 100644 --- a/lmdeploy/pytorch/backends/cuda/moe.py +++ b/lmdeploy/pytorch/backends/cuda/moe.py @@ -1,5 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import List import torch @@ -11,84 +10,27 @@ class TritonFusedMoEImpl(FusedMoEImpl): """triton fused moe implementation.""" - def __init__(self, - gate_up_weights: torch.Tensor, - down_weights: torch.Tensor, - top_k: int, - renormalize: bool = False): - super().__init__() + def __init__(self, top_k: int, renormalize: bool = False): self.top_k = top_k self.renormalize = renormalize - self.register_buffer('gate_up_weights', gate_up_weights) - self.register_buffer('down_weights', down_weights) def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.LongTensor): + topk_ids: torch.LongTensor, gate_up_weights: torch.Tensor, + down_weights: torch.Tensor): """forward.""" return fused_moe(hidden_states, - self.gate_up_weights, - self.down_weights, + gate_up_weights, + down_weights, topk_weights=topk_weights, topk_ids=topk_ids, topk=self.top_k, renormalize=self.renormalize) -def _merge_mlp(gates: List[torch.Tensor], ups: List[torch.Tensor], - downs: List[torch.Tensor]): - """merge experts.""" - num_experts = len(gates) - - def __get_meta(): - gate = gates[0] - down = downs[0] - ffn_dim = gate.weight.size(0) - hidden_dim = down.weight.size(0) - dtype = gate.weight.dtype - device = gate.weight.device - return ffn_dim, hidden_dim, dtype, device - - def __copy_assign_param(param, weight): - """copy assign.""" - weight.copy_(param.data) - param.data = weight - - ffn_dim, hidden_dim, dtype, device = __get_meta() - - gate_up_weights = torch.empty(num_experts, - ffn_dim * 2, - hidden_dim, - device=device, - dtype=dtype) - down_weights = torch.empty(num_experts, - hidden_dim, - ffn_dim, - device=device, - dtype=dtype) - for exp_id in range(num_experts): - gate = gates[exp_id] - up = ups[exp_id] - down = downs[exp_id] - __copy_assign_param(gate.weight, gate_up_weights[exp_id, :ffn_dim]) - __copy_assign_param(up.weight, gate_up_weights[exp_id, ffn_dim:]) - __copy_assign_param(down.weight, down_weights[exp_id]) - - torch.cuda.empty_cache() - return gate_up_weights, down_weights - - class TritonFusedMoEBuilder(FusedMoEBuilder): """triton fused moe builder.""" @staticmethod - def build_from_mlp(gates: List[torch.Tensor], - ups: List[torch.Tensor], - downs: List[torch.Tensor], - top_k: int, - renormalize: bool = False): + def build(top_k: int, renormalize: bool = False): """build from mlp.""" - gate_up_weights, down_weights = _merge_mlp(gates, ups, downs) - return TritonFusedMoEImpl(gate_up_weights, - down_weights, - top_k=top_k, - renormalize=renormalize) + return TritonFusedMoEImpl(top_k=top_k, renormalize=renormalize) diff --git a/lmdeploy/pytorch/backends/cuda/norm.py b/lmdeploy/pytorch/backends/cuda/norm.py index b35a2bef5..9d5100978 100644 --- a/lmdeploy/pytorch/backends/cuda/norm.py +++ b/lmdeploy/pytorch/backends/cuda/norm.py @@ -1,27 +1,28 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch -from torch import nn from lmdeploy.pytorch.kernels.cuda import rms_norm from ..norm import RMSNormBuilder, RMSNormImpl -class TritonRMSNormImpl(RMSNormImpl, nn.Module): +class TritonRMSNormImpl(RMSNormImpl): """triton RMS norm implementation.""" - def __init__(self, weight: torch.Tensor, eps: float = 1e-6): - super().__init__() - self.weight = nn.Parameter(weight.clone()) + def __init__(self, hidden_size: int, eps: float = 1e-6): + self.hidden_size = hidden_size self.eps = eps - def forward(self, x: torch.Tensor, residual: torch.Tensor = None): + def forward(self, + x: torch.Tensor, + weight: torch.Tensor, + residual: torch.Tensor = None): """forward.""" if residual is None: - x = rms_norm(x, self.weight, self.eps) + x = rms_norm(x, weight, self.eps) return x else: - x, residual = rms_norm(x, self.weight, self.eps, residual=residual) + x, residual = rms_norm(x, weight, self.eps, residual=residual) return x, residual @@ -29,6 +30,6 @@ class TritonRMSNormBuilder(RMSNormBuilder): """triton RMS norm implementation builder.""" @staticmethod - def build(weight: torch.Tensor, eps: float = 1e-6): + def build(weight: torch.Tensor, eps: float = 1e-6, inplace: bool = False): """build.""" return TritonRMSNormImpl(weight, eps) diff --git a/lmdeploy/pytorch/backends/cuda/qmodules.py b/lmdeploy/pytorch/backends/cuda/qmodules.py index 2dfb2bc2b..1f1aff55a 100644 --- a/lmdeploy/pytorch/backends/cuda/qmodules.py +++ b/lmdeploy/pytorch/backends/cuda/qmodules.py @@ -1,32 +1,36 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + import torch import torch.distributed as dist from torch import nn from lmdeploy.pytorch.kernels.cuda.w8a8_triton_kernels import ( matmul_kernel_dynamic_quant, per_token_quant_int8, rms_norm_dynamic_quant) -from lmdeploy.pytorch.model_inputs import StepContextManager from lmdeploy.pytorch.models.q_modules import QTensor from ..qmodules import (LinearW8A8Builder, LinearW8A8Impl, RMSNormW8A8Builder, RMSNormW8A8Impl) -class TritonRMSNormW8A8Impl(RMSNormW8A8Impl, nn.Module): +class TritonRMSNormW8A8Impl(RMSNormW8A8Impl): """triton RMS norm w8a8 implementation api.""" - def __init__(self, weight: torch.Tensor, eps: float = 1e-6): + def __init__(self, hidden_size: int, eps: float = 1e-6): super().__init__() - self.weight = nn.Parameter(weight.clone()) + self.hidden_size = hidden_size self.eps = eps - def forward(self, x: torch.Tensor, residual: torch.Tensor = None): + def forward(self, + x: torch.Tensor, + weight: torch.Tensor, + residual: torch.Tensor = None): """forward.""" if residual is not None: x = x + residual residual = x hidden_states_quant, rms_scale = rms_norm_dynamic_quant( - x, self.weight, self.eps) + x, weight, self.eps) x = QTensor(hidden_states_quant, rms_scale) if residual is None: return x @@ -37,21 +41,24 @@ class TritonRMSNormBuilder(RMSNormW8A8Builder): """triton RMS norm w8a8 implementation builder.""" @staticmethod - def build(weight: torch.Tensor, eps: float = 1e-6): + def build(hidden_size: int, eps: float = 1e-6): """build.""" - return TritonRMSNormW8A8Impl(weight, eps) + return TritonRMSNormW8A8Impl(hidden_size, eps) class TritonLinearW8A8Impl(LinearW8A8Impl, nn.Module): """triton linear w8a8 implementation.""" - def __init__(self, mod: nn.Module): - super().__init__() - self.weight = mod.weight - self.scale = mod.scale - self.bias = mod.bias + def __init__(self, in_features: int, out_features: int): + self.in_features = in_features + self.out_features = out_features - def forward(self, x, all_reduce: bool = False): + def forward(self, + x, + weight: torch.Tensor, + scale: torch.Tensor, + bias: Optional[torch.Tensor] = None, + all_reduce: bool = False): """forward.""" if isinstance(x, torch.Tensor): x = x.contiguous() @@ -61,11 +68,11 @@ def forward(self, x, all_reduce: bool = False): input_quant, input_scale = x.tensor, x.scale out = matmul_kernel_dynamic_quant(input_quant, - self.weight, + weight, input_scale, - self.scale, + scale, output_dtype=torch.float16, - bias=self.bias) + bias=bias) if all_reduce: dist.all_reduce(out) @@ -76,6 +83,9 @@ class TritonLinearW8A8Builder(LinearW8A8Builder): """triton linear w8a8 implementation builder.""" @staticmethod - def build(mod: nn.Module, ctx_mgr: StepContextManager = None): + def build(in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None): """build.""" - return TritonLinearW8A8Impl(mod) + return TritonLinearW8A8Impl(in_features, out_features) diff --git a/lmdeploy/pytorch/backends/default/activation.py b/lmdeploy/pytorch/backends/default/activation.py index 72c71f280..e5f03000e 100644 --- a/lmdeploy/pytorch/backends/default/activation.py +++ b/lmdeploy/pytorch/backends/default/activation.py @@ -4,11 +4,10 @@ from ..activation import SiluAndMulBuilder, SiluAndMulImpl -class DefaultSiluAndMulImpl(SiluAndMulImpl, nn.Module): +class DefaultSiluAndMulImpl(SiluAndMulImpl): """silu + multiple residual fused implementation.""" def __init__(self, inplace: bool): - super().__init__() self.inplace = inplace self.silu = nn.SiLU(inplace) diff --git a/lmdeploy/pytorch/backends/default/awq_modules.py b/lmdeploy/pytorch/backends/default/awq_modules.py index 5709fe816..c0a961aad 100644 --- a/lmdeploy/pytorch/backends/default/awq_modules.py +++ b/lmdeploy/pytorch/backends/default/awq_modules.py @@ -1,9 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + import torch from torch import distributed as dist -from torch import nn - -from lmdeploy.pytorch.model_inputs import StepContextManager from ..awq_modules import LinearW4A16Builder, LinearW4A16Impl @@ -11,30 +10,31 @@ class DefaultLinearW4A16Impl(LinearW4A16Impl): """w4a16 linear implementation.""" - def __init__(self, mod: nn.Module): - super().__init__() - self.qweight = mod.qweight - self.qzeros = mod.qzeros - self.scales = mod.scales - self.w_bit = mod.w_bit - self.group_size = mod.group_size - self.bias = mod.bias - self.in_features = mod.in_features - self.out_features = mod.out_features - - def forward(self, x, all_reduce: bool = False): + def __init__(self, in_features: int, out_features: int, w_bit: int, + group_size: int): + self.in_features = in_features + self.out_features = out_features + self.w_bit = w_bit + self.group_size = group_size + + def forward(self, + x, + qweight: torch.Tensor, + scales: torch.Tensor, + qzeros: torch.Tensor, + bias: Optional[torch.Tensor] = None, + all_reduce: bool = False): """forward.""" from awq.utils.packing_utils import dequantize_gemm out_shape = x.shape[:-1] + (self.out_features, ) input_dtype = x.dtype if input_dtype != torch.float16: x = x.half() - - out = dequantize_gemm(self.qweight, self.qzeros, self.scales, - self.w_bit, self.group_size) + out = dequantize_gemm(qweight, qzeros, scales, self.w_bit, + self.group_size) out = torch.matmul(x, out) - out = out + self.bias if self.bias is not None else out + out = out + bias if bias is not None else out out = out.reshape(out_shape) if input_dtype != torch.float16: @@ -48,6 +48,12 @@ class DefaultLinearW4A16Builder(LinearW4A16Builder): """w4a16 linear implementation builder.""" @staticmethod - def build(mod: nn.Module, ctx_mgr: StepContextManager = None): + def build(in_features: int, + out_features: int, + w_bit: int, + group_size: int, + bias: bool = False, + dtype: torch.dtype = None): """build.""" - return DefaultLinearW4A16Impl(mod) + return DefaultLinearW4A16Impl(in_features, out_features, w_bit, + group_size) diff --git a/lmdeploy/pytorch/backends/default/linear.py b/lmdeploy/pytorch/backends/default/linear.py index 9d69c96d3..fe23107ab 100644 --- a/lmdeploy/pytorch/backends/default/linear.py +++ b/lmdeploy/pytorch/backends/default/linear.py @@ -1,8 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. -from torch import distributed as dist -from torch import nn +from typing import Optional -from lmdeploy.pytorch.model_inputs import StepContextManager +import torch +import torch.nn.functional as F +from torch import distributed as dist from ..linear import LinearBuilder, LinearImpl @@ -10,13 +11,13 @@ class DefaultLinearImpl(LinearImpl): """Linear implementation api.""" - def __init__(self, mod: nn.Module): - super().__init__() - self.mod = mod - - def forward(self, x, all_reduce: bool = False): + def forward(self, + x, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + all_reduce: bool = False): """forward.""" - out = self.mod(x) + out = F.linear(x, weight, bias) if all_reduce: dist.all_reduce(out) return out @@ -26,6 +27,9 @@ class DefaultLinearBuilder(LinearBuilder): """linear implementation builder.""" @staticmethod - def build(mod: nn.Module, ctx_mgr: StepContextManager = None): + def build(in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None): """build.""" - return DefaultLinearImpl(mod) + return DefaultLinearImpl() diff --git a/lmdeploy/pytorch/backends/default/moe.py b/lmdeploy/pytorch/backends/default/moe.py index d67626006..a58a47ce6 100644 --- a/lmdeploy/pytorch/backends/default/moe.py +++ b/lmdeploy/pytorch/backends/default/moe.py @@ -1,15 +1,13 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch -from torch import nn from ..moe import SoftmaxTopKBuilder, SoftmaxTopKImpl -class DefaultSoftmaxTopKImpl(SoftmaxTopKImpl, nn.Module): +class DefaultSoftmaxTopKImpl(SoftmaxTopKImpl): """RMS norm implementation api.""" def __init__(self, top_k: int, dim: int = -1): - super().__init__() self.top_k = top_k self.dim = dim diff --git a/lmdeploy/pytorch/backends/default/norm.py b/lmdeploy/pytorch/backends/default/norm.py index 5b57b526f..0870f2943 100644 --- a/lmdeploy/pytorch/backends/default/norm.py +++ b/lmdeploy/pytorch/backends/default/norm.py @@ -1,19 +1,20 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch -from torch import nn from ..norm import RMSNormBuilder, RMSNormImpl -class DefaultRMSNormImpl(RMSNormImpl, nn.Module): +class DefaultRMSNormImpl(RMSNormImpl): """RMS norm implementation api.""" - def __init__(self, weight: torch.Tensor, eps: float = 1e-6): - super().__init__() - self.weight = nn.Parameter(weight.clone()) + def __init__(self, hidden_size: int, eps: float = 1e-6): + self.hidden_size = hidden_size self.eps = eps - def forward(self, x: torch.Tensor, residual: torch.Tensor = None): + def forward(self, + x: torch.Tensor, + weight: torch.Tensor, + residual: torch.Tensor = None): """forward.""" input_dtype = x.dtype if residual is not None: @@ -22,7 +23,7 @@ def forward(self, x: torch.Tensor, residual: torch.Tensor = None): x = x.to(torch.float32) variance = x.pow(2).mean(-1, keepdim=True) x = x * torch.rsqrt(variance + self.eps) - x = self.weight * x.to(input_dtype) + x = weight * x.to(input_dtype) if residual is None: return x return x, residual @@ -32,6 +33,6 @@ class DefaultRMSNormBuilder(RMSNormBuilder): """RMS norm implementation builder.""" @staticmethod - def build(weight: torch.Tensor, eps: float = 1e-6): + def build(hidden_size: int, eps: float = 1e-6, inplace: bool = False): """build.""" - return DefaultRMSNormImpl(weight, eps) + return DefaultRMSNormImpl(hidden_size, eps) diff --git a/lmdeploy/pytorch/backends/linear.py b/lmdeploy/pytorch/backends/linear.py index 6d5336815..2b189a1db 100644 --- a/lmdeploy/pytorch/backends/linear.py +++ b/lmdeploy/pytorch/backends/linear.py @@ -1,16 +1,25 @@ # Copyright (c) OpenMMLab. All rights reserved. from abc import ABC, abstractmethod +from typing import Optional -from torch import nn +import torch -from lmdeploy.pytorch.model_inputs import StepContextManager - -class LinearImpl(ABC, nn.Module): +class LinearImpl(ABC): """Linear implementation api.""" + def update_weights(self, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None): + """update weights.""" + return weight, bias + @abstractmethod - def forward(self, x, all_reduce: bool = False): + def forward(self, + x, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + all_reduce: bool = False): """forward.""" raise NotImplementedError @@ -20,6 +29,9 @@ class LinearBuilder(ABC): @staticmethod @abstractmethod - def build(mod: nn.Module, ctx_mgr: StepContextManager = None): + def build(in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None): """build.""" raise NotImplementedError diff --git a/lmdeploy/pytorch/backends/moe.py b/lmdeploy/pytorch/backends/moe.py index 269a3810d..4a1d5b73d 100644 --- a/lmdeploy/pytorch/backends/moe.py +++ b/lmdeploy/pytorch/backends/moe.py @@ -1,9 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from abc import ABC, abstractmethod -from typing import List import torch -from torch import nn class SoftmaxTopKImpl(ABC): @@ -25,12 +23,18 @@ def build(top_k: int, dim: int = -1): raise NotImplementedError -class FusedMoEImpl(ABC, nn.Module): +class FusedMoEImpl(ABC): """fused moe implementation.""" + def update_weights(self, gate_up_weights: torch.Tensor, + down_weights: torch.Tensor): + """update weights.""" + return gate_up_weights, down_weights + @abstractmethod def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.LongTensor): + topk_ids: torch.LongTensor, gate_up_weights: torch.Tensor, + down_weights: torch.Tensor): """forward.""" raise NotImplementedError @@ -40,10 +44,6 @@ class FusedMoEBuilder(ABC): @staticmethod @abstractmethod - def build_from_mlp(gates: List[torch.Tensor], - ups: List[torch.Tensor], - downs: List[torch.Tensor], - top_k: int, - renormalize: bool = False): + def build(top_k: int, renormalize: bool = False): """build from mlp.""" raise NotImplementedError diff --git a/lmdeploy/pytorch/backends/norm.py b/lmdeploy/pytorch/backends/norm.py index 7635fdc8c..de555c402 100644 --- a/lmdeploy/pytorch/backends/norm.py +++ b/lmdeploy/pytorch/backends/norm.py @@ -8,7 +8,10 @@ class RMSNormImpl(ABC): """RMS norm implementation api.""" @abstractmethod - def forward(self, x: torch.Tensor, residual: torch.Tensor = None): + def forward(self, + x: torch.Tensor, + weight: torch.Tensor, + residual: torch.Tensor = None): """forward.""" raise NotImplementedError @@ -18,6 +21,6 @@ class RMSNormBuilder(ABC): @staticmethod @abstractmethod - def build(weight: torch.Tensor, eps: float = 1e-6, inplace: bool = False): + def build(hidden_size: int, eps: float = 1e-6, inplace: bool = False): """build.""" raise NotImplementedError diff --git a/lmdeploy/pytorch/backends/qmodules.py b/lmdeploy/pytorch/backends/qmodules.py index 8183a9142..a61941b37 100644 --- a/lmdeploy/pytorch/backends/qmodules.py +++ b/lmdeploy/pytorch/backends/qmodules.py @@ -1,17 +1,33 @@ # Copyright (c) OpenMMLab. All rights reserved. from abc import ABC, abstractmethod +from typing import Optional import torch -from torch import nn -from lmdeploy.pytorch.model_inputs import StepContextManager - -class RMSNormW8A8Impl(ABC, nn.Module): +class RMSNormW8A8Impl(ABC): """RMS norm w8a8 implementation api.""" + @staticmethod + def create_weight(hidden_size: int, + dtype: torch.dtype = None, + device: torch.device = None): + """create weight.""" + if dtype is None: + dtype = torch.float16 + if device is None: + device = 'cuda' + weight = torch.nn.Parameter(torch.ones(hidden_size, + dtype=dtype, + device=device), + requires_grad=False) + return weight + @abstractmethod - def forward(self, x: torch.Tensor, residual: torch.Tensor = None): + def forward(self, + x: torch.Tensor, + weight: torch.Tensor, + residual: torch.Tensor = None): """forward.""" raise NotImplementedError @@ -21,16 +37,28 @@ class RMSNormW8A8Builder(ABC): @staticmethod @abstractmethod - def build(weight: torch.Tensor, eps: float = 1e-6): + def build(hidden_size: int, eps: float = 1e-6): """build.""" raise NotImplementedError -class LinearW8A8Impl(ABC, nn.Module): +class LinearW8A8Impl(ABC): """linear w8a8 implementation api.""" + def update_weights(self, + weight: torch.Tensor, + scale: torch.Tensor, + bias: Optional[torch.Tensor] = None): + """update weights.""" + return weight, scale, bias + @abstractmethod - def forward(self, x, all_reduce: bool = False): + def forward(self, + x, + weight: torch.Tensor, + scale: torch.Tensor, + bias: Optional[torch.Tensor] = None, + all_reduce: bool = False): """forward.""" raise NotImplementedError @@ -40,6 +68,9 @@ class LinearW8A8Builder(ABC): @staticmethod @abstractmethod - def build(mod: nn.Module, ctx_mgr: StepContextManager = None): + def build(in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None): """build.""" raise NotImplementedError diff --git a/lmdeploy/pytorch/models/llama.py b/lmdeploy/pytorch/models/llama.py index 730fd6348..2c8383ec6 100644 --- a/lmdeploy/pytorch/models/llama.py +++ b/lmdeploy/pytorch/models/llama.py @@ -11,6 +11,7 @@ RMSNorm, SiluAndMul, build_rotary_embedding) from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear, build_rowwise_linear) +from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from ..weight_loader.dist_utils import (colwise_parallelize_linear, rowwise_parallelize_linear) @@ -199,19 +200,26 @@ def __init__(self, origin: nn.Module, layer_idx: int, input_layernorm = origin.input_layernorm is_w8a8 = hasattr(input_layernorm, 'from_float') self.input_layernorm = RMSNorm( - input_layernorm.weight, + input_layernorm.weight.size(0), input_layernorm.variance_epsilon, + dtype=input_layernorm.weight.dtype, + device=input_layernorm.weight.device, is_w8a8=is_w8a8, ) + load_weight(self.input_layernorm.weight, input_layernorm.weight) # build attention layer norm post_attention_layernorm = origin.post_attention_layernorm is_w8a8 = hasattr(post_attention_layernorm, 'from_float') self.post_attention_layernorm = RMSNorm( - post_attention_layernorm.weight, + post_attention_layernorm.weight.size(0), post_attention_layernorm.variance_epsilon, + dtype=post_attention_layernorm.weight.dtype, + device=post_attention_layernorm.weight.device, is_w8a8=is_w8a8, ) + load_weight(self.post_attention_layernorm.weight, + post_attention_layernorm.weight) def forward( self, @@ -262,9 +270,12 @@ def __init__(self, origin: nn.Module, ctx_mgr: StepContextManager): # build norm norm = origin.norm is_w8a8 = hasattr(norm, 'from_float') - self.norm = RMSNorm(norm.weight, + self.norm = RMSNorm(norm.weight.size(0), norm.variance_epsilon, + dtype=norm.weight.dtype, + device=norm.weight.device, is_w8a8=is_w8a8) + load_weight(self.norm.weight, norm.weight) # build rotary embedding in LlamaModel rotary_emb = origin.layers[0].self_attn.rotary_emb diff --git a/lmdeploy/pytorch/models/mixtral.py b/lmdeploy/pytorch/models/mixtral.py index 32e55f789..ece6411b7 100644 --- a/lmdeploy/pytorch/models/mixtral.py +++ b/lmdeploy/pytorch/models/mixtral.py @@ -13,6 +13,7 @@ build_merged_colwise_linear, build_rowwise_linear) from lmdeploy.pytorch.nn.moe import SoftmaxTopK, build_moe_from_mlp +from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from ..weight_loader.dist_utils import (colwise_parallelize_linear, rowwise_parallelize_linear) @@ -202,17 +203,24 @@ def __init__(self, origin: nn.Module, layer_idx: int, input_layernorm = origin.input_layernorm is_w8a8 = hasattr(input_layernorm, 'from_float') self.input_layernorm = RMSNorm( - input_layernorm.weight, + input_layernorm.weight.size(0), input_layernorm.variance_epsilon, + dtype=input_layernorm.weight.dtype, + device=input_layernorm.weight.device, is_w8a8=is_w8a8, ) + load_weight(self.input_layernorm.weight, input_layernorm.weight) post_attention_layernorm = origin.post_attention_layernorm is_w8a8 = hasattr(post_attention_layernorm, 'from_float') self.post_attention_layernorm = RMSNorm( - post_attention_layernorm.weight, + post_attention_layernorm.weight.size(0), post_attention_layernorm.variance_epsilon, + dtype=post_attention_layernorm.weight.dtype, + device=post_attention_layernorm.weight.device, is_w8a8=is_w8a8, ) + load_weight(self.post_attention_layernorm.weight, + post_attention_layernorm.weight) def forward( self, @@ -260,9 +268,12 @@ def __init__(self, origin: nn.Module, ctx_mgr: StepContextManager): ]) norm = origin.norm is_w8a8 = hasattr(norm, 'from_float') - self.norm = RMSNorm(norm.weight, + self.norm = RMSNorm(norm.weight.size(0), norm.variance_epsilon, + dtype=norm.weight.dtype, + device=norm.weight.device, is_w8a8=is_w8a8) + load_weight(self.norm.weight, norm.weight) emb_type = EmbeddingType.LinearScaling config = origin.config diff --git a/lmdeploy/pytorch/nn/linear.py b/lmdeploy/pytorch/nn/linear.py index 093ead71d..abd041166 100644 --- a/lmdeploy/pytorch/nn/linear.py +++ b/lmdeploy/pytorch/nn/linear.py @@ -105,7 +105,27 @@ def __init__(self, super().__init__() impl_builder = get_backend().get_layer_impl_builder( LayerType.LinearW4A16) - self.impl = impl_builder.build(mod, ctx_mgr) + qweight = mod.qweight + scales = mod.scales + qzeros = mod.qzeros + bias = mod.bias + self.impl = impl_builder.build(mod.in_features, + mod.out_features, + mod.w_bit, + mod.group_size, + bias is not None, + dtype=scales.dtype) + qweight, scales, qzeros, bias = self.impl.update_weights( + qweight, scales, qzeros, bias) + qweight = torch.nn.Parameter(qweight, requires_grad=False) + scales = torch.nn.Parameter(scales, requires_grad=False) + qzeros = torch.nn.Parameter(qzeros, requires_grad=False) + if bias is not None: + bias = torch.nn.Parameter(bias, requires_grad=False) + self.register_parameter('qweight', qweight) + self.register_parameter('scales', scales) + self.register_parameter('qzeros', qzeros) + self.register_parameter('bias', bias) adapter_infos = adapter_infos if adapter_infos is not None else [] self.lora_adapters = None @@ -120,9 +140,11 @@ def forward(self, x): """w4a16 forward.""" is_tp = False if self.colwise else self.is_tp if self.lora_adapters is None: - return self.impl.forward(x, is_tp) + return self.impl.forward(x, self.qweight, self.scales, self.qzeros, + self.bias, is_tp) - out = self.impl.forward(x, False) + out = self.impl.forward(x, self.qweight, self.scales, self.qzeros, + self.bias, False) if self.lora_adapters is not None: for lora_adapter in self.lora_adapters: out = lora_adapter(x, out) @@ -142,14 +164,29 @@ def __init__(self, super().__init__() impl_builder = get_backend().get_layer_impl_builder( LayerType.LinearW8A8) - self.impl = impl_builder.build(mod, ctx_mgr) + weight = mod.weight + scale = mod.scale + bias = mod.bias + + self.impl = impl_builder.build(mod.in_features, + mod.out_features, + bias is not None, + dtype=torch.float16) + weight, scale, bias = self.impl.update_weights(weight, scale, bias) + weight = torch.nn.Parameter(weight, requires_grad=False) + scale = torch.nn.Parameter(scale, requires_grad=False) + if bias is not None: + bias = torch.nn.Parameter(bias, requires_grad=False) + self.register_parameter('weight', weight) + self.register_parameter('scale', scale) + self.register_parameter('bias', bias) self.is_tp = is_tp self.colwise = colwise def forward(self, x): """forward of w8a8.""" is_tp = False if self.colwise else self.is_tp - return self.impl.forward(x, is_tp) + return self.impl.forward(x, self.weight, self.scale, self.bias, is_tp) class BaseLinear(nn.Module): @@ -163,7 +200,15 @@ def __init__(self, is_tp: bool = False): super().__init__() impl_builder = get_backend().get_layer_impl_builder(LayerType.Linear) - self.impl = impl_builder.build(mod, ctx_mgr) + weight = mod.weight + bias = mod.bias + self.impl = impl_builder.build(mod.in_features, + mod.out_features, + bias is not None, + dtype=weight.dtype) + weight, bias = self.impl.update_weights(weight, bias) + self.register_parameter('weight', weight) + self.register_parameter('bias', bias) adapter_infos = adapter_infos if adapter_infos is not None else [] self.lora_adapters = None @@ -178,9 +223,9 @@ def forward(self, x): """forward of linear layer.""" is_tp = False if self.colwise else self.is_tp if self.lora_adapters is None: - return self.impl.forward(x, is_tp) + return self.impl.forward(x, self.weight, self.bias, is_tp) - out = self.impl.forward(x, False) + out = self.impl.forward(x, self.weight, self.bias, False) if self.lora_adapters is not None: for lora_adapter in self.lora_adapters: out = lora_adapter(x, out) diff --git a/lmdeploy/pytorch/nn/moe.py b/lmdeploy/pytorch/nn/moe.py index dfdf49d4d..53f0053e3 100644 --- a/lmdeploy/pytorch/nn/moe.py +++ b/lmdeploy/pytorch/nn/moe.py @@ -7,6 +7,49 @@ from ..backends import LayerType, get_backend +def _merge_mlp(gates: List[torch.Tensor], ups: List[torch.Tensor], + downs: List[torch.Tensor]): + """merge experts.""" + num_experts = len(gates) + + def __get_meta(): + gate = gates[0] + down = downs[0] + ffn_dim = gate.weight.size(0) + hidden_dim = down.weight.size(0) + dtype = gate.weight.dtype + device = gate.weight.device + return ffn_dim, hidden_dim, dtype, device + + def __copy_assign_param(param, weight): + """copy assign.""" + weight.copy_(param.data) + param.data = weight + + ffn_dim, hidden_dim, dtype, device = __get_meta() + + gate_up_weights = torch.empty(num_experts, + ffn_dim * 2, + hidden_dim, + device=device, + dtype=dtype) + down_weights = torch.empty(num_experts, + hidden_dim, + ffn_dim, + device=device, + dtype=dtype) + for exp_id in range(num_experts): + gate = gates[exp_id] + up = ups[exp_id] + down = downs[exp_id] + __copy_assign_param(gate.weight, gate_up_weights[exp_id, :ffn_dim]) + __copy_assign_param(up.weight, gate_up_weights[exp_id, ffn_dim:]) + __copy_assign_param(down.weight, down_weights[exp_id]) + + torch.cuda.empty_cache() + return gate_up_weights, down_weights + + class SoftmaxTopK(nn.Module): """softmax topk.""" @@ -22,6 +65,33 @@ def forward(self, x: torch.Tensor): return self.impl.forward(x) +class FusedMoE(nn.Module): + """fused moe.""" + + def __init__( + self, + gate_up_weights: torch.Tensor, + down_weights: torch.Tensor, + top_k: int, + renormalize: bool = False, + ): + super().__init__() + impl_builder = get_backend().get_layer_impl_builder(LayerType.FusedMoE) + self.impl = impl_builder.build(top_k, renormalize) + gate_up_weights, down_weights = self.impl.update_weights( + gate_up_weights, down_weights) + gate_up_weights = torch.nn.Parameter(gate_up_weights, + requires_grad=False) + down_weights = torch.nn.Parameter(down_weights, requires_grad=False) + self.register_parameter('gate_up_weights', gate_up_weights) + self.register_parameter('down_weights', down_weights) + + def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.LongTensor): + return self.impl.forward(hidden_states, topk_weights, topk_ids, + self.gate_up_weights, self.down_weights) + + def build_moe_from_mlp( gates: List[nn.Linear], ups: List[nn.Linear], @@ -30,9 +100,9 @@ def build_moe_from_mlp( renormalize: bool = False, ): """build moe from mlp.""" - impl_builder = get_backend().get_layer_impl_builder(LayerType.FusedMoE) - return impl_builder.build_from_mlp(gates, - ups, - downs, - top_k=top_k, - renormalize=renormalize) + + gate_up_weights, down_weights = _merge_mlp(gates, ups, downs) + return FusedMoE(gate_up_weights, + down_weights, + top_k=top_k, + renormalize=renormalize) diff --git a/lmdeploy/pytorch/nn/norm.py b/lmdeploy/pytorch/nn/norm.py index 7df0a810a..c9137c129 100644 --- a/lmdeploy/pytorch/nn/norm.py +++ b/lmdeploy/pytorch/nn/norm.py @@ -9,8 +9,10 @@ class RMSNorm(nn.Module): """RMS Norm with add residual.""" def __init__(self, - weight: torch.Tensor, + hidden_size: int, eps: float = 1e-6, + dtype: torch.dtype = None, + device: torch.device = None, is_w8a8: bool = False): super().__init__() backend = get_backend() @@ -18,8 +20,25 @@ def __init__(self, builder = backend.get_layer_impl_builder(LayerType.RMSNormW8A8) else: builder = backend.get_layer_impl_builder(LayerType.RMSNorm) - self.impl = builder.build(weight, eps) + self.register_parameter('weight', + self.create_weight(hidden_size, dtype, device)) + self.impl = builder.build(hidden_size, eps) + + @staticmethod + def create_weight(hidden_size: int, + dtype: torch.dtype = None, + device: torch.device = None): + """create weight.""" + if dtype is None: + dtype = torch.float16 + if device is None: + device = 'cuda' + weight = torch.nn.Parameter(torch.ones(hidden_size, + dtype=dtype, + device=device), + requires_grad=False) + return weight def forward(self, x: torch.Tensor, residual: torch.Tensor = None): """forward.""" - return self.impl.forward(x, residual) + return self.impl.forward(x, self.weight, residual) diff --git a/lmdeploy/pytorch/weight_loader/model_weight_loader.py b/lmdeploy/pytorch/weight_loader/model_weight_loader.py index 9ddbb1169..36da1cb6e 100644 --- a/lmdeploy/pytorch/weight_loader/model_weight_loader.py +++ b/lmdeploy/pytorch/weight_loader/model_weight_loader.py @@ -16,6 +16,28 @@ logger = get_logger('lmdeploy') +def load_weight(param: torch.nn.Parameter, loaded_weight: torch.Tensor, + **kwargs): + """load weight.""" + if hasattr(param, 'weight_loader'): + param.weight_loader(param, loaded_weight, **kwargs) + else: + assert len(kwargs) == 0 + default_weight_loader(param, loaded_weight) + + +def default_weight_loader(param: torch.nn.Parameter, + loaded_weight: torch.Tensor): + """default weight loader.""" + if param.numel() == 1 and loaded_weight.numel() == 1: + param.data.fill_(loaded_weight.item()) + else: + assert param.size() == loaded_weight.size(), ( + f'Attempted to load weight ({loaded_weight.size()}) ' + f'into parameter ({param.size()})') + param.data.copy_(loaded_weight) + + def _get_weight_type(model_path: str, use_safetensors: bool = None): """get weight type.""" weight_type = None From 9df516158be58dadeab6d94ca8a296042a6616c9 Mon Sep 17 00:00:00 2001 From: grimoire Date: Tue, 27 Aug 2024 20:15:18 +0800 Subject: [PATCH 40/89] refactor add model --- lmdeploy/pytorch/adapter/adapter.py | 470 +++--- lmdeploy/pytorch/backends/cuda/qmodules.py | 3 +- lmdeploy/pytorch/backends/cuda/slora.py | 153 +- lmdeploy/pytorch/backends/slora.py | 50 +- lmdeploy/pytorch/engine/engine.py | 6 +- lmdeploy/pytorch/engine/model_agent.py | 257 +--- lmdeploy/pytorch/models/__init__.py | 3 +- lmdeploy/pytorch/models/llama.py | 287 ++-- lmdeploy/pytorch/models/mixtral.py | 298 ++-- lmdeploy/pytorch/models/module_map.py | 27 +- lmdeploy/pytorch/models/patch.py | 222 +-- lmdeploy/pytorch/nn/linear.py | 1323 +++++++++++++---- lmdeploy/pytorch/nn/moe.py | 159 +- lmdeploy/pytorch/nn/norm.py | 11 +- .../weight_loader/adapter_weight_loader.py | 35 - lmdeploy/pytorch/weight_loader/dist_utils.py | 526 ------- .../weight_loader/model_weight_loader.py | 166 +-- 17 files changed, 1999 insertions(+), 1997 deletions(-) delete mode 100644 lmdeploy/pytorch/weight_loader/adapter_weight_loader.py delete mode 100644 lmdeploy/pytorch/weight_loader/dist_utils.py diff --git a/lmdeploy/pytorch/adapter/adapter.py b/lmdeploy/pytorch/adapter/adapter.py index 815412d65..70815aba7 100644 --- a/lmdeploy/pytorch/adapter/adapter.py +++ b/lmdeploy/pytorch/adapter/adapter.py @@ -2,10 +2,11 @@ import re from dataclasses import dataclass, field -from typing import Any, Dict, List +from typing import Any, Dict, Iterable, List import numpy as np import torch +import torch.distributed as dist from torch import Tensor from ..block import LogicalTokenBlocks @@ -16,34 +17,89 @@ def _div_up(a, b): return (a + b - 1) // b -def _cache_weight(cache: Tensor, weight: Tensor, rank_offset: Tensor): - """cache weight.""" - assert weight.dim() == 2 - assert rank_offset.dim() == 1 +def get_ranks_and_scalings(target_name: str, + cfgs: Iterable, + device: torch.device = None): + """get ranks and scalings.""" + ranks = [] + scalings = [] + for cfg in cfgs: + if target_name not in cfg.target_modules: + ranks.append(0) + scalings.append(1) + continue + ranks.append(cfg.r) + scalings.append(float(cfg.lora_alpha / cfg.r)) + ranks = torch.tensor(ranks, device=device) + scalings = torch.tensor(scalings, device=device) + return ranks, scalings + + +def find_all_target(model: torch.nn.Module, target_name: str): + """find all targets.""" + # find packed name + packed_name = target_name + pack_idx = None + packed_modules_mapping = getattr(model, 'packed_modules_mapping', dict()) + for name, sub_names in packed_modules_mapping.items(): + if target_name in sub_names: + pack_idx = sub_names.index(target_name) + packed_name = name + break + + found_mods = [] + name_postfix = f'.{packed_name}' + for name, mod in model.named_modules(): + if not name.endswith(name_postfix): + continue + found_mods.append((name, mod)) - cache = cache.view(-1) - rank, feat_size = weight.size() - assert cache.size(-1) >= feat_size, ('cache.size(-1) >= feat_size failed.') - assert rank <= rank_offset.size(0), ('rank <= rank_offset.size(0) failed.') - for r in range(rank): - r_off = rank_offset[r] - cache[r_off:r_off + feat_size] = weight[r] + return found_mods, pack_idx -def _get_named_loralinears(model: torch.nn.Module): - """get all named loralinear.""" - named_loralinear: Dict[str, torch.nn.Module] = dict() - for name, module in model.named_modules(): - if getattr(module, 'lora_adapters', None) is None: - continue - for idx, adapter in enumerate(module.lora_adapters): - named_loralinear[f'{name}_{idx}'] = adapter - return named_loralinear +def get_max_ranks_per_block(block_numel: int, rank_stride: int): + assert block_numel >= rank_stride, ( + 'LoRA Adapter raquires larger block_size.') + return block_numel // rank_stride + + +def get_ranks_per_block(block_numel: int, rank_stride: int, rank: int): + """ranks per blocks.""" + max_ranks_per_block = get_max_ranks_per_block(block_numel, rank_stride) + return min(rank, max_ranks_per_block) + + +def get_num_required_blocks(block_numel: int, rank_stride: int, rank: int): + """get num required blocks.""" + ranks_per_block = get_ranks_per_block(block_numel, rank_stride, rank) + if rank == 0: + return 0 + return _div_up(rank, ranks_per_block) + + +def get_inblock_offset(block_numel: int, rank_stride: int, rank: int): + """in block offset.""" + ranks_per_block = get_ranks_per_block(block_numel, rank_stride, rank) + num_required_blocks = get_num_required_blocks(block_numel, rank_stride, + rank) + ret = np.arange(ranks_per_block) * rank_stride + ret = ret.repeat(num_required_blocks)[:rank] + return ret -def _get_layer_index(key: str, config: Any): +def get_block_idx_per_rank(block_numel: int, rank_stride: int, rank: int): + """out block idx.""" + ranks_per_block = get_ranks_per_block(block_numel, rank_stride, rank) + num_required_blocks = get_num_required_blocks(block_numel, rank_stride, + rank) + ret = np.arange(num_required_blocks) + ret = ret[:, None].repeat(ranks_per_block, 1) + ret = ret.flatten()[:rank] + return ret + + +def get_layer_index(key: str, layers_pattern: str = None): """get layer index of the lora linear.""" - layers_pattern = getattr(config, 'layers_pattern', None) if isinstance(layers_pattern, str): layers_pattern = [layers_pattern] if layers_pattern is None or len(layers_pattern) == 0: @@ -57,209 +113,117 @@ def _get_layer_index(key: str, config: Any): return int(layer_index[1]) -def get_indexed_lora_linears(model: torch.nn.Module): - """get indexed lora linear.""" - named_linears = _get_named_loralinears(model) - - config = None - peft_config = getattr(model, 'peft_config', dict()) - if len(peft_config) > 0: - config = next(iter(peft_config.values())) - - indexed_linears = dict() - for name, layer in named_linears.items(): - index = _get_layer_index(name, config) - target = name.split('.')[-1] - indexed_linears.setdefault(index, dict()) - indexed_linears[index][target] = layer - return indexed_linears - - -def update_lora_linears(lora_linears: Dict, - weight_maps: List['AdapterWeightMap'], - gpu_caches: List[List], - device: str = 'cuda'): - """update lora linears.""" - - def __update_linear(linear, idx, target_name, adapter_names): - """update linear.""" - linear.layer_idx = idx - linear.target_name = target_name - linear.adapter_info.lora_A = None - linear.adapter_info.lora_B = None - - num_adapters = len(weight_maps) - target_modules = weight_maps[0].target_modules - num_offs = len(weight_maps[0].rank_offset) - max_rank = weight_maps[0].max_rank - target_map = dict((name, idx) for idx, name in enumerate(target_modules)) - all_ranks = [weight_maps[idx].rank for idx in range(num_adapters)] - all_ranks = torch.tensor(all_ranks, device=device).t().contiguous() - all_scaling = [weight_maps[idx].scaling for idx in range(num_adapters)] - all_scaling = torch.tensor(all_scaling, device=device).t().contiguous() - all_boffs = torch.zeros(num_adapters, - num_offs, - dtype=torch.int64, - device=device) - all_boffs = all_boffs.unflatten(-1, (-1, max_rank)) - - adapter_names = [weight_map.adapter_name for weight_map in weight_maps] - - for idx, lora_linear in lora_linears.items(): - a_cache, b_cache = gpu_caches[idx] - for target, linear in lora_linear.items(): - target_id = target_map[target] - ranks = all_ranks[target_id] - scalings = all_scaling[target_id] - boffs = all_boffs[:, target_id] - linear.post_init( - ranks, - scalings, - boffs, - a_cache, - b_cache, - max_rank=max_rank, - ) - __update_linear(linear, - idx, - target_name=target, - adapter_names=adapter_names) - return all_boffs.flatten(1, 2) - - @dataclass -class LoRALinearInfo: +class LoRATargetInfo: """lora linear info.""" - ranks: Dict[str, int] - scalings: Dict[str, int] - adapter_names: List[str] in_features: int out_features: int + colwise: bool rank_stride: int = field(default=0, init=False) def __post_init__(self): """post init.""" self.rank_stride = max(self.in_features, self.out_features) - @classmethod - def from_loralinear(cls, linear: torch.nn.Module): - """create from lora linear.""" - adapter_info = linear.adapter_info - ranks = adapter_info.r - scalings = adapter_info.scaling - out_features = adapter_info.out_features - in_features = adapter_info.in_features - return cls( - ranks=ranks, - scalings=scalings, - adapter_names=list(ranks.keys()), - in_features=in_features, - out_features=out_features, - ) - - def max_ranks_per_block(self, block_numel: int): - assert block_numel >= self.rank_stride, ( - 'LoRA Adapter raquires larger block_size.') - return block_numel // self.rank_stride - - def ranks_per_block(self, block_numel: int, adapter_name: str): - """ranks per blocks.""" - max_ranks_per_block = self.max_ranks_per_block(block_numel) - rank = self.ranks.get(adapter_name, 0) - return min(rank, max_ranks_per_block) - - def num_required_blocks(self, block_numel: int, adapter_name: str): - """get num required blocks.""" - ranks_per_block = self.ranks_per_block(block_numel, adapter_name) - rank = self.ranks.get(adapter_name, 0) - if rank == 0: - return 0 - return _div_up(rank, ranks_per_block) - - def inblock_offset(self, block_numel: int, adapter_name: str): - """in block offset.""" - rank = self.ranks.get(adapter_name, 0) - ranks_per_block = self.ranks_per_block(block_numel, adapter_name) - num_required_blocks = self.num_required_blocks(block_numel, - adapter_name) - ret = np.arange(ranks_per_block) * self.rank_stride - ret = ret.repeat(num_required_blocks)[:rank] - return ret - - def block_idx_per_rank(self, block_numel: int, adapter_name: str): - """out block idx.""" - rank = self.ranks.get(adapter_name, 0) - ranks_per_block = self.ranks_per_block(block_numel, adapter_name) - num_required_blocks = self.num_required_blocks(block_numel, - adapter_name) - ret = np.arange(num_required_blocks) - ret = ret[:, None].repeat(ranks_per_block, 1) - ret = ret.flatten()[:rank] - return ret +def _get_rank_and_world(): + """get rank and world size.""" + rank = 0 + world_size = 1 + if dist.is_initialized(): + rank = dist.get_rank() + world_size = dist.get_world_size() -def get_loralinear_info(model: torch.nn.Module): - """get loralinear info.""" - indexed_lora_linears = get_indexed_lora_linears(model) - if len(indexed_lora_linears) == 0: - return dict() - lora_linears = indexed_lora_linears[0] - infos = dict() - for target_name, linear in lora_linears.items(): - infos[target_name] = LoRALinearInfo.from_loralinear(linear) - return infos + return rank, world_size @dataclass class AdapterWeightMap: - adapter_id: int adapter_name: str + path: str rank: List[int] - scaling: List[int] rank_offset: np.ndarray max_rank: int target_modules: List[str] + colwise: List[bool] + + @staticmethod + def _get_weight(weight: torch.Tensor, is_lora_a: bool, is_col: bool, + rank: int, world_size: int): + """get sliced weight.""" + if world_size == 1: + return weight + + if not is_col and is_lora_a: + # rowwise + weight = weight.chunk(world_size, dim=1)[rank] + else: + # colwise + weight = weight.chunk(world_size, dim=0)[rank] + return weight + + @staticmethod + def _fill_a_cache(weight: torch.Tensor, cache: torch.Tensor, + rank_off: torch.Tensor): + """fill a cache.""" + num_ranks, feat_size = weight.shape + + for rank in range(num_ranks): + off = rank_off[rank] + cache[off:off + feat_size].copy_(weight[rank]) + + @staticmethod + def _fill_b_cache(weight: torch.Tensor, cache: torch.Tensor, + rank_off: torch.Tensor): + """fill a cache.""" + feat_size, num_ranks = weight.shape + + for rank in range(num_ranks): + off = rank_off[rank] + cache[off:off + feat_size].copy_(weight[:, rank]) + + def cache_adapter(self, caches: List[List[Tensor]]): + """cache all linear.""" + if self.path is None: + return + checkpoint_path = f'{self.path}/adapter_model.bin' + state_dict = torch.load(checkpoint_path, map_location='cpu') - @classmethod - def cache_lora_a(cls, cache: Tensor, weight: Tensor, rank_offset: Tensor): - """cache lora a weight.""" - return _cache_weight(cache, weight, rank_offset) + dist_rank, world_size = _get_rank_and_world() - @classmethod - def cache_lora_b(cls, cache: Tensor, weight: Tensor, rank_offset: Tensor): - """cache lora b weight.""" - return _cache_weight(cache, weight.t(), rank_offset) - - def cache_lora_linear(self, lora_linear: Dict[str, torch.nn.Module], - cache_a: Tensor, cache_b: Tensor): - """cache lora linear.""" - name = self.adapter_name target_modules = self.target_modules - rank_offset = self.rank_offset.reshape(-1, self.max_rank) - for tidx, target in enumerate(target_modules): - linear = lora_linear[target] - adapter_info = linear.adapter_info - if not (name in adapter_info.lora_A - and name in adapter_info.lora_B): - continue - linear_a = adapter_info.lora_A[name] - linear_b = adapter_info.lora_B[name] - weight_a = linear_a.weight - weight_b = linear_b.weight - assert weight_a is not None - assert weight_b is not None - r_offset = rank_offset[tidx] - self.cache_lora_a(cache_a, weight_a, r_offset) - self.cache_lora_b(cache_b, weight_b, r_offset) - - def cache_adapter(self, lora_linears: Dict, caches: List[List[Tensor]]): - """cache all linear.""" - assert len(lora_linears) == len(caches), ( - 'len(lora_linears) == len(caches)') - - for idx, lora_linear in lora_linears.items(): - cache_a, cache_b = caches[idx] - self.cache_lora_linear(lora_linear, cache_a, cache_b) + target_map = dict( + (name, idx) for idx, name in enumerate(target_modules)) + num_targets = len(target_modules) + rank_offset = self.rank_offset.view(num_targets, -1) + for key, weight in state_dict.items(): + layer_idx = get_layer_index(key, None) + a_cache, b_cache = caches[layer_idx] + a_cache = a_cache.view(-1) + b_cache = b_cache.view(-1) + + split_key = key.split('.') + assert split_key[-1] == 'weight' + target_name = split_key[-3] + if split_key[-2] == 'lora_A': + is_lora_a = True + elif split_key[-2] == 'lora_B': + is_lora_a = False + else: + raise RuntimeError(f'Unexpected key: {key}') + + target_id = target_map[target_name] + rank_off = rank_offset[target_id] + is_col = self.colwise[target_id] + weight = self._get_weight(weight, + is_lora_a, + is_col, + rank=dist_rank, + world_size=world_size) + if is_lora_a: + self._fill_a_cache(weight, a_cache, rank_off) + else: + self._fill_b_cache(weight, b_cache, rank_off) @dataclass @@ -271,9 +235,11 @@ class SchedulerAdapter: rank: List[int] scaling: List[int] target_modules: List[str] + target_infos: List[LoRATargetInfo] logical_blocks: LogicalTokenBlocks inblock_offset: np.ndarray block_idx_per_rank: np.ndarray + adapter_path: str = None block_stride: int = 0 max_rank: int = 0 num_required_blocks: int = 0 @@ -281,36 +247,43 @@ class SchedulerAdapter: _active: bool = field(default=False, init=False) @classmethod - def new(cls, adapter_id: int, adapter_name: str, - linear_infos: Dict[str, LoRALinearInfo], block_numel: int, - max_rank: int): + def new(cls, adapter_id: int, adapter_name: str, adapter_path: str, + adapter_cfg: Any, target_infos: Dict[str, LoRATargetInfo], + block_numel: int, max_rank: int): """new.""" - target_modules = list(linear_infos.keys()) + target_modules = list(target_infos.keys()) rank = [] scaling = [] - for linear in linear_infos.values(): - ranks = linear.ranks - rank.append(ranks.get(adapter_name, 0)) - scaling.append(linear.scalings.get(adapter_name, 1.0)) - inblock_offset = [np.empty((0, ), dtype=np.int64)] block_idx_per_rank = [np.empty((0, ), dtype=np.int64)] num_required_blocks = 0 for target_name in target_modules: - linear = linear_infos[target_name] - ib_offset = linear.inblock_offset(block_numel, adapter_name) + + # get rank and scaling + r = 0 + s = 1.0 + if target_name in adapter_cfg.target_modules: + r = adapter_cfg.r + if r != 0: + s = adapter_cfg.lora_alpha / r + rank.append(r) + scaling.append(s) + + info = target_infos[target_name] + rank_stride = info.rank_stride + ib_offset = get_inblock_offset(block_numel, rank_stride, r) pad_ib_offset = np.zeros((max_rank, ), dtype=np.int64) pad_ib_offset[:ib_offset.shape[0]] = ib_offset inblock_offset.append(pad_ib_offset) - bidx_p_rank = linear.block_idx_per_rank( - block_numel, adapter_name) + num_required_blocks + bidx_p_rank = get_block_idx_per_rank(block_numel, rank_stride, + r) + num_required_blocks pad_bidx_p_rank = np.zeros((max_rank, ), dtype=np.int64) pad_bidx_p_rank[:bidx_p_rank.shape[0]] = bidx_p_rank block_idx_per_rank.append(pad_bidx_p_rank) - num_required_blocks += linear.num_required_blocks( - block_numel, adapter_name) + num_required_blocks += get_num_required_blocks( + block_numel, rank_stride, r) inblock_offset = np.concatenate(inblock_offset) block_idx_per_rank = np.concatenate(block_idx_per_rank) @@ -320,9 +293,11 @@ def new(cls, adapter_id: int, adapter_name: str, rank=rank, scaling=scaling, target_modules=target_modules, + target_infos=target_infos, logical_blocks=LogicalTokenBlocks(), inblock_offset=inblock_offset, block_idx_per_rank=block_idx_per_rank, + adapter_path=adapter_path, block_stride=block_numel, max_rank=max_rank, num_required_blocks=num_required_blocks, @@ -356,49 +331,70 @@ def name(self): def build_weight_map(self): """build weight map.""" assert self.rank_offset is not None + colwise = [ + self.target_infos[name].colwise for name in self.target_modules + ] return AdapterWeightMap( - adapter_id=self.adapter_id, adapter_name=self.name, + path=self.adapter_path, rank=self.rank, - scaling=self.scaling, rank_offset=self.rank_offset, max_rank=self.max_rank, target_modules=self.target_modules, + colwise=colwise, ) -def _get_adapter_names(linear_infos: Dict[str, LoRALinearInfo]): - """get adapter names.""" - adapter_names = set() - for info in linear_infos.values(): - adapter_names.update(info.adapter_names) - return list(adapter_names) +class NoneLoraConfig: + + def __init__(self): + self.r = 0 + self.lora_alpha = 8 + self.target_modules = [] class AdapterManager: """adapter manager.""" - def __init__(self, linear_infos: Dict[str, LoRALinearInfo], - block_numel: int): - self.linear_infos = linear_infos + def __init__(self, adapters: Dict[str, str], + target_infos: Dict[str, LoRATargetInfo], block_numel: int): + + self.target_infos = target_infos self.block_numel = block_numel - self._adapters: Dict[str, SchedulerAdapter] = dict() + if adapters is None: + adapters = dict() - adapter_names = _get_adapter_names(linear_infos) + self.adapter_paths = dict( + (name, path) for name, path in adapters.items()) + self.adapter_paths[None] = None + + self.adapter_cfgs = self._get_adapter_cfgs(adapters) + + adapter_names = list(adapters.keys()) self.adapter_id_map = dict( (name, idx + 1) for idx, name in enumerate(adapter_names)) self.adapter_id_map[None] = 0 + self._adapters: Dict[str, SchedulerAdapter] = dict() self.max_rank = self._get_max_rank() self._add_non_adapter() + @staticmethod + def _get_adapter_cfgs(adapters: Dict[str, str]): + """get adapter cfgs.""" + if len(adapters) == 0: + return {None: NoneLoraConfig()} + from peft import PeftConfig + adapter_cfgs = dict((name, PeftConfig.from_pretrained(path)) + for name, path in adapters.items()) + adapter_cfgs[None] = NoneLoraConfig() + return adapter_cfgs + def _get_max_rank(self): """get max rank.""" max_rank = 0 - for linear in self.linear_infos.values(): - ranks = linear.ranks - if len(ranks) > 0: - max_rank = max(max_rank, max(ranks.values())) + for cfg in self.adapter_cfgs.values(): + max_rank = max(max_rank, cfg.r) return max_rank def _add_non_adapter(self): @@ -424,10 +420,14 @@ def num_adapters(self): def add_adapter(self, adapter_name: str): """add adapter.""" adapter_id = self.adapter_id_map[adapter_name] + adapter_cfg = self.adapter_cfgs[adapter_name] + adapter_path = self.adapter_paths[adapter_name] adapter = SchedulerAdapter.new( adapter_id, adapter_name, - self.linear_infos, + adapter_path, + adapter_cfg, + self.target_infos, self.block_numel, max_rank=self.max_rank, ) diff --git a/lmdeploy/pytorch/backends/cuda/qmodules.py b/lmdeploy/pytorch/backends/cuda/qmodules.py index 1f1aff55a..30f729a63 100644 --- a/lmdeploy/pytorch/backends/cuda/qmodules.py +++ b/lmdeploy/pytorch/backends/cuda/qmodules.py @@ -3,7 +3,6 @@ import torch import torch.distributed as dist -from torch import nn from lmdeploy.pytorch.kernels.cuda.w8a8_triton_kernels import ( matmul_kernel_dynamic_quant, per_token_quant_int8, rms_norm_dynamic_quant) @@ -46,7 +45,7 @@ def build(hidden_size: int, eps: float = 1e-6): return TritonRMSNormW8A8Impl(hidden_size, eps) -class TritonLinearW8A8Impl(LinearW8A8Impl, nn.Module): +class TritonLinearW8A8Impl(LinearW8A8Impl): """triton linear w8a8 implementation.""" def __init__(self, in_features: int, out_features: int): diff --git a/lmdeploy/pytorch/backends/cuda/slora.py b/lmdeploy/pytorch/backends/cuda/slora.py index c57bca29f..84319446d 100644 --- a/lmdeploy/pytorch/backends/cuda/slora.py +++ b/lmdeploy/pytorch/backends/cuda/slora.py @@ -26,35 +26,10 @@ class PackedLoRAInput: class TritonSLoRAImpl(SLoRAImpl): """triton slora implementation.""" - def __init__(self, - adapter_info: AdapterInfo, - ctx_mgr: StepContextManager, - colwise: bool = True): - super().__init__() - self.base_slice = adapter_info.base_slice - self.ctx_mgr = ctx_mgr - self.colwise = colwise - - def post_init( - self, - ranks: torch.Tensor, - scalings: torch.Tensor, - rank_offsets: torch.Tensor, - a_cache: torch.Tensor, - b_cache: torch.Tensor, - max_rank: int, - ): - """post init.""" - self.ranks = ranks - self.scalings = scalings - self.rank_offsets = rank_offsets - self.a_cache = a_cache - self.b_cache = b_cache - self.max_rank = max_rank - - def _make_packed_lora_input(self, x): + @staticmethod + def _make_packed_lora_input(x, ctx_mgr): """make PackedLoRAInput.""" - context = self.ctx_mgr.current_context() + context = ctx_mgr.current_context() # adapter cache max_q_seq_length = x.numel() // x.size(-1) @@ -69,9 +44,10 @@ def _make_packed_lora_input(self, x): def _forward_rowwise(self, lora_input: PackedLoRAInput, base_output: torch.Tensor, + adapter_info: AdapterInfo, is_tp: bool = True): """forward_rowwise.""" - sliced_base = base_output[..., self.base_slice] + sliced_base = base_output[..., adapter_info.base_slice] out_size = sliced_base.size(-1) if is_tp: rank = dist.get_rank() @@ -79,45 +55,45 @@ def _forward_rowwise(self, out_size //= world_size if not lora_input.is_decoding: xa = mbgmm_a(lora_input.x, - self.a_cache, + adapter_info.a_cache, q_start_loc=lora_input.q_start_loc, q_seqlens=lora_input.q_seqlens, adapter_ids=lora_input.adapter_ids, - rank_offset=self.rank_offsets, - ranks=self.ranks, + rank_offset=adapter_info.rank_offsets, + ranks=adapter_info.ranks, max_seq_len=lora_input.max_seq_len, - max_rank=self.max_rank) + max_rank=adapter_info.max_rank) lora_out = mbgmm_b(xa, - self.b_cache, + adapter_info.b_cache, q_start_loc=lora_input.q_start_loc, q_seqlens=lora_input.q_seqlens, adapter_ids=lora_input.adapter_ids, - scaling=self.scalings, - rank_offset=self.rank_offsets, - ranks=self.ranks, + scaling=adapter_info.scalings, + rank_offset=adapter_info.rank_offsets, + ranks=adapter_info.ranks, max_seq_len=lora_input.max_seq_len, - max_rank=self.max_rank, + max_rank=adapter_info.max_rank, out_size=out_size) else: xa = mbgmv_a(lora_input.x, - self.a_cache, + adapter_info.a_cache, adapter_ids=lora_input.adapter_ids, - rank_offset=self.rank_offsets, - ranks=self.ranks, - max_rank=self.max_rank) + rank_offset=adapter_info.rank_offsets, + ranks=adapter_info.ranks, + max_rank=adapter_info.max_rank) lora_out = mbgmv_b(xa, - self.b_cache, + adapter_info.b_cache, adapter_ids=lora_input.adapter_ids, - scaling=self.scalings, - rank_offset=self.rank_offsets, - ranks=self.ranks, - max_rank=self.max_rank, + scaling=adapter_info.scalings, + rank_offset=adapter_info.rank_offsets, + ranks=adapter_info.ranks, + max_rank=adapter_info.max_rank, out_size=out_size) if is_tp: out_shape = base_output.shape out = base_output.flatten(0, -2) - slice_off = self.base_slice.start + slice_off = adapter_info.base_slice.start slice_off = 0 if slice_off is None else slice_off slice_start = slice_off + rank * out_size slice_end = slice_start + out_size @@ -130,8 +106,12 @@ def _forward_rowwise(self, return out - def _forward_colwise(self, lora_input: PackedLoRAInput, - base_output: torch.Tensor): + def _forward_colwise( + self, + lora_input: PackedLoRAInput, + base_output: torch.Tensor, + adapter_info: AdapterInfo, + ): """forward_colwise.""" def __gather_xa(xa): @@ -142,69 +122,76 @@ def __gather_xa(xa): gathered_xa = gathered_xa.permute(1, 0, 2).flatten(-2, -1) return gathered_xa - sliced_base = base_output[..., self.base_slice] + base_slice = adapter_info.base_slice + a_cache = adapter_info.a_cache + b_cache = adapter_info.b_cache + rank_offsets = adapter_info.rank_offsets + ranks = adapter_info.ranks + max_rank = adapter_info.max_rank + scalings = adapter_info.scalings + sliced_base = base_output[..., base_slice] out_size = sliced_base.size(-1) world_size = dist.get_world_size() if not lora_input.is_decoding: xa = mbgmm_a(lora_input.x, - self.a_cache, + a_cache, q_start_loc=lora_input.q_start_loc, q_seqlens=lora_input.q_seqlens, adapter_ids=lora_input.adapter_ids, - rank_offset=self.rank_offsets, - ranks=self.ranks, + rank_offset=rank_offsets, + ranks=ranks, max_seq_len=lora_input.max_seq_len, - max_rank=self.max_rank, + max_rank=max_rank, rank_step=world_size) gathered_xa = __gather_xa(xa) - if len(self.ranks) > 1: + if len(ranks) > 1: gathered_xa = rearange_all_gather( gathered_xa, b_start_loc=lora_input.q_start_loc, b_seq_lens=lora_input.q_seqlens, adapter_ids=lora_input.adapter_ids, - ranks=self.ranks, + ranks=ranks, world_size=world_size, max_seq_len=lora_input.max_seq_len, output=gathered_xa) lora_out = mbgmm_b(gathered_xa, - self.b_cache, + b_cache, q_start_loc=lora_input.q_start_loc, q_seqlens=lora_input.q_seqlens, adapter_ids=lora_input.adapter_ids, - scaling=self.scalings, - rank_offset=self.rank_offsets, - ranks=self.ranks, + scaling=scalings, + rank_offset=rank_offsets, + ranks=ranks, max_seq_len=lora_input.max_seq_len, - max_rank=self.max_rank, + max_rank=max_rank, out_size=out_size) else: xa = mbgmv_a(lora_input.x, - self.a_cache, + a_cache, adapter_ids=lora_input.adapter_ids, - rank_offset=self.rank_offsets, - ranks=self.ranks, - max_rank=self.max_rank, + rank_offset=rank_offsets, + ranks=ranks, + max_rank=max_rank, rank_step=world_size) gathered_xa = __gather_xa(xa) - if len(self.ranks) > 1: + if len(ranks) > 1: gathered_xa = rearange_all_gather( gathered_xa, b_start_loc=lora_input.q_start_loc, b_seq_lens=lora_input.q_seqlens, adapter_ids=lora_input.adapter_ids, - ranks=self.ranks, + ranks=ranks, world_size=world_size, max_seq_len=lora_input.max_seq_len, output=gathered_xa) lora_out = mbgmv_b(gathered_xa, - self.b_cache, + b_cache, adapter_ids=lora_input.adapter_ids, - scaling=self.scalings, - rank_offset=self.rank_offsets, - ranks=self.ranks, - max_rank=self.max_rank, + scaling=scalings, + rank_offset=rank_offsets, + ranks=ranks, + max_rank=max_rank, out_size=out_size) lora_out = lora_out.reshape(sliced_base.shape) @@ -215,23 +202,23 @@ def __gather_xa(xa): def forward(self, x: torch.Tensor, base_output: torch.Tensor, - target_name: str, - layer_idx: int, + adapter_info: AdapterInfo, + ctx_mgr: StepContextManager, + colwise: bool, is_tp: bool = True): """forward.""" - lora_input = self._make_packed_lora_input(x) - if self.colwise and is_tp: - return self._forward_colwise(lora_input, base_output) + lora_input = self._make_packed_lora_input(x, ctx_mgr) + if colwise and is_tp: + return self._forward_colwise(lora_input, base_output, adapter_info) else: - return self._forward_rowwise(lora_input, base_output, is_tp) + return self._forward_rowwise(lora_input, base_output, adapter_info, + is_tp) class TritonSLoRABuilder(SLoRABuilder): """triton slora layer builder.""" @staticmethod - def build(adapter_info: AdapterInfo, - ctx_mgr: StepContextManager, - colwise: bool = True): + def build(): """build.""" - return TritonSLoRAImpl(adapter_info, ctx_mgr, colwise) + return TritonSLoRAImpl() diff --git a/lmdeploy/pytorch/backends/slora.py b/lmdeploy/pytorch/backends/slora.py index 18f8c9106..6fc606cbb 100644 --- a/lmdeploy/pytorch/backends/slora.py +++ b/lmdeploy/pytorch/backends/slora.py @@ -3,7 +3,6 @@ from dataclasses import dataclass import torch -from torch import nn from lmdeploy.pytorch.model_inputs import StepContextManager @@ -11,50 +10,27 @@ @dataclass class AdapterInfo: """Adapter information.""" - r: dict - lora_A: nn.ModuleDict - lora_B: nn.ModuleDict - scaling: dict - base_slice: slice in_features: int out_features: int - - @staticmethod - def from_lora_linear(mod: nn.Module, base_slice: slice = None): - if base_slice is None: - base_slice = slice(None) - return AdapterInfo( - r=mod.r, - lora_A=mod.lora_A, - lora_B=mod.lora_B, - scaling=mod.scaling, - base_slice=base_slice, - in_features=mod.in_features, - out_features=mod.out_features, - ) + ranks: torch.Tensor + scalings: torch.Tensor + rank_offsets: torch.Tensor + a_cache: torch.Tensor + b_cache: torch.Tensor + base_slice: slice + max_rank: int -class SLoRAImpl(ABC, nn.Module): +class SLoRAImpl(ABC): """slora implementation api.""" - def post_init( - self, - ranks: torch.Tensor, - scalings: torch.Tensor, - rank_offsets: torch.Tensor, - a_cache: torch.Tensor, - b_cache: torch.Tensor, - max_rank: int, - ): - """post init.""" - raise NotImplementedError - @abstractmethod def forward(self, x: torch.Tensor, base_output: torch.Tensor, - target_name: str, - layer_idx: int, + adapter_info: AdapterInfo, + ctx_mgr: StepContextManager, + colwise: bool, is_tp: bool = True): """forward.""" raise NotImplementedError @@ -65,8 +41,6 @@ class SLoRABuilder(ABC): @staticmethod @abstractmethod - def build(adapter_info: AdapterInfo, - ctx_mgr: StepContextManager, - colwise: bool = True): + def build(): """build.""" raise NotImplementedError diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 7c03aa562..da695c8cb 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -234,11 +234,11 @@ def _create_buffers(self): def _build_adapter_manager(self, adapters): if adapters is not None and len(adapters) > 0: - linear_info = self.model_agent.get_loralinear_info() + linear_infos = self.model_agent.get_lora_target_info() else: - linear_info = dict() + linear_infos = dict() block_numel = self.model_agent.get_block_numel() - return AdapterManager(linear_info, block_numel) + return AdapterManager(adapters, linear_infos, block_numel) def _bind_request_manager(self): """bind request manager.""" diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index af901a21c..266a532b7 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -2,7 +2,6 @@ import asyncio import atexit import os -import warnings from datetime import timedelta from typing import Any, Callable, Dict, List @@ -10,24 +9,21 @@ import torch.distributed as dist from torch import multiprocessing as mp -from lmdeploy.pytorch.accel import LoadNoInit from lmdeploy.utils import get_logger -from ..adapter.adapter import (AdapterWeightMap, get_indexed_lora_linears, - get_loralinear_info, update_lora_linears) +from ..adapter.adapter import AdapterWeightMap from ..backends import get_backend from ..config import BackendConfig, CacheConfig, ModelConfig from ..devices import DeviceContext, get_device_manager from ..model_inputs import ModelInputs -from ..models.patch import patch, update_custom_module_map, update_model +from ..models.patch import (add_adapters, build_patched_model, + update_custom_module_map) from ..utils import get_gpu_memory from ..weight_loader.model_weight_loader import load_model_weights from .cache_engine import CacheEngine logger = get_logger('lmdeploy') -_PATCH_ARG_NAMES = ['context', 'use_origin'] - def _update_cache_config(model_config: ModelConfig, cache_config: CacheConfig, @@ -161,56 +157,18 @@ def model_forward( def _get_indexed_lora_linears(model): """get indexed lora linears.""" + from ..adapter.adapter import get_indexed_lora_linears if hasattr(model, 'get_model'): model = model.get_model() return get_indexed_lora_linears(model) -def _get_loralinear_info(model): +def _get_lora_target_info(model, adapters: Dict[str, str]): """get lora linear info.""" + from ..adapter.adapter import get_lora_target_info if hasattr(model, 'get_model'): model = model.get_model() - return get_loralinear_info(model) - - -def _load_adapters(hf_model: torch.nn.Module, - adapters: Dict[str, str], - device_map: str = 'cpu'): - """load adapters.""" - if not adapters: - return - for name, path in adapters.items(): - logger.info(f'load adapter <{name}> from "{path}".') - hf_model.load_adapter(path, name, device_map=device_map) - - -def _add_adapters(hf_model: torch.nn.Module, adapters: Dict[str, str]): - """add adapters.""" - if not adapters: - return - from peft import PeftConfig, inject_adapter_in_model - for name, path in adapters.items(): - config = PeftConfig.from_pretrained(path) - inject_adapter_in_model(config, model=hf_model, adapter_name=name) - - -def _remove_unused_modules(hf_model: torch.nn.Module, model_cfg: ModelConfig): - """remove unused modules.""" - if model_cfg.unused_modules is not None and len( - model_cfg.unused_modules) > 0: - for mod in model_cfg.unused_modules: - has_mod = True - parts = mod.split('.') - mod_path = 'hf_model' - for p in parts: - if eval(f'hasattr({mod_path}, "{p}")'): - mod_path = f'{mod_path}.{p}' - else: - has_mod = False - break - if has_mod: - exec(f'del {mod_path}') - return hf_model + return get_lora_target_info(model, adapters) SwapMap = Dict[int, int] @@ -223,7 +181,7 @@ def __init__(self, model_config: ModelConfig, cache_config: CacheConfig): self.model_config = model_config self.cache_config = cache_config - def get_loralinear_info(self): + def get_lora_target_info(self): """get lora linear info.""" raise NotImplementedError('Not implemented') @@ -278,16 +236,11 @@ def __init__(self, adapters: Dict[str, str] = None, trust_remote_code: bool = True): super().__init__(model_config=model_config, cache_config=cache_config) - torch_dtype = model_config.dtype device = 'cuda' self.backend_config = backend_config + self._adapters = adapters - self.patched_model = self._build_model( - model_path, - torch_dtype=torch_dtype, - adapters=adapters, - device=device, - trust_remote_code=trust_remote_code) + self.patched_model = self._build_model(model_path, device=device) _update_cache_config(model_config, cache_config) @@ -300,41 +253,29 @@ def __init__(self, device=device) self.cache_engine = CacheEngine(cache_config, model_config) + + self._target_infos = None + if adapters is not None: + self._target_infos = add_adapters(self.patched_model, + self.cache_engine.gpu_cache, + adapters=adapters) + self.stream = torch.cuda.Stream() - def _build_model(self, - model_path: str, - torch_dtype: torch.dtype, - adapters: Dict[str, str] = None, - device: torch.device = 'cuda', - trust_remote_code: bool = True): + def _build_model(self, model_path: str, device: torch.device = 'cuda'): """build patched model.""" - with LoadNoInit(), warnings.catch_warnings(): - warnings.simplefilter('ignore') - hf_model = self.model_config.auto_model_cls.from_pretrained( - model_path, - torch_dtype=torch_dtype, - device_map=device, - trust_remote_code=trust_remote_code, - **self.model_config.init_kwargs) - hf_model.eval() - hf_model.config.use_cache = True - # build for vlm model - _remove_unused_modules(hf_model, self.model_config) - - if adapters: - _load_adapters(hf_model, adapters) - custom_module_map = self.model_config.custom_module_map if custom_module_map is not None: update_custom_module_map(custom_module_map) - patched_model = update_model(hf_model) - + logger.info('build model.') + patched_model = build_patched_model(self.model_config, device=device) + logger.info('loading weights.') + load_model_weights(patched_model, model_path, device=device) return patched_model - def get_loralinear_info(self): + def get_lora_target_info(self): """get lora linear info.""" - return _get_loralinear_info(self.patched_model) + return self._target_infos def get_block_numel(self): """get block nelement.""" @@ -344,24 +285,11 @@ def get_block_numel(self): def paging_adapters(self, weight_maps: List[AdapterWeightMap]): """paging adapter.""" logger.info('paging adapters.') - lora_linears = _get_indexed_lora_linears(self.patched_model) cpu_caches = self.cache_engine.cpu_cache - num_blocks = self.cache_engine.num_cpu_blocks - cpu_caches = [(kcache.view(num_blocks, - -1), vcache.view(num_blocks, -1)) + cpu_caches = [(kcache.flatten(1, -1), vcache.flatten(1, -1)) for kcache, vcache in cpu_caches] - gpu_caches = self.cache_engine.gpu_cache - num_gpu_blocks = self.cache_engine.num_gpu_blocks - gpu_caches = [(kcache.view(num_gpu_blocks, - -1), vcache.view(num_gpu_blocks, -1)) - for kcache, vcache in gpu_caches] for weight_map in weight_maps: - weight_map.cache_adapter(lora_linears, cpu_caches) - rank_offsets = update_lora_linears(lora_linears, - weight_maps, - gpu_caches, - device='cuda') - self.patched_model.rank_offsets = rank_offsets + weight_map.cache_adapter(cpu_caches) def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap): @@ -409,25 +337,6 @@ async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, return output -def _create_device_map(model: torch.nn.Module, - world_size: int, - device_map: dict = None): - """Distribute params to each devices.""" - free_mems = [get_gpu_memory(gpu_id)[0] for gpu_id in range(world_size)] - free_mems = torch.tensor(free_mems) - if device_map is None: - device_map = dict() - for name, param in model.named_parameters(): - device_id = free_mems.argmax().item() - device_map[name] = device_id - free_mems[device_id] -= param.numel() * param.element_size() - for name, param in model.named_buffers(): - device_id = free_mems.argmax().item() - device_map[name] = device_id - free_mems[device_id] -= param.numel() * param.element_size() - return device_map - - @torch.inference_mode() def _tp_build_model( rank: int, @@ -437,10 +346,8 @@ def _tp_build_model( backend_config: BackendConfig, adapters: Dict[str, str], world_size: int, - trust_remote_code: bool = True, ): """build tensor parallel model.""" - from accelerate import init_empty_weights patched_model = None cache_engine = None @@ -469,40 +376,17 @@ def _broadcast_config(cache_config): return config_list[0] try: - torch_dtype = model_config.dtype - device_map = None - with init_empty_weights(), warnings.catch_warnings(): - warnings.simplefilter('ignore') - model = model_config.auto_model_cls.from_pretrained( - model_path, - torch_dtype=torch_dtype, - trust_remote_code=trust_remote_code, - **model_config.init_kwargs) - # build for vlm model - _remove_unused_modules(model, model_config) - if rank == 0: - device_map = _create_device_map(model, world_size) - _add_adapters(model, adapters) - if rank == 0: - # adapter would remove weight of linear. - device_map = _create_device_map(model, world_size, device_map) - - model.eval() - model.config.use_cache = True + device_map = torch.device('cuda') custom_module_map = model_config.custom_module_map if custom_module_map is not None: update_custom_module_map(custom_module_map) - patched_model = patch(model) - load_model_weights(patched_model, - model_path, - adapters, - rank=rank, - world_size=world_size, - device='cuda') if rank == 0: - logger.debug('Updating model.') - patched_model = update_model(patched_model) + logger.info('build model.') + patched_model = build_patched_model(model_config, device=device_map) + if rank == 0: + logger.info('loading weights.') + load_model_weights(patched_model, model_path, device=device_map) _update_cache_config(model_config, cache_config, @@ -522,10 +406,16 @@ def _broadcast_config(cache_config): model_config, rank=rank, world_size=world_size) + target_infos = None + if adapters is not None: + target_infos = add_adapters(patched_model, + cache_engine.gpu_cache, + adapters=adapters) + except Exception as e: raise e - return patched_model, cache_engine, cache_config + return patched_model, cache_engine, cache_config, target_infos def _broadcast_inputs(rank: int, inputs: Any, stream: torch.cuda.Stream): @@ -542,17 +432,16 @@ def _broadcast_inputs(rank: int, inputs: Any, stream: torch.cuda.Stream): @torch.inference_mode() def _tp_paging_adapters( rank: int, - patched_model: torch.nn.Module, cache_engine: CacheEngine, - weight_map: AdapterWeightMap = None, + weight_maps: AdapterWeightMap = None, ): """tp paging adapters.""" - def __get_weight_map(): + def __get_weight_map(weight_maps): """get weight map.""" if rank == 0: - assert weight_map is not None - dist_obj = [weight_map] + assert weight_maps is not None + dist_obj = [weight_maps] else: dist_obj = [None] dist.broadcast_object_list(dist_obj) @@ -560,29 +449,14 @@ def __get_weight_map(): def __paging(weight_maps): """paging.""" - lora_linears = _get_indexed_lora_linears(patched_model) cpu_caches = cache_engine.cpu_cache - num_blocks = cache_engine.num_cpu_blocks - cpu_caches = [(kcache.view(num_blocks, - -1), vcache.view(num_blocks, -1)) + cpu_caches = [(kcache.flatten(1, -1), vcache.flatten(1, -1)) for kcache, vcache in cpu_caches] - gpu_caches = cache_engine.gpu_cache - num_gpu_blocks = cache_engine.num_gpu_blocks - gpu_caches = [(kcache.view(num_gpu_blocks, - -1), vcache.view(num_gpu_blocks, -1)) - for kcache, vcache in gpu_caches] for weight_map in weight_maps: - weight_map.cache_adapter(lora_linears, cpu_caches) - rank_offsets = update_lora_linears(lora_linears, - weight_maps, - gpu_caches, - device='cuda') - patched_model.rank_offsets = rank_offsets + weight_map.cache_adapter(cpu_caches) - weight_maps = __get_weight_map() + weight_maps = __get_weight_map(weight_maps) - if rank == 0: - logger.info('tp paging adapters.') if len(weight_maps) > 0: __paging(weight_maps) @@ -610,21 +484,16 @@ def _tp_model_loop( world_size (int): The distribution world size. """ stream = torch.cuda.Stream() - patched_model, cache_engine, _ = _tp_build_model( - rank, - model_path, - model_config, - cache_config, - backend_config, - adapters, - world_size=world_size, - trust_remote_code=trust_remote_code) + patched_model, cache_engine, _, _ = _tp_build_model(rank, + model_path, + model_config, + cache_config, + backend_config, + adapters=adapters, + world_size=world_size) if adapters: - _tp_paging_adapters(rank, - patched_model, - cache_engine=cache_engine, - weight_map=None) + _tp_paging_adapters(rank, cache_engine=cache_engine, weight_maps=None) while True: inputs, swap_in_map, swap_out_map, exit_flag = _broadcast_inputs( @@ -758,18 +627,18 @@ def __signal_term_handler(sig, frame): world_size=world_size, trust_remote_code=trust_remote_code) - model, cache_engine, cache_config = self._build_model( + model, cache_engine, cache_config, target_infos = self._build_model( model_path=model_path, model_config=model_config, cache_config=cache_config, backend_config=backend_config, adapters=adapters, world_size=world_size, - trust_remote_code=trust_remote_code, ) self.patched_model = model self.cache_config = cache_config self.cache_engine = cache_engine + self._target_infos = target_infos self.stream = torch.cuda.Stream() def _start_sub_process(self, model_path: str, model_config: ModelConfig, @@ -831,12 +700,11 @@ def _build_model( backend_config: BackendConfig, adapters: Dict[str, str], world_size: int, - trust_remote_code=True, ): """build model.""" _check_context_alive(self.mp_context) rank = 0 - model, cache_engine, cache_config = _tp_build_model( + model, cache_engine, cache_config, target_infos = _tp_build_model( rank, model_path=model_path, model_config=model_config, @@ -844,14 +712,13 @@ def _build_model( backend_config=backend_config, adapters=adapters, world_size=world_size, - trust_remote_code=trust_remote_code, ) - return model, cache_engine, cache_config + return model, cache_engine, cache_config, target_infos - def get_loralinear_info(self): + def get_lora_target_info(self): """get lora linear info.""" - return _get_loralinear_info(self.patched_model) + return self._target_infos def get_block_numel(self): """get block nelement.""" @@ -864,8 +731,8 @@ def paging_adapters(self, weight_maps: List[AdapterWeightMap]): return _check_context_alive(self.mp_context) rank = 0 - _tp_paging_adapters(rank, self.patched_model, self.cache_engine, - weight_maps) + logger.info('paging adapters.') + _tp_paging_adapters(rank, self.cache_engine, weight_maps) def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap): diff --git a/lmdeploy/pytorch/models/__init__.py b/lmdeploy/pytorch/models/__init__.py index fd697634a..2a69e4e71 100644 --- a/lmdeploy/pytorch/models/__init__.py +++ b/lmdeploy/pytorch/models/__init__.py @@ -1,5 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .patch import patch from .q_modules import QLinear, QRMSNorm -__all__ = ['patch', 'QLinear', 'QRMSNorm'] +__all__ = ['QLinear', 'QRMSNorm'] diff --git a/lmdeploy/pytorch/models/llama.py b/lmdeploy/pytorch/models/llama.py index 2c8383ec6..b7e4c7065 100644 --- a/lmdeploy/pytorch/models/llama.py +++ b/lmdeploy/pytorch/models/llama.py @@ -1,46 +1,51 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Iterable, List, Optional, Tuple, Union import torch import torch.distributed as dist from torch import nn from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.models.llama import LlamaConfig from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, EmbeddingType, RMSNorm, SiluAndMul, build_rotary_embedding) from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear, - build_rowwise_linear) + build_qkv_proj, build_rowwise_linear) from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight -from ..weight_loader.dist_utils import (colwise_parallelize_linear, - rowwise_parallelize_linear) - class LlamaAttention(nn.Module): """Rewrite module of LlamaAttention.""" - def __init__(self, origin: nn.Module, ctx_mgr: StepContextManager): + def __init__(self, + config: LlamaConfig, + dtype: torch.dtype = None, + device: torch.device = None): super().__init__() world_size = 1 if dist.is_initialized(): world_size = dist.get_world_size() - is_tp = world_size > 1 - self.ctx_mgr = ctx_mgr - self.num_heads = origin.num_heads // world_size - self.num_kv_heads = origin.num_key_value_heads // world_size - self.head_dim = origin.head_dim + quantization_config = getattr(config, 'quantization_config', None) + num_heads = config.num_attention_heads + num_key_value_heads = config.num_key_value_heads + hidden_size = config.hidden_size + head_dim = getattr(config, 'head_dim', hidden_size // num_heads) + self.num_heads = num_heads // world_size + self.num_kv_heads = num_key_value_heads // world_size + self.head_dim = head_dim # packed qkv - self.qkv_proj = build_merged_colwise_linear( - origin.q_proj, - origin.k_proj, - origin.v_proj, - ctx_mgr=ctx_mgr, - is_tp=is_tp, + self.qkv_proj = build_qkv_proj( + hidden_size, + num_q_heads=num_heads, + num_kv_heads=num_key_value_heads, + head_size=head_dim, + bias=config.attention_bias, + quant_config=quantization_config, + dtype=dtype, + device=device, ) - # free old weight - del origin.q_proj, origin.k_proj, origin.v_proj # rotary embedding self.apply_rotary_pos_emb = ApplyRotaryEmb() @@ -54,30 +59,13 @@ def __init__(self, origin: nn.Module, ctx_mgr: StepContextManager): ) # o_proj - self.o_proj = build_rowwise_linear( - origin.o_proj, - ctx_mgr=ctx_mgr, - is_tp=is_tp, - ) - - @staticmethod - def _load_weights(mod, loader, rank: int, world_size: int, - device: torch.device): - """load weights, support TP.""" - # split weight of qkv proj. - for mod_name in ['q_proj', 'k_proj', 'v_proj']: - colwise_parallelize_linear(getattr(mod, mod_name), - loader, - rank=rank, - world_size=world_size, - prefix=mod_name) - - # split weight of o_proj - rowwise_parallelize_linear(mod.o_proj, - loader, - rank=rank, - world_size=world_size, - prefix='o_proj') + self.o_proj = build_rowwise_linear(num_heads * head_dim, + hidden_size, + bias=config.attention_bias, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) def forward( self, @@ -126,55 +114,39 @@ def forward( # o proj attn_output = self.o_proj(attn_output) - return attn_output class LlamaMLP(nn.Module): - def __init__(self, origin: nn.Module, ctx_mgr: StepContextManager): + def __init__(self, + config: LlamaConfig, + dtype: torch.dtype = None, + device: torch.device = None): super().__init__() - world_size = 1 - if dist.is_initialized(): - world_size = dist.get_world_size() - is_tp = world_size > 1 + quantization_config = getattr(config, 'quantization_config', None) # gate up self.gate_up_proj = build_merged_colwise_linear( - origin.gate_proj, - origin.up_proj, - ctx_mgr=ctx_mgr, - is_tp=is_tp, + config.hidden_size, + [config.intermediate_size, config.intermediate_size], + bias=config.mlp_bias, + dtype=dtype, + device=device, + quant_config=quantization_config, + is_tp=True, ) - # free old weight - del origin.gate_proj, origin.up_proj # silu and mul self.act_fn = SiluAndMul(inplace=True) # down - self.down_proj = build_rowwise_linear(origin.down_proj, - ctx_mgr=ctx_mgr, - is_tp=is_tp) - - @staticmethod - def _load_weights(mod: nn.Module, loader, rank: int, world_size: int, - device: torch.device): - """load weights, support TP.""" - - # split weight of gate_proj and up_proj - for mod_name in ['gate_proj', 'up_proj']: - colwise_parallelize_linear(getattr(mod, mod_name), - loader, - rank=rank, - world_size=world_size, - prefix=mod_name) - - # split weight of down_proj - rowwise_parallelize_linear(mod.down_proj, - loader, - rank=rank, - world_size=world_size, - prefix='down_proj') + self.down_proj = build_rowwise_linear(config.intermediate_size, + config.hidden_size, + bias=config.mlp_bias, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) def forward(self, x): """forward.""" @@ -185,41 +157,35 @@ def forward(self, x): class LlamaDecoderLayer(nn.Module): - def __init__(self, origin: nn.Module, layer_idx: int, - ctx_mgr: StepContextManager): + def __init__(self, + config: LlamaConfig, + layer_idx: int, + dtype: torch.dtype = None, + device: torch.device = None): super().__init__() self.layer_idx = layer_idx + quantization_config = getattr(config, 'quantization_config', None) # build attention layer - self.self_attn = LlamaAttention(origin.self_attn, ctx_mgr) + self.self_attn = LlamaAttention(config, dtype=dtype, device=device) # builf MLP - self.mlp = LlamaMLP(origin.mlp, ctx_mgr) + self.mlp = LlamaMLP(config, dtype=dtype, device=device) # build input layer norm - input_layernorm = origin.input_layernorm - is_w8a8 = hasattr(input_layernorm, 'from_float') - self.input_layernorm = RMSNorm( - input_layernorm.weight.size(0), - input_layernorm.variance_epsilon, - dtype=input_layernorm.weight.dtype, - device=input_layernorm.weight.device, - is_w8a8=is_w8a8, - ) - load_weight(self.input_layernorm.weight, input_layernorm.weight) + self.input_layernorm = RMSNorm(config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) # build attention layer norm - post_attention_layernorm = origin.post_attention_layernorm - is_w8a8 = hasattr(post_attention_layernorm, 'from_float') self.post_attention_layernorm = RMSNorm( - post_attention_layernorm.weight.size(0), - post_attention_layernorm.variance_epsilon, - dtype=post_attention_layernorm.weight.dtype, - device=post_attention_layernorm.weight.device, - is_w8a8=is_w8a8, - ) - load_weight(self.post_attention_layernorm.weight, - post_attention_layernorm.weight) + config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) def forward( self, @@ -256,48 +222,56 @@ def forward( class LlamaModel(nn.Module): - def __init__(self, origin: nn.Module, ctx_mgr: StepContextManager): + def __init__(self, + config: LlamaConfig, + dtype: torch.dtype = None, + device: torch.device = None): super().__init__() - self.ctx_mgr = ctx_mgr - self.embed_tokens = origin.embed_tokens + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + quantization_config = getattr(config, 'quantization_config', None) + + self.embed_tokens = nn.Embedding(config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=dtype, + device=device) # build all decode layers self.layers = nn.ModuleList([ - LlamaDecoderLayer(layer, idx, ctx_mgr) - for idx, layer in enumerate(origin.layers) + LlamaDecoderLayer(config, layer_idx, dtype=dtype, device=device) + for layer_idx in range(config.num_hidden_layers) ]) # build norm - norm = origin.norm - is_w8a8 = hasattr(norm, 'from_float') - self.norm = RMSNorm(norm.weight.size(0), - norm.variance_epsilon, - dtype=norm.weight.dtype, - device=norm.weight.device, - is_w8a8=is_w8a8) - load_weight(self.norm.weight, norm.weight) + self.norm = RMSNorm(config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) # build rotary embedding in LlamaModel - rotary_emb = origin.layers[0].self_attn.rotary_emb - rotary_name = type(rotary_emb).__name__ - if rotary_name in [ - 'LlamaRotaryEmbedding', 'LlamaLinearScalingRotaryEmbedding' - ]: + rope_scaling = config.rope_scaling + if rope_scaling is None: emb_type = EmbeddingType.LinearScaling - elif rotary_name == 'LlamaDynamicNTKScalingRotaryEmbedding': - emb_type = EmbeddingType.DynamicNTKScaling - config = origin.config + else: + rope_type = rope_scaling['rope_type'] + if rope_type == 'dynamic': + emb_type = EmbeddingType.DynamicNTKScaling + elif rope_type == 'llama3': + emb_type = EmbeddingType.Llama3 + else: + raise RuntimeError(f'Unsupported rope type: {rope_type}') + rope_dim = config.hidden_size // config.num_attention_heads rope_max_pos_emb = config.max_position_embeddings rope_base = config.rope_theta scaling_factor = 1.0 low_freq_factor = 1.0 high_freq_factor = 4.0 - if config.rope_scaling is not None: - rope_scaling = config.rope_scaling + if rope_scaling is not None: scaling_factor = rope_scaling.get('scaling_factor', scaling_factor) - if rope_scaling['rope_type'] == 'llama3': - emb_type = EmbeddingType.Llama3 + if emb_type == EmbeddingType.Llama3: low_freq_factor = rope_scaling.get('low_freq_factor', 1.0) high_freq_factor = rope_scaling.get('high_freq_factor', 1.0) self.rotary_emb = build_rotary_embedding( @@ -358,13 +332,34 @@ class LlamaForCausalLM(nn.Module): support_cuda_graph = True - def __init__(self, origin: nn.Module, ctx_mgr: StepContextManager): + packed_modules_mapping = { + 'qkv_proj': [ + 'q_proj', + 'k_proj', + 'v_proj', + ], + 'gate_up_proj': [ + 'gate_proj', + 'up_proj', + ], + } + + def __init__(self, + config: LlamaConfig, + ctx_mgr: StepContextManager, + dtype: torch.dtype = None, + device: torch.device = None): super().__init__() + self.config = config self.ctx_mgr = ctx_mgr # build LLamaModel - self.model = LlamaModel(origin.model, ctx_mgr) + self.model = LlamaModel(config, dtype=dtype, device=device) # build lm_head - self.lm_head = build_rowwise_linear(origin.lm_head) + self.lm_head = build_rowwise_linear(config.hidden_size, + config.vocab_size, + bias=False, + dtype=dtype, + device=device) def forward( self, @@ -422,3 +417,35 @@ def prepare_inputs_for_generation( attn_metadata=attn_metadata, inputs_embeds=inputs_embeds, ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """load weights.""" + # modify from vllm + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ('.qkv_proj', '.q_proj', 'q'), + ('.qkv_proj', '.k_proj', 'k'), + ('.qkv_proj', '.v_proj', 'v'), + ('.gate_up_proj', '.gate_proj', 0), + ('.gate_up_proj', '.up_proj', 1), + ] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if 'rotary_emb.inv_freq' in name: + continue + if ('rotary_emb.cos_cached' in name + or 'rotary_emb.sin_cached' in name): + continue + if self.config.tie_word_embeddings and 'lm_head.weight' in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + load_weight(param, loaded_weight, shard_id=shard_id) + break + else: + param = params_dict[name] + load_weight(param, loaded_weight) diff --git a/lmdeploy/pytorch/models/mixtral.py b/lmdeploy/pytorch/models/mixtral.py index ece6411b7..78a776cdd 100644 --- a/lmdeploy/pytorch/models/mixtral.py +++ b/lmdeploy/pytorch/models/mixtral.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Iterable, List, Optional, Tuple, Union import torch import torch.distributed as dist @@ -9,44 +9,48 @@ from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, EmbeddingType, RMSNorm, build_rotary_embedding) -from lmdeploy.pytorch.nn.linear import (build_colwise_linear, - build_merged_colwise_linear, - build_rowwise_linear) -from lmdeploy.pytorch.nn.moe import SoftmaxTopK, build_moe_from_mlp +from lmdeploy.pytorch.nn.linear import build_qkv_proj, build_rowwise_linear +from lmdeploy.pytorch.nn.moe import FusedMoE, SoftmaxTopK from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight -from ..weight_loader.dist_utils import (colwise_parallelize_linear, - rowwise_parallelize_linear) - class MixtralAttention(nn.Module): """mixtral attention.""" - def __init__(self, origin: nn.Module, ctx_mgr: StepContextManager): + def __init__(self, + config: Any, + dtype: torch.dtype = None, + device: torch.device = None): super().__init__() world_size = 1 if dist.is_initialized(): world_size = dist.get_world_size() - is_tp = world_size > 1 - self.ctx_mgr = ctx_mgr - self.num_heads = origin.num_heads // world_size - self.num_kv_heads = origin.num_key_value_heads // world_size - self.head_dim = origin.head_dim + quantization_config = None + + num_heads = config.num_attention_heads + num_key_value_heads = config.num_key_value_heads + hidden_size = config.hidden_size + head_dim = hidden_size // num_heads + self.num_heads = num_heads // world_size + self.num_kv_heads = num_key_value_heads // world_size + self.head_dim = head_dim # qkv - self.qkv_proj = build_merged_colwise_linear( - origin.q_proj, - origin.k_proj, - origin.v_proj, - ctx_mgr=ctx_mgr, - is_tp=is_tp, + self.qkv_proj = build_qkv_proj( + hidden_size, + num_q_heads=num_heads, + num_kv_heads=num_key_value_heads, + head_size=head_dim, + bias=False, + quant_config=quantization_config, + dtype=dtype, + device=device, ) - del origin.q_proj, origin.k_proj, origin.v_proj self.apply_rotary_pos_emb = ApplyRotaryEmb() # attention - self.window_size = origin.config.sliding_window or -1 + self.window_size = config.sliding_window or -1 self.attn_fwd = Attention( self.num_heads, self.head_dim, @@ -55,27 +59,14 @@ def __init__(self, origin: nn.Module, ctx_mgr: StepContextManager): sliding_window=self.window_size, ) - self.o_proj = build_rowwise_linear( - origin.o_proj, - ctx_mgr=ctx_mgr, - is_tp=is_tp, - ) - - @staticmethod - def _load_weights(mod, loader, rank: int, world_size: int, - device: torch.device): - """load weights.""" - for mod_name in ['q_proj', 'k_proj', 'v_proj']: - colwise_parallelize_linear(getattr(mod, mod_name), - loader, - rank=rank, - world_size=world_size, - prefix=mod_name) - rowwise_parallelize_linear(mod.o_proj, - loader, - rank=rank, - world_size=world_size, - prefix='o_proj') + # o_proj + self.o_proj = build_rowwise_linear(num_heads * head_dim, + hidden_size, + bias=False, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) def forward( self, @@ -123,50 +114,46 @@ def forward( return attn_output -class MixtralBLockSparseTop2MLP(nn.Module): - - def _load_weights(self, loader, rank: int, world_size: int, - device: torch.device): - """load weights.""" - for mod_name in ['w1', 'w3']: - colwise_parallelize_linear(getattr(self, mod_name), - loader, - rank=rank, - world_size=world_size, - prefix=mod_name) - rowwise_parallelize_linear(self.w2, - loader, - rank=rank, - world_size=world_size, - prefix='w2') - - class MixtralSparseMoeBlock(nn.Module): """mixtral sparse moe block.""" - def __init__(self, origin: nn.Module, ctx_mgr: StepContextManager): + def __init__(self, + config: Any, + dtype: torch.dtype = None, + device: torch.device = None): super().__init__() world_size = 1 if dist.is_initialized(): world_size = dist.get_world_size() is_tp = world_size > 1 self.is_tp = is_tp - self.top_k = origin.top_k - self.gate = build_colwise_linear( - origin.gate, - ctx_mgr=ctx_mgr, - is_tp=is_tp, + + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok + + self.gate = build_rowwise_linear( + self.hidden_dim, + self.num_experts, + bias=False, + dtype=dtype, + device=device, + is_tp=False, ) + self.softmax_topk = SoftmaxTopK(self.top_k) - gates = [exp.w1 for exp in origin.experts] - ups = [exp.w3 for exp in origin.experts] - downs = [exp.w2 for exp in origin.experts] - self.fused_moe = build_moe_from_mlp(gates, - ups, - downs, - top_k=self.top_k, - renormalize=True) + self.experts = FusedMoE( + self.hidden_dim, + self.ffn_dim, + self.num_experts, + top_k=self.top_k, + renormalize=True, + dtype=dtype, + device=device, + is_tp=False, + ) def forward(self, hidden_states: torch.Tensor): """forward.""" @@ -175,7 +162,7 @@ def forward(self, hidden_states: torch.Tensor): router_logits = self.gate(hidden_states) topk_weights, topk_ids = self.softmax_topk(router_logits) - out_states = self.fused_moe( + out_states = self.experts( hidden_states, topk_weights, topk_ids, @@ -191,36 +178,35 @@ def forward(self, hidden_states: torch.Tensor): class MixtralDecoderLayer(nn.Module): """mixtral decoder layer.""" - def __init__(self, origin: nn.Module, layer_idx: int, - ctx_mgr: StepContextManager): + def __init__(self, + config: Any, + layer_idx: int, + dtype: torch.dtype = None, + device: torch.device = None): super().__init__() self.layer_idx = layer_idx - self.self_attn = MixtralAttention(origin.self_attn, ctx_mgr) - self.block_sparse_moe = MixtralSparseMoeBlock(origin.block_sparse_moe, - ctx_mgr) - - # norm - input_layernorm = origin.input_layernorm - is_w8a8 = hasattr(input_layernorm, 'from_float') - self.input_layernorm = RMSNorm( - input_layernorm.weight.size(0), - input_layernorm.variance_epsilon, - dtype=input_layernorm.weight.dtype, - device=input_layernorm.weight.device, - is_w8a8=is_w8a8, - ) - load_weight(self.input_layernorm.weight, input_layernorm.weight) - post_attention_layernorm = origin.post_attention_layernorm - is_w8a8 = hasattr(post_attention_layernorm, 'from_float') + quantization_config = None + + # build attention layer + self.self_attn = MixtralAttention(config, dtype=dtype, device=device) + self.block_sparse_moe = MixtralSparseMoeBlock(config, + dtype=dtype, + device=device) + + # build input layer norm + self.input_layernorm = RMSNorm(config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) + + # build attention layer norm self.post_attention_layernorm = RMSNorm( - post_attention_layernorm.weight.size(0), - post_attention_layernorm.variance_epsilon, - dtype=post_attention_layernorm.weight.dtype, - device=post_attention_layernorm.weight.device, - is_w8a8=is_w8a8, - ) - load_weight(self.post_attention_layernorm.weight, - post_attention_layernorm.weight) + config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) def forward( self, @@ -258,25 +244,31 @@ def forward( class MixtralModel(nn.Module): """mixtral model.""" - def __init__(self, origin: nn.Module, ctx_mgr: StepContextManager): + def __init__(self, + config: Any, + dtype: torch.dtype = None, + device: torch.device = None): super().__init__() - self.ctx_mgr = ctx_mgr - self.embed_tokens = origin.embed_tokens + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = nn.Embedding(config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=dtype, + device=device) self.layers = nn.ModuleList([ - MixtralDecoderLayer(layer, idx, ctx_mgr) - for idx, layer in enumerate(origin.layers) + MixtralDecoderLayer(config, layer_idx, dtype=dtype, device=device) + for layer_idx in range(config.num_hidden_layers) ]) - norm = origin.norm - is_w8a8 = hasattr(norm, 'from_float') - self.norm = RMSNorm(norm.weight.size(0), - norm.variance_epsilon, - dtype=norm.weight.dtype, - device=norm.weight.device, - is_w8a8=is_w8a8) - load_weight(self.norm.weight, norm.weight) + + # build norm + self.norm = RMSNorm(config.hidden_size, + config.rms_norm_eps, + quant_config=None, + dtype=dtype, + device=device) emb_type = EmbeddingType.LinearScaling - config = origin.config rope_dim = config.hidden_size // config.num_attention_heads rope_max_pos_emb = config.max_position_embeddings rope_base = config.rope_theta @@ -286,7 +278,7 @@ def __init__(self, origin: nn.Module, ctx_mgr: StepContextManager): rope_max_pos_emb, rope_base, scaling_factor, - emb_type, + emb_type=emb_type, ) def forward( @@ -331,11 +323,21 @@ class MixtralForCausalLM(nn.Module): support_cuda_graph = True - def __init__(self, origin: nn.Module, ctx_mgr: StepContextManager): + def __init__(self, + config: Any, + ctx_mgr: StepContextManager, + dtype: torch.dtype = None, + device: torch.device = None): super().__init__() + self.config = config self.ctx_mgr = ctx_mgr - self.model = MixtralModel(origin.model, ctx_mgr) - self.lm_head = build_rowwise_linear(origin.lm_head) + self.model = MixtralModel(config, dtype=dtype, device=device) + # build lm_head + self.lm_head = build_rowwise_linear(config.hidden_size, + config.vocab_size, + bias=False, + dtype=dtype, + device=device) def forward( self, @@ -380,3 +382,57 @@ def prepare_inputs_for_generation( attn_metadata=attn_metadata, inputs_embeds=inputs_embeds, ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """load weights.""" + # modify from vllm + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ('.qkv_proj', '.q_proj', 'q'), + ('.qkv_proj', '.k_proj', 'k'), + ('.qkv_proj', '.v_proj', 'v'), + ] + + num_experts = self.config.num_local_experts + expert_params_mapping = [] + for exp_id in range(num_experts): + gate_param = ('.experts.gate_up_weights', + f'.experts.{exp_id}.w1.weight', exp_id, 'gate') + up_param = ('.experts.gate_up_weights', + f'.experts.{exp_id}.w3.weight', exp_id, 'up') + down_param = ('.experts.down_weights', + f'.experts.{exp_id}.w2.weight', exp_id, 'down') + expert_params_mapping += [gate_param, up_param, down_param] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if 'rotary_emb.inv_freq' in name: + continue + if ('rotary_emb.cos_cached' in name + or 'rotary_emb.sin_cached' in name): + continue + if self.config.tie_word_embeddings and 'lm_head.weight' in name: + continue + for (param_name, weight_name, expert_id, + shard_id) in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + load_weight(param, + loaded_weight, + expert_id=expert_id, + shard_id=shard_id) + break + else: + for (param_name, weight_name, + shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + load_weight(param, loaded_weight, shard_id=shard_id) + break + else: + param = params_dict[name] + load_weight(param, loaded_weight) diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index ef1f85420..169974b7e 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -4,20 +4,7 @@ # llama MODULE_MAP = { - 'transformers.models.llama.modeling_llama.LlamaFlashAttention2': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaAttention', - 'transformers.models.llama.modeling_llama.LlamaSdpaAttention': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaAttention', - 'transformers.models.llama.modeling_llama.LlamaAttention': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaAttention', - 'transformers.models.llama.modeling_llama.LlamaMLP': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaMLP', - # support modeling rewritten in lmdeploy - 'modeling_llama.LlamaAttention': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaAttention', - 'modeling_llama.LlamaMLP': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaMLP', - 'modeling_llama.LlamaForCausalLM': + 'LlamaForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaForCausalLM', } @@ -224,17 +211,7 @@ # mixtral MODULE_MAP.update({ - 'transformers.models.mixtral.modeling_mixtral.MixtralAttention': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mixtral.MixtralAttention', - 'transformers.models.mixtral.modeling_mixtral.MixtralFlashAttention2': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mixtral.MixtralAttention', - 'transformers.models.mixtral.modeling_mixtral.MixtralSdpaAttention': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mixtral.MixtralAttention', - 'transformers.models.mixtral.modeling_mixtral.MixtralBLockSparseTop2MLP': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mixtral.MixtralBLockSparseTop2MLP', - 'transformers.models.mixtral.modeling_mixtral.MixtralBlockSparseTop2MLP': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mixtral.MixtralBLockSparseTop2MLP', - 'transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM': + 'MixtralForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mixtral.MixtralForCausalLM', }) diff --git a/lmdeploy/pytorch/models/patch.py b/lmdeploy/pytorch/models/patch.py index ed0284c66..dfe3660c9 100644 --- a/lmdeploy/pytorch/models/patch.py +++ b/lmdeploy/pytorch/models/patch.py @@ -4,13 +4,13 @@ import os.path as osp import re import sys -from copy import copy -from typing import Any, Dict +from typing import Any, Dict, List import torch from lmdeploy.utils import get_logger +from ..config import ModelConfig from ..devices import get_device_manager from .module_map import (CUSTOM_MODULE_MAP, DEVICE_SPECIAL_MODULE_MAP, MODULE_MAP) @@ -91,29 +91,6 @@ def _find_submodulename(): return rewrite_qualname -def _update_module_type(model: Any, cls_type: type, custom_attrs: dict = None): - """Update class type of model.""" - # directly return origin model is not cool - # origin model would be registered as a submodule - old_type = type(model) - - @property - def get_origin_mod(self): - origin_mod = copy(self) - origin_mod.__class__ = old_type - return origin_mod - - attrs = dict(cls_type.__dict__) - custom_attrs = custom_attrs or dict() - custom_attrs['origin_mod'] = get_origin_mod - attrs.update(custom_attrs) - new_type = type(cls_type.__name__, (type(model), ), attrs) - model = copy(model) - model.__class__ = new_type - - return model - - def get_rewrite_cls(model: torch.nn.Module, module_map: Dict[str, str] = None): """get rewrite cls.""" if module_map is None: @@ -125,34 +102,6 @@ def get_rewrite_cls(model: torch.nn.Module, module_map: Dict[str, str] = None): return _class_from_qualname(rewrite_qualname) -def _patch(model: torch.nn.Module, module_map: Dict[str, - str]) -> torch.nn.Module: - """patch the model with rewrite module. - - Args: - model (Module): model to be patched. - - Returns: - Module: The patched model - """ - - def _recursive_children(named_children): - """recursive children.""" - for name, child in named_children: - _patch(child, module_map=module_map) - - _recursive_children(model.named_children()) - rewrite_qualname = _find_rewrite_module_qualname(model, - module_map=module_map) - - if rewrite_qualname is not None: - cls_type = _class_from_qualname(rewrite_qualname) - if hasattr(cls_type, '_load_weights'): - setattr(model, '_load_weights', cls_type._load_weights) - - return model - - def _get_module_map(): """get module map.""" module_map = MODULE_MAP.copy() @@ -165,39 +114,6 @@ def _get_module_map(): return module_map -@torch.inference_mode() -def patch(model: torch.nn.Module, ): - """Patch the model with rewrite modules. - - Extra arguments will be patched in forward of model, weights on each rank - will be partitioned. - - Args: - model (Module): Model to be patched. - - Returns: - Module: The patched model. - """ - module_map = _get_module_map() - model = _patch(model, module_map=module_map) - return model - - -def update_model(model: torch.nn.Module): - """build model.""" - from lmdeploy.pytorch.model_inputs import StepContextManager - ctx_mgr = StepContextManager() - module_map = _get_module_map() - - rewrite_qualname = _find_rewrite_module_qualname(model, - module_map=module_map) - - if rewrite_qualname is not None: - model_cls = _class_from_qualname(rewrite_qualname) - - return model_cls(model, ctx_mgr) - - def update_custom_module_map(module_map_path: str): """moad custom module map from file.""" from importlib.machinery import SourceFileLoader @@ -235,3 +151,137 @@ def update_custom_module_map(module_map_path: str): new_mod_map[k] = v CUSTOM_MODULE_MAP.update(new_mod_map) + + +def _get_model_class(config, module_map): + architectures = getattr(config, 'architectures', []) + for arch in architectures: + if arch in module_map: + qualname = module_map[arch] + module_cls = _class_from_qualname(qualname) + return module_cls + + raise RuntimeError( + f'Can not found rewrite for architectures: {architectures}') + + +@torch.inference_mode() +def build_patched_model(config: ModelConfig, device: torch.device = None): + """build patched model.""" + from lmdeploy.pytorch.model_inputs import StepContextManager + ctx_mgr = StepContextManager() + module_map = _get_module_map() + model_config = config.hf_config + if device is None: + device = torch.device('cuda') + model_cls = _get_model_class(model_config, module_map) + model = model_cls(model_config, ctx_mgr, dtype=config.dtype, device=device) + return model.eval() + + +@torch.inference_mode() +def add_adapters(model: torch.nn.Module, + kv_caches: List[List[torch.Tensor]], + adapters: Dict[str, str], + device: torch.device = None): + """add adapters.""" + from peft import PeftConfig + from peft.tuners.lora import LoraConfig + + from lmdeploy.pytorch.adapter.adapter import (LoRATargetInfo, + find_all_target, + get_layer_index, + get_ranks_and_scalings) + from lmdeploy.pytorch.nn.linear import SLoRA + num_adapters = len(adapters) + if num_adapters == 0: + return + + if device is None: + device = torch.device('cuda') + + # model could be graph runner + origin_model = model + if hasattr(model, 'get_model'): + model = model.get_model() + ctx_mgr = model.ctx_mgr + + adapter_cfgs = [ + PeftConfig.from_pretrained(path) for path in adapters.values() + ] + # get layer pattern (should be same between different adapter) + config = next(iter(adapter_cfgs)) + layers_pattern = getattr(config, 'layers_pattern', None) + + # insert one for no adapter + adapter_cfgs = [LoraConfig(r=0, target_modules=[])] + adapter_cfgs + + # target layer name to add adapter + target_names = set() + max_rank = 0 + for cfg in adapter_cfgs: + target_names = target_names.union(cfg.target_modules) + max_rank = max(max_rank, cfg.r) + target_names = list(target_names) + target_names = sorted(target_names) + num_targets = len(target_names) + + # get rank offsets + # add 1 for none adapter + rank_offsets = torch.zeros(num_adapters + 1, + num_targets * max_rank, + dtype=torch.int64, + device=device) + + target_infos = dict() + for target_idx, target_name in enumerate(target_names): + # get ranks and scalings + ranks, scalings = get_ranks_and_scalings(target_name, + adapter_cfgs, + device=device) + found_mods, pack_idx = find_all_target(model, target_name) + r_start = target_idx * max_rank + r_end = r_start + max_rank + r_offs = rank_offsets[:, r_start:r_end] + + in_features = 0 + out_features = 0 + colwise = True + for name, mod in found_mods: + assert hasattr(mod, 'lora_adapters') + layer_idx = get_layer_index(name, layers_pattern) + k_cache, v_cache = kv_caches[layer_idx] + in_features = mod.in_features + colwise = mod.colwise + if pack_idx is None: + base_slice = slice(0, mod.out_features) + out_features = mod.out_features + else: + prev_feats = sum(mod.all_out_features[:pack_idx]) + out_features = mod.all_out_features[pack_idx] + base_slice = slice(prev_feats, prev_feats + out_features) + + slora = SLoRA( + in_features, + out_features, + ranks=ranks, + scalings=scalings, + rank_offsets=r_offs, + a_cache=k_cache, + b_cache=v_cache, + base_slice=base_slice, + max_rank=max_rank, + ctx_mgr=ctx_mgr, + colwise=colwise, + is_tp=mod.is_tp, + ) + mod.lora_adapters.append(slora) + + target_info = LoRATargetInfo(in_features=in_features, + out_features=out_features, + colwise=colwise) + target_infos[target_name] = target_info + + # add rank_offsets + setattr(origin_model, 'rank_offsets', rank_offsets) + return target_infos diff --git a/lmdeploy/pytorch/nn/linear.py b/lmdeploy/pytorch/nn/linear.py index abd041166..c7df46eb2 100644 --- a/lmdeploy/pytorch/nn/linear.py +++ b/lmdeploy/pytorch/nn/linear.py @@ -1,11 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, List +from typing import Any, List, Optional import torch import torch.distributed as dist from torch import nn -from lmdeploy.pytorch.models.q_modules import QLinear +from lmdeploy.pytorch.weight_loader.model_weight_loader import \ + default_weight_loader from lmdeploy.utils import get_logger from ..backends import LayerType, get_backend @@ -13,32 +14,6 @@ logger = get_logger('lmdeploy') -try: - from peft.tuners.lora import Linear as LoRALinear -except ImportError: - logger.debug('load peft.tuners.lora.Linear failed.') - - class LoRALinear: - pass - - -try: - from peft.tuners.lora.awq import AwqLoraLinear -except ImportError: - logger.debug('load peft.tuners.lora.awq.AwqLoraLinear failed.') - - class AwqLoraLinear: - pass - - -try: - from awq.modules.linear.gemm import WQLinear_GEMM -except ImportError: - logger.debug('load awq.modules.linear.gemm.WQLinearGEMM failed.') - - class WQLinear_GEMM: - pass - def _get_world_rank(): """get current world size and rank.""" @@ -52,90 +27,254 @@ def _get_world_rank(): return world_size, rank +def _div_up(a: int, b: int): + """div up.""" + return (a + b - 1) // b + + +def _chunk_align(weight: torch.Tensor, chunks: int, dim: int, align: int): + """chunk aligned.""" + if align == 1: + return weight.chunk(chunks, dim=dim) + size = weight.size(dim) + assert size % align == 0 + aligned_size = size // align + align_per_chunk = _div_up(aligned_size, chunks) + sections = [align_per_chunk] * (chunks - 1) + sections += [aligned_size - align_per_chunk * (chunks - 1)] + sections = [sec * align for sec in sections] + return weight.split(sections, dim=dim) + + +def _update_feature_size(feature_size: int, + world_size: int, + rank: int, + align: int = 1): + """update feature size.""" + assert feature_size % align == 0 + aligned_size = feature_size // align + align_per_rank = _div_up(aligned_size, world_size) + prev_feats = align_per_rank * rank + updated_aligned_size = min(align_per_rank, aligned_size - prev_feats) + return updated_aligned_size * align + + class SLoRA(nn.Module): """SLoRA layer.""" def __init__(self, - adapter_info: AdapterInfo, + in_features: int, + out_features: int, + ranks: torch.Tensor, + scalings: torch.Tensor, + rank_offsets: torch.Tensor, + a_cache: torch.Tensor, + b_cache: torch.Tensor, + base_slice: slice, + max_rank: int, ctx_mgr: Any = None, colwise: bool = True, is_tp: bool = True): super().__init__() - self.adapter_info = adapter_info + self.adapter_info = AdapterInfo( + in_features=in_features, + out_features=out_features, + ranks=ranks, + scalings=scalings, + rank_offsets=rank_offsets, + a_cache=a_cache, + b_cache=b_cache, + base_slice=base_slice, + max_rank=max_rank, + ) impl_builder = get_backend().get_layer_impl_builder(LayerType.SLoRA) - self.impl = impl_builder.build(adapter_info, ctx_mgr, colwise=colwise) - self.target_name = None - self.layer_idx = None + self.impl = impl_builder.build() self.is_tp = is_tp - - def post_init( - self, - ranks: torch.Tensor, - scalings: torch.Tensor, - rank_offsets: torch.Tensor, - a_cache: torch.Tensor, - b_cache: torch.Tensor, - max_rank: int, - ): - """post init.""" - self.impl.post_init( - ranks, - scalings, - rank_offsets, - a_cache, - b_cache, - max_rank, - ) + self.ctx_mgr = ctx_mgr + self.colwise = colwise def forward(self, x, base_output=None): """forward of loraA@loraB.""" - return self.impl.forward(x, base_output, self.target_name, - self.layer_idx, self.is_tp) + return self.impl.forward(x, + base_output, + self.adapter_info, + ctx_mgr=self.ctx_mgr, + colwise=self.colwise, + is_tp=self.is_tp) class AwqLinear(nn.Module): """w4a16 linear.""" def __init__(self, - mod: nn.Module, - adapter_infos: List[AdapterInfo] = None, - ctx_mgr: Any = None, + in_features: int, + out_features: int, + w_bit: int, + group_size: int, + bias: bool, + device: Optional[torch.device] = None, colwise: bool = True, is_tp: bool = False): super().__init__() + if device is None: + device = torch.device('cpu') + dtype = torch.float16 + if is_tp: + in_features, out_features = self._get_io_features( + in_features, out_features, w_bit, group_size, colwise) + qweight, scales, qzeros, bias = self.create_weights( + in_features, out_features, w_bit, group_size, bias, dtype, device) impl_builder = get_backend().get_layer_impl_builder( LayerType.LinearW4A16) - qweight = mod.qweight - scales = mod.scales - qzeros = mod.qzeros - bias = mod.bias - self.impl = impl_builder.build(mod.in_features, - mod.out_features, - mod.w_bit, - mod.group_size, + self.impl = impl_builder.build(in_features, + out_features, + w_bit, + group_size, bias is not None, dtype=scales.dtype) - qweight, scales, qzeros, bias = self.impl.update_weights( - qweight, scales, qzeros, bias) qweight = torch.nn.Parameter(qweight, requires_grad=False) + qweight.weight_loader = self.weight_loader + qweight._weight_type = 'qweight' scales = torch.nn.Parameter(scales, requires_grad=False) + scales.weight_loader = self.weight_loader + scales._weight_type = 'scales' qzeros = torch.nn.Parameter(qzeros, requires_grad=False) + qzeros.weight_loader = self.weight_loader + qzeros._weight_type = 'qzeros' if bias is not None: bias = torch.nn.Parameter(bias, requires_grad=False) + bias.weight_loader = self.weight_loader + bias._weight_type = 'bias' self.register_parameter('qweight', qweight) self.register_parameter('scales', scales) self.register_parameter('qzeros', qzeros) self.register_parameter('bias', bias) - adapter_infos = adapter_infos if adapter_infos is not None else [] - self.lora_adapters = None - if len(adapter_infos) > 0: - self.lora_adapters = nn.ModuleList( - SLoRA(info, ctx_mgr, colwise, is_tp) for info in adapter_infos) - + self.in_features = in_features + self.out_features = out_features + self.w_bit = w_bit + self.group_size = group_size + self.elem_per_int = 32 // self.w_bit + self.lora_adapters = [] self.is_tp = is_tp self.colwise = colwise + def _get_io_features(self, in_features: int, out_features: int, w_bit: int, + group_size: int, colwise: bool): + """get io features.""" + align = max(32 // w_bit, group_size) + world_size, rank = _get_world_rank() + if colwise: + out_features = _update_feature_size(out_features, + world_size, + rank, + align=align) + else: + in_features = _update_feature_size(in_features, + world_size, + rank, + align=align) + return in_features, out_features + + def _weight_loader_tp_colwise(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, rank: int, + world_size: int): + """weight loader for colwise linear.""" + if loaded_weight.dim() == 1: + # bias + align = max(self.elem_per_int, self.group_size) + weight = _chunk_align(loaded_weight, world_size, 0, align)[rank] + return default_weight_loader(param, weight) + + if loaded_weight.size(1) == self.out_features: + # scaling + align = max(self.elem_per_int, self.group_size) + weight = _chunk_align(loaded_weight, world_size, 1, align)[rank] + return default_weight_loader(param, weight) + + align = max(self.elem_per_int, self.group_size) // self.elem_per_int + weight = _chunk_align(loaded_weight, world_size, 1, align)[rank] + return default_weight_loader(param, weight) + + def _weight_loader_tp_rowwise(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, rank: int, + world_size: int): + """weight loader for rowwise linear.""" + if loaded_weight.dim() == 1: + # bias + return default_weight_loader(param, loaded_weight) + + if loaded_weight.size(0) == self.in_features: + # qweight + align = max(self.elem_per_int, self.group_size) + weight = _chunk_align(loaded_weight, world_size, 0, align)[rank] + return default_weight_loader(param, weight) + + align = max(self.elem_per_int, self.group_size) // self.group_size + weight = _chunk_align(loaded_weight, world_size, 0, align)[rank] + return default_weight_loader(param, weight) + + def weight_loader(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor): + """weight loader.""" + if not self.is_tp: + return default_weight_loader(param, loaded_weight) + + world_size, rank = _get_world_rank() + if self.colwise: + return self._weight_loader_tp_colwise(param, loaded_weight, rank, + world_size) + else: + return self._weight_loader_tp_rowwise(param, loaded_weight, rank, + world_size) + + def create_weights(self, in_features: int, out_features: int, w_bit: int, + group_size: int, bias: bool, dtype: torch.dtype, + device: torch.device): + """create weights.""" + assert in_features % group_size == 0 + elem_per_int = 32 // w_bit + assert out_features % elem_per_int == 0 + + grouped_in_feats = in_features // group_size + quant_out_feats = out_features // elem_per_int + qweight = torch.empty((in_features, quant_out_feats), + dtype=torch.int32, + device=device) + scales = torch.empty((grouped_in_feats, out_features), + dtype=dtype, + device=device) + qzeros = torch.empty((grouped_in_feats, quant_out_feats), + dtype=torch.int32, + device=device) + if bias: + bias = torch.empty((out_features, ), dtype=dtype, device=device) + else: + bias = None + return qweight, scales, qzeros, bias + + def update_weights(self): + """update weights.""" + qweight, scales, qzeros, bias = self.impl.update_weights( + self.qweight, self.scales, self.qzeros, self.bias) + qweight = torch.nn.Parameter(qweight, requires_grad=False) + qweight.weight_loader = self.weight_loader + qweight._weight_type = 'qweight' + scales = torch.nn.Parameter(scales, requires_grad=False) + scales.weight_loader = self.weight_loader + scales._weight_type = 'scales' + qzeros = torch.nn.Parameter(qzeros, requires_grad=False) + qzeros.weight_loader = self.weight_loader + qzeros._weight_type = 'qzeros' + if bias is not None: + bias = torch.nn.Parameter(bias, requires_grad=False) + bias.weight_loader = self.weight_loader + bias._weight_type = 'bias' + self.register_parameter('qweight', qweight) + self.register_parameter('scales', scales) + self.register_parameter('qzeros', qzeros) + self.register_parameter('bias', bias) + def forward(self, x): """w4a16 forward.""" is_tp = False if self.colwise else self.is_tp @@ -153,314 +292,864 @@ def forward(self, x): return out +class MergedAwqLinear(AwqLinear): + """merged awq linear.""" + + def __init__(self, + in_features: int, + all_out_features: List[int], + w_bit: int, + group_size: int, + bias: bool, + replicate: List[bool] = None, + device: Optional[torch.device] = None, + is_tp: bool = True, + out_names: List[int] = None): + if replicate is None: + replicate = tuple(False for _ in all_out_features) + all_out_features = self._update_all_out_features( + all_out_features, w_bit, group_size, replicate) + self.all_out_features = all_out_features + self.replicate = replicate + if out_names is None: + out_names = torch.arange(len(self.all_out_features)).tolist() + assert len(out_names) == len(self.all_out_features) + self.out_names_map = dict( + (name, idx) for idx, name in enumerate(out_names)) + out_features = sum(all_out_features) + super().__init__(in_features, + out_features, + w_bit, + group_size, + bias, + device, + colwise=True, + is_tp=is_tp) + self.qweight.weight_loader = self.weight_loader + self.qweight._weight_type = 'qweight' + self.scales.weight_loader = self.weight_loader + self.scales._weight_type = 'scales' + self.qzeros.weight_loader = self.weight_loader + self.qzeros._weight_type = 'qzeros' + if self.bias is not None: + self.bias.weight_loader = self.weight_loader + self.bias._weight_type = 'bias' + + def _get_io_features(self, in_features: int, out_features: int, w_bit: int, + group_size: int, colwise: bool): + """get io features.""" + return in_features, out_features + + def _update_all_out_features(self, all_out_features: List[int], w_bit: int, + group_size: int, replicate: List[bool]): + """update all out features.""" + world_size, rank = _get_world_rank() + new_all_out_features = [] + align = max(32 // w_bit, group_size) + for out_feat, rep in zip(all_out_features, replicate): + if rep: + new_all_out_features.append(out_feat) + new_out_feat = _update_feature_size(out_feat, world_size, rank, + align) + new_all_out_features.append(new_out_feat) + return new_all_out_features + + def weight_loader(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, shard_id: Any): + """weight loader.""" + world_size, rank = _get_world_rank() + shard_idx = self.out_names_map[shard_id] + + if loaded_weight.dim() == 1: + # bias + align = max(self.elem_per_int, self.group_size) + param_w = param.data.split(self.all_out_features, 0)[shard_idx] + if not self.replicate[shard_idx]: + weight = _chunk_align(loaded_weight, world_size, 0, + align)[rank] + param_w.copy_(weight) + + if param._weight_type == 'scales': + # scales + align = max(self.elem_per_int, self.group_size) + param_w = param.data.split(self.all_out_features, 1)[shard_idx] + else: + # qweight or qzeros + align = max(self.elem_per_int, + self.group_size) // self.elem_per_int + quanted_out_feats = [ + feat // self.elem_per_int for feat in self.all_out_features + ] + param_w = param.data.split(quanted_out_feats, 1)[shard_idx] + + if not self.replicate[shard_idx]: + weight = _chunk_align(loaded_weight, world_size, 1, align)[rank] + param_w.copy_(weight) + + +class QKVAwqLinear(MergedAwqLinear): + """qkv awq linear.""" + + def __init__(self, + in_features: int, + num_q_heads: int, + num_kv_heads: int, + head_size: int, + head_size_v: int, + w_bit: int, + group_size: int, + replicate_kv: bool = False, + bias: bool = False, + device: Optional[torch.device] = None, + is_tp: bool = True): + if device is None: + device = torch.device('cpus') + num_q_heads, num_kv_heads = self._update_num_heads( + num_q_heads, num_kv_heads, replicate_kv) + all_out_features = self._get_qkv_out_features(num_q_heads, + num_kv_heads, head_size, + head_size_v) + replicate = (False, replicate_kv, replicate_kv) + out_names = ('q', 'k', 'v') + self.num_q_heads = num_q_heads + self.num_kv_heads = num_kv_heads + self.head_size = head_size + self.head_size_v = head_size_v + super().__init__(in_features, + all_out_features, + w_bit=w_bit, + group_size=group_size, + bias=bias, + replicate=replicate, + device=device, + is_tp=is_tp, + out_names=out_names) + + def _get_qkv_out_features(self, num_q_heads: int, num_kv_heads: int, + head_size: int, head_size_v: int): + """get io features.""" + all_out_features = (num_q_heads * head_size, num_kv_heads * head_size, + num_kv_heads * head_size_v) + return all_out_features + + def _update_all_out_features(self, all_out_features: List[int], w_bit: int, + group_size: int, replicate: List[bool]): + """update all out features.""" + return all_out_features + + def _update_num_heads(self, num_q_heads: int, num_kv_heads: int, + replicate_kv: bool): + """update num heads.""" + world_size, rank = _get_world_rank() + num_q_heads = _update_feature_size(num_q_heads, world_size, rank) + if not replicate_kv: + num_kv_heads = _update_feature_size(num_kv_heads, world_size, rank) + + return num_q_heads, num_kv_heads + + class W8A8Linear(nn.Module): """w8a8 linear.""" def __init__(self, - mod: nn.Module, - ctx_mgr: Any = None, + in_features: int, + out_features: int, + bias: bool, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, colwise: bool = True, is_tp: bool = False): super().__init__() + if device is None: + device = torch.device('cpu') + if dtype is None: + dtype = torch.float16 + if is_tp: + in_features, out_features = self._get_io_features( + in_features, out_features, colwise) impl_builder = get_backend().get_layer_impl_builder( LayerType.LinearW8A8) - weight = mod.weight - scale = mod.scale - bias = mod.bias - - self.impl = impl_builder.build(mod.in_features, - mod.out_features, + self.impl = impl_builder.build(in_features, + out_features, bias is not None, - dtype=torch.float16) - weight, scale, bias = self.impl.update_weights(weight, scale, bias) + dtype=dtype) + weight, scale, bias = self.create_weights(in_features, out_features, + bias, dtype, device) weight = torch.nn.Parameter(weight, requires_grad=False) + weight.weight_loader = self.weight_loader scale = torch.nn.Parameter(scale, requires_grad=False) + scale.weight_loader = self.weight_loader if bias is not None: bias = torch.nn.Parameter(bias, requires_grad=False) + bias.weight_loader = self.weight_loader self.register_parameter('weight', weight) self.register_parameter('scale', scale) self.register_parameter('bias', bias) + + self.in_features = in_features + self.out_features = out_features + self.lora_adapters = [] self.is_tp = is_tp self.colwise = colwise + def _get_io_features(self, in_features: int, out_features: int, + colwise: bool): + """get io features.""" + world_size, rank = _get_world_rank() + if colwise: + out_features = _update_feature_size(out_features, world_size, rank) + else: + in_features = _update_feature_size(in_features, world_size, rank) + return in_features, out_features + + def _weight_loader_tp_colwise(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, rank: int, + world_size: int): + """weight loader for colwise linear.""" + weight = loaded_weight.chunk(world_size, 0)[rank] + return default_weight_loader(param, weight) + + def _weight_loader_tp_rowwise(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, rank: int, + world_size: int): + """weight loader for rowwise linear.""" + if loaded_weight.dim() == 2 and param.dtype == torch.int8: + weight = loaded_weight.chunk(world_size, 1)[rank] + return default_weight_loader(param, weight) + else: + # bias + return default_weight_loader(param, loaded_weight) + + def weight_loader(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor): + """weight loader.""" + if not self.is_tp: + return default_weight_loader(param, loaded_weight) + + world_size, rank = _get_world_rank() + if self.colwise: + return self._weight_loader_tp_colwise(param, loaded_weight, rank, + world_size) + else: + return self._weight_loader_tp_rowwise(param, loaded_weight, rank, + world_size) + + def create_weights(self, in_features: int, out_features: int, bias: bool, + dtype: torch.dtype, device: torch.device): + """create weights.""" + weight = torch.empty((out_features, in_features), + dtype=torch.int8, + device=device) + scale = torch.empty((out_features, 1), + dtype=torch.float32, + device=device) + if bias: + bias = torch.empty((out_features, ), dtype=dtype, device=device) + else: + bias = None + return weight, scale, bias + + def update_weights(self): + """update weights.""" + weight, scale, bias = self.impl.update_weights(self.weight, self.scale, + self.bias) + weight = torch.nn.Parameter(weight, requires_grad=False) + self.weight.weight_loader = self.weight_loader + scale = torch.nn.Parameter(scale, requires_grad=False) + self.scale.weight_loader = self.weight_loader + if bias is not None: + bias = torch.nn.Parameter(bias, requires_grad=False) + self.bias.weight_loader = self.weight_loader + self.register_parameter('weight', weight) + self.register_parameter('scale', scale) + self.register_parameter('bias', bias) + def forward(self, x): - """forward of w8a8.""" + """forward of w8a8 linear.""" is_tp = False if self.colwise else self.is_tp - return self.impl.forward(x, self.weight, self.scale, self.bias, is_tp) + if len(self.lora_adapters) == 0: + return self.impl.forward(x, self.weight, self.scale, self.bias, + is_tp) + + out = self.impl.forward(x, self.weight, self.scale, self.bias, False) + for lora_adapter in self.lora_adapters: + out = lora_adapter(x, out) + if is_tp: + dist.all_reduce(out) + return out + + +class MergedW8A8Linear(W8A8Linear): + """merged w8a8 linear.""" + + def __init__(self, + in_features: int, + all_out_features: List[int], + bias: bool, + replicate: List[bool] = None, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + is_tp: bool = True, + out_names: List[int] = None): + if replicate is None: + replicate = tuple(False for _ in all_out_features) + all_out_features = self._update_all_out_features( + all_out_features, replicate) + self.all_out_features = all_out_features + self.replicate = replicate + if out_names is None: + out_names = torch.arange(len(self.all_out_features)).tolist() + assert len(out_names) == len(self.all_out_features) + self.out_names_map = dict( + (name, idx) for idx, name in enumerate(out_names)) + out_features = sum(all_out_features) + super().__init__(in_features, + out_features, + bias, + dtype, + device, + colwise=True, + is_tp=is_tp) + self.weight.weight_loader = self.weight_loader + self.scale.weight_loader = self.weight_loader + if self.bias is not None: + self.bias.weight_loader = self.weight_loader + + def _get_io_features(self, in_features: int, out_features: int, + colwise: bool): + """get io features.""" + return in_features, out_features + + def _update_all_out_features(self, all_out_features: List[int], + replicate: List[bool]): + """update all out features.""" + world_size, rank = _get_world_rank() + new_all_out_features = [] + for out_feat, rep in zip(all_out_features, replicate): + if rep: + new_all_out_features.append(out_feat) + new_out_feat = _update_feature_size(out_feat, world_size, rank) + new_all_out_features.append(new_out_feat) + return new_all_out_features + + def weight_loader(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, shard_id: Any): + """weight loader.""" + world_size, rank = _get_world_rank() + shard_idx = self.out_names_map[shard_id] + param_w = param.data.split(self.all_out_features, 0)[shard_idx] + if not self.replicate[shard_idx]: + loaded_weight = loaded_weight.chunk(world_size, 0)[rank] + param_w.copy_(loaded_weight) + + +class QKVW8A8Linear(MergedW8A8Linear): + """qkv w8a8 linear.""" + + def __init__(self, + in_features: int, + num_q_heads: int, + num_kv_heads: int, + head_size: int, + head_size_v: int, + replicate_kv: bool = False, + bias: bool = False, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + is_tp: bool = True): + if device is None: + device = torch.device('cpus') + if dtype is None: + dtype = torch.float16 + num_q_heads, num_kv_heads = self._update_num_heads( + num_q_heads, num_kv_heads, replicate_kv) + all_out_features = self._get_qkv_out_features(num_q_heads, + num_kv_heads, head_size, + head_size_v) + replicate = (False, replicate_kv, replicate_kv) + out_names = ('q', 'k', 'v') + self.num_q_heads = num_q_heads + self.num_kv_heads = num_kv_heads + self.head_size = head_size + self.head_size_v = head_size_v + super().__init__(in_features, + all_out_features, + bias=bias, + replicate=replicate, + dtype=dtype, + device=device, + is_tp=is_tp, + out_names=out_names) + + def _get_qkv_out_features(self, num_q_heads: int, num_kv_heads: int, + head_size: int, head_size_v: int): + """get io features.""" + all_out_features = (num_q_heads * head_size, num_kv_heads * head_size, + num_kv_heads * head_size_v) + return all_out_features + + def _update_all_out_features(self, all_out_features: List[int], + replicate: List[bool]): + """update all out features.""" + return all_out_features + + def _update_num_heads(self, num_q_heads: int, num_kv_heads: int, + replicate_kv: bool): + """update num heads.""" + world_size, rank = _get_world_rank() + num_q_heads = _update_feature_size(num_q_heads, world_size, rank) + if not replicate_kv: + num_kv_heads = _update_feature_size(num_kv_heads, world_size, rank) + + return num_q_heads, num_kv_heads class BaseLinear(nn.Module): """linear layer.""" def __init__(self, - mod: nn.Module, - adapter_infos: List[AdapterInfo] = None, - ctx_mgr: Any = None, + in_features: int, + out_features: int, + bias: bool, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, colwise: bool = True, is_tp: bool = False): super().__init__() + if device is None: + device = torch.device('cpu') + if dtype is None: + dtype = torch.float16 + if is_tp: + in_features, out_features = self._get_io_features( + in_features, out_features, colwise) impl_builder = get_backend().get_layer_impl_builder(LayerType.Linear) - weight = mod.weight - bias = mod.bias - self.impl = impl_builder.build(mod.in_features, - mod.out_features, + self.impl = impl_builder.build(in_features, + out_features, bias is not None, - dtype=weight.dtype) - weight, bias = self.impl.update_weights(weight, bias) + dtype=dtype) + weight, bias = self.create_weights(in_features, out_features, bias, + dtype, device) + weight = torch.nn.Parameter(weight, requires_grad=False) + weight.weight_loader = self.weight_loader + if bias is not None: + bias = torch.nn.Parameter(bias, requires_grad=False) + bias.weight_loader = self.weight_loader self.register_parameter('weight', weight) self.register_parameter('bias', bias) - adapter_infos = adapter_infos if adapter_infos is not None else [] - self.lora_adapters = None - if len(adapter_infos) > 0: - self.lora_adapters = nn.ModuleList( - SLoRA(info, ctx_mgr, colwise, is_tp) for info in adapter_infos) - + self.in_features = in_features + self.out_features = out_features + self.lora_adapters = [] self.is_tp = is_tp self.colwise = colwise + def _get_io_features(self, in_features: int, out_features: int, + colwise: bool): + """get io features.""" + world_size, rank = _get_world_rank() + if colwise: + out_features = _update_feature_size(out_features, world_size, rank) + else: + in_features = _update_feature_size(in_features, world_size, rank) + return in_features, out_features + + def _weight_loader_tp_colwise(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, rank: int, + world_size: int): + """weight loader for colwise linear.""" + weight = loaded_weight.chunk(world_size, 0)[rank] + return default_weight_loader(param, weight) + + def _weight_loader_tp_rowwise(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, rank: int, + world_size: int): + """weight loader for rowwise linear.""" + if loaded_weight.dim() == 2: + weight = loaded_weight.chunk(world_size, 1)[rank] + return default_weight_loader(param, weight) + else: + # bias + return default_weight_loader(param, loaded_weight) + + def weight_loader(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor): + """weight loader.""" + if not self.is_tp: + return default_weight_loader(param, loaded_weight) + + world_size, rank = _get_world_rank() + if self.colwise: + return self._weight_loader_tp_colwise(param, loaded_weight, rank, + world_size) + else: + return self._weight_loader_tp_rowwise(param, loaded_weight, rank, + world_size) + + def create_weights(self, in_features: int, out_features: int, bias: bool, + dtype: torch.dtype, device: torch.device): + """create weights.""" + weight = torch.empty((out_features, in_features), + dtype=dtype, + device=device) + if bias: + bias = torch.empty((out_features, ), dtype=dtype, device=device) + else: + bias = None + return weight, bias + + def update_weights(self): + """update weights.""" + weight, bias = self.impl.update_weights(self.weight, self.bias) + weight = torch.nn.Parameter(weight, requires_grad=False) + self.weight.weight_loader = self.weight_loader + if bias is not None: + bias = torch.nn.Parameter(bias, requires_grad=False) + self.bias.weight_loader = self.weight_loader + self.register_parameter('weight', weight) + self.register_parameter('bias', bias) + def forward(self, x): """forward of linear layer.""" is_tp = False if self.colwise else self.is_tp - if self.lora_adapters is None: + if len(self.lora_adapters) == 0: return self.impl.forward(x, self.weight, self.bias, is_tp) out = self.impl.forward(x, self.weight, self.bias, False) - if self.lora_adapters is not None: - for lora_adapter in self.lora_adapters: - out = lora_adapter(x, out) + for lora_adapter in self.lora_adapters: + out = lora_adapter(x, out) if is_tp: dist.all_reduce(out) return out -def _merge_base_linear(*linears: List[nn.Module]): - """merge naive linear.""" - weights = [mod.weight for mod in linears] - bias = [mod.bias for mod in linears] - - in_features = weights[0].size(1) - dtype = weights[0].dtype - device = weights[0].device - for w in weights: - assert w.size(1) == in_features - assert w.dtype == dtype - assert w.device == device - out_features = sum(w.size(0) for w in weights) - - new_weight = torch.cat(weights, dim=0) - new_bias = None - if bias[0] is not None: - assert all(b is not None for b in bias) - new_bias = torch.cat(bias) - has_bias = new_bias is not None - merged_linear = nn.Linear(in_features, - out_features, - bias=has_bias, - dtype=dtype, - device=device) - state_dict = dict(weight=new_weight) - if has_bias: - state_dict['bias'] = new_bias - merged_linear.load_state_dict(state_dict) - return merged_linear - - -def _merge_qlinear(*linears: List[nn.Module]): - """merge qlinear.""" - weights = [mod.weight for mod in linears] - scales = [mod.scale for mod in linears] - bias = [mod.bias for mod in linears] - - in_features = weights[0].size(1) - dtype = weights[0].dtype - device = weights[0].device - for w in weights: - assert w.size(1) == in_features - assert w.dtype == dtype - assert w.device == device - out_features = sum(w.size(0) for w in weights) - - new_weight = torch.cat(weights, dim=0) - new_scale = torch.cat(scales, dim=0) - new_bias = None - if bias[0] is not None: - assert all(b is not None for b in bias) - new_bias = torch.cat(bias) - has_bias = new_bias is not None - merged_linear = QLinear(in_features, - out_features, - bias=has_bias, - dtype=dtype, - device=device) - state_dict = dict( - weight=new_weight, - scale=new_scale, - ) - if has_bias: - state_dict['bias'] = new_bias - merged_linear.load_state_dict(state_dict) - return merged_linear - - -def _merge_awqlinear(*linears: List[nn.Module]): - """merge awqlinear.""" - qweights = [mod.qweight for mod in linears] - scales = [mod.scales for mod in linears] - qzeros = [mod.qzeros for mod in linears] - bias = [mod.bias for mod in linears] - w_bits = [mod.w_bit for mod in linears] - group_sizes = [mod.group_size for mod in linears] - - w_bit = w_bits[0] - group_size = group_sizes[0] - assert all(wb == w_bit for wb in w_bits) - assert all(gs == group_size for gs in group_sizes) - in_features = qweights[0].size(0) - device = qweights[0].device - for w in qweights: - assert w.size(0) == in_features - assert w.device == device - out_features = sum(s.size(1) for s in scales) - - new_qweight = torch.cat(qweights, dim=1) - new_scales = torch.cat(scales, dim=1) - new_qzeros = torch.cat(qzeros, dim=1) - new_bias = None - if bias[0] is not None: - assert all(b is not None for b in bias) - new_bias = torch.cat(bias) - has_bias = new_bias is not None - merged_linear = WQLinear_GEMM( - w_bit, - group_size, - in_features, - out_features, - bias=has_bias, - dev=device, - ) - state_dict = dict( - qweight=new_qweight, - scales=new_scales, - qzeros=new_qzeros, - ) - if has_bias: - state_dict['bias'] = new_bias - merged_linear.load_state_dict(state_dict) - return merged_linear - - -def build_linear(mod: nn.Module, - adapter_infos: List[AdapterInfo] = None, - ctx_mgr: Any = None, +class MergedBaseLinear(BaseLinear): + """merged base linear.""" + + def __init__(self, + in_features: int, + all_out_features: List[int], + bias: bool, + replicate: List[bool] = None, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + is_tp: bool = True, + out_names: List[int] = None): + if replicate is None: + replicate = tuple(False for _ in all_out_features) + all_out_features = self._update_all_out_features( + all_out_features, replicate) + self.all_out_features = all_out_features + self.replicate = replicate + if out_names is None: + out_names = torch.arange(len(self.all_out_features)).tolist() + assert len(out_names) == len(self.all_out_features) + self.out_names_map = dict( + (name, idx) for idx, name in enumerate(out_names)) + out_features = sum(all_out_features) + super().__init__(in_features, + out_features, + bias, + dtype, + device, + colwise=True, + is_tp=is_tp) + self.weight.weight_loader = self.weight_loader + if self.bias is not None: + self.bias.weight_loader = self.weight_loader + + def _get_io_features(self, in_features: int, out_features: int, + colwise: bool): + """get io features.""" + return in_features, out_features + + def _update_all_out_features(self, all_out_features: List[int], + replicate: List[bool]): + """update all out features.""" + world_size, rank = _get_world_rank() + new_all_out_features = [] + for out_feat, rep in zip(all_out_features, replicate): + if rep: + new_all_out_features.append(out_feat) + new_out_feat = _update_feature_size(out_feat, world_size, rank) + new_all_out_features.append(new_out_feat) + return new_all_out_features + + def weight_loader(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, shard_id: Any): + """weight loader.""" + world_size, rank = _get_world_rank() + shard_idx = self.out_names_map[shard_id] + param_w = param.data.split(self.all_out_features, 0)[shard_idx] + if not self.replicate[shard_idx]: + loaded_weight = loaded_weight.chunk(world_size, 0)[rank] + param_w.copy_(loaded_weight) + + +class QKVBaseLinear(MergedBaseLinear): + """qkv base linear.""" + + def __init__(self, + in_features: int, + num_q_heads: int, + num_kv_heads: int, + head_size: int, + head_size_v: int, + replicate_kv: bool = False, + bias: bool = False, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + is_tp: bool = True): + if device is None: + device = torch.device('cpus') + if dtype is None: + dtype = torch.float16 + num_q_heads, num_kv_heads = self._update_num_heads( + num_q_heads, num_kv_heads, replicate_kv) + all_out_features = self._get_qkv_out_features(num_q_heads, + num_kv_heads, head_size, + head_size_v) + replicate = (False, replicate_kv, replicate_kv) + out_names = ('q', 'k', 'v') + self.num_q_heads = num_q_heads + self.num_kv_heads = num_kv_heads + self.head_size = head_size + self.head_size_v = head_size_v + super().__init__(in_features, + all_out_features, + bias=bias, + replicate=replicate, + dtype=dtype, + device=device, + is_tp=is_tp, + out_names=out_names) + + def _get_qkv_out_features(self, num_q_heads: int, num_kv_heads: int, + head_size: int, head_size_v: int): + """get io features.""" + all_out_features = (num_q_heads * head_size, num_kv_heads * head_size, + num_kv_heads * head_size_v) + return all_out_features + + def _update_all_out_features(self, all_out_features: List[int], + replicate: List[bool]): + """update all out features.""" + return all_out_features + + def _update_num_heads(self, num_q_heads: int, num_kv_heads: int, + replicate_kv: bool): + """update num heads.""" + world_size, rank = _get_world_rank() + num_q_heads = _update_feature_size(num_q_heads, world_size, rank) + if not replicate_kv: + num_kv_heads = _update_feature_size(num_kv_heads, world_size, rank) + + return num_q_heads, num_kv_heads + + +def build_linear(in_features: int, + out_features: int, + bias: bool, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, colwise: bool = True, - is_tp: bool = False) -> nn.Module: + is_tp: bool = False, + quant_config: Any = None) -> nn.Module: """build linear.""" if is_tp: - world_size, rank = _get_world_rank() + world_size, _ = _get_world_rank() is_tp = world_size > 1 - if isinstance(mod, nn.Linear): - return BaseLinear(mod, - adapter_infos, - ctx_mgr, - colwise=colwise, - is_tp=is_tp) - elif isinstance(mod, WQLinear_GEMM): - return AwqLinear(mod, - adapter_infos, - ctx_mgr, - colwise=colwise, - is_tp=is_tp) - elif isinstance(mod, QLinear): - return W8A8Linear(mod, ctx_mgr, colwise, is_tp) - elif isinstance(mod, LoRALinear): - base_layer = mod.base_layer - adapter_info = AdapterInfo.from_lora_linear(mod) - return build_linear(base_layer, [adapter_info], - ctx_mgr=ctx_mgr, - colwise=colwise, - is_tp=is_tp) - elif isinstance(mod, AwqLoraLinear): - base_layer = mod.base_layer - adapter_info = AdapterInfo.from_lora_linear(mod) - return build_linear(base_layer, [adapter_info], - ctx_mgr=ctx_mgr, - colwise=colwise, - is_tp=is_tp) + if quant_config is None: + return BaseLinear( + in_features, + out_features, + bias=bias, + dtype=dtype, + device=device, + colwise=colwise, + is_tp=is_tp, + ) + + quant_method = quant_config['quant_method'] + if quant_method == 'awq': + w_bit = quant_config.get('bits', 4) + group_size = quant_config.get('group_size', 128) + return AwqLinear( + in_features, + out_features, + w_bit=w_bit, + group_size=group_size, + bias=bias, + device=device, + colwise=colwise, + is_tp=is_tp, + ) + if quant_method == 'w8a8': + return W8A8Linear( + in_features, + out_features, + bias=bias, + dtype=dtype, + device=device, + colwise=colwise, + is_tp=is_tp, + ) else: - raise NotImplementedError(f'Unknown linear type: {type(mod)}') + raise RuntimeError(f'Unsupported quant method: {quant_method}') -def build_colwise_linear(mod: nn.Module, - adapter_infos: List[AdapterInfo] = None, - ctx_mgr: Any = None, - is_tp: bool = False) -> nn.Module: +def build_colwise_linear(in_features: int, + out_features: int, + bias: bool, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + is_tp: bool = False, + quant_config: Any = None) -> nn.Module: """build columnwise parallel linear layer.""" - return build_linear(mod, adapter_infos, ctx_mgr, colwise=True, is_tp=is_tp) - - -def build_rowwise_linear(mod: nn.Module, - adapter_infos: List[AdapterInfo] = None, - ctx_mgr: Any = None, - is_tp: bool = False) -> nn.Module: + return build_linear(in_features=in_features, + out_features=out_features, + bias=bias, + dtype=dtype, + device=device, + colwise=True, + is_tp=is_tp, + quant_config=quant_config) + + +def build_rowwise_linear(in_features: int, + out_features: int, + bias: bool, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + is_tp: bool = False, + quant_config: Any = None) -> nn.Module: """build rowwise parallel linear layer.""" - return build_linear(mod, - adapter_infos, - ctx_mgr, + return build_linear(in_features=in_features, + out_features=out_features, + bias=bias, + dtype=dtype, + device=device, colwise=False, - is_tp=is_tp) + is_tp=is_tp, + quant_config=quant_config) def build_merged_colwise_linear( - *linears: List[nn.Module], - ctx_mgr: Any = None, - is_tp: bool = False, + in_features: int, + all_out_features: List[int], + bias: bool, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + quant_config: Any = None, + is_tp: bool = True, + out_names: List[Any] = None, ): """merge linear.""" - base_layers = [] - out_features = [] - adapter_infos = [] - cum_out_feature = 0 - for mod in linears: - # get base layers - base_layer = getattr(mod, 'base_layer', mod) - base_layers.append(base_layer) - - # get out_feature - if hasattr(base_layer, 'weight'): - weight = base_layer.weight - out_feature = weight.size(0) - else: - scales = base_layer.scales - out_feature = scales.size(1) - slice_start = cum_out_feature - cum_out_feature += out_feature - - # get adapter info - adapter_info = None - if isinstance(mod, (LoRALinear, AwqLoraLinear)): - adapter_slice = slice(slice_start, cum_out_feature) - adapter_info = AdapterInfo.from_lora_linear(mod, adapter_slice) - out_features.append(out_feature) - if adapter_info is not None: - adapter_infos.append(adapter_info) - - # check base layer type - base_type = type(base_layers[0]) - assert all(isinstance(layer, base_type) for layer in base_layers) - - # merge base layer - if base_type == nn.Linear: - base_layer = _merge_base_linear(*base_layers) - elif base_type == WQLinear_GEMM: - base_layer = _merge_awqlinear(*base_layers) - elif base_type == QLinear: - base_layer = _merge_qlinear(*base_layers) + if is_tp: + world_size, _ = _get_world_rank() + is_tp = world_size > 1 + + if quant_config is None: + return MergedBaseLinear( + in_features=in_features, + all_out_features=all_out_features, + bias=bias, + dtype=dtype, + device=device, + is_tp=is_tp, + out_names=out_names, + ) + + quant_method = quant_config['quant_method'] + if quant_method == 'awq': + w_bit = quant_config.get('bits', 4) + group_size = quant_config.get('group_size', 128) + return MergedAwqLinear( + in_features, + all_out_features=all_out_features, + w_bit=w_bit, + group_size=group_size, + bias=bias, + device=device, + is_tp=is_tp, + ) + if quant_method == 'w8a8': + return MergedW8A8Linear( + in_features=in_features, + all_out_features=all_out_features, + bias=bias, + dtype=dtype, + device=device, + is_tp=is_tp, + out_names=out_names, + ) + else: + raise RuntimeError(f'Unsupported quant method: {quant_method}') + + +def build_qkv_proj(in_features: int, + num_q_heads: int, + num_kv_heads: int, + head_size: int, + head_size_v: int = None, + replicate_kv: bool = False, + bias: bool = False, + quant_config: Any = None, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + is_tp: bool = True): + """build qkv proj.""" + if is_tp: + world_size, _ = _get_world_rank() + is_tp = world_size > 1 + + if head_size_v is None: + head_size_v = head_size + + if quant_config is None: + return QKVBaseLinear( + in_features=in_features, + num_q_heads=num_q_heads, + num_kv_heads=num_kv_heads, + head_size=head_size, + head_size_v=head_size_v, + replicate_kv=replicate_kv, + bias=bias, + dtype=dtype, + device=device, + is_tp=is_tp, + ) + + quant_method = quant_config['quant_method'] + if quant_method == 'awq': + w_bit = quant_config.get('bits', 4) + group_size = quant_config.get('group_size', 128) + return QKVAwqLinear( + in_features=in_features, + num_q_heads=num_q_heads, + num_kv_heads=num_kv_heads, + head_size=head_size, + head_size_v=head_size_v, + replicate_kv=replicate_kv, + w_bit=w_bit, + group_size=group_size, + bias=bias, + device=device, + is_tp=is_tp, + ) + if quant_method == 'w8a8': + return QKVW8A8Linear( + in_features=in_features, + num_q_heads=num_q_heads, + num_kv_heads=num_kv_heads, + head_size=head_size, + head_size_v=head_size_v, + replicate_kv=replicate_kv, + bias=bias, + dtype=dtype, + device=device, + is_tp=is_tp, + ) else: - raise NotImplementedError(f'Unknown linear type: {type(mod)}') - ret = build_colwise_linear(base_layer, - adapter_infos, - ctx_mgr=ctx_mgr, - is_tp=is_tp) - return ret + raise RuntimeError(f'Unsupported quant method: {quant_method}') diff --git a/lmdeploy/pytorch/nn/moe.py b/lmdeploy/pytorch/nn/moe.py index 53f0053e3..48ecb60c8 100644 --- a/lmdeploy/pytorch/nn/moe.py +++ b/lmdeploy/pytorch/nn/moe.py @@ -1,53 +1,23 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import List +from typing import Optional import torch +import torch.distributed as dist from torch import nn from ..backends import LayerType, get_backend -def _merge_mlp(gates: List[torch.Tensor], ups: List[torch.Tensor], - downs: List[torch.Tensor]): - """merge experts.""" - num_experts = len(gates) - - def __get_meta(): - gate = gates[0] - down = downs[0] - ffn_dim = gate.weight.size(0) - hidden_dim = down.weight.size(0) - dtype = gate.weight.dtype - device = gate.weight.device - return ffn_dim, hidden_dim, dtype, device - - def __copy_assign_param(param, weight): - """copy assign.""" - weight.copy_(param.data) - param.data = weight - - ffn_dim, hidden_dim, dtype, device = __get_meta() - - gate_up_weights = torch.empty(num_experts, - ffn_dim * 2, - hidden_dim, - device=device, - dtype=dtype) - down_weights = torch.empty(num_experts, - hidden_dim, - ffn_dim, - device=device, - dtype=dtype) - for exp_id in range(num_experts): - gate = gates[exp_id] - up = ups[exp_id] - down = downs[exp_id] - __copy_assign_param(gate.weight, gate_up_weights[exp_id, :ffn_dim]) - __copy_assign_param(up.weight, gate_up_weights[exp_id, ffn_dim:]) - __copy_assign_param(down.weight, down_weights[exp_id]) - - torch.cuda.empty_cache() - return gate_up_weights, down_weights +def _get_world_rank(): + """get current world size and rank.""" + world_size = 1 + rank = 0 + + if dist.is_initialized(): + world_size = dist.get_world_size() + rank = dist.get_rank() + + return world_size, rank class SoftmaxTopK(nn.Module): @@ -68,41 +38,98 @@ def forward(self, x: torch.Tensor): class FusedMoE(nn.Module): """fused moe.""" - def __init__( - self, - gate_up_weights: torch.Tensor, - down_weights: torch.Tensor, - top_k: int, - renormalize: bool = False, - ): + def __init__(self, + hidden_dim: int, + ffn_dim: int, + num_experts: int, + top_k: int, + renormalize: bool = False, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + is_tp: bool = False): super().__init__() + if device is None: + device = torch.device('cpu') + if dtype is None: + dtype = torch.float16 + hidden_dim, ffn_dim = self._update_args(hidden_dim, ffn_dim) + impl_builder = get_backend().get_layer_impl_builder(LayerType.FusedMoE) self.impl = impl_builder.build(top_k, renormalize) + + gate_up_weights, down_weights = self.create_weights(hidden_dim, + ffn_dim, + num_experts, + dtype=dtype, + device=device) + gate_up_weights = torch.nn.Parameter(gate_up_weights, + requires_grad=False) + down_weights = torch.nn.Parameter(down_weights, requires_grad=False) + gate_up_weights.weight_loader = self.weight_loader + down_weights.weight_loader = self.weight_loader + gate_up_weights._weight_type = 'gate_up_weights' + down_weights._weight_type = 'down_weights' + self.register_parameter('gate_up_weights', gate_up_weights) + self.register_parameter('down_weights', down_weights) + + self.hidden_dim = hidden_dim + self.ffn_dim = ffn_dim + self.num_experts = num_experts + self.dtype = dtype + self.device = device + self.is_tp = is_tp + + def _update_args(self, hidden_dim: int, ffn_dim: int): + """update args.""" + world_size, _ = _get_world_rank() + assert ffn_dim % world_size == 0 + ffn_dim = ffn_dim // world_size + return hidden_dim, ffn_dim + + def create_weights(self, hidden_dim: int, ffn_dim: int, num_experts: int, + dtype: torch.dtype, device: torch.device): + """create weights.""" + gate_up_weights = torch.empty((num_experts, ffn_dim * 2, hidden_dim), + dtype=dtype, + device=device) + down_weights = torch.empty((num_experts, hidden_dim, ffn_dim), + dtype=dtype, + device=device) + return gate_up_weights, down_weights + + def update_weights(self): + """update weights.""" gate_up_weights, down_weights = self.impl.update_weights( - gate_up_weights, down_weights) + self.gate_up_weights, self.down_weights) gate_up_weights = torch.nn.Parameter(gate_up_weights, requires_grad=False) down_weights = torch.nn.Parameter(down_weights, requires_grad=False) + gate_up_weights.weight_loader = self.weight_loader + down_weights.weight_loader = self.weight_loader + gate_up_weights._weight_type = 'gate_up_weights' + down_weights._weight_type = 'down_weights' self.register_parameter('gate_up_weights', gate_up_weights) self.register_parameter('down_weights', down_weights) + def weight_loader(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, expert_id: int, + shard_id: str): + """weight loader.""" + world_size, rank = _get_world_rank() + if shard_id == 'gate': + param_data = param.data[expert_id, :self.ffn_dim] + weight = loaded_weight.chunk(world_size, dim=0)[rank] + elif shard_id == 'up': + param_data = param.data[expert_id, self.ffn_dim:] + weight = loaded_weight.chunk(world_size, dim=0)[rank] + elif shard_id == 'down': + param_data = param.data[expert_id] + weight = loaded_weight.chunk(world_size, dim=1)[rank] + else: + raise RuntimeError(f'Unknown shard_id: {shard_id}') + param_data.copy_(weight) + def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.LongTensor): return self.impl.forward(hidden_states, topk_weights, topk_ids, self.gate_up_weights, self.down_weights) - - -def build_moe_from_mlp( - gates: List[nn.Linear], - ups: List[nn.Linear], - downs: List[nn.Linear], - top_k: int, - renormalize: bool = False, -): - """build moe from mlp.""" - - gate_up_weights, down_weights = _merge_mlp(gates, ups, downs) - return FusedMoE(gate_up_weights, - down_weights, - top_k=top_k, - renormalize=renormalize) diff --git a/lmdeploy/pytorch/nn/norm.py b/lmdeploy/pytorch/nn/norm.py index c9137c129..5394e1892 100644 --- a/lmdeploy/pytorch/nn/norm.py +++ b/lmdeploy/pytorch/nn/norm.py @@ -1,10 +1,17 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Any + import torch from torch import nn from ..backends import LayerType, get_backend +def _is_w8a8(quant_config: Any): + """is w8a8.""" + return False + + class RMSNorm(nn.Module): """RMS Norm with add residual.""" @@ -13,10 +20,10 @@ def __init__(self, eps: float = 1e-6, dtype: torch.dtype = None, device: torch.device = None, - is_w8a8: bool = False): + quant_config: Any = None): super().__init__() backend = get_backend() - if is_w8a8: + if _is_w8a8(quant_config): builder = backend.get_layer_impl_builder(LayerType.RMSNormW8A8) else: builder = backend.get_layer_impl_builder(LayerType.RMSNorm) diff --git a/lmdeploy/pytorch/weight_loader/adapter_weight_loader.py b/lmdeploy/pytorch/weight_loader/adapter_weight_loader.py deleted file mode 100644 index db0fabeb9..000000000 --- a/lmdeploy/pytorch/weight_loader/adapter_weight_loader.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from contextlib import contextmanager - - -class AdapterWeightLoader: - """adapter weight loader.""" - - def __init__(self, adapter_name: str, adapter_path: str): - from peft.utils.save_and_load import load_peft_weights - self._adapter_name = adapter_name - self._adapter_path = adapter_path - self._state_dict = load_peft_weights(adapter_path) - self._prefix = 'base_model.model.' - - def pop(self, key: str): - """pop weight.""" - key = self._prefix + key - return self._state_dict.pop(key) - - def get(self, key: str): - """get weight.""" - key = self._prefix + key - return self._state_dict.get(key) - - @contextmanager - def prefix_context(self, mod_name: str): - """update prefix by mod name.""" - old_prefix = self._prefix - if len(old_prefix) == 0: - new_prefix = f'{mod_name}.' - else: - new_prefix = f'{old_prefix}{mod_name}.' - self._prefix = new_prefix - yield new_prefix - self._prefix = old_prefix diff --git a/lmdeploy/pytorch/weight_loader/dist_utils.py b/lmdeploy/pytorch/weight_loader/dist_utils.py deleted file mode 100644 index bad054e00..000000000 --- a/lmdeploy/pytorch/weight_loader/dist_utils.py +++ /dev/null @@ -1,526 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import math -from typing import List - -import torch - -from lmdeploy.pytorch.models.q_modules import QLinear -from lmdeploy.utils import get_logger - -from .model_weight_loader import ModelWeightLoader - -logger = get_logger('lmdeploy') - -try: - from peft.tuners.lora import Linear as LoRALinear -except ImportError: - logger.debug('load peft.tuner.lora.Linear failed.') - - class LoRALinear: - pass - - -try: - from peft.tuners.lora.awq import AwqLoraLinear -except ImportError: - logger.debug('load peft.tuners.lora.awq.AwqLoraLinear failed.') - - class AwqLoraLinear: - pass - - -try: - from awq.modules.linear.gemm import WQLinear_GEMM -except ImportError: - logger.debug('load awq.modules.linear.gemm.WQLinearGEMM failed.') - - class WQLinear_GEMM: - pass - - -def _div_up(a, b): - """div up.""" - return (a + b - 1) // b - - -def _math_lcm(*args): - """lcm.""" - return int(math.prod(args) / math.gcd(*args)) - - -def get_prefixed_name(name: str, prefix: str): - """get prefixed name.""" - if len(prefix) == 0: - return name - else: - return f'{prefix}.{name}' - - -def cast_dtype(param: torch.Tensor, dtype: torch.dtype): - """cast dtype.""" - if param.dtype != dtype: - param = param.to(dtype) - return param - - -def colwise_parallelize_linear_naive(mod: torch.nn.Module, - loader: ModelWeightLoader, - rank: int, - world_size: int, - prefix: str = ''): - """colwise parallelize linear.""" - - def __update_param(name, param): - """update_param.""" - dtype = param.dtype - prefixed_name = get_prefixed_name(name, prefix) - param = loader.pop(prefixed_name).chunk(world_size)[rank] - param = cast_dtype(param, dtype) - return param - - for name, param in mod.named_parameters(): - param = __update_param(name, param) - param = torch.nn.Parameter(param, requires_grad=False) - mod.register_parameter(name, param) - for name, param in mod.named_buffers(): - param = __update_param(name, param) - mod.register_buffer(name, param) - - -def colwise_parallelize_loralinear(module: torch.nn.Module, - loader: ModelWeightLoader, - rank: int, - world_size: int, - prefix: str = ''): - """colwise parallelize loralinear.""" - if isinstance(module.base_layer, WQLinear_GEMM): - parallel_base_func = colwise_parallelize_wqlinear - else: - parallel_base_func = colwise_parallelize_linear_naive - parallel_base_func(module.base_layer, - loader, - rank=rank, - world_size=world_size, - prefix=prefix) - for key, mod in module.lora_A.items(): - ada_loader = loader.adapter(key) - colwise_parallelize_linear_naive(mod, - ada_loader, - rank=rank, - world_size=world_size, - prefix=get_prefixed_name( - 'lora_A', prefix)) - for key, mod in module.lora_B.items(): - ada_loader = loader.adapter(key) - colwise_parallelize_linear_naive(mod, - ada_loader, - rank=rank, - world_size=world_size, - prefix=get_prefixed_name( - 'lora_B', prefix)) - module._tp_mode = 'colwise' - - -def _get_split_size_with_align(size: int, align: int, num_chunk: int): - """get split size with align.""" - assert size % align == 0 - num_aligned = size // align - split_size = _div_up(num_aligned, num_chunk) * align - return split_size - - -def colwise_parallelize_wqlinear(mod: torch.nn.Module, - loader: ModelWeightLoader, - rank: int, - world_size: int, - prefix: str = ''): - """colwise parallelize wqlinear.""" - elem_per_word = 32 // mod.w_bit - group_size = mod.group_size - lcm = _math_lcm(elem_per_word, group_size) - num_out = mod.scales.size(1) - - split_size = _get_split_size_with_align(num_out, lcm, world_size) - qsplit_size = split_size // elem_per_word - - def __update_param(name, param): - """update_param.""" - dtype = param.dtype - prefixed_name = get_prefixed_name(name, prefix) - if name == 'bias': - ssize = split_size - dim = 0 - elif name == 'scales': - ssize = split_size - dim = 1 - else: - ssize = qsplit_size - dim = 1 - param = loader.pop(prefixed_name) - param = param.split(ssize, dim)[rank] - param = cast_dtype(param, dtype) - return param - - for name, param in mod.named_parameters(): - param = __update_param(name, param) - param = torch.nn.Parameter(param, requires_grad=False) - mod.register_parameter(name, param) - for name, param in mod.named_buffers(): - param = __update_param(name, param) - mod.register_buffer(name, param) - mod.in_features = mod.qweight.size(0) - mod.out_features = mod.scales.size(1) - - -def colwise_parallelize_linear(module: torch.nn.Module, - loader: ModelWeightLoader, - rank: int, - world_size: int, - prefix: str = ''): - """colwise parallelize linear.""" - if isinstance(module, (torch.nn.Linear, QLinear)): - return colwise_parallelize_linear_naive(module, - loader, - rank=rank, - world_size=world_size, - prefix=prefix) - elif isinstance(module, (LoRALinear, AwqLoraLinear)): - return colwise_parallelize_loralinear(module, - loader, - rank=rank, - world_size=world_size, - prefix=prefix) - elif isinstance(module, WQLinear_GEMM): - return colwise_parallelize_wqlinear(module, - loader, - rank=rank, - world_size=world_size, - prefix=prefix) - else: - raise TypeError(f'Unsupported module: {type(module)}') - - -def rowwise_parallelize_linear_naive(mod: torch.nn.Module, - loader: ModelWeightLoader, - rank: int, - world_size: int, - prefix: str = ''): - """rowwise parallelize linear.""" - - def __update_param(name: str, param: torch.Tensor): - """update_param.""" - dtype = param.dtype - prefixed_name = get_prefixed_name(name, prefix) - param = loader.pop(prefixed_name) - if name == 'weight': - param = param.chunk(world_size, 1)[rank] - if name == 'bias': - param /= world_size - param = cast_dtype(param, dtype) - return param - - for name, param in mod.named_parameters(): - param = __update_param(name, param) - param = torch.nn.Parameter(param, requires_grad=False) - mod.register_parameter(name, param) - for name, param in mod.named_buffers(): - param = __update_param(name, param) - mod.register_buffer(name, param) - - -def rowwise_parallelize_loralinear(module: LoRALinear, - loader: ModelWeightLoader, - rank: int, - world_size: int, - prefix: str = ''): - """colwise parallelize loralinear.""" - if isinstance(module.base_layer, WQLinear_GEMM): - parallel_base_func = rowwise_parallelize_wqlinear - else: - parallel_base_func = rowwise_parallelize_linear_naive - parallel_base_func(module.base_layer, - loader, - rank=rank, - world_size=world_size, - prefix=prefix) - for key, mod in module.lora_A.items(): - ada_loader = loader.adapter(key) - rowwise_parallelize_linear_naive(mod, - ada_loader, - rank=rank, - world_size=world_size, - prefix=get_prefixed_name( - 'lora_A', prefix)) - for key, mod in module.lora_B.items(): - ada_loader = loader.adapter(key) - colwise_parallelize_linear_naive(mod, - ada_loader, - rank=rank, - world_size=world_size, - prefix=get_prefixed_name( - 'lora_B', prefix)) - module._tp_mode = 'colwise' - - -def rowwise_parallelize_wqlinear(mod: torch.nn.Module, - loader: ModelWeightLoader, - rank: int, - world_size: int, - prefix: str = ''): - """rowwise parallelize linear.""" - elem_per_word = 32 // mod.w_bit - group_size = mod.group_size - lcm = _math_lcm(elem_per_word, group_size) - num_in = mod.qweight.size(0) - - split_size = _get_split_size_with_align(num_in, lcm, world_size) - qsplit_size = split_size // group_size - - def __update_param(name: str, param: torch.Tensor): - """update_param.""" - dtype = param.dtype - prefixed_name = get_prefixed_name(name, prefix) - param = loader.pop(prefixed_name) - if name == 'bias': - param /= world_size - elif name == 'qweight': - param = param.split(split_size)[rank] - else: - param = param.split(qsplit_size)[rank] - param = cast_dtype(param, dtype) - return param - - for name, param in mod.named_parameters(): - param = __update_param(name, param) - param = torch.nn.Parameter(param, requires_grad=False) - mod.register_parameter(name, param) - for name, param in mod.named_buffers(): - param = __update_param(name, param) - mod.register_buffer(name, param) - mod.in_features = mod.qweight.size(0) - mod.out_features = mod.scales.size(1) - - -def rowwise_parallelize_linear(module: torch.nn.Module, - loader: ModelWeightLoader, - rank: int, - world_size: int, - prefix: str = ''): - """colwise parallelize linear.""" - if isinstance(module, (torch.nn.Linear, QLinear)): - return rowwise_parallelize_linear_naive(module, - loader, - rank=rank, - world_size=world_size, - prefix=prefix) - elif isinstance(module, (LoRALinear, AwqLoraLinear)): - return rowwise_parallelize_loralinear(module, - loader, - rank=rank, - world_size=world_size, - prefix=prefix) - elif isinstance(module, WQLinear_GEMM): - return rowwise_parallelize_wqlinear(module, - loader, - rank=rank, - world_size=world_size, - prefix=prefix) - else: - raise TypeError(f'Unsupported module: {type(module)}') - - -def colwise_split_parallelize_linear_naive(module: torch.nn.Module, - sections: List[int], - loader: ModelWeightLoader, - rank: int, - world_size: int, - prefix: str = ''): - """colwise split linear naive.""" - - def __update_param(name: str, param: torch.Tensor): - dtype = param.dtype - prefixed_name = get_prefixed_name(name, prefix) - param = loader.pop(prefixed_name) - splited_param = param.split(sections, dim=0) - updated_param = [] - for p in splited_param: - p = p.chunk(world_size)[rank] - p = cast_dtype(p, dtype) - updated_param.append(p) - param = torch.cat(updated_param) - return param - - for name, param in module.named_parameters(): - param = __update_param(name, param) - param = torch.nn.Parameter(param, requires_grad=False) - module.register_parameter(name, param) - for name, param in module.named_buffers(): - param = __update_param(name, param) - module.register_buffer(name, param) - - -def colwise_split_parallelize_loralinear(module: LoRALinear, - sections: List[int], - loader: ModelWeightLoader, - rank: int, - world_size: int, - prefix: str = ''): - """colwise split loralinear.""" - if isinstance(module.base_layer, WQLinear_GEMM): - parallel_base_func = colwise_split_parallelize_wqlinear - else: - parallel_base_func = colwise_split_parallelize_linear_naive - parallel_base_func(module.base_layer, - sections, - loader, - rank=rank, - world_size=world_size, - prefix=prefix) - for key, mod in module.lora_A.items(): - ada_loader = loader.adapter(key) - colwise_parallelize_linear_naive(mod, - ada_loader, - rank=rank, - world_size=world_size, - prefix=get_prefixed_name( - 'lora_A', prefix)) - for key, mod in module.lora_B.items(): - ada_loader = loader.adapter(key) - colwise_split_parallelize_linear_naive(mod, - sections, - ada_loader, - rank=rank, - world_size=world_size, - prefix=get_prefixed_name( - 'lora_B', prefix)) - module._tp_mode = 'colwise' - - -def colwise_split_parallelize_wqlinear(module: torch.nn.Module, - sections: List[int], - loader: ModelWeightLoader, - rank: int, - world_size: int, - prefix: str = ''): - """colwise split wqlinear.""" - elem_per_word = 32 // module.w_bit - group_size = module.group_size - lcm = _math_lcm(elem_per_word, group_size) - - for s in sections: - assert s % lcm == 0 - - def __update_param(name: str, param: torch.Tensor): - dtype = param.dtype - prefixed_name = get_prefixed_name(name, prefix) - param = loader.pop(prefixed_name) - if name == 'bias': - dim = 0 - sec = sections - elif name == 'scales': - dim = 1 - sec = sections - else: - dim = 1 - sec = [s // elem_per_word for s in sections] - splited_param = param.split(sec, dim=dim) - updated_param = [] - for p in splited_param: - if name == 'bias': - p = p.chunk(world_size)[rank] - else: - p = p.chunk(world_size, 1)[rank] - p = cast_dtype(p, dtype) - updated_param.append(p) - if name == 'bias': - param = torch.cat(updated_param) - else: - param = torch.cat(updated_param, 1) - return param - - for name, param in module.named_parameters(): - param = __update_param(name, param) - param = torch.nn.Parameter(param, requires_grad=False) - module.register_parameter(name, param) - for name, param in module.named_buffers(): - param = __update_param(name, param) - module.register_buffer(name, param) - module.in_features = module.qweight.size(0) - module.out_features = module.scales.size(1) - - -def colwise_split_parallelize_linear(module: torch.nn.Module, - sections: List[int], - loader: ModelWeightLoader, - rank: int, - world_size: int, - prefix: str = ''): - """colwise split linear.""" - if isinstance(module, (torch.nn.Linear, QLinear)): - return colwise_split_parallelize_linear_naive(module, - sections, - loader, - rank=rank, - world_size=world_size, - prefix=prefix) - elif isinstance(module, (LoRALinear, AwqLoraLinear)): - return colwise_split_parallelize_loralinear(module, - sections, - loader, - rank=rank, - world_size=world_size, - prefix=prefix) - elif isinstance(module, WQLinear_GEMM): - return colwise_split_parallelize_wqlinear(module, - sections, - loader, - rank=rank, - world_size=world_size, - prefix=prefix) - else: - raise TypeError(f'Unsupported module: {type(module)}') - - -def load_no_recursive(mod: torch.nn.Module, - loader: ModelWeightLoader, - rank: int = 0, - prefix: str = ''): - """default load linear naive.""" - for name, param in mod.named_parameters(recurse=False): - prefixed_name = get_prefixed_name(name, prefix) - dtype = param.dtype - if not loader.has(prefixed_name): - logger.debug(f'rank [{rank}]' - f' failed to find weight: {name}.') - param = torch.empty_like(param, device='cpu') - else: - param = loader.pop(prefixed_name) - if param.dtype != dtype: - param = param.to(dtype) - mod.register_parameter(name, - torch.nn.Parameter(param, requires_grad=False)) - for name, param in mod.named_buffers(recurse=False): - prefixed_name = get_prefixed_name(name, prefix) - dtype = param.dtype - if not loader.has(prefixed_name): - logger.debug(f'rank [{rank}]' - f' failed to find weight: {name}.') - param = torch.empty_like(param, device='cpu') - else: - param = loader.pop(prefixed_name) - if param.dtype != dtype: - param = param.to(dtype) - mod.register_buffer(name, param) - - -def default_load_linear(module: torch.nn.Module, - loader: ModelWeightLoader, - rank: int = 0, - prefix: str = ''): - """default load linear.""" - if isinstance(module, (torch.nn.Linear, QLinear, WQLinear_GEMM)): - load_no_recursive(module, loader, rank=rank, prefix=prefix) - elif isinstance(module, (LoRALinear, AwqLoraLinear)): - raise NotImplementedError('Not implemented, please contact us.') - else: - raise TypeError(f'Unsupported module: {type(module)}') diff --git a/lmdeploy/pytorch/weight_loader/model_weight_loader.py b/lmdeploy/pytorch/weight_loader/model_weight_loader.py index 36da1cb6e..bf48b2b78 100644 --- a/lmdeploy/pytorch/weight_loader/model_weight_loader.py +++ b/lmdeploy/pytorch/weight_loader/model_weight_loader.py @@ -1,8 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import json import os.path as osp -from contextlib import ExitStack, contextmanager -from typing import Dict import torch from transformers.modeling_utils import load_state_dict @@ -11,8 +9,6 @@ from lmdeploy.utils import get_logger -from .adapter_weight_loader import AdapterWeightLoader - logger = get_logger('lmdeploy') @@ -96,153 +92,61 @@ def _get_weight_path(model_path: str, weight_type: str): class ModelWeightLoader: """model weight loader for sharded weights.""" - def __init__(self, model_path: str, adapters: Dict[str, str] = None): + def __init__(self, model_path: str, prefix: str = None): self.model_path = model_path - self._state_dict = dict() weight_type, is_sharded = _get_weight_type(model_path) self._weight_type = weight_type self._is_sharded = is_sharded - self._prefix = '' + self._prefix = prefix + self._shard_paths = self._get_shard_paths(model_path, is_sharded, + weight_type) + + @staticmethod + def _get_shard_paths(model_path: str, is_sharded: bool, weight_type: str): + """get shard paths.""" if is_sharded: - self._weight_map = _get_weight_map(model_path, weight_type) + weight_map = _get_weight_map(model_path, weight_type) + paths = set(weight_map.values()) + paths = tuple(f'{model_path}/{path}' for path in paths) + return paths else: - weight_path, weight_name = _get_weight_path( - model_path, weight_type) - self._load_shard(weight_path) - keys = list(self._state_dict.keys()) - self._weight_map = dict((k, weight_name) for k in keys) - - if adapters is None: - adapters = dict() - - self._adapter_loaders: Dict[str, AdapterWeightLoader] = dict() - for ada_name, ada_path in adapters.items(): - ada_loader = AdapterWeightLoader(ada_name, ada_path) - self._adapter_loaders[ada_name] = ada_loader + path, _ = _get_weight_path(model_path, weight_type) + return (path, ) def _load_shard(self, path: str): """load shards.""" - self._state_dict.update(load_state_dict(path)) - - def _load_shard_for_key(self, key: str): - """load shard for key.""" - if key in self._state_dict: - return - if key not in self._weight_map: - raise RuntimeError(f'Unknown weight: {key}.') - - shard_file = osp.join(self.model_path, self._weight_map[key]) - self._load_shard(shard_file) - if key not in self._state_dict: - raise RuntimeError(f'Can not found "{key}" in "{shard_file}"') - - def pop(self, key: str): - """pop weight.""" - key = self._prefix + key - self._load_shard_for_key(key) - return self._state_dict.pop(key) - - def get(self, key: str): - """get weight.""" - key = self._prefix + key - self._load_shard_for_key(key) - return self._state_dict.get(key) - - def has(self, key: str): - """check has key.""" - key = self._prefix + key - return key in self._weight_map - - def adapter(self, key: str): - """get adapter loader.""" - if key not in self._adapter_loaders: - raise RuntimeError(f'Unknown adapter: {key}') - return self._adapter_loaders[key] - - @contextmanager - def prefix_context(self, mod_name: str): - """update prefix by mod name.""" - old_prefix = self._prefix - if len(old_prefix) == 0: - new_prefix = f'{mod_name}.' - else: - new_prefix = f'{old_prefix}{mod_name}.' - self._prefix = new_prefix - - with ExitStack() as stack: - for ada in self._adapter_loaders.values(): - stack.enter_context(ada.prefix_context(mod_name)) - yield new_prefix - self._prefix = old_prefix + state_dict = load_state_dict(path) + if self._prefix is not None: + state_dict = dict( + (f'{self._prefix}{k}', v) for k, v in state_dict.items()) + return state_dict def load_model_weights( self, model: torch.nn.Module, - rank: int = 0, - world_size: int = 1, - device: torch.device = 'cpu', - load_only: bool = False, + device: torch.device = None, ): """load model weights implementation.""" - loader = self - - def __load_no_recursive(mod: torch.nn.Module): - """load no recursive.""" - for name, param in mod.named_parameters(recurse=False): - dtype = param.dtype - if not loader.has(name): - logger.debug(f'rank [{rank}]' - f' failed to find weight: {name}.') - param = torch.empty_like(param, device='cpu') - else: - param = loader.pop(name) - if param.dtype != dtype: - param = param.to(dtype) - mod.register_parameter( - name, torch.nn.Parameter(param, requires_grad=False)) - for name, param in mod.named_buffers(recurse=False): - dtype = param.dtype - if not loader.has(name): - logger.debug(f'rank [{rank}]' - f' failed to find weight: {name}.') - param = torch.empty_like(param, device='cpu') - else: - param = loader.pop(name) - if param.dtype != dtype: - param = param.to(dtype) - mod.register_buffer(name, param) - - if not load_only and hasattr(model, '_load_weights'): - model._load_weights(model, loader, rank, world_size, device=device) - else: - __load_no_recursive(model) - for name, child in model.named_children(): - with loader.prefix_context(name): - self.load_model_weights(child, - rank=rank, - world_size=world_size, - device=device) - - model.to(device) + assert hasattr(model, 'load_weights') + paths = self._shard_paths + for path in paths: + state_dict = self._load_shard(path) + model.load_weights(state_dict.items()) + if device is not None: + device = model.to(device) @torch.inference_mode() def load_model_weights(model: torch.nn.Module, checkpoint_path: str, - adapters: Dict[str, str] = None, - rank: int = 0, - world_size: int = 1, - device: torch.device = 'cpu'): + prefix: str = None, + device: torch.device = None): """Loading model weights.""" - if rank == 0: - logger.info('Loading model weights, please waiting.') - if adapters is None: - adapters = dict() - loader = ModelWeightLoader(checkpoint_path, adapters=adapters) - loader.load_model_weights(model, - rank=rank, - world_size=world_size, - device=device) - model.tie_weights() + loader = ModelWeightLoader(checkpoint_path, prefix=prefix) + loader.load_model_weights(model, device=device) model.eval() + for _, mod in model.named_modules(): + if not hasattr(mod, 'update_weights'): + continue + mod.update_weights() From 476bce2c714d2e0e11e403dcafe929d9484080ea Mon Sep 17 00:00:00 2001 From: grimoire Date: Wed, 28 Aug 2024 11:41:41 +0800 Subject: [PATCH 41/89] optimize nn --- lmdeploy/pytorch/models/llama.py | 26 +--- lmdeploy/pytorch/models/mixtral.py | 37 +----- lmdeploy/pytorch/nn/attention.py | 14 ++ lmdeploy/pytorch/nn/linear.py | 204 +++++++++++------------------ lmdeploy/pytorch/nn/moe.py | 31 ++--- lmdeploy/pytorch/nn/utils.py | 33 +++++ 6 files changed, 153 insertions(+), 192 deletions(-) create mode 100644 lmdeploy/pytorch/nn/utils.py diff --git a/lmdeploy/pytorch/models/llama.py b/lmdeploy/pytorch/models/llama.py index b7e4c7065..bbc1c9f4e 100644 --- a/lmdeploy/pytorch/models/llama.py +++ b/lmdeploy/pytorch/models/llama.py @@ -2,7 +2,6 @@ from typing import Any, Iterable, List, Optional, Tuple, Union import torch -import torch.distributed as dist from torch import nn from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.models.llama import LlamaConfig @@ -23,17 +22,11 @@ def __init__(self, dtype: torch.dtype = None, device: torch.device = None): super().__init__() - world_size = 1 - if dist.is_initialized(): - world_size = dist.get_world_size() quantization_config = getattr(config, 'quantization_config', None) num_heads = config.num_attention_heads num_key_value_heads = config.num_key_value_heads hidden_size = config.hidden_size head_dim = getattr(config, 'head_dim', hidden_size // num_heads) - self.num_heads = num_heads // world_size - self.num_kv_heads = num_key_value_heads // world_size - self.head_dim = head_dim # packed qkv self.qkv_proj = build_qkv_proj( @@ -52,10 +45,10 @@ def __init__(self, # attention self.attn_fwd = Attention( - self.num_heads, - self.head_dim, - num_kv_heads=self.num_kv_heads, - v_head_size=self.head_dim, + num_heads, + head_dim, + num_kv_heads=num_key_value_heads, + v_head_size=head_dim, ) # o_proj @@ -80,15 +73,8 @@ def forward( qkv_states = self.qkv_proj(hidden_states) # (-1, heads, head_dim) qkv_states = qkv_states.flatten(0, -2) - qkv_states = qkv_states.unflatten(-1, (-1, self.head_dim)) - query_states, key_states, value_states = qkv_states.split( - ( - self.num_heads, - self.num_kv_heads, - self.num_kv_heads, - ), - dim=1, - ) + query_states, key_states, value_states = self.qkv_proj.split_qkv( + qkv_states) # apply rotary embedding cos, sin = rotary_pos_emb diff --git a/lmdeploy/pytorch/models/mixtral.py b/lmdeploy/pytorch/models/mixtral.py index 78a776cdd..138c42ef6 100644 --- a/lmdeploy/pytorch/models/mixtral.py +++ b/lmdeploy/pytorch/models/mixtral.py @@ -2,7 +2,6 @@ from typing import Any, Iterable, List, Optional, Tuple, Union import torch -import torch.distributed as dist from torch import nn from transformers.modeling_outputs import BaseModelOutputWithPast @@ -22,18 +21,12 @@ def __init__(self, dtype: torch.dtype = None, device: torch.device = None): super().__init__() - world_size = 1 - if dist.is_initialized(): - world_size = dist.get_world_size() quantization_config = None num_heads = config.num_attention_heads num_key_value_heads = config.num_key_value_heads hidden_size = config.hidden_size head_dim = hidden_size // num_heads - self.num_heads = num_heads // world_size - self.num_kv_heads = num_key_value_heads // world_size - self.head_dim = head_dim # qkv self.qkv_proj = build_qkv_proj( @@ -52,10 +45,10 @@ def __init__(self, # attention self.window_size = config.sliding_window or -1 self.attn_fwd = Attention( - self.num_heads, - self.head_dim, - num_kv_heads=self.num_kv_heads, - v_head_size=self.head_dim, + num_heads, + head_dim, + num_kv_heads=num_key_value_heads, + v_head_size=head_dim, sliding_window=self.window_size, ) @@ -80,15 +73,8 @@ def forward( qkv_states = self.qkv_proj(hidden_states) # (-1, heads, head_dim) qkv_states = qkv_states.flatten(0, -2) - qkv_states = qkv_states.unflatten(-1, (-1, self.head_dim)) - query_states, key_states, value_states = qkv_states.split( - ( - self.num_heads, - self.num_kv_heads, - self.num_kv_heads, - ), - dim=1, - ) + query_states, key_states, value_states = self.qkv_proj.split_qkv( + qkv_states) cos, sin = rotary_pos_emb query_states, key_states = self.apply_rotary_pos_emb( @@ -122,12 +108,6 @@ def __init__(self, dtype: torch.dtype = None, device: torch.device = None): super().__init__() - world_size = 1 - if dist.is_initialized(): - world_size = dist.get_world_size() - is_tp = world_size > 1 - self.is_tp = is_tp - self.hidden_dim = config.hidden_size self.ffn_dim = config.intermediate_size self.num_experts = config.num_local_experts @@ -152,7 +132,7 @@ def __init__(self, renormalize=True, dtype=dtype, device=device, - is_tp=False, + all_reduce=True, ) def forward(self, hidden_states: torch.Tensor): @@ -169,9 +149,6 @@ def forward(self, hidden_states: torch.Tensor): ) out_states = out_states.reshape(batch_size, sequence_length, -1) - - if self.is_tp: - dist.all_reduce(out_states) return out_states, router_logits diff --git a/lmdeploy/pytorch/nn/attention.py b/lmdeploy/pytorch/nn/attention.py index a7a7326bd..a73496824 100644 --- a/lmdeploy/pytorch/nn/attention.py +++ b/lmdeploy/pytorch/nn/attention.py @@ -4,6 +4,7 @@ from ..backends import LayerType, get_backend from ..backends.attention import AttentionMetadata +from .utils import get_distribute_size, get_world_rank class Attention(nn.Module): @@ -19,9 +20,13 @@ def __init__( alibi_scale: float = None, sliding_window: int = None, logit_softcapping: float = None, + replicate_kv: bool = False, **kwargs, ): super().__init__() + num_heads, num_kv_heads = self._update_num_heads( + num_heads, num_kv_heads, replicate_kv) + layer_backend = get_backend() impl_builder = layer_backend.get_layer_impl_builder( LayerType.Attention) @@ -38,6 +43,15 @@ def __init__( **kwargs, ) + def _update_num_heads(self, num_heads: int, num_kv_heads: int, + replicate_kv: bool): + """update heads.""" + world_size, rank = get_world_rank() + num_heads = get_distribute_size(num_heads, world_size, rank) + if not replicate_kv: + num_kv_heads = get_distribute_size(num_kv_heads, world_size, rank) + return num_heads, num_kv_heads + def forward( self, query: torch.Tensor, diff --git a/lmdeploy/pytorch/nn/linear.py b/lmdeploy/pytorch/nn/linear.py index c7df46eb2..232d836b2 100644 --- a/lmdeploy/pytorch/nn/linear.py +++ b/lmdeploy/pytorch/nn/linear.py @@ -11,27 +11,11 @@ from ..backends import LayerType, get_backend from ..backends.slora import AdapterInfo +from .utils import div_up, get_distribute_size, get_world_rank logger = get_logger('lmdeploy') -def _get_world_rank(): - """get current world size and rank.""" - world_size = 1 - rank = 0 - - if dist.is_initialized(): - world_size = dist.get_world_size() - rank = dist.get_rank() - - return world_size, rank - - -def _div_up(a: int, b: int): - """div up.""" - return (a + b - 1) // b - - def _chunk_align(weight: torch.Tensor, chunks: int, dim: int, align: int): """chunk aligned.""" if align == 1: @@ -39,24 +23,45 @@ def _chunk_align(weight: torch.Tensor, chunks: int, dim: int, align: int): size = weight.size(dim) assert size % align == 0 aligned_size = size // align - align_per_chunk = _div_up(aligned_size, chunks) + align_per_chunk = div_up(aligned_size, chunks) sections = [align_per_chunk] * (chunks - 1) sections += [aligned_size - align_per_chunk * (chunks - 1)] sections = [sec * align for sec in sections] return weight.split(sections, dim=dim) -def _update_feature_size(feature_size: int, - world_size: int, - rank: int, - align: int = 1): - """update feature size.""" - assert feature_size % align == 0 - aligned_size = feature_size // align - align_per_rank = _div_up(aligned_size, world_size) - prev_feats = align_per_rank * rank - updated_aligned_size = min(align_per_rank, aligned_size - prev_feats) - return updated_aligned_size * align +class QKVMixin: + + def _get_qkv_out_features(self, num_q_heads: int, num_kv_heads: int, + head_size: int, head_size_v: int): + """get io features.""" + all_out_features = (num_q_heads * head_size, num_kv_heads * head_size, + num_kv_heads * head_size_v) + return all_out_features + + def _update_num_heads(self, num_q_heads: int, num_kv_heads: int, + replicate_kv: bool): + """update num heads.""" + world_size, rank = get_world_rank() + num_q_heads = get_distribute_size(num_q_heads, world_size, rank) + if not replicate_kv: + num_kv_heads = get_distribute_size(num_kv_heads, world_size, rank) + + return num_q_heads, num_kv_heads + + def split_qkv(self, x: torch.Tensor): + """split query, key and value.""" + num_q_heads = self.num_q_heads + num_kv_heads = self.num_kv_heads + head_size = self.head_size + head_size_v = self.head_size_v + + sections = self.all_out_features + q, k, v = x.split(sections, dim=-1) + q = q.unflatten(-1, (num_q_heads, head_size)) + k = k.unflatten(-1, (num_kv_heads, head_size)) + v = v.unflatten(-1, (num_kv_heads, head_size_v)) + return q, k, v class SLoRA(nn.Module): @@ -163,17 +168,17 @@ def _get_io_features(self, in_features: int, out_features: int, w_bit: int, group_size: int, colwise: bool): """get io features.""" align = max(32 // w_bit, group_size) - world_size, rank = _get_world_rank() + world_size, rank = get_world_rank() if colwise: - out_features = _update_feature_size(out_features, - world_size, - rank, - align=align) - else: - in_features = _update_feature_size(in_features, + out_features = get_distribute_size(out_features, world_size, rank, align=align) + else: + in_features = get_distribute_size(in_features, + world_size, + rank, + align=align) return in_features, out_features def _weight_loader_tp_colwise(self, param: torch.nn.Parameter, @@ -220,7 +225,7 @@ def weight_loader(self, param: torch.nn.Parameter, if not self.is_tp: return default_weight_loader(param, loaded_weight) - world_size, rank = _get_world_rank() + world_size, rank = get_world_rank() if self.colwise: return self._weight_loader_tp_colwise(param, loaded_weight, rank, world_size) @@ -301,10 +306,10 @@ def __init__(self, w_bit: int, group_size: int, bias: bool, - replicate: List[bool] = None, + replicate: Optional[List[bool]] = None, device: Optional[torch.device] = None, is_tp: bool = True, - out_names: List[int] = None): + out_names: Optional[List[int]] = None): if replicate is None: replicate = tuple(False for _ in all_out_features) all_out_features = self._update_all_out_features( @@ -341,23 +346,24 @@ def _get_io_features(self, in_features: int, out_features: int, w_bit: int, return in_features, out_features def _update_all_out_features(self, all_out_features: List[int], w_bit: int, - group_size: int, replicate: List[bool]): + group_size: int, + replicate: Optional[List[bool]]): """update all out features.""" - world_size, rank = _get_world_rank() + world_size, rank = get_world_rank() new_all_out_features = [] align = max(32 // w_bit, group_size) for out_feat, rep in zip(all_out_features, replicate): if rep: new_all_out_features.append(out_feat) - new_out_feat = _update_feature_size(out_feat, world_size, rank, - align) + new_out_feat = get_distribute_size(out_feat, world_size, rank, + align) new_all_out_features.append(new_out_feat) return new_all_out_features def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, shard_id: Any): """weight loader.""" - world_size, rank = _get_world_rank() + world_size, rank = get_world_rank() shard_idx = self.out_names_map[shard_id] if loaded_weight.dim() == 1: @@ -387,7 +393,7 @@ def weight_loader(self, param: torch.nn.Parameter, param_w.copy_(weight) -class QKVAwqLinear(MergedAwqLinear): +class QKVAwqLinear(MergedAwqLinear, QKVMixin): """qkv awq linear.""" def __init__(self, @@ -425,28 +431,12 @@ def __init__(self, is_tp=is_tp, out_names=out_names) - def _get_qkv_out_features(self, num_q_heads: int, num_kv_heads: int, - head_size: int, head_size_v: int): - """get io features.""" - all_out_features = (num_q_heads * head_size, num_kv_heads * head_size, - num_kv_heads * head_size_v) - return all_out_features - def _update_all_out_features(self, all_out_features: List[int], w_bit: int, - group_size: int, replicate: List[bool]): + group_size: int, + replicate: Optional[List[bool]]): """update all out features.""" return all_out_features - def _update_num_heads(self, num_q_heads: int, num_kv_heads: int, - replicate_kv: bool): - """update num heads.""" - world_size, rank = _get_world_rank() - num_q_heads = _update_feature_size(num_q_heads, world_size, rank) - if not replicate_kv: - num_kv_heads = _update_feature_size(num_kv_heads, world_size, rank) - - return num_q_heads, num_kv_heads - class W8A8Linear(nn.Module): """w8a8 linear.""" @@ -495,11 +485,11 @@ def __init__(self, def _get_io_features(self, in_features: int, out_features: int, colwise: bool): """get io features.""" - world_size, rank = _get_world_rank() + world_size, rank = get_world_rank() if colwise: - out_features = _update_feature_size(out_features, world_size, rank) + out_features = get_distribute_size(out_features, world_size, rank) else: - in_features = _update_feature_size(in_features, world_size, rank) + in_features = get_distribute_size(in_features, world_size, rank) return in_features, out_features def _weight_loader_tp_colwise(self, param: torch.nn.Parameter, @@ -526,7 +516,7 @@ def weight_loader(self, param: torch.nn.Parameter, if not self.is_tp: return default_weight_loader(param, loaded_weight) - world_size, rank = _get_world_rank() + world_size, rank = get_world_rank() if self.colwise: return self._weight_loader_tp_colwise(param, loaded_weight, rank, world_size) @@ -586,11 +576,11 @@ def __init__(self, in_features: int, all_out_features: List[int], bias: bool, - replicate: List[bool] = None, + replicate: Optional[List[bool]] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, is_tp: bool = True, - out_names: List[int] = None): + out_names: Optional[List[int]] = None): if replicate is None: replicate = tuple(False for _ in all_out_features) all_out_features = self._update_all_out_features( @@ -621,21 +611,21 @@ def _get_io_features(self, in_features: int, out_features: int, return in_features, out_features def _update_all_out_features(self, all_out_features: List[int], - replicate: List[bool]): + replicate: Optional[List[bool]]): """update all out features.""" - world_size, rank = _get_world_rank() + world_size, rank = get_world_rank() new_all_out_features = [] for out_feat, rep in zip(all_out_features, replicate): if rep: new_all_out_features.append(out_feat) - new_out_feat = _update_feature_size(out_feat, world_size, rank) + new_out_feat = get_distribute_size(out_feat, world_size, rank) new_all_out_features.append(new_out_feat) return new_all_out_features def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, shard_id: Any): """weight loader.""" - world_size, rank = _get_world_rank() + world_size, rank = get_world_rank() shard_idx = self.out_names_map[shard_id] param_w = param.data.split(self.all_out_features, 0)[shard_idx] if not self.replicate[shard_idx]: @@ -643,7 +633,7 @@ def weight_loader(self, param: torch.nn.Parameter, param_w.copy_(loaded_weight) -class QKVW8A8Linear(MergedW8A8Linear): +class QKVW8A8Linear(MergedW8A8Linear, QKVMixin): """qkv w8a8 linear.""" def __init__(self, @@ -681,28 +671,11 @@ def __init__(self, is_tp=is_tp, out_names=out_names) - def _get_qkv_out_features(self, num_q_heads: int, num_kv_heads: int, - head_size: int, head_size_v: int): - """get io features.""" - all_out_features = (num_q_heads * head_size, num_kv_heads * head_size, - num_kv_heads * head_size_v) - return all_out_features - def _update_all_out_features(self, all_out_features: List[int], - replicate: List[bool]): + replicate: Optional[List[bool]]): """update all out features.""" return all_out_features - def _update_num_heads(self, num_q_heads: int, num_kv_heads: int, - replicate_kv: bool): - """update num heads.""" - world_size, rank = _get_world_rank() - num_q_heads = _update_feature_size(num_q_heads, world_size, rank) - if not replicate_kv: - num_kv_heads = _update_feature_size(num_kv_heads, world_size, rank) - - return num_q_heads, num_kv_heads - class BaseLinear(nn.Module): """linear layer.""" @@ -747,11 +720,11 @@ def __init__(self, def _get_io_features(self, in_features: int, out_features: int, colwise: bool): """get io features.""" - world_size, rank = _get_world_rank() + world_size, rank = get_world_rank() if colwise: - out_features = _update_feature_size(out_features, world_size, rank) + out_features = get_distribute_size(out_features, world_size, rank) else: - in_features = _update_feature_size(in_features, world_size, rank) + in_features = get_distribute_size(in_features, world_size, rank) return in_features, out_features def _weight_loader_tp_colwise(self, param: torch.nn.Parameter, @@ -778,7 +751,7 @@ def weight_loader(self, param: torch.nn.Parameter, if not self.is_tp: return default_weight_loader(param, loaded_weight) - world_size, rank = _get_world_rank() + world_size, rank = get_world_rank() if self.colwise: return self._weight_loader_tp_colwise(param, loaded_weight, rank, world_size) @@ -830,11 +803,11 @@ def __init__(self, in_features: int, all_out_features: List[int], bias: bool, - replicate: List[bool] = None, + replicate: Optional[List[bool]] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, is_tp: bool = True, - out_names: List[int] = None): + out_names: Optional[List[int]] = None): if replicate is None: replicate = tuple(False for _ in all_out_features) all_out_features = self._update_all_out_features( @@ -864,21 +837,21 @@ def _get_io_features(self, in_features: int, out_features: int, return in_features, out_features def _update_all_out_features(self, all_out_features: List[int], - replicate: List[bool]): + replicate: Optional[List[bool]]): """update all out features.""" - world_size, rank = _get_world_rank() + world_size, rank = get_world_rank() new_all_out_features = [] for out_feat, rep in zip(all_out_features, replicate): if rep: new_all_out_features.append(out_feat) - new_out_feat = _update_feature_size(out_feat, world_size, rank) + new_out_feat = get_distribute_size(out_feat, world_size, rank) new_all_out_features.append(new_out_feat) return new_all_out_features def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, shard_id: Any): """weight loader.""" - world_size, rank = _get_world_rank() + world_size, rank = get_world_rank() shard_idx = self.out_names_map[shard_id] param_w = param.data.split(self.all_out_features, 0)[shard_idx] if not self.replicate[shard_idx]: @@ -886,7 +859,7 @@ def weight_loader(self, param: torch.nn.Parameter, param_w.copy_(loaded_weight) -class QKVBaseLinear(MergedBaseLinear): +class QKVBaseLinear(MergedBaseLinear, QKVMixin): """qkv base linear.""" def __init__(self, @@ -924,28 +897,11 @@ def __init__(self, is_tp=is_tp, out_names=out_names) - def _get_qkv_out_features(self, num_q_heads: int, num_kv_heads: int, - head_size: int, head_size_v: int): - """get io features.""" - all_out_features = (num_q_heads * head_size, num_kv_heads * head_size, - num_kv_heads * head_size_v) - return all_out_features - def _update_all_out_features(self, all_out_features: List[int], - replicate: List[bool]): + replicate: Optional[List[bool]]): """update all out features.""" return all_out_features - def _update_num_heads(self, num_q_heads: int, num_kv_heads: int, - replicate_kv: bool): - """update num heads.""" - world_size, rank = _get_world_rank() - num_q_heads = _update_feature_size(num_q_heads, world_size, rank) - if not replicate_kv: - num_kv_heads = _update_feature_size(num_kv_heads, world_size, rank) - - return num_q_heads, num_kv_heads - def build_linear(in_features: int, out_features: int, @@ -957,7 +913,7 @@ def build_linear(in_features: int, quant_config: Any = None) -> nn.Module: """build linear.""" if is_tp: - world_size, _ = _get_world_rank() + world_size, _ = get_world_rank() is_tp = world_size > 1 if quant_config is None: @@ -1047,7 +1003,7 @@ def build_merged_colwise_linear( ): """merge linear.""" if is_tp: - world_size, _ = _get_world_rank() + world_size, _ = get_world_rank() is_tp = world_size > 1 if quant_config is None: @@ -1101,7 +1057,7 @@ def build_qkv_proj(in_features: int, is_tp: bool = True): """build qkv proj.""" if is_tp: - world_size, _ = _get_world_rank() + world_size, _ = get_world_rank() is_tp = world_size > 1 if head_size_v is None: diff --git a/lmdeploy/pytorch/nn/moe.py b/lmdeploy/pytorch/nn/moe.py index 48ecb60c8..ffd68ddea 100644 --- a/lmdeploy/pytorch/nn/moe.py +++ b/lmdeploy/pytorch/nn/moe.py @@ -6,18 +6,7 @@ from torch import nn from ..backends import LayerType, get_backend - - -def _get_world_rank(): - """get current world size and rank.""" - world_size = 1 - rank = 0 - - if dist.is_initialized(): - world_size = dist.get_world_size() - rank = dist.get_rank() - - return world_size, rank +from .utils import get_world_rank class SoftmaxTopK(nn.Module): @@ -46,7 +35,7 @@ def __init__(self, renormalize: bool = False, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, - is_tp: bool = False): + all_reduce: bool = True): super().__init__() if device is None: device = torch.device('cpu') @@ -77,11 +66,14 @@ def __init__(self, self.num_experts = num_experts self.dtype = dtype self.device = device - self.is_tp = is_tp + world_size, _ = get_world_rank() + if world_size == 1: + all_reduce = False + self.all_reduce = all_reduce def _update_args(self, hidden_dim: int, ffn_dim: int): """update args.""" - world_size, _ = _get_world_rank() + world_size, _ = get_world_rank() assert ffn_dim % world_size == 0 ffn_dim = ffn_dim // world_size return hidden_dim, ffn_dim @@ -115,7 +107,7 @@ def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int, shard_id: str): """weight loader.""" - world_size, rank = _get_world_rank() + world_size, rank = get_world_rank() if shard_id == 'gate': param_data = param.data[expert_id, :self.ffn_dim] weight = loaded_weight.chunk(world_size, dim=0)[rank] @@ -131,5 +123,8 @@ def weight_loader(self, param: torch.nn.Parameter, def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.LongTensor): - return self.impl.forward(hidden_states, topk_weights, topk_ids, - self.gate_up_weights, self.down_weights) + ret = self.impl.forward(hidden_states, topk_weights, topk_ids, + self.gate_up_weights, self.down_weights) + if self.all_reduce: + dist.all_reduce(ret) + return ret diff --git a/lmdeploy/pytorch/nn/utils.py b/lmdeploy/pytorch/nn/utils.py new file mode 100644 index 000000000..c60ef835e --- /dev/null +++ b/lmdeploy/pytorch/nn/utils.py @@ -0,0 +1,33 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import torch.distributed as dist + + +def div_up(a: int, b: int): + """div up.""" + return (a + b - 1) // b + + +def get_world_rank(): + """get current world size and rank.""" + world_size = 1 + rank = 0 + + if dist.is_initialized(): + world_size = dist.get_world_size() + rank = dist.get_rank() + + return world_size, rank + + +def get_distribute_size(feature_size: int, + world_size: int, + rank: int, + align: int = 1): + """update feature size.""" + assert feature_size % align == 0 + aligned_size = feature_size // align + align_per_rank = div_up(aligned_size, world_size) + prev_feats = align_per_rank * rank + updated_aligned_size = min(align_per_rank, aligned_size - prev_feats) + return updated_aligned_size * align From eca95ff9ec46c55e0603b99275b1b0c4ac3eeef3 Mon Sep 17 00:00:00 2001 From: grimoire Date: Wed, 28 Aug 2024 11:50:27 +0800 Subject: [PATCH 42/89] fix linear device --- lmdeploy/pytorch/nn/linear.py | 11 +---------- lmdeploy/pytorch/nn/utils.py | 6 +++--- 2 files changed, 4 insertions(+), 13 deletions(-) diff --git a/lmdeploy/pytorch/nn/linear.py b/lmdeploy/pytorch/nn/linear.py index 232d836b2..12e60bf12 100644 --- a/lmdeploy/pytorch/nn/linear.py +++ b/lmdeploy/pytorch/nn/linear.py @@ -31,6 +31,7 @@ def _chunk_align(weight: torch.Tensor, chunks: int, dim: int, align: int): class QKVMixin: + """qkv mixin.""" def _get_qkv_out_features(self, num_q_heads: int, num_kv_heads: int, head_size: int, head_size_v: int): @@ -408,8 +409,6 @@ def __init__(self, bias: bool = False, device: Optional[torch.device] = None, is_tp: bool = True): - if device is None: - device = torch.device('cpus') num_q_heads, num_kv_heads = self._update_num_heads( num_q_heads, num_kv_heads, replicate_kv) all_out_features = self._get_qkv_out_features(num_q_heads, @@ -647,10 +646,6 @@ def __init__(self, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, is_tp: bool = True): - if device is None: - device = torch.device('cpus') - if dtype is None: - dtype = torch.float16 num_q_heads, num_kv_heads = self._update_num_heads( num_q_heads, num_kv_heads, replicate_kv) all_out_features = self._get_qkv_out_features(num_q_heads, @@ -873,10 +868,6 @@ def __init__(self, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, is_tp: bool = True): - if device is None: - device = torch.device('cpus') - if dtype is None: - dtype = torch.float16 num_q_heads, num_kv_heads = self._update_num_heads( num_q_heads, num_kv_heads, replicate_kv) all_out_features = self._get_qkv_out_features(num_q_heads, diff --git a/lmdeploy/pytorch/nn/utils.py b/lmdeploy/pytorch/nn/utils.py index c60ef835e..ad2fde818 100644 --- a/lmdeploy/pytorch/nn/utils.py +++ b/lmdeploy/pytorch/nn/utils.py @@ -21,9 +21,9 @@ def get_world_rank(): def get_distribute_size(feature_size: int, - world_size: int, - rank: int, - align: int = 1): + world_size: int, + rank: int, + align: int = 1): """update feature size.""" assert feature_size % align == 0 aligned_size = feature_size // align From d14289e79f0c63678a5589c1e521f594ebd52c43 Mon Sep 17 00:00:00 2001 From: grimoire Date: Wed, 28 Aug 2024 14:20:28 +0800 Subject: [PATCH 43/89] support baichuan 7b 13b --- lmdeploy/pytorch/backends/attention.py | 6 +- lmdeploy/pytorch/backends/cuda/attention.py | 87 ++- lmdeploy/pytorch/models/baichuan.py | 691 +++++++++++--------- lmdeploy/pytorch/models/module_map.py | 14 +- lmdeploy/pytorch/nn/attention.py | 18 +- 5 files changed, 441 insertions(+), 375 deletions(-) diff --git a/lmdeploy/pytorch/backends/attention.py b/lmdeploy/pytorch/backends/attention.py index e07bed5b9..ef0ac0e9a 100644 --- a/lmdeploy/pytorch/backends/attention.py +++ b/lmdeploy/pytorch/backends/attention.py @@ -29,7 +29,7 @@ def __init__( scale: float = None, num_kv_heads: int = None, v_head_size: int = None, - alibi_scale: float = None, + alibi: bool = None, sliding_window: int = None, logit_softcapping: float = None, **kwargs, @@ -48,7 +48,7 @@ def __init__( self.scale = scale self.num_kv_heads = num_kv_heads self.v_head_size = v_head_size - self.alibi_scale = alibi_scale + self.alibi = alibi self.sliding_window = sliding_window self.logit_softcapping = logit_softcapping @@ -77,7 +77,7 @@ def build( scale: float = None, num_kv_heads: int = None, v_head_size: int = None, - alibi_scale: float = None, + alibi: bool = False, sliding_window: int = None, logical_softcapping: float = None, **kwargs, diff --git a/lmdeploy/pytorch/backends/cuda/attention.py b/lmdeploy/pytorch/backends/cuda/attention.py index 75f2d3b05..d54b04850 100644 --- a/lmdeploy/pytorch/backends/cuda/attention.py +++ b/lmdeploy/pytorch/backends/cuda/attention.py @@ -1,9 +1,22 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch +import torch.distributed as dist from ..attention import AttentionBuilder, AttentionImpl, AttentionMetadata +def get_world_rank(): + """get current world size and rank.""" + world_size = 1 + rank = 0 + + if dist.is_initialized(): + world_size = dist.get_world_size() + rank = dist.get_rank() + + return world_size, rank + + class TritonAttentionMetadata(AttentionMetadata): """triton attention metadata.""" pass @@ -19,27 +32,34 @@ def __init__( scale: float = None, num_kv_heads: int = None, v_head_size: int = None, - alibi_scale: float = None, + alibi: bool = False, sliding_window: int = None, logit_softcapping: float = None, **kwargs, ): super().__init__( - num_heads, - head_size, - scale, - num_kv_heads, - v_head_size, - alibi_scale, - sliding_window, - logit_softcapping, + num_heads=num_heads, + head_size=head_size, + scale=scale, + num_kv_heads=num_kv_heads, + v_head_size=v_head_size, + alibi=alibi, + sliding_window=sliding_window, + logit_softcapping=logit_softcapping, **kwargs, ) - from lmdeploy.pytorch.kernels.cuda import (fill_kv_cache, + from lmdeploy.pytorch.kernels.cuda import (alibi_paged_attention_fwd, + fill_kv_cache, paged_attention_fwd) self.fill_kv_cache = fill_kv_cache self.paged_attention_fwd = paged_attention_fwd + self.alibi_paged_attention_fwd = alibi_paged_attention_fwd + + # for alibi attention + world_size, rank = get_world_rank() + self.alibi_head_offset = self.num_heads * rank + self.alibi_num_heads = self.num_heads * world_size def forward( self, @@ -79,20 +99,35 @@ def forward( o_shape = q_shape[:-1] + (self.v_head_size, ) attn_output = query.new_empty(o_shape) - self.paged_attention_fwd( - query, - k_cache, - v_cache, - attn_output, - block_offsets, - q_start_loc=q_start_loc, - q_seqlens=q_seqlens, - kv_seqlens=kv_seqlens, - max_seqlen=max_q_seqlen, - window_size=self.sliding_window, - sm_scale=self.scale, - logit_softcapping=self.logit_softcapping, - ) + if not self.alibi: + self.paged_attention_fwd( + query, + k_cache, + v_cache, + attn_output, + block_offsets, + q_start_loc=q_start_loc, + q_seqlens=q_seqlens, + kv_seqlens=kv_seqlens, + max_seqlen=max_q_seqlen, + window_size=self.sliding_window, + sm_scale=self.scale, + logit_softcapping=self.logit_softcapping, + ) + else: + self.alibi_paged_attention_fwd( + query, + k_cache, + v_cache, + attn_output, + block_offsets, + b_start_loc=q_start_loc, + b_seq_len=q_seqlens, + b_kv_seq_len=kv_seqlens, + max_input_len=max_q_seqlen, + head_offset=self.alibi_head_offset, + num_heads=self.alibi_num_heads, + ) return attn_output @@ -107,7 +142,7 @@ def build( scale: float = None, num_kv_heads: int = None, v_head_size: int = None, - alibi_scale: float = None, + alibi: bool = False, sliding_window: int = None, logical_softcapping: float = None, **kwargs, @@ -118,7 +153,7 @@ def build( scale=scale, num_kv_heads=num_kv_heads, v_head_size=v_head_size, - alibi_scale=alibi_scale, + alibi=alibi, sliding_window=sliding_window, logical_softcapping=logical_softcapping, **kwargs) diff --git a/lmdeploy/pytorch/models/baichuan.py b/lmdeploy/pytorch/models/baichuan.py index 1a7c31952..8d8462e39 100644 --- a/lmdeploy/pytorch/models/baichuan.py +++ b/lmdeploy/pytorch/models/baichuan.py @@ -1,392 +1,433 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import List, Optional, Tuple, Union +from typing import Any, Iterable, List, Optional, Tuple, Union import torch -import torch.distributed as dist from torch import nn from transformers.modeling_outputs import BaseModelOutputWithPast -from ..kernels import apply_rotary_pos_emb -from ..kernels.alibi_pagedattention import alibi_paged_attention_fwd -from ..kernels.fill_kv_cache import fill_kv_cache -from ..kernels.pagedattention import paged_attention_fwd -from ..weight_loader.dist_utils import (colwise_split_parallelize_linear, - rowwise_parallelize_linear) +from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, EmbeddingType, + RMSNorm, SiluAndMul, build_rotary_embedding) +from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear, + build_qkv_proj, build_rowwise_linear) +from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight -class PatchedRMSNorm(nn.Module): - """Rewrite RMSNorm.""" +def _is_baichuan_13b(config: Any): + """is baichuan 13b.""" + return config.num_hidden_layers == 40 - def forward(self, hidden_states): - """forward.""" - from ..kernels import rms_norm - epsilon = getattr(self, 'epsilon', None) - if epsilon is None: - epsilon = getattr(self, 'variance_epsilon', 1e-10) - ret = rms_norm(hidden_states, self.weight, epsilon) - - return ret - - -def _attention_load_weights(mod, loader, rank: int = 0, world_size: int = 1): - """load weights.""" - w_pack_out = mod.W_pack.out_features - sections = [w_pack_out // 3] * 3 - colwise_split_parallelize_linear(mod.W_pack, - sections, - loader, - rank=rank, - world_size=world_size, - prefix='W_pack') - rowwise_parallelize_linear(mod.o_proj, - loader, - rank=rank, - world_size=world_size, - prefix='o_proj') - - -class Attention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper.""" - - def _load_weights(self, loader, rank: int, world_size: int, - device: torch.device): - """load weights.""" - _attention_load_weights(self, loader, rank=rank, world_size=world_size) - @classmethod - def _distribute_output_fn(cls, outputs, **kwargs): - """Distribution output hook.""" - dist.all_reduce(outputs[0]) - return outputs +class BaichuanAttention(nn.Module): + """Rewrite module of Attention.""" + + def __init__(self, + config: Any, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + num_heads = config.num_attention_heads + num_key_value_heads = num_heads + hidden_size = config.hidden_size + head_dim = hidden_size // num_heads + self.is_13b = _is_baichuan_13b(config) + + # packed qkv + self.W_pack = build_qkv_proj( + hidden_size, + num_q_heads=num_heads, + num_kv_heads=num_key_value_heads, + head_size=head_dim, + bias=False, + quant_config=quantization_config, + dtype=dtype, + device=device, + ) - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - """Rewrite of Attention.forward.""" - world_size = 1 - if dist.is_initialized(): - world_size = dist.get_world_size() - return self._contiguous_batching_forward( - hidden_states, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - world_size=world_size, + # rotary embedding + self.apply_rotary_pos_emb = ApplyRotaryEmb() + + # attention + self.attn_fwd = Attention( + num_heads, + head_dim, + num_kv_heads=num_key_value_heads, + v_head_size=head_dim, + alibi=self.is_13b, ) - def _contiguous_batching_forward( + # o_proj + self.o_proj = build_rowwise_linear(num_heads * head_dim, + hidden_size, + bias=False, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward( self, hidden_states: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - world_size: int = 1, + attn_metadata: Any = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Rewrite implementation of Attention.forward. - - Add continuous batching support. Add paged attention support. TP - support. - """ - assert not output_attentions - context = self.context.context - max_kv_seq_length = context.max_kv_seq_length - kv_seq_length = context.kv_seq_length - q_seq_length = context.q_seq_length - q_start_loc = context.q_start_loc - block_offsets = context.block_offsets - max_q_seq_length = context.max_q_seq_length - - num_heads = self.num_heads // world_size - num_kv_heads = self.num_heads // world_size - head_dim = self.head_dim - - def _qkv_proj(hidden_states): - """qkv proj.""" - proj = self.W_pack(hidden_states) - return proj.chunk(3, -1) - - def _rotary_emb_fn(query_states, key_states, value_states): - if hasattr(self, 'rotary_emb'): - cos, sin = self.rotary_emb(value_states, - seq_len=max_kv_seq_length) - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin, position_ids, - context.position_ids_1d) - return query_states, key_states, value_states - - query_states, key_states, value_states = _qkv_proj(hidden_states) - - query_states = query_states.view(-1, num_heads, head_dim) - key_states = key_states.view(-1, num_kv_heads, head_dim) - value_states = value_states.view(-1, num_kv_heads, head_dim) - - query_states, key_states, value_states = _rotary_emb_fn( - query_states, key_states, value_states) - - fill_kv_cache( - key_states, - value_states, - past_key_value[0], - past_key_value[1], - q_start_loc, - q_seq_length, - kv_seq_length=kv_seq_length, - max_q_seq_length=max_q_seq_length, - block_offsets=block_offsets, - ) + """Rewrite of LlamaAttention.forward.""" + # qkv proj + qkv_states = self.W_pack(hidden_states) + # (-1, heads, head_dim) + qkv_states = qkv_states.flatten(0, -2) + query_states, key_states, value_states = self.W_pack.split_qkv( + qkv_states) + + # apply rotary embedding + if not self.is_13b: + cos, sin = rotary_pos_emb + query_states, key_states = self.apply_rotary_pos_emb( + query_states, + key_states, + cos, + sin, + inplace=True, + ) - attn_output = query_states - paged_attention_fwd( + # attention + attn_output = self.attn_fwd( query_states, + key_states, + value_states, past_key_value[0], past_key_value[1], - attn_output, - block_offsets, - q_start_loc=q_start_loc, - q_seqlens=q_seq_length, - kv_seqlens=kv_seq_length, - max_seqlen=max_q_seq_length, + attn_metadata, + inplace=True, ) + attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1) - hidden_size = num_heads * head_dim - attn_output = attn_output.reshape(*hidden_states.shape[:-1], - hidden_size) - + # o proj attn_output = self.o_proj(attn_output) + return attn_output + + +class MLP(nn.Module): + + def __init__(self, + config: Any, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + # gate up + self.gate_up_proj = build_merged_colwise_linear( + config.hidden_size, + [config.intermediate_size, config.intermediate_size], + bias=False, + dtype=dtype, + device=device, + quant_config=quantization_config, + is_tp=True, + ) - return attn_output, None, past_key_value - - -class BaichuanAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper.""" + # silu and mul + self.act_fn = SiluAndMul(inplace=True) - def _load_weights(self, loader, rank: int, world_size: int, - device: torch.device): - """load weights.""" - _attention_load_weights(self, loader, rank=rank, world_size=world_size) + # down + self.down_proj = build_rowwise_linear(config.intermediate_size, + config.hidden_size, + bias=False, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) - @classmethod - def _distribute_output_fn(cls, outputs, **kwargs): - """Distribution output hook.""" - dist.all_reduce(outputs[0]) - return outputs + def forward(self, x): + """forward.""" + gate_up = self.gate_up_proj(x) + act = self.act_fn(gate_up) + return self.down_proj(act) + + +class DecoderLayer(nn.Module): + """Baichuan decoder layer.""" + + def __init__(self, + config: Any, + layer_idx: int, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.layer_idx = layer_idx + quantization_config = getattr(config, 'quantization_config', None) + + # build attention layer + self.self_attn = BaichuanAttention(config, dtype=dtype, device=device) + + # builf MLP + self.mlp = MLP(config, dtype=dtype, device=device) + + # build input layer norm + self.input_layernorm = RMSNorm(config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) + + # build attention layer norm + self.post_attention_layernorm = RMSNorm( + config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - """Rewrite of BaichuanAttention.forward.""" - world_size = 1 - if dist.is_initialized(): - world_size = dist.get_world_size() - return self._contiguous_batching_forward( - hidden_states, - past_key_value=past_key_value, - output_attentions=output_attentions, - world_size=world_size, - ) + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], + past_key_value: Optional[List[torch.FloatTensor]], + residual: Optional[torch.Tensor] = None, + attn_metadata: Any = None, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + """forward.""" + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) - def _contiguous_batching_forward( - self, - hidden_states: torch.Tensor, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - world_size: int = 1, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - """Rewrite implementation of BaichuanAttention.forward. - - Add continuous batching support. Add paged attention support. TP - support. - """ - assert not output_attentions - context = self.context.context - kv_seq_length = context.kv_seq_length - q_seq_length = context.q_seq_length - q_start_loc = context.q_start_loc - block_offsets = context.block_offsets - max_q_seq_length = context.max_q_seq_length - - num_heads = self.num_heads // world_size - num_kv_heads = self.num_heads // world_size - head_dim = self.head_dim - - def _qkv_proj(hidden_states): - proj = self.W_pack(hidden_states) - return proj.chunk(3, -1) - - query_states, key_states, value_states = _qkv_proj(hidden_states) - query_states = query_states.view(-1, num_heads, head_dim) - key_states = key_states.view(-1, num_kv_heads, head_dim) - value_states = value_states.view(-1, num_kv_heads, head_dim) - - fill_kv_cache( - key_states, - value_states, - past_key_value[0], - past_key_value[1], - q_start_loc, - q_seq_length, - kv_seq_length=kv_seq_length, - max_q_seq_length=max_q_seq_length, - block_offsets=block_offsets, + # Self Attention + hidden_states = self.self_attn( + hidden_states=hidden_states, + rotary_pos_emb=rotary_pos_emb, + past_key_value=past_key_value, + attn_metadata=attn_metadata, ) - attn_output = query_states - - num_heads_full = num_heads - head_offset = 0 - if dist.is_initialized(): - world_size = dist.get_world_size() - rank = dist.get_rank() - num_heads_full = num_heads * world_size - head_offset = num_heads * rank - alibi_paged_attention_fwd(query_states, - past_key_value[0], - past_key_value[1], - attn_output, - block_offsets, - b_start_loc=q_start_loc, - b_seq_len=q_seq_length, - b_kv_seq_len=kv_seq_length, - max_input_len=max_q_seq_length, - head_offset=head_offset, - num_heads=num_heads_full) - - hidden_size = num_heads * head_dim - attn_output = attn_output.reshape(*hidden_states.shape[:-1], - hidden_size) - - attn_output = self.o_proj(attn_output) + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) - return attn_output, None, past_key_value + outputs = (hidden_states, residual) + return outputs class BaichuanModel(nn.Module): + """Baichuan model.""" + + def __init__(self, + config: Any, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + quantization_config = getattr(config, 'quantization_config', None) + + self.embed_tokens = nn.Embedding(config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=dtype, + device=device) + + # build all decode layers + self.layers = nn.ModuleList([ + DecoderLayer(config, layer_idx, dtype=dtype, device=device) + for layer_idx in range(config.num_hidden_layers) + ]) + + # build norm + self.norm = RMSNorm(config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) + + self.is_13b = _is_baichuan_13b(config) + if not self.is_13b: + # build rotary embedding in LlamaModel + emb_type = EmbeddingType.LinearScaling + rope_dim = config.hidden_size // config.num_attention_heads + rope_max_pos_emb = config.max_position_embeddings + rope_base = 10000 + scaling_factor = 1.0 + self.rotary_emb = build_rotary_embedding( + rope_dim, + rope_max_pos_emb, + rope_base, + scaling_factor, + emb_type=emb_type, + ) - def _continuous_batching_forward_7b( + def forward( self, input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + attn_metadata: Any = None, inputs_embeds: Optional[torch.FloatTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: - """Rewrite implementation of 7b BaichuanModel.forward.""" - output_attentions = False - use_cache = True + """forward.""" + # token embedding if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - # Attention mask is not necessary in continuous batching - attention_mask = None - hidden_states = inputs_embeds - # decoder layers + # rotary embedding + rotary_pos_emb = (None, None) + if not self.is_13b: + cos, sin = self.rotary_emb(hidden_states, position_ids) + cos, sin = cos[0], sin[0] + rotary_pos_emb = (cos, sin) + + # decoding + residual = None for idx, decoder_layer in enumerate(self.layers): past_key_value = past_key_values[idx] - layer_outputs = decoder_layer( + hidden_states, residual = decoder_layer( hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, + rotary_pos_emb=rotary_pos_emb, past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, + residual=residual, + attn_metadata=attn_metadata, ) - hidden_states = layer_outputs[0] - hidden_states = self.norm(hidden_states) + # norm + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + def get_input_embeddings(self): + """get input embeddings.""" + return self.embed_tokens + + +class BaichuanForCausalLM(nn.Module): + """rewrote model of LlamaForCausalLM.""" + + support_cuda_graph = True + + packed_modules_mapping = { + 'gate_up_proj': [ + 'gate_proj', + 'up_proj', + ], + } + + def __init__(self, + config: Any, + ctx_mgr: StepContextManager, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.ctx_mgr = ctx_mgr + # build BaichuanModel + self.model = BaichuanModel(config, dtype=dtype, device=device) + # build lm_head + self.lm_head = build_rowwise_linear(config.hidden_size, + config.vocab_size, + bias=False, + dtype=dtype, + device=device) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: List[List[torch.Tensor]], + attn_metadata: Any = None, + inputs_embeds: torch.Tensor = None, + **kwargs, + ): + """model forward, return logits.""" + hidden_states = self.model( + input_ids=input_ids, + position_ids=position_ids, past_key_values=past_key_values, - hidden_states=None, - attentions=None, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, ) - def _continuous_batching_forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - """Rewrite implementation of BaichuanModel.forward. - - Add continuous batching support. Add paged attention support. TP - support. - """ - use_cache = False - output_attentions = False - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - # Attention mask is not necessary in continuous batching - attention_mask = None - hidden_states = inputs_embeds - - # decoder layers - for idx, decoder_layer in enumerate(self.layers): - past_key_value = past_key_values[idx] - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) + logits = self.lm_head(hidden_states) + logits = logits.float() + return logits - hidden_states = layer_outputs[0] + def get_input_embeddings(self): + """get input embeddings.""" + return self.model.get_input_embeddings() - hidden_states = self.norm(hidden_states) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, + def prepare_inputs_for_generation( + self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + context: StepContext = None, + ): + """prepare input.""" + # get input_ids, position_ids and attention metadatas + input_ids = context.input_ids + position_ids = context.position_ids + attn_metadata = context.attn_metadata + + # process vision embeddings + vision_embeddings = context.input_embeddings + vision_embedding_indexing = context.input_embedding_indexing + if vision_embeddings is not None and len(vision_embeddings) > 0: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + inputs_embeds[:, + vision_embedding_indexing, :] = vision_embeddings.to( + inputs_embeds) + + # inputs of forward + return dict( + input_ids=input_ids, + position_ids=position_ids, past_key_values=past_key_values, - hidden_states=None, - attentions=None, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, ) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - output_hidden_states: Optional[bool] = False, - return_dict: Optional[bool] = True, - ): - """Rewrite of BaichuanModel.forward.""" - if position_ids is not None: - return self._continuous_batching_forward_7b( - input_ids, - attention_mask, - position_ids, - past_key_values, - inputs_embeds, - ) - else: - return self._continuous_batching_forward( - input_ids, - attention_mask, - past_key_values, - inputs_embeds, - ) + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """load weights.""" + # modify from vllm + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ('.gate_up_proj', '.gate_proj', 0), + ('.gate_up_proj', '.up_proj', 1), + ] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if 'rotary_emb.inv_freq' in name: + continue + if ('rotary_emb.cos_cached' in name + or 'rotary_emb.sin_cached' in name): + continue + if self.config.tie_word_embeddings and 'lm_head.weight' in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + load_weight(param, loaded_weight, shard_id=shard_id) + break + else: + if '.W_pack' in name: + q, k, v = loaded_weight.chunk(3, 0) + param = params_dict[name] + load_weight(param, q, shard_id='q') + load_weight(param, k, shard_id='k') + load_weight(param, v, shard_id='v') + elif 'lm_head' in name: + loaded_weight = nn.functional.normalize(loaded_weight) + param = params_dict[name] + load_weight(param, loaded_weight) + else: + param = params_dict[name] + load_weight(param, loaded_weight) diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index 169974b7e..70b971278 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -31,18 +31,8 @@ # baichuan MODULE_MAP.update({ - 'modeling_baichuan.Model': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaModel', # noqa - 'modeling_baichuan.BaichuanModel': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.baichuan.BaichuanModel', # noqa - 'modeling_baichuan.Attention': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.baichuan.Attention', # noqa - 'modeling_baichuan.BaichuanAttention': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.baichuan.BaichuanAttention', # noqa - 'modeling_baichuan.MLP': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaMLP', # noqa - 'modeling_baichuan.RMSNorm': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.baichuan.PatchedRMSNorm', + 'BaichuanForCausalLM': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.baichuan.BaichuanForCausalLM', }) # chatglm2 diff --git a/lmdeploy/pytorch/nn/attention.py b/lmdeploy/pytorch/nn/attention.py index a73496824..c290bc358 100644 --- a/lmdeploy/pytorch/nn/attention.py +++ b/lmdeploy/pytorch/nn/attention.py @@ -17,7 +17,7 @@ def __init__( scale: float = None, num_kv_heads: int = None, v_head_size: int = None, - alibi_scale: float = None, + alibi: bool = False, sliding_window: int = None, logit_softcapping: float = None, replicate_kv: bool = False, @@ -32,14 +32,14 @@ def __init__( LayerType.Attention) self.impl = impl_builder.build( - num_heads, - head_size, - scale, - num_kv_heads, - v_head_size, - alibi_scale, - sliding_window, - logit_softcapping, + num_heads=num_heads, + head_size=head_size, + scale=scale, + num_kv_heads=num_kv_heads, + v_head_size=v_head_size, + alibi=alibi, + sliding_window=sliding_window, + logit_softcapping=logit_softcapping, **kwargs, ) From 8424a21ffaae97247ebf5266deed2e8b8362b859 Mon Sep 17 00:00:00 2001 From: grimoire Date: Wed, 28 Aug 2024 21:11:31 +0800 Subject: [PATCH 44/89] support deepseekv2 no-tp --- .../backends/default/rotary_embedding.py | 123 +- lmdeploy/pytorch/backends/rotary_embedding.py | 11 + lmdeploy/pytorch/kernels/cuda/rms_norm.py | 2 + lmdeploy/pytorch/models/awq_modules.py | 67 -- lmdeploy/pytorch/models/baichuan.py | 10 +- lmdeploy/pytorch/models/deepseek_v2.py | 1041 +++++++++++------ lmdeploy/pytorch/models/llama.py | 10 +- lmdeploy/pytorch/models/mixtral.py | 12 +- lmdeploy/pytorch/models/module_map.py | 14 +- lmdeploy/pytorch/nn/__init__.py | 1 + lmdeploy/pytorch/nn/rotary_embedding.py | 4 +- 11 files changed, 865 insertions(+), 430 deletions(-) delete mode 100644 lmdeploy/pytorch/models/awq_modules.py diff --git a/lmdeploy/pytorch/backends/default/rotary_embedding.py b/lmdeploy/pytorch/backends/default/rotary_embedding.py index 6625ff2b8..4e7e0ba1c 100644 --- a/lmdeploy/pytorch/backends/default/rotary_embedding.py +++ b/lmdeploy/pytorch/backends/default/rotary_embedding.py @@ -5,13 +5,20 @@ from torch import nn from ..rotary_embedding import (EmbeddingType, RotaryEmbeddingBuilder, - RotaryEmbeddingImpl) + RotaryEmbeddingImpl, YarnParameters) -def _rotary_embedding_fwd(position_ids: torch.Tensor, inv_freq: torch.Tensor, - scaling_factor: float, dtype: torch.dtype, - device_type: torch.device): +def _rotary_embedding_fwd(position_ids: torch.Tensor, + inv_freq: torch.Tensor, + scaling_factor: float, + mscale: float = 1.0, + dtype: torch.dtype = None, + device_type: torch.device = None): """rotary embedding forward.""" + if dtype is None: + dtype = torch.float16 + if device_type is None: + device_type = 'cuda' position_ids = position_ids.float() / scaling_factor inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], @@ -27,6 +34,11 @@ def _rotary_embedding_fwd(position_ids: torch.Tensor, inv_freq: torch.Tensor, emb = freqs.repeat(1, 1, 2) cos = emb.cos() sin = emb.sin() + + if mscale != 1.0: + cos = cos * mscale + sin = sin * mscale + return cos.to(dtype=dtype), sin.to(dtype=dtype) @@ -139,6 +151,102 @@ def __init__( self.register_buffer('inv_freq_llama', inv_freq_llama) +def yarn_find_correction_dim(num_rotations, + dim, + base=10000, + max_position_embeddings=2048): + """yarn_find_correction_dim.""" + return (dim * math.log(max_position_embeddings / + (num_rotations * 2 * math.pi))) / (2 * + math.log(base)) + + +# Find dim range bounds based on rotations +def yarn_find_correction_range(low_rot, + high_rot, + dim, + base=10000, + max_position_embeddings=2048): + """yarn_find_correction_range.""" + low = math.floor( + yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = math.ceil( + yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def yarn_get_mscale(scale=1, mscale=1): + """yarn_get_mscale.""" + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +def yarn_linear_ramp_mask(min, max, dim): + """yarn_linear_ramp_mask.""" + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +class YarnRotaryEmbeddingImpl(RotaryEmbeddingImpl): + """yarn rotary embedding implementation.""" + + def __init__(self, + dim: int, + base: int = 10000, + scaling_factor: float = 1.0, + original_max_position_embeddings: int = 4096, + yarn_params: YarnParameters = None): + super().__init__(dim, base, scaling_factor) + self.original_max_position_embeddings = \ + original_max_position_embeddings + assert yarn_params is not None + self.beta_fast = yarn_params.beta_fast + self.beta_slow = yarn_params.beta_slow + self.mscale = yarn_params.mscale + self.mscale_all_dim = yarn_params.mscale_all_dim + + # get inv_freq + freq_extra = 1.0 / (self.base**( + torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + freq_inter = 1.0 / (self.scaling_factor * self.base**( + torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + low, high = yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + dim, + self.base, + self.original_max_position_embeddings, + ) + inv_freq_mask = 1.0 - yarn_linear_ramp_mask( + low, high, dim // 2).to(dtype=torch.float32) + inv_freq = freq_inter * (1 - + inv_freq_mask) + freq_extra * inv_freq_mask + self.register_buffer('inv_freq', inv_freq, persistent=False) + + # get mscale + self.mscale = float( + yarn_get_mscale(self.scaling_factor, self.mscale) / + yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)) + + def forward(self, x: torch.Tensor, position_ids: torch.Tensor): + """forward.""" + device_type = x.device.type + dtype = x.dtype + if self.inv_freq.device != x.device: + self.inv_freq = self.inv_freq.to(x.device) + return _rotary_embedding_fwd(position_ids, + self.inv_freq, + scaling_factor=1.0, + mscale=self.mscale, + dtype=dtype, + device_type=device_type) + + class DefaultRotaryEmbeddingBuilder(RotaryEmbeddingBuilder): """rotary embedding builder.""" @@ -150,6 +258,7 @@ def build( scaling_factor: float = 1.0, low_freq_factor: float = 1.0, high_freq_factor: float = 4.0, + yarn_params: YarnParameters = None, emb_type: EmbeddingType = EmbeddingType.Default, ): """build.""" @@ -162,6 +271,12 @@ def build( return Llama3RotaryEmbeddingImpl(dim, base, scaling_factor, low_freq_factor, high_freq_factor, max_position_embeddings) + elif emb_type == EmbeddingType.Yarn: + return YarnRotaryEmbeddingImpl(dim, + base, + scaling_factor, + max_position_embeddings, + yarn_params=yarn_params) else: raise NotImplementedError( f'Unsupported embedding type: {emb_type}') diff --git a/lmdeploy/pytorch/backends/rotary_embedding.py b/lmdeploy/pytorch/backends/rotary_embedding.py index 1a2cc919f..2dfae1e78 100644 --- a/lmdeploy/pytorch/backends/rotary_embedding.py +++ b/lmdeploy/pytorch/backends/rotary_embedding.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from abc import ABC, abstractmethod +from dataclasses import dataclass from enum import Enum, auto @@ -9,6 +10,16 @@ class EmbeddingType(Enum): LinearScaling = auto() DynamicNTKScaling = auto() Llama3 = auto() + Yarn = auto() + + +@dataclass +class YarnParameters: + """Yarn parameters.""" + beta_fast: int = 32 + beta_slow: float = 1 + mscale: int = 1 + mscale_all_dim: int = 0 class RotaryEmbeddingImpl(ABC): diff --git a/lmdeploy/pytorch/kernels/cuda/rms_norm.py b/lmdeploy/pytorch/kernels/cuda/rms_norm.py index 0a171f762..f0cea6fe7 100644 --- a/lmdeploy/pytorch/kernels/cuda/rms_norm.py +++ b/lmdeploy/pytorch/kernels/cuda/rms_norm.py @@ -90,6 +90,8 @@ def rms_norm(hidden_states: Tensor, out: Tensor = None, out_residual: Tensor = None): """rms norm.""" + if not hidden_states.is_contiguous(): + hidden_states = hidden_states.contiguous() feat_size = weight.shape[0] seq_len = hidden_states.numel() // hidden_states.size(-1) diff --git a/lmdeploy/pytorch/models/awq_modules.py b/lmdeploy/pytorch/models/awq_modules.py deleted file mode 100644 index 19f800bc9..000000000 --- a/lmdeploy/pytorch/models/awq_modules.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import torch -from torch import nn - - -def wq_gemm_forward( - x, - qweight, - qzeros, - scales, - w_bit=4, - group_size=128, - bias=None, - out_features=0, -): - """wq gemm forward.""" - from awq.modules.linear.gemm import AWQ_INSTALLED, dequantize_gemm - out_shape = x.shape[:-1] + (out_features, ) - input_dtype = x.dtype - if input_dtype != torch.float16: - x = x.half() - - if AWQ_INSTALLED: - from awq.modules.linear.gemm import awq_ext - FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0] * x.shape[1] >= 1024 - - if FP16_MATMUL_HEURISTIC_CONDITION: - # TODO: remove event wait if awq kernel set stream - default_stream = torch.cuda.default_stream() - event_def = torch.cuda.Event() - event_def.record() - event_def.wait(default_stream) - out = awq_ext.dequantize_weights_cuda(qweight, scales, qzeros, 0, - 0, 0, False) - event_def = torch.cuda.Event() - event_def.record(default_stream) - event_def.wait() - out = torch.matmul(x, out) - else: - x = x.flatten(0, -2) - if not x.is_contiguous(): - x = x.contiguous() - out = awq_ext.gemm_forward_cuda(x, qweight, scales, qzeros, 8) - else: - out = dequantize_gemm(qweight, qzeros, scales, w_bit, group_size) - out = torch.matmul(x, out) - - out = out + bias if bias is not None else out - out = out.reshape(out_shape) - - # always want 3D tensor if tensor is 2D - if len(out.shape) == 2: - out = out.unsqueeze(0) - - if input_dtype != torch.float16: - out = out.to(dtype=input_dtype) - return out - - -class PatchedWQLinear_GEMM(nn.Module): - - def forward(self, x): - """forward.""" - out_features = self.scales.size(1) - return wq_gemm_forward(x, self.qweight, self.qzeros, self.scales, - self.w_bit, self.group_size, self.bias, - out_features) diff --git a/lmdeploy/pytorch/models/baichuan.py b/lmdeploy/pytorch/models/baichuan.py index 8d8462e39..575a82c83 100644 --- a/lmdeploy/pytorch/models/baichuan.py +++ b/lmdeploy/pytorch/models/baichuan.py @@ -1,9 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Iterable, List, Optional, Tuple, Union +from typing import Any, Iterable, List, Optional, Tuple import torch from torch import nn -from transformers.modeling_outputs import BaseModelOutputWithPast from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, EmbeddingType, @@ -72,8 +71,7 @@ def forward( rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], past_key_value: Optional[Tuple[torch.Tensor]] = None, attn_metadata: Any = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: + ): """Rewrite of LlamaAttention.forward.""" # qkv proj qkv_states = self.W_pack(hidden_states) @@ -188,7 +186,7 @@ def forward( past_key_value: Optional[List[torch.FloatTensor]], residual: Optional[torch.Tensor] = None, attn_metadata: Any = None, - ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + ): """forward.""" if residual is None: residual = hidden_states @@ -268,7 +266,7 @@ def forward( past_key_values: Optional[List[torch.FloatTensor]] = None, attn_metadata: Any = None, inputs_embeds: Optional[torch.FloatTensor] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: + ): """forward.""" # token embedding diff --git a/lmdeploy/pytorch/models/deepseek_v2.py b/lmdeploy/pytorch/models/deepseek_v2.py index 69f87329c..1b3cb2559 100644 --- a/lmdeploy/pytorch/models/deepseek_v2.py +++ b/lmdeploy/pytorch/models/deepseek_v2.py @@ -1,373 +1,760 @@ # Copyright (c) OpenMMLab. All rights reserved. -import gc -from typing import Any, Optional +import math +from typing import Any, Dict, Iterable, List, Optional, Tuple import torch import torch.distributed as dist from torch import nn -from lmdeploy.pytorch.kernels.fused_moe import fused_moe +from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, EmbeddingType, + RMSNorm, SiluAndMul, build_rotary_embedding) +from lmdeploy.pytorch.nn.linear import (build_colwise_linear, + build_merged_colwise_linear, + build_rowwise_linear) +from lmdeploy.pytorch.nn.moe import FusedMoE, SoftmaxTopK +from lmdeploy.pytorch.nn.rotary_embedding import YarnParameters +from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight + + +def yarn_get_mscale(scale=1, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +def get_world_rank(): + """get current world size and rank.""" + import torch.distributed as dist + world_size = 1 + rank = 0 + + if dist.is_initialized(): + world_size = dist.get_world_size() + rank = dist.get_rank() + + return world_size, rank + + +class DeepseekV2BMM(nn.Module): + """wrapped bmm.""" + + def __init__(self, batch: int, in_features: int, out_features: int, + dtype: torch.dtype, device: torch.device): + super().__init__() + batch = self._update_batch(batch) + + weight = self.create_weight(batch, + in_features, + out_features, + dtype=dtype, + device=device) + weight = torch.nn.Parameter(weight, requires_grad=False) + self.register_parameter('weight', weight) + weight.weight_loader = self.weight_loader + + self.batch = batch + self.in_features = in_features + self.out_features = out_features + self.dtype = dtype + self.device = device + + def _update_batch(self, batch: int): + """update out features.""" + world_size, _ = get_world_rank() + batch = batch // world_size + return batch + + def create_weight(self, batch: int, in_features: int, out_features: int, + dtype: torch.dtype, device: torch.device): + """create weight.""" + return torch.empty((batch, in_features, out_features), + dtype=dtype, + device=device) + + def weight_loader(self, param: nn.Parameter, weight: torch.Tensor): + """weight loader.""" + world_size, rank = get_world_rank() + weight = weight.chunk(world_size, 0)[rank] + param.data.copy_(weight) + + def forward(self, x: torch.Tensor, output: torch.Tensor): + """forward.""" + torch.bmm(x.transpose(0, 1), self.weight, out=output.transpose(0, 1)) + + +class DeepseekV2Attention(nn.Module): + """deepseekv2 attention.""" + + def __init__(self, + config: Any, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + quantization_config = None + self.q_lora_rank = config.q_lora_rank + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.qk_rope_head_dim = config.qk_rope_head_dim + self.kv_lora_rank = config.kv_lora_rank + self.v_head_dim = config.v_head_dim + self.qk_nope_head_dim = config.qk_nope_head_dim + self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim -from ..kernels import (apply_rotary_pos_emb, fill_kv_cache, - paged_attention_fwd, rms_norm) -from ..weight_loader.dist_utils import (colwise_parallelize_linear, - rowwise_parallelize_linear) + if self.q_lora_rank is None: + self.q_proj = build_colwise_linear( + self.hidden_size, + self.num_heads * self.q_head_dim, + bias=False, + dtype=dtype, + device=device, + is_tp=True, + ) + else: + self.q_a_proj = build_colwise_linear( + self.hidden_size, + config.q_lora_rank, + bias=config.attention_bias, + dtype=dtype, + device=device, + is_tp=False, + ) + self.q_a_layernorm = RMSNorm(config.q_lora_rank, + 1e-6, + quant_config=quantization_config, + dtype=dtype, + device=device) + self.q_b_proj = build_colwise_linear( + config.q_lora_rank, + self.num_heads * self.q_head_dim, + bias=False, + dtype=dtype, + device=device, + is_tp=True, + ) + + self.kv_a_proj_with_mqa = build_colwise_linear( + self.hidden_size, + config.kv_lora_rank + config.qk_rope_head_dim, + bias=config.attention_bias, + dtype=dtype, + device=device, + is_tp=False, + ) + self.kv_a_layernorm = RMSNorm(config.kv_lora_rank, + 1e-6, + quant_config=quantization_config, + dtype=dtype, + device=device) + self.kc = DeepseekV2BMM(self.num_heads, + config.qk_nope_head_dim, + config.kv_lora_rank, + dtype=dtype, + device=device) + + self.apply_rotary_pos_emb = ApplyRotaryEmb() + + self.softmax_scale = self.q_head_dim**(-0.5) + + if config.rope_scaling is not None: + mscale_all_dim = config.rope_scaling.get('mscale_all_dim', 0) + scaling_factor = config.rope_scaling['factor'] + if mscale_all_dim: + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + self.softmax_scale = self.softmax_scale * mscale * mscale + + self.attn_fwd = Attention( + self.num_heads, + config.kv_lora_rank + self.qk_rope_head_dim, + scale=self.softmax_scale, + num_kv_heads=1, + v_head_size=config.kv_lora_rank, + replicate_kv=True, + ) + self.vc = DeepseekV2BMM(self.num_heads, + config.kv_lora_rank, + self.v_head_dim, + dtype=dtype, + device=device) + self.o_proj = build_rowwise_linear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=config.attention_bias, + dtype=dtype, + device=device, + is_tp=True, + ) -class PatchedDeepseekV2Attention(nn.Module): + def _q_proj(self, hidden_states, num_heads: int, nope_size: int, + pe_size: int): + """q proj.""" + q_len = hidden_states.size(1) - def _load_weights(self, loader, rank: int, world_size: int, - device: torch.device): - """load weights.""" + query_states = hidden_states.new_empty(q_len, num_heads, + nope_size + pe_size) - mod_name_list = ['kv_a_proj_with_mqa', 'kv_a_layernorm'] - if self.q_lora_rank is not None: - mod_name_list += ['q_a_proj', 'q_a_layernorm'] - for mod_name in mod_name_list: - with loader.prefix_context(mod_name): - loader.load_model_weights(getattr(self, mod_name), - rank=rank, - world_size=world_size, - device=device) - - mod_name_list = ['q_proj', 'kv_b_proj'] - if self.q_lora_rank is not None: - mod_name_list = ['q_b_proj', 'kv_b_proj'] - for mod_name in mod_name_list: - colwise_parallelize_linear(getattr(self, mod_name), - loader, - rank=rank, - world_size=world_size, - prefix=mod_name) - for mod_name in ['o_proj']: - rowwise_parallelize_linear(getattr(self, mod_name), - loader, - rank=rank, - world_size=world_size, - prefix=mod_name) - - def _update_model_fn(self): - """update model.""" - qk_nope_head_dim = self.qk_nope_head_dim - v_head_dim = self.v_head_dim - - def __update_pe(mod, head_dim, pe_dim_offset): - weight = mod.weight.data - # (num_heads, q_head_dim, input_dim) - weight = weight.unflatten(0, (-1, head_dim)) - # (num_heads, nope_head_dim, input_dim) - w_pe = weight[:, pe_dim_offset:] - # (num_heads, nope_head_dim//2, 2, input_dim) - new_w_pe = w_pe.unflatten(1, (-1, 2)) - # (num_heads, nope_head_dim, input_dim) - new_w_pe = new_w_pe.transpose(1, 2).flatten(1, 2) - weight[:, pe_dim_offset:] = new_w_pe - - # prevent shuffle before apply rotary embedding if self.q_lora_rank is None: - __update_pe(self.q_proj, self.q_head_dim, qk_nope_head_dim) + q = self.q_proj(hidden_states) else: - __update_pe(self.q_b_proj, self.q_head_dim, qk_nope_head_dim) - kv_dim = self.kv_lora_rank + self.qk_rope_head_dim - __update_pe(self.kv_a_proj_with_mqa, kv_dim, self.kv_lora_rank) - - kv_b_proj = self.kv_b_proj - w_kc, w_vc = kv_b_proj.weight.unflatten( - 0, (-1, qk_nope_head_dim + v_head_dim)).split( - [qk_nope_head_dim, v_head_dim], dim=1) - - self.register_parameter('w_kc', - torch.nn.Parameter(w_kc, requires_grad=False)) - w_vc = w_vc.transpose(1, 2).contiguous() - self.register_parameter('w_vc', - torch.nn.Parameter(w_vc, requires_grad=False)) - - delattr(self, 'kv_b_proj') - gc.collect() - - @classmethod - def _distribute_output_fn(cls, outputs, **kwargs): - """Distribution output hook.""" - dist.all_reduce(outputs[0]) - return outputs + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(q_len, num_heads, self.q_head_dim) + # q_pe: (q_len, num_heads, qk_rope_head_dim) + q_nope, q_pe = torch.split( + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + # q_nope: (q_len, num_heads, kv_lora_rank) + q_nope_out = query_states[..., :nope_size] + self.kc(q_nope, q_nope_out) + return query_states, q_pe + + def _kv_proj(self, hidden_states, nope_size: int): + """kv proj.""" + # (q_len, 1, nope_size + pe_size) + key_states = self.kv_a_proj_with_mqa(hidden_states[0, :, None]) + # (q_len, 1, pe_size) + k_pe = key_states[..., nope_size:] + # kv_a_layernorm + value_states = key_states[..., :nope_size] + value_states = self.kv_a_layernorm(value_states) + key_states[..., :nope_size] = value_states + return key_states, value_states, k_pe + + def _qkv_proj(self, hidden_states: torch.Tensor, num_heads: int): + """qkv proj.""" + nope_size = self.kv_lora_rank + pe_size = self.qk_rope_head_dim + query_states, q_pe = self._q_proj(hidden_states, num_heads, nope_size, + pe_size) + key_states, value_states, k_pe = self._kv_proj(hidden_states, + nope_size) + + return query_states, key_states, value_states, q_pe, k_pe - def _contiguous_batching_forward_impl( + def forward( self, hidden_states: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Any] = None, - world_size: int = 1, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attn_metadata: Any = None, ): - """forward impl.""" - context = self.context.context - kv_seq_length = context.kv_seq_length - q_seq_length = context.q_seq_length - q_start_loc = context.q_start_loc - block_offsets = context.block_offsets - max_q_seq_length = context.max_q_seq_length - max_kv_seq_length = context.max_kv_seq_length + """Rewrite of LlamaAttention.forward.""" + world_size = 1 + if dist.is_initialized(): + world_size = dist.get_world_size() num_heads = self.num_heads // world_size - q_len = hidden_states.size(1) - - def __q_proj(hidden_states, nope_size: int, pe_size: int): - """q proj.""" - query_states = hidden_states.new_empty(q_len, num_heads, - nope_size + pe_size) - - if self.q_lora_rank is None: - q = self.q_proj(hidden_states) - else: - q = self.q_b_proj( - self.q_a_layernorm(self.q_a_proj(hidden_states))) - q = q.view(q_len, num_heads, self.q_head_dim) - # q_pe: (q_len, num_heads, qk_rope_head_dim) - q_nope, q_pe = torch.split( - q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - # q_nope: (q_len, num_heads, kv_lora_rank) - q_nope_out = query_states[..., :nope_size] - torch.bmm(q_nope.transpose(0, 1), - self.w_kc, - out=q_nope_out.transpose(0, 1)) - return query_states, q_pe - - def __kv_proj(hidden_states, nope_size: int): - """kv proj.""" - # (q_len, 1, nope_size + pe_size) - key_states = self.kv_a_proj_with_mqa(hidden_states[0, :, None]) - # (q_len, 1, nope_size) - value_states = key_states[..., :nope_size] - # (q_len, 1, pe_size) - k_pe = key_states[..., nope_size:] - # inplace kv_a_layernorm - rms_norm(value_states, - weight=self.kv_a_layernorm.weight, - eps=self.kv_a_layernorm.variance_epsilon, - out=value_states) - return key_states, value_states, k_pe - - def __qkv_proj(hidden_states): - """qkv proj.""" - nope_size = self.kv_lora_rank - pe_size = self.qk_rope_head_dim - query_states, q_pe = __q_proj(hidden_states, nope_size, pe_size) - key_states, value_states, k_pe = __kv_proj(hidden_states, - nope_size) - - return query_states, key_states, value_states, q_pe, k_pe - - def __rotary_emb_fn(q_pe, k_pe, out_q_pe, out_k_pe): - """rope.""" - if not hasattr(context, '_cos'): - cos, sin = self.rotary_emb(q_pe, seq_len=max_kv_seq_length) - context._cos = cos - context._sin = sin - else: - cos = context._cos - sin = context._sin - - apply_rotary_pos_emb(q_pe, - k_pe, - cos, - sin, - position_ids, - context.position_ids_1d, - q_embed=out_q_pe, - k_embed=out_k_pe) - return out_q_pe, out_k_pe - - query_states, key_states, value_states, q_pe, k_pe = __qkv_proj( - hidden_states) nope_size = self.kv_lora_rank - __rotary_emb_fn(q_pe, k_pe, query_states[..., nope_size:], - key_states[..., nope_size:]) + q_len = hidden_states.size(1) - fill_kv_cache( - key_states, - value_states[..., :0], - past_key_value[0], - past_key_value[0][..., :0], - q_start_loc, - q_seq_length, - kv_seq_length=kv_seq_length, - max_q_seq_length=max_q_seq_length, - block_offsets=block_offsets, + # qkv_proj + query_states, key_states, value_states, q_pe, k_pe = self._qkv_proj( + hidden_states, num_heads=num_heads) + + cos, sin = rotary_pos_emb + q_pe, k_pe = self.apply_rotary_pos_emb( + q_pe, + k_pe, + cos, + sin, + inplace=False, ) + query_states[..., nope_size:] = q_pe + key_states[..., nope_size:] = k_pe - attn_output = query_states[..., :nope_size] - block_size = past_key_value[0].size(1) - shared_kv = block_size >= 64 - paged_attention_fwd( + attn_output = self.attn_fwd( query_states, + key_states, + value_states, past_key_value[0], past_key_value[0][..., :nope_size], - attn_output, - block_offsets, - q_start_loc=q_start_loc, - q_seqlens=q_seq_length, - kv_seqlens=kv_seq_length, - max_seqlen=max_q_seq_length, - sm_scale=self.softmax_scale, - shared_kv=shared_kv, + attn_metadata, + inplace=True, ) - - # (num_heads, q_len, v_head_dim) attn_bmm_out = attn_output.new_empty(q_len, num_heads, self.v_head_dim) - torch.bmm(attn_output.transpose(0, 1), - self.w_vc, - out=attn_bmm_out.transpose(0, 1)) - # (1, q_len, o_proj_input) - attn_output = attn_bmm_out.flatten(-2, -1)[None] + self.vc(attn_output, attn_bmm_out) + attn_output = attn_bmm_out.flatten(-2, -1)[None] attn_output = self.o_proj(attn_output) - return attn_output, None, past_key_value + + return attn_output + + +class DeepseekV2MoE(nn.Module): + """Deepseek v2 MoE.""" + + def __init__(self, + config: Any, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.hidden_dim = config.hidden_size + self.ffn_dim = config.moe_intermediate_size + self.num_experts = config.n_routed_experts + self.top_k = config.num_experts_per_tok + self.norm_topk_prob = config.norm_topk_prob + self.routed_scaling_factor = config.routed_scaling_factor + self.renormalize = self.top_k > 1 and self.norm_topk_prob + + self.gate = build_rowwise_linear( + self.hidden_dim, + self.num_experts, + bias=False, + dtype=dtype, + device=device, + is_tp=False, + ) + + self.softmax_topk = SoftmaxTopK(self.top_k) + + self.experts = FusedMoE( + self.hidden_dim, + self.ffn_dim, + self.num_experts, + top_k=self.top_k, + renormalize=self.renormalize, + dtype=dtype, + device=device, + all_reduce=False, + ) + + self.shared_experts = None + if config.n_shared_experts is not None: + intermediate_size = (config.moe_intermediate_size * + config.n_shared_experts) + self.shared_experts = DeepseekV2MLP( + config=config, + intermediate_size=intermediate_size, + dtype=dtype, + device=device, + is_tp=True, + ) + + def forward(self, hidden_states: torch.Tensor): + """forward.""" + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + router_logits = self.gate(hidden_states) + + topk_weights, topk_ids = self.softmax_topk(router_logits) + if not self.renormalize: + topk_weights = topk_weights * self.routed_scaling_factor + out_states = self.experts( + hidden_states, + topk_weights, + topk_ids, + ) + + if self.shared_experts is not None: + shared_states = self.shared_experts(hidden_states) + out_states += shared_states + out_states = out_states.reshape(batch_size, sequence_length, -1) + + return out_states + + +class DeepseekV2MLP(nn.Module): + """Deepseek v2 mlp.""" + + def __init__(self, + config: Any, + intermediate_size: int = None, + dtype: torch.dtype = None, + device: torch.device = None, + is_tp: bool = True): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + # gate up + if intermediate_size is None: + intermediate_size = config.intermediate_size + self.gate_up_proj = build_merged_colwise_linear( + config.hidden_size, + [intermediate_size, intermediate_size], + bias=False, + dtype=dtype, + device=device, + quant_config=quantization_config, + is_tp=is_tp, + ) + + # silu and mul + self.act_fn = SiluAndMul(inplace=True) + + # down + self.down_proj = build_rowwise_linear(intermediate_size, + config.hidden_size, + bias=False, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=is_tp) + + def forward(self, x): + """forward.""" + gate_up = self.gate_up_proj(x) + act = self.act_fn(gate_up) + return self.down_proj(act) + + +class DeepseekV2DecoderLayer(nn.Module): + """Deepseekv2 decoder layer.""" + + def __init__(self, + config: Any, + layer_idx: int, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.layer_idx = layer_idx + quantization_config = None + + # build attention layer + self.self_attn = DeepseekV2Attention(config, + dtype=dtype, + device=device) + + # mlp + self.mlp = (DeepseekV2MoE(config, dtype=dtype, device=device) if + (config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0) else + DeepseekV2MLP(config, dtype=dtype, device=device)) + + # build input layer norm + self.input_layernorm = RMSNorm(config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) + + # build attention layer norm + self.post_attention_layernorm = RMSNorm( + config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], + past_key_value: Optional[List[torch.FloatTensor]], + residual: Optional[torch.Tensor] = None, + attn_metadata: Any = None, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + # Self Attention + hidden_states = self.self_attn( + hidden_states=hidden_states, + rotary_pos_emb=rotary_pos_emb, + past_key_value=past_key_value, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + + outputs = (hidden_states, residual) + return outputs + + +class DeepseekV2Model(nn.Module): + """mixtral model.""" + + def __init__(self, + config: Any, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = nn.Embedding(config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=dtype, + device=device) + self.layers = nn.ModuleList([ + DeepseekV2DecoderLayer(config, + layer_idx, + dtype=dtype, + device=device) + for layer_idx in range(config.num_hidden_layers) + ]) + + # build norm + self.norm = RMSNorm(config.hidden_size, + config.rms_norm_eps, + quant_config=None, + dtype=dtype, + device=device) + + emb_type = EmbeddingType.LinearScaling + rope_dim = config.qk_rope_head_dim + rope_max_pos_emb = config.max_position_embeddings + rope_base = config.rope_theta + scaling_factor = 1.0 + other_params = dict() + if config.rope_scaling is not None: + scaling_type = config.rope_scaling['type'] + scaling_factor = config.rope_scaling['factor'] + if scaling_type == 'dynamic': + emb_type = EmbeddingType.DynamicNTKScaling + elif scaling_type == 'yarn': + emb_type = EmbeddingType.Yarn + rope_max_pos_emb = config.rope_scaling.get( + 'original_max_position_embeddings', 4096) + kwargs = { + key: config.rope_scaling[key] + for key in [ + 'beta_fast', + 'beta_slow', + 'mscale', + 'mscale_all_dim', + ] if key in self.config.rope_scaling + } + yarn_params = YarnParameters(**kwargs) + other_params['yarn_params'] = yarn_params + self.rotary_emb = build_rotary_embedding(rope_dim, + rope_max_pos_emb, + rope_base, + scaling_factor, + emb_type=emb_type, + **other_params) + + def forward( + self, + input_ids: torch.LongTensor = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Any] = None, - output_attentions: bool = False, - use_cache: bool = False, + past_key_values: Optional[List[torch.FloatTensor]] = None, + attn_metadata: Any = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ): + """forward.""" + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + hidden_states = inputs_embeds + residual = None + cos, sin = self.rotary_emb(hidden_states, position_ids) + cos, sin = cos[0], sin[0] + rotary_pos_emb = (cos, sin) + for idx, decoder_layer in enumerate(self.layers): + + past_key_value = past_key_values[idx] + hidden_states, residual = decoder_layer( + hidden_states, + rotary_pos_emb=rotary_pos_emb, + past_key_value=past_key_value, + residual=residual, + attn_metadata=attn_metadata, + ) + + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + def get_input_embeddings(self): + """get input embeddings.""" + return self.embed_tokens + + +class DeepseekV2ForCausalLM(nn.Module): + """mixture model for causalLM.""" + + support_cuda_graph = True + + def __init__(self, + config: Any, + ctx_mgr: StepContextManager, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.ctx_mgr = ctx_mgr + self.model = DeepseekV2Model(config, dtype=dtype, device=device) + # build lm_head + self.lm_head = build_rowwise_linear(config.hidden_size, + config.vocab_size, + bias=False, + dtype=dtype, + device=device) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: List[List[torch.Tensor]], + attn_metadata: Any = None, + inputs_embeds: torch.Tensor = None, **kwargs, ): - """rewrite of forward.""" - world_size = 1 - if dist.is_initialized(): - world_size = dist.get_world_size() - return self._contiguous_batching_forward_impl( - hidden_states, + hidden_states = self.model( + input_ids=input_ids, position_ids=position_ids, - past_key_value=past_key_value, - world_size=world_size, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, ) + logits = self.lm_head(hidden_states) + logits = logits.float() + return logits -def _div_up(a, b): - """div up.""" - return (a + b - 1) // b + def get_input_embeddings(self): + """get input embeddings.""" + return self.model.get_input_embeddings() + def prepare_inputs_for_generation( + self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + context: StepContext = None, + ): + """prepare input.""" + input_ids = context.input_ids + position_ids = context.position_ids + attn_metadata = context.attn_metadata -class PatchedDeepseekV2MoE(nn.Module): + return dict( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + ) - def _load_weights(self, loader, rank: int, world_size: int, - device: torch.device): - """load weights.""" + def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, + params_dict: Dict[str, nn.Parameter], + expert_params_mapping: List): + """load weight experts.""" + for (param_name, weight_name, expert_id, + shard_id) in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + load_weight(param, + loaded_weight, + expert_id=expert_id, + shard_id=shard_id) + break + else: + param = params_dict[name] + load_weight(param, loaded_weight) - def __load_mlp(exp_id, exp): - """load mlp.""" - with loader.prefix_context(f'experts.{exp_id}'): - loader.load_model_weights( - exp, - rank=rank, - world_size=world_size, - device=device, - load_only=True, - ) - - def __drop_mlp(exp_id, exp): - """drop mlp.""" - for name, _ in exp.named_parameters(recurse=True): - loader.pop(f'experts.{exp_id}.{name}') - - num_experts = len(self.experts) - exp_per_rank = _div_up(num_experts, world_size) - first_exp = rank * exp_per_rank - last_exp = min(num_experts, first_exp + exp_per_rank) - for exp_id, exp in enumerate(self.experts): - if first_exp <= exp_id < last_exp: - __load_mlp(exp_id, exp) - else: - __drop_mlp(exp_id, exp) - self.experts = self.experts[first_exp:last_exp] - with loader.prefix_context('gate'): - loader.load_model_weights(self.gate, - rank=rank, - world_size=world_size, - device=device) + def _load_weight_attention(self, name: str, loaded_weight: torch.Tensor, + params_dict: Dict[str, nn.Parameter], + update_pe_mapping: List): + """load weight attention.""" - if self.config.n_shared_experts is not None: - with loader.prefix_context('shared_experts'): - loader.load_model_weights(self.shared_experts, - rank=rank, - world_size=world_size, - device=device) - - def _update_model_fn(self): - """update model.""" - num_experts = len(self.experts) - - def __get_meta(): - exp = self.experts[0] - ffn_dim = exp.gate_proj.weight.size(0) - hidden_dim = exp.down_proj.weight.size(0) - dtype = exp.gate_proj.weight.dtype - device = exp.gate_proj.weight.device - return ffn_dim, hidden_dim, dtype, device - - def __copy_assign_param(param, weight): - """copy assign.""" - weight.copy_(param.data) - param.data = weight - - ffn_dim, hidden_dim, dtype, device = __get_meta() - - gate_up_weights = torch.empty(num_experts, - ffn_dim * 2, - hidden_dim, - device=device, - dtype=dtype) - down_weights = torch.empty(num_experts, - hidden_dim, - ffn_dim, - device=device, - dtype=dtype) - - for exp_id, exp in enumerate(self.experts): - __copy_assign_param(exp.gate_proj.weight, - gate_up_weights[exp_id, :ffn_dim]) - __copy_assign_param(exp.up_proj.weight, gate_up_weights[exp_id, - ffn_dim:]) - __copy_assign_param(exp.down_proj.weight, down_weights[exp_id]) - - torch.cuda.empty_cache() - - self.register_buffer('gate_up_weights', gate_up_weights) - self.register_buffer('down_weights', down_weights) - - @classmethod - def _distribute_output_fn(cls, outputs, **kwargs): - """Distribution output hook.""" - dist.all_reduce(outputs) - return outputs + def __update_pe(weight, head_dim: int, pe_dim_offset: int): + # (num_heads, q_head_dim, input_dim) + weight = weight.unflatten(0, (-1, head_dim)) + # (num_heads, nope_head_dim, input_dim) + w_pe = weight[:, pe_dim_offset:] + # (num_heads, nope_head_dim//2, 2, input_dim) + new_w_pe = w_pe.unflatten(1, (-1, 2)) + # (num_heads, nope_head_dim, input_dim) + new_w_pe = new_w_pe.transpose(1, 2).flatten(1, 2) + weight[:, pe_dim_offset:] = new_w_pe + weight = weight.flatten(0, 1) + return weight + + for (mod_name, head_dim, pe_dim_offset) in update_pe_mapping: + if mod_name not in name: + continue + weight = __update_pe(loaded_weight, head_dim, pe_dim_offset) + param = params_dict[name] + load_weight(param, weight) + break + else: + if '.kv_b_proj' in name: + config = self.config + v_head_dim = config.v_head_dim + qk_nope_head_dim = config.qk_nope_head_dim + w_kc, w_vc = loaded_weight.unflatten( + 0, (-1, qk_nope_head_dim + v_head_dim)).split( + [qk_nope_head_dim, v_head_dim], dim=1) + w_vc = w_vc.transpose(1, 2).contiguous() + kc_param_name = name.replace('.kv_b_proj', '.kc') + param_kc = params_dict[kc_param_name] + load_weight(param_kc, w_kc) + vc_param_name = name.replace('.kv_b_proj', '.vc') + param_vc = params_dict[vc_param_name] + load_weight(param_vc, w_vc) + else: + param = params_dict[name] + load_weight(param, loaded_weight) - def forward(self, hidden_states): - identity = hidden_states - orig_shape = hidden_states.shape - topk_idx, topk_weight, _ = self.gate(hidden_states) - hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - y = self.moe_infer(hidden_states, topk_idx, - topk_weight).view(*orig_shape) - if self.config.n_shared_experts is not None: - y = y + self.shared_experts.forward(identity) - return y - - def moe_infer(self, x, topk_ids, topk_weight): - """moe infer.""" - world_size = 1 - rank = 0 - if dist.is_initialized(): - world_size = dist.get_world_size() - rank = dist.get_rank() - exp_per_rank = self.gate_up_weights.size(0) - expert_offset = rank * exp_per_rank - ret = fused_moe(x, - self.gate_up_weights, - self.down_weights, - topk_weight, - topk_ids, - topk=self.num_experts_per_tok, - expert_offset=expert_offset, - num_experts=world_size * exp_per_rank, - renormalize=False) - return ret + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """load weights.""" + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ('.gate_up_proj', '.gate_proj', 0), + ('.gate_up_proj', '.up_proj', 1), + ] + + config = self.config + qk_rope_head_dim = config.qk_rope_head_dim + kv_lora_rank = config.kv_lora_rank + qk_nope_head_dim = config.qk_nope_head_dim + q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim + kv_dim = kv_lora_rank + qk_rope_head_dim + update_pe_mapping = [('q_proj', q_head_dim, qk_nope_head_dim), + ('q_b_proj', q_head_dim, qk_nope_head_dim), + ('kv_a_proj_with_mqa', kv_dim, kv_lora_rank)] + + num_experts = self.config.n_routed_experts + expert_params_mapping = [] + for exp_id in range(num_experts): + gate_param = ('.experts.gate_up_weights', + f'.experts.{exp_id}.gate_proj.weight', exp_id, + 'gate') + up_param = ('.experts.gate_up_weights', + f'.experts.{exp_id}.up_proj.weight', exp_id, 'up') + down_param = ('.experts.down_weights', + f'.experts.{exp_id}.down_proj.weight', exp_id, + 'down') + expert_params_mapping += [gate_param, up_param, down_param] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if 'rotary_emb.inv_freq' in name: + continue + if ('rotary_emb.cos_cached' in name + or 'rotary_emb.sin_cached' in name): + continue + if self.config.tie_word_embeddings and 'lm_head.weight' in name: + continue + if '.experts' in name: + self._load_weight_experts( + name, + loaded_weight, + params_dict, + expert_params_mapping=expert_params_mapping) + elif '.self_attn' in name: + # attention + self._load_weight_attention(name, loaded_weight, params_dict, + update_pe_mapping) + else: + # other + for (param_name, weight_name, + shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + load_weight(param, loaded_weight, shard_id=shard_id) + break + else: + param = params_dict[name] + load_weight(param, loaded_weight) diff --git a/lmdeploy/pytorch/models/llama.py b/lmdeploy/pytorch/models/llama.py index bbc1c9f4e..4d122bf9b 100644 --- a/lmdeploy/pytorch/models/llama.py +++ b/lmdeploy/pytorch/models/llama.py @@ -1,9 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Iterable, List, Optional, Tuple, Union +from typing import Any, Iterable, List, Optional, Tuple import torch from torch import nn -from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.models.llama import LlamaConfig from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager @@ -66,8 +65,7 @@ def forward( rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], past_key_value: Optional[Tuple[torch.Tensor]] = None, attn_metadata: Any = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: + ): """Rewrite of LlamaAttention.forward.""" # qkv proj qkv_states = self.qkv_proj(hidden_states) @@ -180,7 +178,7 @@ def forward( past_key_value: Optional[List[torch.FloatTensor]], residual: Optional[torch.Tensor] = None, attn_metadata: Any = None, - ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + ): if residual is None: residual = hidden_states @@ -277,7 +275,7 @@ def forward( past_key_values: Optional[List[torch.FloatTensor]] = None, attn_metadata: Any = None, inputs_embeds: Optional[torch.FloatTensor] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: + ): """Rewrite of LlamaModel.forward.""" # token embedding diff --git a/lmdeploy/pytorch/models/mixtral.py b/lmdeploy/pytorch/models/mixtral.py index 138c42ef6..365731019 100644 --- a/lmdeploy/pytorch/models/mixtral.py +++ b/lmdeploy/pytorch/models/mixtral.py @@ -1,9 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Iterable, List, Optional, Tuple, Union +from typing import Any, Iterable, List, Optional, Tuple import torch from torch import nn -from transformers.modeling_outputs import BaseModelOutputWithPast from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, EmbeddingType, @@ -67,8 +66,7 @@ def forward( rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], past_key_value: Optional[Tuple[torch.Tensor]] = None, attn_metadata: Any = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: + ): """Rewrite of LlamaAttention.forward.""" qkv_states = self.qkv_proj(hidden_states) # (-1, heads, head_dim) @@ -192,7 +190,7 @@ def forward( past_key_value: Optional[List[torch.FloatTensor]], residual: Optional[torch.Tensor] = None, attn_metadata: Any = None, - ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + ): if residual is None: residual = hidden_states @@ -265,8 +263,8 @@ def forward( past_key_values: Optional[List[torch.FloatTensor]] = None, attn_metadata: Any = None, inputs_embeds: Optional[torch.FloatTensor] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - """Rewrite of LlamaModel.forward.""" + ): + """forward.""" if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index 70b971278..47cdf50be 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -253,18 +253,8 @@ # deepseek-v2 MODULE_MAP.update({ - 'modeling_deepseek.DeepseekV2Attention': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.deepseek_v2.PatchedDeepseekV2Attention', - 'modeling_deepseek.DeepseekV2FlashAttention2': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.deepseek_v2.PatchedDeepseekV2Attention', - 'modeling_deepseek.DeepseekV2Model': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaModel', - 'modeling_deepseek.DeepseekV2MoE': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.deepseek_v2.PatchedDeepseekV2MoE', - 'modeling_deepseek.DeepseekV2RMSNorm': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaRMSNorm', - 'modeling_deepseek.DeepseekV2MLP': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaMLP', + 'DeepseekV2ForCausalLM': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.deepseek_v2.DeepseekV2ForCausalLM' }) # cogvlm diff --git a/lmdeploy/pytorch/nn/__init__.py b/lmdeploy/pytorch/nn/__init__.py index 4e04b0627..8303c27a6 100644 --- a/lmdeploy/pytorch/nn/__init__.py +++ b/lmdeploy/pytorch/nn/__init__.py @@ -6,4 +6,5 @@ from .norm import RMSNorm # noqa: F401 from .rotary_embedding import ApplyRotaryEmb # noqa: F401 from .rotary_embedding import EmbeddingType # noqa: F401 +from .rotary_embedding import YarnParameters # noqa: F401 from .rotary_embedding import build_rotary_embedding # noqa: F401 diff --git a/lmdeploy/pytorch/nn/rotary_embedding.py b/lmdeploy/pytorch/nn/rotary_embedding.py index 66126cb99..c39eb21a5 100644 --- a/lmdeploy/pytorch/nn/rotary_embedding.py +++ b/lmdeploy/pytorch/nn/rotary_embedding.py @@ -2,7 +2,7 @@ from torch import Tensor, nn from ..backends import LayerType, get_backend -from ..backends.rotary_embedding import EmbeddingType +from ..backends.rotary_embedding import EmbeddingType, YarnParameters def build_rotary_embedding( @@ -12,6 +12,7 @@ def build_rotary_embedding( scaling_factor: float = 1.0, low_freq_factor: float = 1.0, high_freq_factor: float = 4.0, + yarn_params: YarnParameters = None, emb_type: EmbeddingType = EmbeddingType.Default) -> nn.Module: """build rotary embedding op.""" backend = get_backend() @@ -23,6 +24,7 @@ def build_rotary_embedding( scaling_factor, low_freq_factor=low_freq_factor, high_freq_factor=high_freq_factor, + yarn_params=yarn_params, emb_type=emb_type) From 6395a009fa9e28e0453bbcf3de300c658fd5bf8c Mon Sep 17 00:00:00 2001 From: grimoire Date: Thu, 29 Aug 2024 10:57:17 +0800 Subject: [PATCH 45/89] support deepseek v2 tp --- lmdeploy/pytorch/models/deepseek_v2.py | 29 ++++-- lmdeploy/pytorch/nn/linear.py | 98 ++++++++++++------- .../weight_loader/model_weight_loader.py | 12 +++ 3 files changed, 93 insertions(+), 46 deletions(-) diff --git a/lmdeploy/pytorch/models/deepseek_v2.py b/lmdeploy/pytorch/models/deepseek_v2.py index 1b3cb2559..b46cbd3d9 100644 --- a/lmdeploy/pytorch/models/deepseek_v2.py +++ b/lmdeploy/pytorch/models/deepseek_v2.py @@ -326,7 +326,13 @@ def __init__(self, dtype=dtype, device=device, is_tp=True, + all_reduce=False, ) + world_size, _ = get_world_rank() + if world_size > 1: + self._all_reduce = True + else: + self._all_reduce = False def forward(self, hidden_states: torch.Tensor): """forward.""" @@ -348,6 +354,9 @@ def forward(self, hidden_states: torch.Tensor): out_states += shared_states out_states = out_states.reshape(batch_size, sequence_length, -1) + if self._all_reduce: + dist.all_reduce(out_states) + return out_states @@ -359,7 +368,8 @@ def __init__(self, intermediate_size: int = None, dtype: torch.dtype = None, device: torch.device = None, - is_tp: bool = True): + is_tp: bool = True, + all_reduce: bool = True): super().__init__() quantization_config = getattr(config, 'quantization_config', None) # gate up @@ -379,13 +389,16 @@ def __init__(self, self.act_fn = SiluAndMul(inplace=True) # down - self.down_proj = build_rowwise_linear(intermediate_size, - config.hidden_size, - bias=False, - quant_config=quantization_config, - dtype=dtype, - device=device, - is_tp=is_tp) + self.down_proj = build_rowwise_linear( + intermediate_size, + config.hidden_size, + bias=False, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=is_tp, + all_reduce=all_reduce, + ) def forward(self, x): """forward.""" diff --git a/lmdeploy/pytorch/nn/linear.py b/lmdeploy/pytorch/nn/linear.py index 12e60bf12..9ccd0ef2a 100644 --- a/lmdeploy/pytorch/nn/linear.py +++ b/lmdeploy/pytorch/nn/linear.py @@ -112,15 +112,18 @@ def forward(self, x, base_output=None): class AwqLinear(nn.Module): """w4a16 linear.""" - def __init__(self, - in_features: int, - out_features: int, - w_bit: int, - group_size: int, - bias: bool, - device: Optional[torch.device] = None, - colwise: bool = True, - is_tp: bool = False): + def __init__( + self, + in_features: int, + out_features: int, + w_bit: int, + group_size: int, + bias: bool, + device: Optional[torch.device] = None, + colwise: bool = True, + is_tp: bool = False, + all_reduce: bool = True, + ): super().__init__() if device is None: device = torch.device('cpu') @@ -164,6 +167,7 @@ def __init__(self, self.lora_adapters = [] self.is_tp = is_tp self.colwise = colwise + self.all_reduce = all_reduce def _get_io_features(self, in_features: int, out_features: int, w_bit: int, group_size: int, colwise: bool): @@ -283,17 +287,18 @@ def update_weights(self): def forward(self, x): """w4a16 forward.""" - is_tp = False if self.colwise else self.is_tp + all_reduce = False if self.colwise else self.is_tp + all_reduce = all_reduce and self.all_reduce if self.lora_adapters is None: return self.impl.forward(x, self.qweight, self.scales, self.qzeros, - self.bias, is_tp) + self.bias, all_reduce) out = self.impl.forward(x, self.qweight, self.scales, self.qzeros, self.bias, False) if self.lora_adapters is not None: for lora_adapter in self.lora_adapters: out = lora_adapter(x, out) - if is_tp: + if all_reduce: dist.all_reduce(out) return out @@ -440,14 +445,17 @@ def _update_all_out_features(self, all_out_features: List[int], w_bit: int, class W8A8Linear(nn.Module): """w8a8 linear.""" - def __init__(self, - in_features: int, - out_features: int, - bias: bool, - dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, - colwise: bool = True, - is_tp: bool = False): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + colwise: bool = True, + is_tp: bool = False, + all_reduce: bool = True, + ): super().__init__() if device is None: device = torch.device('cpu') @@ -480,6 +488,7 @@ def __init__(self, self.lora_adapters = [] self.is_tp = is_tp self.colwise = colwise + self.all_reduce = all_reduce def _get_io_features(self, in_features: int, out_features: int, colwise: bool): @@ -555,15 +564,16 @@ def update_weights(self): def forward(self, x): """forward of w8a8 linear.""" - is_tp = False if self.colwise else self.is_tp + all_reduce = False if self.colwise else self.is_tp + all_reduce = all_reduce and self.all_reduce if len(self.lora_adapters) == 0: return self.impl.forward(x, self.weight, self.scale, self.bias, - is_tp) + all_reduce) out = self.impl.forward(x, self.weight, self.scale, self.bias, False) for lora_adapter in self.lora_adapters: out = lora_adapter(x, out) - if is_tp: + if all_reduce: dist.all_reduce(out) return out @@ -675,14 +685,17 @@ def _update_all_out_features(self, all_out_features: List[int], class BaseLinear(nn.Module): """linear layer.""" - def __init__(self, - in_features: int, - out_features: int, - bias: bool, - dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, - colwise: bool = True, - is_tp: bool = False): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + colwise: bool = True, + is_tp: bool = False, + all_reduce: bool = True, + ): super().__init__() if device is None: device = torch.device('cpu') @@ -711,6 +724,7 @@ def __init__(self, self.lora_adapters = [] self.is_tp = is_tp self.colwise = colwise + self.all_reduce = all_reduce def _get_io_features(self, in_features: int, out_features: int, colwise: bool): @@ -779,14 +793,15 @@ def update_weights(self): def forward(self, x): """forward of linear layer.""" - is_tp = False if self.colwise else self.is_tp + all_reduce = False if self.colwise else self.is_tp + all_reduce = all_reduce and self.all_reduce if len(self.lora_adapters) == 0: - return self.impl.forward(x, self.weight, self.bias, is_tp) + return self.impl.forward(x, self.weight, self.bias, all_reduce) out = self.impl.forward(x, self.weight, self.bias, False) for lora_adapter in self.lora_adapters: out = lora_adapter(x, out) - if is_tp: + if all_reduce: dist.all_reduce(out) return out @@ -901,7 +916,8 @@ def build_linear(in_features: int, device: Optional[torch.device] = None, colwise: bool = True, is_tp: bool = False, - quant_config: Any = None) -> nn.Module: + quant_config: Any = None, + all_reduce: bool = True) -> nn.Module: """build linear.""" if is_tp: world_size, _ = get_world_rank() @@ -916,6 +932,7 @@ def build_linear(in_features: int, device=device, colwise=colwise, is_tp=is_tp, + all_reduce=all_reduce, ) quant_method = quant_config['quant_method'] @@ -931,6 +948,7 @@ def build_linear(in_features: int, device=device, colwise=colwise, is_tp=is_tp, + all_reduce=all_reduce, ) if quant_method == 'w8a8': return W8A8Linear( @@ -941,6 +959,7 @@ def build_linear(in_features: int, device=device, colwise=colwise, is_tp=is_tp, + all_reduce=all_reduce, ) else: raise RuntimeError(f'Unsupported quant method: {quant_method}') @@ -961,7 +980,8 @@ def build_colwise_linear(in_features: int, device=device, colwise=True, is_tp=is_tp, - quant_config=quant_config) + quant_config=quant_config, + all_reduce=False) def build_rowwise_linear(in_features: int, @@ -970,7 +990,8 @@ def build_rowwise_linear(in_features: int, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, is_tp: bool = False, - quant_config: Any = None) -> nn.Module: + quant_config: Any = None, + all_reduce: bool = True) -> nn.Module: """build rowwise parallel linear layer.""" return build_linear(in_features=in_features, out_features=out_features, @@ -979,7 +1000,8 @@ def build_rowwise_linear(in_features: int, device=device, colwise=False, is_tp=is_tp, - quant_config=quant_config) + quant_config=quant_config, + all_reduce=all_reduce) def build_merged_colwise_linear( diff --git a/lmdeploy/pytorch/weight_loader/model_weight_loader.py b/lmdeploy/pytorch/weight_loader/model_weight_loader.py index bf48b2b78..b97988262 100644 --- a/lmdeploy/pytorch/weight_loader/model_weight_loader.py +++ b/lmdeploy/pytorch/weight_loader/model_weight_loader.py @@ -3,6 +3,7 @@ import os.path as osp import torch +import torch.distributed as dist from transformers.modeling_utils import load_state_dict from transformers.utils import (SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME) @@ -12,6 +13,14 @@ logger = get_logger('lmdeploy') +def _get_rank(): + """get rank.""" + rank = 0 + if dist.is_initialized(): + rank = dist.get_rank() + return rank + + def load_weight(param: torch.nn.Parameter, loaded_weight: torch.Tensor, **kwargs): """load weight.""" @@ -130,7 +139,10 @@ def load_model_weights( """load model weights implementation.""" assert hasattr(model, 'load_weights') paths = self._shard_paths + rank = _get_rank() for path in paths: + if rank == 0: + logger.debug(f'loading weights: {path}') state_dict = self._load_shard(path) model.load_weights(state_dict.items()) if device is not None: From d34a2b423615373d0ba6c4fdb1658f768c1d80ec Mon Sep 17 00:00:00 2001 From: grimoire Date: Thu, 29 Aug 2024 11:12:56 +0800 Subject: [PATCH 46/89] add log --- .../weight_loader/model_weight_loader.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/lmdeploy/pytorch/weight_loader/model_weight_loader.py b/lmdeploy/pytorch/weight_loader/model_weight_loader.py index b97988262..76f7c741c 100644 --- a/lmdeploy/pytorch/weight_loader/model_weight_loader.py +++ b/lmdeploy/pytorch/weight_loader/model_weight_loader.py @@ -13,12 +13,14 @@ logger = get_logger('lmdeploy') -def _get_rank(): +def _get_world_rank(): """get rank.""" rank = 0 + world_size = 1 if dist.is_initialized(): rank = dist.get_rank() - return rank + world_size = dist.get_world_size() + return world_size, rank def load_weight(param: torch.nn.Parameter, loaded_weight: torch.Tensor, @@ -139,10 +141,17 @@ def load_model_weights( """load model weights implementation.""" assert hasattr(model, 'load_weights') paths = self._shard_paths - rank = _get_rank() + world_size, rank = _get_world_rank() for path in paths: - if rank == 0: - logger.debug(f'loading weights: {path}') + + # log + file_name = osp.split(path)[1] + msg = f'loading weights - "{file_name}"' + if world_size > 1: + msg = f'rank[{rank}] {msg}' + logger.info(msg) + + # process state_dict = self._load_shard(path) model.load_weights(state_dict.items()) if device is not None: From fc7a78a42350eb740625f737608789eb273de485 Mon Sep 17 00:00:00 2001 From: grimoire Date: Thu, 29 Aug 2024 11:18:51 +0800 Subject: [PATCH 47/89] fix ut --- lmdeploy/pytorch/adapter/adapter.py | 1 - lmdeploy/pytorch/paging/block_manager/base_block_manager.py | 2 +- lmdeploy/pytorch/paging/scheduler.py | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/lmdeploy/pytorch/adapter/adapter.py b/lmdeploy/pytorch/adapter/adapter.py index 70815aba7..da7149cb9 100644 --- a/lmdeploy/pytorch/adapter/adapter.py +++ b/lmdeploy/pytorch/adapter/adapter.py @@ -358,7 +358,6 @@ class AdapterManager: def __init__(self, adapters: Dict[str, str], target_infos: Dict[str, LoRATargetInfo], block_numel: int): - self.target_infos = target_infos self.block_numel = block_numel if adapters is None: diff --git a/lmdeploy/pytorch/paging/block_manager/base_block_manager.py b/lmdeploy/pytorch/paging/block_manager/base_block_manager.py index 9146ec9a5..0630c7a31 100644 --- a/lmdeploy/pytorch/paging/block_manager/base_block_manager.py +++ b/lmdeploy/pytorch/paging/block_manager/base_block_manager.py @@ -248,7 +248,7 @@ def __init__(self, self.block_tables: Dict[int, BlockTable] = {} if adapter_manager is None: - adapter_manager = AdapterManager(dict(), 0) + adapter_manager = AdapterManager(dict(), dict(), 0) self.adapter_manager = adapter_manager @classmethod diff --git a/lmdeploy/pytorch/paging/scheduler.py b/lmdeploy/pytorch/paging/scheduler.py index f9c3186be..79b4833d8 100644 --- a/lmdeploy/pytorch/paging/scheduler.py +++ b/lmdeploy/pytorch/paging/scheduler.py @@ -49,7 +49,7 @@ def __init__(self, self.actived_adapters: Set[str] = set() if adapter_manager is None: - adapter_manager = AdapterManager(dict(), 0) + adapter_manager = AdapterManager(dict(), dict(), 0) self.adapter_manager = adapter_manager self.block_manager = build_block_manager(cache_config, adapter_manager) From 9a6855d3d3dcef554e4a539d0c5804c8f301c193 Mon Sep 17 00:00:00 2001 From: grimoire Date: Thu, 29 Aug 2024 17:56:14 +0800 Subject: [PATCH 48/89] support chatglm --- lmdeploy/pytorch/models/chatglm2.py | 794 ++++++++++++++----------- lmdeploy/pytorch/models/deepseek_v2.py | 21 +- lmdeploy/pytorch/models/module_map.py | 16 +- lmdeploy/pytorch/models/patch.py | 12 + 4 files changed, 483 insertions(+), 360 deletions(-) diff --git a/lmdeploy/pytorch/models/chatglm2.py b/lmdeploy/pytorch/models/chatglm2.py index d2009d30b..e32c3c7b6 100644 --- a/lmdeploy/pytorch/models/chatglm2.py +++ b/lmdeploy/pytorch/models/chatglm2.py @@ -1,405 +1,507 @@ # Copyright (c) OpenMMLab. All rights reserved. -# Adapted from https://huggingface.co/THUDM/chatglm2-6b/blob/main/modeling_chatglm.py # noqa: E501 - -from typing import List, Optional, Tuple +from typing import Any, Iterable, List, Optional, Tuple import torch -import torch.distributed as dist -import torch.nn as nn -import torch.utils.checkpoint -from transformers.modeling_outputs import (BaseModelOutputWithPast, - CausalLMOutputWithPast) +from torch import nn +from transformers.configuration_utils import PretrainedConfig -from ..kernels import fill_kv_cache, paged_attention_fwd -from ..weight_loader.dist_utils import (colwise_split_parallelize_linear, - rowwise_parallelize_linear) +from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, EmbeddingType, + RMSNorm, SiluAndMul, build_rotary_embedding) +from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear, + build_qkv_proj, build_rowwise_linear) +from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight -class PatchedRMSNorm(nn.Module): - """Rewrite RMSNorm.""" +class SelfAttention(torch.nn.Module): + """Parallel self-attention layer abstract class. - def forward(self, hidden_states): - """forward.""" - # torch.nn.functional.normalize based implementation might leads - # to wrong output - from ..kernels import rms_norm + Self-attention layer takes input with size [s, b, h] and returns output of + the same size. + """ + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + + self.projection_size = config.kv_channels * config.num_attention_heads + self.num_attention_heads = config.num_attention_heads + self.num_kv_heads = self.num_attention_heads + self.head_size = (self.projection_size // config.num_attention_heads) + self.multi_query_attention = config.multi_query_attention + if self.multi_query_attention: + self.num_kv_heads = config.multi_query_group_num + self.query_key_value = build_qkv_proj( + config.hidden_size, + num_q_heads=self.num_attention_heads, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + bias=config.add_bias_linear or config.add_qkv_bias, + quant_config=quantization_config, + dtype=dtype, + device=device, + ) + + # apply rotary + self.apply_rotary_pos_emb = ApplyRotaryEmb() + + # attention + self.attn_fwd = Attention( + self.num_attention_heads, + self.head_size, + num_kv_heads=self.num_kv_heads, + ) + + # o_proj + self.dense = build_rowwise_linear(self.projection_size, + config.hidden_size, + bias=config.add_bias_linear + or config.add_qkv_bias, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + @staticmethod + def _extract_rope(states: torch.Tensor): + """extract rope.""" + rope = states.chunk(2, -1)[0] + rope = rope.unflatten(-1, (-1, 2)) + rope = rope.transpose(-2, -1).flatten(-2, -1).contiguous() + return rope + + @staticmethod + def _fill_rope(states: torch.Tensor, rope: torch.Tensor): + """fill rope.""" + rope_part = states.chunk(2, -1)[0] + rope = rope.unflatten(-1, (2, -1)) + rope = rope.transpose(-2, -1).flatten(-2, -1) + rope_part.copy_(rope) + return states + + def forward( + self, + hidden_states: torch.Tensor, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attn_metadata: Any = None, + ): + """Rewrite of LlamaAttention.forward.""" + # qkv proj + qkv_states = self.query_key_value(hidden_states) + # (-1, heads, head_dim) + qkv_states = qkv_states.flatten(0, -2) + (query_states, key_states, + value_states) = self.query_key_value.split_qkv(qkv_states) + + # apply rotary embedding + cos, sin = rotary_pos_emb + q_rope = self._extract_rope(query_states) + k_rope = self._extract_rope(key_states) + q_rope, k_rope = self.apply_rotary_pos_emb( + q_rope, + k_rope, + cos, + sin, + inplace=True, + ) + query_states = self._fill_rope(query_states, q_rope) + key_states = self._fill_rope(key_states, k_rope) + + # attention + attn_output = self.attn_fwd( + query_states, + key_states, + value_states, + past_key_value[0], + past_key_value[1], + attn_metadata, + inplace=True, + ) + attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1) + + # o proj + attn_output = self.dense(attn_output) + return attn_output - ret = rms_norm(hidden_states, self.weight, self.eps) - return ret +class MLP(nn.Module): + """mlp.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + + self.add_bias = config.add_bias_linear + # gate up + self.dense_h_to_4h = build_merged_colwise_linear( + config.hidden_size, + [config.ffn_hidden_size, config.ffn_hidden_size], + bias=self.add_bias, + dtype=dtype, + device=device, + quant_config=quantization_config, + is_tp=True, + ) -def split_tensor_along_last_dim( - tensor: torch.Tensor, - num_partitions: int, - contiguous_split_chunks: bool = False, -) -> List[torch.Tensor]: - """Split a tensor along its last dimension. + # silu and mul + self.act_fn = SiluAndMul(inplace=True) - Arguments: - tensor: input tensor. - num_partitions: number of partitions to split the tensor - contiguous_split_chunks: If True, make each chunk contiguous - in memory. + # down + self.dense_4h_to_h = build_rowwise_linear( + config.ffn_hidden_size, + config.hidden_size, + bias=self.add_bias, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) - Returns: - A list of Tensors - """ - tensor_list = tensor.chunk(num_partitions, dim=-1) - if contiguous_split_chunks: - return tuple(chunk.contiguous() for chunk in tensor_list) - - return tensor_list - - -def apply_rotary_pos_emb(x: torch.Tensor, - rope_cache: torch.Tensor) -> torch.Tensor: - # x: [b, sq, np, hn] - # rope_cache: [b, sq, dim/4, 2] - sq, hn = x.size(1), x.size(-1) - xslice = x[..., :hn // 2] - rope_cache = rope_cache[:, :sq] - xshaped = xslice.unflatten(-1, (-1, 2)) - rope_cache = rope_cache.unsqueeze(2) - - # inplace - torch.stack( - [ - xshaped[..., 0] * rope_cache[..., 0] - - xshaped[..., 1] * rope_cache[..., 1], - xshaped[..., 1] * rope_cache[..., 0] + - xshaped[..., 0] * rope_cache[..., 1], - ], - -1, - out=xshaped, - ) - return x - - -class PatchedSelfAttention(nn.Module): - """Parallel self-attention layer abstract class. + def forward(self, x): + """forward.""" + gate_up = self.dense_h_to_4h(x) + act = self.act_fn(gate_up) + return self.dense_4h_to_h(act) - Self-attention layer takes input with size [s, b, h] and returns output of + +class GLMBlock(torch.nn.Module): + """A single transformer layer. + + Transformer layer takes input with size [s, b, h] and returns an output of the same size. """ - def _load_weights(self, loader, rank: int, world_size: int, - device: torch.device): - """load weights.""" - sections = [ - self.num_attention_heads_per_partition * - self.hidden_size_per_attention_head, - self.num_multi_query_groups_per_partition * - self.hidden_size_per_attention_head, - self.num_multi_query_groups_per_partition * - self.hidden_size_per_attention_head, - ] - colwise_split_parallelize_linear(self.query_key_value, - sections, - loader, - rank=rank, - world_size=world_size, - prefix='query_key_value') - rowwise_parallelize_linear(self.dense, - loader, - rank=rank, - world_size=world_size, - prefix='dense') - - @classmethod - def _distribute_output_fn(cls, outputs, **kwargs): - """Distribution output hook.""" - dist.all_reduce(outputs[0]) - return outputs - - def _contiguous_batching_forward( + def __init__(self, + config: PretrainedConfig, + layer_number: int, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.layer_number = layer_number + self.apply_residual_connection_post_layernorm = \ + config.apply_residual_connection_post_layernorm + assert not self.apply_residual_connection_post_layernorm + + quantization_config = getattr(config, 'quantization_config', None) + + # build attention layer + self.self_attention = SelfAttention(config, dtype=dtype, device=device) + + # builf MLP + self.mlp = MLP(config, dtype=dtype, device=device) + + # build input layer norm + self.input_layernorm = RMSNorm(config.hidden_size, + config.layernorm_epsilon, + quant_config=quantization_config, + dtype=dtype, + device=device) + + # build attention layer norm + self.post_attention_layernorm = RMSNorm( + config.hidden_size, + config.layernorm_epsilon, + quant_config=quantization_config, + dtype=dtype, + device=device) + + def forward( self, hidden_states: torch.Tensor, - rotary_pos_emb: Optional[torch.Tensor] = None, - kv_cache: Optional[Tuple[torch.Tensor]] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - # hidden_states: [b, sq, h] - - # ================================================= - # Pre-allocate memory for key-values for inference. - # ================================================= - # ===================== - # Query, Key, and Value - # ===================== - - # Attention heads [b, sq, h] --> [b, sq, (np * 3 * hn)] - - world_size = 1 - if dist.is_initialized(): - world_size = dist.get_world_size() - - context = self.context.context - max_q_seq_length = context.max_q_seq_length - q_start_loc = context.q_start_loc - q_seq_length = context.q_seq_length - kv_seq_length = context.kv_seq_length - block_offsets = context.block_offsets - - mixed_x_layer = self.query_key_value(hidden_states) - - if self.multi_query_attention: - (query_layer, key_layer, value_layer) = mixed_x_layer.split( - [ - self.num_attention_heads_per_partition * - self.hidden_size_per_attention_head // world_size, - self.num_multi_query_groups_per_partition * - self.hidden_size_per_attention_head // world_size, - self.num_multi_query_groups_per_partition * - self.hidden_size_per_attention_head // world_size, - ], - dim=-1, - ) - query_layer = query_layer.unflatten( - -1, (-1, self.hidden_size_per_attention_head)) - key_layer = key_layer.unflatten( - -1, (-1, self.hidden_size_per_attention_head)) - value_layer = value_layer.unflatten( - -1, (-1, self.hidden_size_per_attention_head)) + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], + past_key_value: Optional[List[torch.FloatTensor]], + residual: Optional[torch.Tensor] = None, + attn_metadata: Any = None, + ): + + if residual is None: + residual = hidden_states + layernorm_output = self.input_layernorm(hidden_states) else: - new_tensor_shape = mixed_x_layer.size()[:-1] + ( - self.num_attention_heads_per_partition // world_size, - 3 * self.hidden_size_per_attention_head, - ) - mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) - - # [b, sq, np, 3 * hn] --> 3 [b, sq, np, hn] - (query_layer, key_layer, - value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) - - # apply relative positional encoding (rotary embedding) - query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb) - key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb) - - # adjust key and value for inference - cache_k, cache_v = kv_cache - fill_kv_cache( - key_layer[0], - value_layer[0], - cache_k, - cache_v, - q_start_loc, - q_seq_length, - kv_seq_length=kv_seq_length, - max_q_seq_length=max_q_seq_length, - block_offsets=block_offsets, + layernorm_output, residual = self.input_layernorm( + hidden_states, residual) + + # Self Attention + layernorm_input = self.self_attention( + hidden_states=layernorm_output, + rotary_pos_emb=rotary_pos_emb, + past_key_value=past_key_value, + attn_metadata=attn_metadata, ) - # ================================== - # core attention computation - # ================================== + # Fully Connected + layernorm_output, residual = self.post_attention_layernorm( + layernorm_input, residual) + mlp_output = self.mlp(layernorm_output) + + outputs = (mlp_output, residual) + return outputs + + +class GLMTransformer(nn.Module): + """Transformer class.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + self.num_layers = config.num_layers + self.post_layer_norm = config.post_layer_norm - context_layer = query_layer - paged_attention_fwd(query_layer, - cache_k, - cache_v, - context_layer, - block_offsets, - q_start_loc=q_start_loc, - q_seqlens=q_seq_length, - kv_seqlens=kv_seq_length, - max_seqlen=max_q_seq_length) + def build_layer(layer_number): + """build layer.""" + return GLMBlock(config, layer_number, dtype=dtype, device=device) - context_layer = context_layer.flatten(-2) + self.layers = torch.nn.ModuleList( + [build_layer(i + 1) for i in range(self.num_layers)]) - # ================= - # Output. [b, sq, h] - # ================= - output = self.dense(context_layer) + if self.post_layer_norm: + assert config.rmsnorm + self.final_layernorm = RMSNorm(config.hidden_size, + config.layernorm_epsilon, + quant_config=quantization_config, + dtype=dtype, + device=device) - return output, kv_cache + def _get_layer(self, layer_number: int): + """get layer.""" + return self.layers[layer_number] - def forward(self, + def forward( + self, + hidden_states: torch.LongTensor, + rotary_pos_emb: List[torch.Tensor], + past_key_values: Optional[List[torch.FloatTensor]], + attn_metadata: Any, + ): + """forward.""" + residual = None + for index in range(self.num_layers): + layer = self._get_layer(index) + hidden_states, residual = layer( hidden_states, - attention_mask, rotary_pos_emb, - kv_cache=None, - use_cache=True, - output_attentions=False, - **kwargs): - return self._contiguous_batching_forward( - hidden_states, - rotary_pos_emb, - kv_cache, - ) + past_key_value=past_key_values[index], + residual=residual, + attn_metadata=attn_metadata, + ) + if self.post_layer_norm: + hidden_states, _ = self.final_layernorm(hidden_states, residual) + return hidden_states -class MLP(nn.Module): - def _load_weights(self, loader, rank: int, world_size: int, - device: torch.device): - """load weights.""" - w_pack_out = self.dense_h_to_4h.out_features - sections = [w_pack_out // 2] * 2 - colwise_split_parallelize_linear(self.dense_h_to_4h, - sections, - loader, - rank=rank, - world_size=world_size, - prefix='dense_h_to_4h') - rowwise_parallelize_linear(self.dense_4h_to_h, - loader, - rank=rank, - world_size=world_size, - prefix='dense_4h_to_h') - - @classmethod - def _distribute_output_fn(cls, outputs, **kwargs): - """Distribution output hook.""" - dist.all_reduce(outputs) - return outputs +class Embedding(nn.Module): + """Language model embeddings.""" + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.hidden_size = config.hidden_size + # Word embeddings (parallel). + self.word_embeddings = nn.Embedding(config.padded_vocab_size, + self.hidden_size, + dtype=dtype, + device=device) + self.fp32_residual_connection = config.fp32_residual_connection -class PatchedChatGLMModel(nn.Module): - - def _contiguous_batching_forward( - self, - input_ids, - position_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.BoolTensor] = None, - full_attention_mask: Optional[torch.BoolTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], - ...]] = None, - inputs_embeds: Optional[torch.Tensor] = None): - assert input_ids is not None - context = self.context.context - # get inputs from context - vision_embeddings = context.input_embeddings - vision_embedding_indexing = context.input_embedding_indexing + def forward(self, input_ids): + """Rewrite to not transpose hidden_statens for all models.""" + # Embeddings. + embeddings = self.word_embeddings(input_ids) + if self.fp32_residual_connection: + embeddings = embeddings.float() + return embeddings + + +class ChatGLMModel(nn.Module): + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.embedding = Embedding(config, dtype=dtype, device=device) + + # build rotary embedding + emb_type = EmbeddingType.LinearScaling + rotary_dim = (config.hidden_size // config.num_attention_heads + if config.kv_channels is None else config.kv_channels) + rope_max_pos_emb = 1 << 20 + rope_base = 10000 * getattr(config, 'rope_ratio', 1.0) + self.rotary_pos_emb = build_rotary_embedding( + rotary_dim // 2, + rope_max_pos_emb, + rope_base, + emb_type=emb_type, + ) + + # build encoder + self.encoder = GLMTransformer(config, dtype=dtype, device=device) - output_hidden_states = False - use_cache = True + # output_layers + self.output_layer = build_rowwise_linear(config.hidden_size, + config.padded_vocab_size, + bias=False, + dtype=dtype, + device=device) - batch_size, seq_length = input_ids.shape + def forward( + self, + input_ids: torch.LongTensor = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + attn_metadata: Any = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ): + """forward.""" + # token embedding if inputs_embeds is None: inputs_embeds = self.embedding(input_ids) - if vision_embeddings is not None and len(vision_embeddings) > 0: - # multi-modality - inputs_embeds[:, - vision_embedding_indexing, :] = vision_embeddings.to( - inputs_embeds) - hidden_states = inputs_embeds - if getattr(self, 'pre_seq_len', None) is not None: - if past_key_values is None: - past_key_values = self.get_prompt(batch_size=batch_size, - device=input_ids.device, - dtype=inputs_embeds.dtype) - - # Rotary positional embeddings - rotary_pos_emb = self.rotary_pos_emb(self.seq_length) - # glm-4v - if getattr(self, 'vision', None) is not None: - from .cogvlm import _get_cogvlm_position_ids - position_ids_1d = _get_cogvlm_position_ids(context) - else: - position_ids_1d = context.position_ids_1d - - rotary_pos_emb = rotary_pos_emb[position_ids_1d[None]] - - # Run encoder. - (hidden_states, presents, all_hidden_states, - all_self_attentions) = self.encoder( - inputs_embeds, - full_attention_mask, - rotary_pos_emb=rotary_pos_emb, - kv_caches=past_key_values, - use_cache=use_cache, - output_hidden_states=output_hidden_states) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - ) + # rotary embedding + cos, sin = self.rotary_pos_emb(hidden_states, position_ids) + cos, sin = cos[0], sin[0] + rotary_pos_emb = (cos, sin) - def forward(self, - input_ids, - position_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.BoolTensor] = None, - full_attention_mask: Optional[torch.BoolTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, - torch.Tensor], - ...]] = None, - inputs_embeds: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - **kwargs): - return self._contiguous_batching_forward( - input_ids=input_ids, - position_ids=position_ids, - attention_mask=attention_mask, - full_attention_mask=full_attention_mask, + hidden_states = self.encoder( + hidden_states, + rotary_pos_emb=rotary_pos_emb, past_key_values=past_key_values, - inputs_embeds=inputs_embeds, + attn_metadata=attn_metadata, ) + return hidden_states -class PatchedEmbedding(nn.Module): + def get_input_embeddings(self): + """get input embeddings.""" + return self.embedding - def forward(self, input_ids): - """Rewrite to not transpose hidden_statens for all models.""" - # Embeddings. - words_embeddings = self.word_embeddings(input_ids) - embeddings = words_embeddings - # If the input flag for fp32 residual connection is set, - # convert for float. - if self.fp32_residual_connection: - embeddings = embeddings.float() - return embeddings +class ChatGLMForConditionalGeneration(nn.Module): + """rewrote model of LlamaForCausalLM.""" -class PatchedChatGLMForConditionalGeneration(nn.Module): - - def forward(self, - input_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Tuple[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - return_last_logit: Optional[bool] = False, - **kwargs): - """rewrite to not transpose logits for all models.""" - transformer_outputs = self.transformer( + support_cuda_graph = True + + def __init__(self, + config: PretrainedConfig, + ctx_mgr: StepContextManager, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.ctx_mgr = ctx_mgr + # build Model + self.transformer = ChatGLMModel(config, dtype=dtype, device=device) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: List[List[torch.Tensor]], + attn_metadata: Any = None, + inputs_embeds: torch.Tensor = None, + **kwargs, + ): + """model forward, return logits.""" + hidden_states = self.transformer( input_ids=input_ids, position_ids=position_ids, - attention_mask=attention_mask, past_key_values=past_key_values, + attn_metadata=attn_metadata, inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_hidden_states=output_hidden_states, - return_dict=return_dict, ) - hidden_states = transformer_outputs[0] - if return_last_logit: - hidden_states = hidden_states[-1:] - lm_logits = self.transformer.output_layer(hidden_states) + logits = self.transformer.output_layer(hidden_states) + logits = logits.float() + return logits - loss = None + def get_input_embeddings(self): + """get input embeddings.""" + return self.transformer.get_input_embeddings() - if not return_dict: - output = (lm_logits, ) + transformer_outputs[1:] - return ((loss, ) + output) if loss is not None else output + def prepare_inputs_for_generation( + self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + context: StepContext = None, + ): + """prepare input.""" + # get input_ids, position_ids and attention metadatas + input_ids = context.input_ids + position_ids = context.position_ids + attn_metadata = context.attn_metadata + + # process vision embeddings + vision_embeddings = context.input_embeddings + vision_embedding_indexing = context.input_embedding_indexing + if vision_embeddings is not None and len(vision_embeddings) > 0: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + inputs_embeds[:, + vision_embedding_indexing, :] = vision_embeddings.to( + inputs_embeds) - return CausalLMOutputWithPast( - loss=loss, - logits=lm_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, + # inputs of forward + return dict( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """load weights.""" + # modify from vllm + + config = self.config + num_attention_heads = config.num_attention_heads + projection_size = config.kv_channels * num_attention_heads + num_kv_heads = num_attention_heads + head_size = (projection_size // num_attention_heads) + multi_query_attention = config.multi_query_attention + if multi_query_attention: + num_kv_heads = config.multi_query_group_num + qkv_section = [ + head_size * num_attention_heads, head_size * num_kv_heads, + head_size * num_kv_heads + ] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if 'rotary_pos_emb.inv_freq' in name: + continue + if ('rotary_pos_emb.cos_cached' in name + or 'rotary_pos_emb.sin_cached' in name): + continue + if (self.config.tie_word_embeddings + and 'output_layer.weight' in name): + continue + if '.query_key_value' in name: + q, k, v = loaded_weight.split(qkv_section) + param = params_dict[name] + load_weight(param, q, shard_id='q') + load_weight(param, k, shard_id='k') + load_weight(param, v, shard_id='v') + elif '.dense_h_to_4h' in name: + gate, up = loaded_weight.chunk(2) + param = params_dict[name] + load_weight(param, gate, shard_id=0) + load_weight(param, up, shard_id=1) + else: + param = params_dict[name] + load_weight(param, loaded_weight) diff --git a/lmdeploy/pytorch/models/deepseek_v2.py b/lmdeploy/pytorch/models/deepseek_v2.py index b46cbd3d9..5a6bfb9f9 100644 --- a/lmdeploy/pytorch/models/deepseek_v2.py +++ b/lmdeploy/pytorch/models/deepseek_v2.py @@ -293,6 +293,9 @@ def __init__(self, self.norm_topk_prob = config.norm_topk_prob self.routed_scaling_factor = config.routed_scaling_factor self.renormalize = self.top_k > 1 and self.norm_topk_prob + self.topk_method = config.topk_method + self.n_group = config.n_group + self.topk_group = config.topk_group self.gate = build_rowwise_linear( self.hidden_dim, @@ -340,7 +343,23 @@ def forward(self, hidden_states: torch.Tensor): hidden_states = hidden_states.view(-1, hidden_dim) router_logits = self.gate(hidden_states) - topk_weights, topk_ids = self.softmax_topk(router_logits) + if self.topk_method == 'greedy': + topk_weights, topk_ids = self.softmax_topk(router_logits) + elif self.topk_method == 'group_limited_greedy': + grouped_logits = router_logits.unflatten(-1, (self.n_group, -1)) + group_scores = (grouped_logits.max(-1).values) + group_idx = torch.topk(group_scores, + k=self.topk_group, + dim=-1, + sorted=False)[1] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + group_mask = ~group_mask.bool()[..., None] + grouped_logits = grouped_logits.masked_fill(group_mask, 0.0) + router_logits = grouped_logits.flatten(1, 2) + topk_weights, topk_ids = self.softmax_topk(router_logits) + else: + raise RuntimeError(f'Unsupported topk_method: {self.topk_method}') if not self.renormalize: topk_weights = topk_weights * self.routed_scaling_factor out_states = self.experts( diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index ae3784984..056b2aea4 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -35,20 +35,10 @@ f'{LMDEPLOY_PYTORCH_MODEL_PATH}.baichuan.BaichuanForCausalLM', }) -# chatglm2 +# chatglm MODULE_MAP.update({ - 'modeling_chatglm.SelfAttention': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.chatglm2.PatchedSelfAttention', - 'modeling_chatglm.ChatGLMModel': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.chatglm2.PatchedChatGLMModel', - 'modeling_chatglm.MLP': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.chatglm2.MLP', - 'modeling_chatglm.RMSNorm': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.chatglm2.PatchedRMSNorm', - 'modeling_chatglm.Embedding': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.chatglm2.PatchedEmbedding', - 'modeling_chatglm.ChatGLMForConditionalGeneration': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.chatglm2.PatchedChatGLMForConditionalGeneration', # noqa: E501 + 'ChatGLMForConditionalGeneration': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.chatglm2.ChatGLMForConditionalGeneration', # noqa: E501 }) # internlm diff --git a/lmdeploy/pytorch/models/patch.py b/lmdeploy/pytorch/models/patch.py index dfe3660c9..0015b4974 100644 --- a/lmdeploy/pytorch/models/patch.py +++ b/lmdeploy/pytorch/models/patch.py @@ -154,6 +154,18 @@ def update_custom_module_map(module_map_path: str): def _get_model_class(config, module_map): + """get model class.""" + auto_map = getattr(config, 'auto_map', dict()) + if 'AutoModelForCausalLM' in auto_map: + mapname = auto_map['AutoModelForCausalLM'] + if '.' in mapname: + mapname = mapname.split('.')[-1] + if mapname in module_map: + qualname = module_map[mapname] + module_cls = _class_from_qualname(qualname) + return module_cls + raise RuntimeError(f'Can not found rewrite for auto_map: {mapname}') + architectures = getattr(config, 'architectures', []) for arch in architectures: if arch in module_map: From 694bb040ddb2ad0578bf952a6eb4475fa9489dda Mon Sep 17 00:00:00 2001 From: grimoire Date: Thu, 29 Aug 2024 18:39:40 +0800 Subject: [PATCH 49/89] support llava --- lmdeploy/pytorch/config.py | 11 +----- lmdeploy/pytorch/configurations/builder.py | 27 ------------- lmdeploy/pytorch/configurations/chatglm.py | 5 +-- lmdeploy/pytorch/configurations/cogvlm.py | 1 - .../pytorch/configurations/deepseek_v2.py | 4 +- lmdeploy/pytorch/configurations/internvl.py | 38 +------------------ lmdeploy/pytorch/configurations/llava.py | 11 +----- lmdeploy/pytorch/configurations/llava_hf.py | 14 +------ .../pytorch/configurations/phi3_vision.py | 3 +- lmdeploy/pytorch/models/llama.py | 19 ++++++++++ lmdeploy/pytorch/models/llava.py | 34 +++++++++++++---- lmdeploy/pytorch/models/module_map.py | 25 ++++++------ lmdeploy/pytorch/models/patch.py | 19 +++++++--- 13 files changed, 80 insertions(+), 131 deletions(-) diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index 652d93864..38384c0e6 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -1,9 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. -from dataclasses import dataclass, field -from typing import Any, Dict, List +from dataclasses import dataclass +from typing import Any, List import torch -from transformers import AutoModelForCausalLM def _update_torch_dtype(config: 'ModelConfig', default: str = 'float16'): @@ -97,10 +96,6 @@ class ModelConfig: multi_query_attention: bool = False vocab_size: int = 40000 hf_config: Any = None - init_kwargs: Dict[str, Any] = field(default_factory=dict) - model_arch: str = None - unused_modules: List[str] = None - auto_model_cls: Any = AutoModelForCausalLM cogvlm_style: bool = False custom_module_map: str = None @@ -132,8 +127,6 @@ def from_hf_config(cls, hf_config: Any, model_path: str = None): assert model_config.head_dim is not None model_config.v_head_dim = model_config.head_dim - model_arch = model_config.hf_config.architectures[0] - model_config.model_arch = model_arch # should after setting `hf_config` and `model_arch` attributes model_config = _update_torch_dtype(model_config) diff --git a/lmdeploy/pytorch/configurations/builder.py b/lmdeploy/pytorch/configurations/builder.py index 1e739af14..89bf51ca4 100644 --- a/lmdeploy/pytorch/configurations/builder.py +++ b/lmdeploy/pytorch/configurations/builder.py @@ -51,30 +51,3 @@ def build(cls, hf_config, model_path: str = None): cfg.hf_config = hf_config return cfg - - -class ProxyAutoModel: - """wrapper of auto model class.""" - - def __init__(self, model_cls=None) -> None: - """init.""" - if model_cls is None: - from transformers import AutoModelForCausalLM - model_cls = AutoModelForCausalLM - self.model_cls = model_cls - - def from_config(self, *args, **kwargs): - """wrap from_config.""" - if hasattr(self.model_cls, '_from_config'): - if 'trust_remote_code' in kwargs: - kwargs.pop('trust_remote_code') - return self.model_cls._from_config(*args, **kwargs) - else: - return self.model_cls.from_config(*args, **kwargs) - - def from_pretrained(self, *args, **kwargs): - """wrap from_pretrained.""" - if hasattr(self.model_cls, - '_from_config') and 'trust_remote_code' in kwargs: - kwargs.pop('trust_remote_code') - return self.model_cls.from_pretrained(*args, **kwargs) diff --git a/lmdeploy/pytorch/configurations/chatglm.py b/lmdeploy/pytorch/configurations/chatglm.py index 045406afb..7911c985d 100644 --- a/lmdeploy/pytorch/configurations/chatglm.py +++ b/lmdeploy/pytorch/configurations/chatglm.py @@ -18,7 +18,6 @@ def build(cls, hf_config, model_path: str = None): bos_token_id = hf_config.bos_token_id if bos_token_id is None: bos_token_id = hf_config.pad_token_id - init_kwargs = dict(empty_init=False) cfg = ModelConfig(hidden_size=hf_config.hidden_size, num_layers=hf_config.num_layers, num_attention_heads=hf_config.num_attention_heads, @@ -26,10 +25,8 @@ def build(cls, hf_config, model_path: str = None): bos_token_id=bos_token_id, eos_token_id=hf_config.eos_token_id, head_dim=head_dim, - vocab_size=hf_config.padded_vocab_size, - init_kwargs=init_kwargs) + vocab_size=hf_config.padded_vocab_size) # glm-4v if hasattr(hf_config, 'vision_config'): - cfg.unused_modules = ['transformer.vision'] cfg.cogvlm_style = True return cfg diff --git a/lmdeploy/pytorch/configurations/cogvlm.py b/lmdeploy/pytorch/configurations/cogvlm.py index 7a5b12489..f5fe5695e 100644 --- a/lmdeploy/pytorch/configurations/cogvlm.py +++ b/lmdeploy/pytorch/configurations/cogvlm.py @@ -18,7 +18,6 @@ def build(cls, hf_config, model_path: str = None): cfg = DefaultModelConfigBuilder.build(hf_config) if getattr(hf_config, 'num_multi_query_heads', None): cfg.num_key_value_heads = hf_config.num_multi_query_heads - cfg.unused_modules = ['model.vision'] cfg.cogvlm_style = True torch_dtype = 'bfloat16' if torch.cuda.is_bf16_supported( ) else 'float16' diff --git a/lmdeploy/pytorch/configurations/deepseek_v2.py b/lmdeploy/pytorch/configurations/deepseek_v2.py index 8c89cd50a..37aa4b0d6 100644 --- a/lmdeploy/pytorch/configurations/deepseek_v2.py +++ b/lmdeploy/pytorch/configurations/deepseek_v2.py @@ -19,7 +19,6 @@ def build(cls, hf_config, model_path: str = None): v_head_dim = 0 num_attention_heads = hf_config.num_attention_heads num_key_value_heads = 1 - init_kwargs = dict(attn_implementation='eager') return ModelConfig(hidden_size=hf_config.hidden_size, num_layers=hf_config.num_hidden_layers, num_attention_heads=num_attention_heads, @@ -30,5 +29,4 @@ def build(cls, hf_config, model_path: str = None): k_head_dim=k_head_dim, v_head_dim=v_head_dim, vocab_size=hf_config.vocab_size, - multi_query_attention=True, - init_kwargs=init_kwargs) + multi_query_attention=True) diff --git a/lmdeploy/pytorch/configurations/internvl.py b/lmdeploy/pytorch/configurations/internvl.py index 12be23be1..929b928dd 100644 --- a/lmdeploy/pytorch/configurations/internvl.py +++ b/lmdeploy/pytorch/configurations/internvl.py @@ -3,41 +3,6 @@ from .default import DefaultModelConfigBuilder -class ProxyAutoModel: - """wrapper of auto model class.""" - - def __init__(self) -> None: - """init.""" - from transformers import AutoModel - self.model_cls = AutoModel - - def from_config(self, *args, **kwargs): - """wrap from_config.""" - if hasattr(self.model_cls, '_from_config'): - if 'trust_remote_code' in kwargs: - kwargs.pop('trust_remote_code') - return self.model_cls._from_config(*args, **kwargs) - else: - return self.model_cls.from_config(*args, **kwargs) - - def from_pretrained(self, *args, **kwargs): - """wrap from_pretrained.""" - from transformers import AutoConfig - if hasattr(self.model_cls, - '_from_config') and 'trust_remote_code' in kwargs: - kwargs.pop('trust_remote_code') - config = AutoConfig.from_pretrained(args[0], trust_remote_code=True) - quantization_config = getattr(config.llm_config, 'quantization_config', - None) - if quantization_config is not None: - quantization_config['modules_to_not_convert'] = [ - 'lm_head', 'vision_model' - ] - config.quantization_config = quantization_config - kwargs['config'] = config - return self.model_cls.from_pretrained(*args, **kwargs) - - class InternVLModelConfigBuilder(AutoModelConfigBuilder): @classmethod @@ -49,7 +14,6 @@ def condition(cls, hf_config): def build(cls, hf_config, model_path: str = None): """build llava hf.""" cfg = DefaultModelConfigBuilder.build(hf_config.llm_config) - cfg.unused_modules = ['InternVisionModel'] + # cfg.unused_modules = ['InternVisionModel'] cfg.hf_config = hf_config - cfg.auto_model_cls = ProxyAutoModel() return cfg diff --git a/lmdeploy/pytorch/configurations/llava.py b/lmdeploy/pytorch/configurations/llava.py index c2bf4dd3e..aaeeeeadf 100644 --- a/lmdeploy/pytorch/configurations/llava.py +++ b/lmdeploy/pytorch/configurations/llava.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .builder import AutoModelConfigBuilder, ProxyAutoModel +from .builder import AutoModelConfigBuilder from .default import DefaultModelConfigBuilder @@ -16,18 +16,11 @@ def condition(cls, hf_config): def build(cls, hf_config, model_path: str = None): """build.""" arch = hf_config.architectures[0] - if arch == 'LlavaLlamaForCausalLM': + if arch in ['LlavaLlamaForCausalLM', 'LlavaMistralForCausalLM']: from llava.model.language_model.llava_llama import LlavaConfig - from llava.model.language_model.llava_llama import \ - LlavaLlamaForCausalLM as LlavaModel # reload hf_config due to model_type='llava' is already # registered in transformers hf_config = LlavaConfig.from_pretrained(model_path) - elif arch == 'LlavaMistralForCausalLM': - from llava.model.language_model.llava_mistral import \ - LlavaMistralForCausalLM as LlavaModel cfg = DefaultModelConfigBuilder.build(hf_config) - cfg.auto_model_cls = ProxyAutoModel(LlavaModel) - cfg.unused_modules = ['model.vision_tower', 'model.mm_projector'] return cfg diff --git a/lmdeploy/pytorch/configurations/llava_hf.py b/lmdeploy/pytorch/configurations/llava_hf.py index a7ff062bc..4cc007e31 100644 --- a/lmdeploy/pytorch/configurations/llava_hf.py +++ b/lmdeploy/pytorch/configurations/llava_hf.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from lmdeploy.pytorch.config import ModelConfig -from .builder import AutoModelConfigBuilder, ProxyAutoModel +from .builder import AutoModelConfigBuilder class LlavaHfModelConfigBuilder(AutoModelConfigBuilder): @@ -17,16 +17,6 @@ def condition(cls, hf_config): @classmethod def build(cls, hf_config, model_path: str = None): """build llava hf.""" - arch = hf_config.architectures[0] - if arch == 'LlavaForConditionalGeneration': - from transformers import \ - LlavaForConditionalGeneration as _LlavaModel - elif arch == 'LlavaNextForConditionalGeneration': - from transformers import \ - LlavaNextForConditionalGeneration as _LlavaModel - else: - raise RuntimeError(f'Unsupported Llava model arch: {arch}') - text_config = hf_config.text_config hidden_size = getattr(text_config, 'hidden_size', 4096) num_attention_heads = getattr(text_config, 'num_attention_heads', 32) @@ -45,7 +35,5 @@ def build(cls, hf_config, model_path: str = None): eos_token_id=eos_token_id, head_dim=head_dim, vocab_size=text_config.vocab_size, - unused_modules=['vision_tower', 'multi_modal_projector'], hf_config=hf_config, - auto_model_cls=ProxyAutoModel(_LlavaModel), ) diff --git a/lmdeploy/pytorch/configurations/phi3_vision.py b/lmdeploy/pytorch/configurations/phi3_vision.py index 5f633ca57..21f76f21a 100644 --- a/lmdeploy/pytorch/configurations/phi3_vision.py +++ b/lmdeploy/pytorch/configurations/phi3_vision.py @@ -14,6 +14,5 @@ def condition(cls, hf_config): def build(cls, hf_config, model_path: str = None): """build.""" cfg = DefaultModelConfigBuilder.build(hf_config) - cfg.unused_modules = ['model.vision_embed_tokens'] - cfg.init_kwargs = dict(_attn_implementation=None) + # cfg.unused_modules = ['model.vision_embed_tokens'] return cfg diff --git a/lmdeploy/pytorch/models/llama.py b/lmdeploy/pytorch/models/llama.py index 4d122bf9b..6f18f94bb 100644 --- a/lmdeploy/pytorch/models/llama.py +++ b/lmdeploy/pytorch/models/llama.py @@ -433,3 +433,22 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): else: param = params_dict[name] load_weight(param, loaded_weight) + + +class LlavaLlamaForCausalLM(LlamaForCausalLM): + """llava llama for causallm.""" + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """load weights.""" + + new_weights = dict() + for key, val in weights: + if key.startswith('model.vision_tower'): + continue + if key.startswith('model.mm_projector'): + continue + if key.startswith('model.image_newline'): + continue + new_weights[key] = val + + super().load_weights(new_weights.items()) diff --git a/lmdeploy/pytorch/models/llava.py b/lmdeploy/pytorch/models/llava.py index ff6658cac..68b090b55 100644 --- a/lmdeploy/pytorch/models/llava.py +++ b/lmdeploy/pytorch/models/llava.py @@ -1,24 +1,31 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, List +from typing import Any, Iterable, List, Tuple import torch from torch import nn +from transformers.configuration_utils import PretrainedConfig from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from .patch import get_rewrite_cls +from .patch import build_model_from_hf_config -class PatchedLlavaForConditionalGeneration(nn.Module): +class LlavaForConditionalGeneration(nn.Module): support_cuda_graph = True - def __init__(self, origin: nn.Module, ctx_mgr: StepContextManager): + def __init__(self, + config: PretrainedConfig, + ctx_mgr: StepContextManager, + dtype: torch.dtype = None, + device: torch.device = None): super().__init__() + self.config = config self.ctx_mgr = ctx_mgr - language_model_cls = get_rewrite_cls(origin.language_model) - self.language_model = language_model_cls(origin.language_model, - ctx_mgr) + text_config = config.text_config + self.language_model = build_model_from_hf_config(text_config, + dtype=dtype, + device=device) def forward( self, @@ -67,3 +74,16 @@ def prepare_inputs_for_generation( attn_metadata=attn_metadata, inputs_embeds=inputs_embeds, ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """load weights.""" + + prefix_length = len('language_model.') + new_weights = dict() + for key, val in weights: + if not key.startswith('language_model.'): + continue + new_key = key[prefix_length:] + new_weights[new_key] = val + + self.language_model.load_weights(new_weights.items()) diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index 056b2aea4..d2a6d6e9b 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -262,20 +262,17 @@ }) # llava -MODULE_MAP.update({ - 'llava.model.language_model.llava_llama.LlavaLlamaForCausalLM': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaForCausalLM', - 'llava.model.language_model.llava_llama.LlavaLlamaModel': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaModel', - 'llava.model.language_model.llava_mistral.LlavaMistralForCausalLM': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llava.PatchedLlavaLlamaForCausalLM', - 'llava.model.language_model.llava_mistral.LlavaMistralModel': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaModel', - 'transformers.models.llava.modeling_llava.LlavaForConditionalGeneration': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llava.PatchedLlavaForConditionalGeneration', # noqa: E501 - 'transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration': # noqa: E501 - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llava.PatchedLlavaForConditionalGeneration' -}) +MODULE_MAP.update( + { + 'LlavaLlamaForCausalLM': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlavaLlamaForCausalLM', + 'llava.model.language_model.llava_mistral.LlavaMistralForCausalLM': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llava.PatchedLlavaLlamaForCausalLM', + 'LlavaForConditionalGeneration': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llava.LlavaForConditionalGeneration', # noqa: E501 + 'LlavaNextForConditionalGeneration': # noqa: E501 + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llava.LlavaForConditionalGeneration' + }) # internvl MODULE_MAP.update({ diff --git a/lmdeploy/pytorch/models/patch.py b/lmdeploy/pytorch/models/patch.py index 0015b4974..cae8f43cf 100644 --- a/lmdeploy/pytorch/models/patch.py +++ b/lmdeploy/pytorch/models/patch.py @@ -7,6 +7,7 @@ from typing import Any, Dict, List import torch +from transformers.configuration_utils import PretrainedConfig from lmdeploy.utils import get_logger @@ -177,20 +178,28 @@ def _get_model_class(config, module_map): f'Can not found rewrite for architectures: {architectures}') -@torch.inference_mode() -def build_patched_model(config: ModelConfig, device: torch.device = None): - """build patched model.""" +def build_model_from_hf_config(model_config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + """build model from hf config.""" from lmdeploy.pytorch.model_inputs import StepContextManager ctx_mgr = StepContextManager() module_map = _get_module_map() - model_config = config.hf_config if device is None: device = torch.device('cuda') model_cls = _get_model_class(model_config, module_map) - model = model_cls(model_config, ctx_mgr, dtype=config.dtype, device=device) + model = model_cls(model_config, ctx_mgr, dtype=dtype, device=device) return model.eval() +@torch.inference_mode() +def build_patched_model(config: ModelConfig, device: torch.device = None): + """build patched model.""" + model_config = config.hf_config + dtype = config.dtype + return build_model_from_hf_config(model_config, dtype=dtype, device=device) + + @torch.inference_mode() def add_adapters(model: torch.nn.Module, kv_caches: List[List[torch.Tensor]], From 6d47b6373c99ae58f0015e832acd99fa697736d0 Mon Sep 17 00:00:00 2001 From: grimoire Date: Thu, 29 Aug 2024 20:02:25 +0800 Subject: [PATCH 50/89] add falcon --- lmdeploy/pytorch/models/falcon.py | 724 +++++++++++++++----------- lmdeploy/pytorch/models/functional.py | 159 ------ lmdeploy/pytorch/models/internlm.py | 138 ----- lmdeploy/pytorch/models/llama.py | 3 + lmdeploy/pytorch/models/module_map.py | 29 +- lmdeploy/pytorch/nn/linear.py | 45 +- 6 files changed, 468 insertions(+), 630 deletions(-) delete mode 100644 lmdeploy/pytorch/models/functional.py delete mode 100644 lmdeploy/pytorch/models/internlm.py diff --git a/lmdeploy/pytorch/models/falcon.py b/lmdeploy/pytorch/models/falcon.py index d68a0c32e..145c765d2 100644 --- a/lmdeploy/pytorch/models/falcon.py +++ b/lmdeploy/pytorch/models/falcon.py @@ -1,344 +1,466 @@ # Copyright (c) OpenMMLab. All rights reserved. -# Adapted from: -# https://huggingface.co/tiiuae/falcon-7b-instruct -# https://github.com/huggingface/transformers/blob/v4.33-release/src/transformers/models/falcon/modeling_falcon.py # noqa - -from typing import Optional, Tuple, Union +from typing import Any, Iterable, List, Optional, Tuple import torch -import torch.distributed as dist -import torch.nn as nn -import torch.utils.checkpoint -from transformers.modeling_outputs import \ - BaseModelOutputWithPastAndCrossAttentions +from torch import nn +from transformers.configuration_utils import PretrainedConfig + +from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, EmbeddingType, + build_rotary_embedding) +from lmdeploy.pytorch.nn.linear import (build_colwise_linear, build_qkv_proj, + build_rowwise_linear) +from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight + + +class FalconAttention(torch.nn.Module): + """Parallel self-attention layer abstract class. + + Self-attention layer takes input with size [s, b, h] and returns output of + the same size. + """ + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.num_kv_heads = self.num_attention_heads + self.head_size = (self.hidden_size // config.num_attention_heads) + self.multi_query_attention = config.multi_query + if self.multi_query_attention: + self.num_kv_heads = 1 + self.query_key_value = build_qkv_proj( + config.hidden_size, + num_q_heads=self.num_attention_heads, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + bias=config.bias, + replicate_kv=self.multi_query_attention, + quant_config=quantization_config, + dtype=dtype, + device=device, + ) -from ..kernels import (alibi_paged_attention_fwd, apply_rotary_pos_emb, - fill_kv_cache, fused_rotary_emb, paged_attention_fwd) -from ..weight_loader.dist_utils import (colwise_parallelize_linear, - rowwise_parallelize_linear) + # apply rotary + self.apply_rotary_pos_emb = ApplyRotaryEmb() + self.rotary = config.rotary + # attention + self.attn_fwd = Attention( + self.num_attention_heads, + self.head_size, + num_kv_heads=self.num_kv_heads, + alibi=config.alibi, + ) -class PatchedFalconAttention(nn.Module): + # o_proj + self.dense = build_rowwise_linear(self.hidden_size, + config.hidden_size, + bias=config.bias, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True, + tp_align_size=self.head_size) - def _load_weights(self, loader, rank: int, world_size: int, - device: torch.device): - """load weights.""" - if self.multi_query: - weight = loader.pop('query_key_value.weight') - weight = weight.unflatten(0, (-1, self.head_dim)) - q_weight = weight[:self.num_heads] - q_weight = q_weight.chunk(world_size, 0)[rank] - kv_weight = weight[-2:] - weight = torch.cat([q_weight, kv_weight]) - weight = torch.nn.Parameter(weight.flatten(0, 1), - requires_grad=False) - self.query_key_value.register_parameter('weight', weight) - else: - colwise_parallelize_linear(self.query_key_value, - loader, - rank=rank, - world_size=world_size, - prefix='query_key_value') - - # dense - weight = loader.pop('dense.weight') - if self.multi_query: - weight = weight.reshape(self.hidden_size, -1, self.head_dim) - else: - weight = weight.reshape( - self.hidden_size, - -1, # num groups - self.num_heads * self.head_dim, + def forward( + self, + hidden_states: torch.Tensor, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attn_metadata: Any = None, + ): + """Rewrite of LlamaAttention.forward.""" + # qkv proj + qkv_states = self.query_key_value(hidden_states) + # (-1, heads, head_dim) + qkv_states = qkv_states.flatten(0, -2) + (query_states, key_states, + value_states) = self.query_key_value.split_qkv(qkv_states) + + # apply rotary embedding + if self.rotary: + cos, sin = rotary_pos_emb + query_states, key_states = self.apply_rotary_pos_emb( + query_states, + key_states, + cos, + sin, + inplace=True, ) - weight = weight.chunk(world_size, 1)[rank] - weight = torch.nn.Parameter(weight.reshape(self.hidden_size, -1), - requires_grad=False) - self.dense.register_parameter('weight', weight) - - @classmethod - def _distribute_output_fn(cls, outputs, **kwargs): - """Distribution output hook.""" - dist.all_reduce(outputs[0]) - return outputs - def _split_heads( - self, fused_qkv: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Split the last dimension into (num_heads, head_dim), results share - same memory storage as `fused_qkv` - - Args: - fused_qkv (`torch.tensor`, *required*): - [batch_size, seq_length, num_heads * 3 * head_dim] - - Returns: - query: [batch_size, seq_length, num_heads, head_dim] - key: [batch_size, seq_length, num_heads, head_dim] - value: [batch_size, seq_length, num_heads, head_dim] - """ - if self.new_decoder_architecture: - # e.g. 40b-instruct model - batch, seq_len, _ = fused_qkv.shape - qkv = fused_qkv.view(batch, seq_len, -1, - self.num_heads // self.num_kv_heads + 2, - self.head_dim) - query = qkv[:, :, :, :-2] - key = qkv[:, :, :, [-2]] - value = qkv[:, :, :, [-1]] - # because cache_engine & kernel - # already handled grouped attention - # removing broadcast make it faster and more memory-saving - # key = torch.broadcast_to(key, query.shape) - # value = torch.broadcast_to(value, query.shape) - - query, key, value = [x.flatten(2, 3) for x in (query, key, value)] - return query, key, value - elif not self.multi_query: - # e.g. rw-1b model - batch_size, seq_length, three_times_hidden_size = fused_qkv.shape - fused_qkv = fused_qkv.view(batch_size, seq_length, - self.num_heads // dist.get_world_size(), - 3, self.head_dim) - return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., - 2, :] + # attention + attn_output = self.attn_fwd( + query_states, + key_states, + value_states, + past_key_value[0], + past_key_value[1], + attn_metadata, + inplace=True, + ) + attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1) + + # o proj + attn_output = self.dense(attn_output) + return attn_output + + +class FalconMLP(nn.Module): + """Falcon mlp.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + + self.add_bias = config.bias + ffn_hidden_size = getattr(config, 'ffn_hidden_size', + config.hidden_size * 4) + # gate up + self.dense_h_to_4h = build_colwise_linear( + config.hidden_size, + ffn_hidden_size, + bias=self.add_bias, + dtype=dtype, + device=device, + quant_config=quantization_config, + is_tp=True, + ) + + # silu and mul + self.act_fn = nn.GELU() + + # down + self.dense_4h_to_h = build_rowwise_linear( + ffn_hidden_size, + config.hidden_size, + bias=self.add_bias, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward(self, x): + """forward.""" + gate_up = self.dense_h_to_4h(x) + act = self.act_fn(gate_up) + return self.dense_4h_to_h(act) + + +class FalconDecoderLayer(nn.Module): + """falcon decoder layer.""" + + def __init__(self, + config: PretrainedConfig, + layer_idx: int, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.layer_idx = layer_idx + self.config = config + hidden_size = config.hidden_size + + # build attention layer + self.self_attention = FalconAttention(config, + dtype=dtype, + device=device) + + # builf MLP + self.mlp = FalconMLP(config, dtype=dtype, device=device) + + if not hasattr(config, 'num_ln_in_parallel_attn'): + config.num_ln_in_parallel_attn = None + if (config.num_ln_in_parallel_attn is None + and config.new_decoder_architecture): + config.num_ln_in_parallel_attn = 2 + + if not config.parallel_attn: + self.post_attention_layernorm = nn.LayerNorm( + hidden_size, + eps=config.layer_norm_epsilon, + dtype=dtype, + device=device) + self.input_layernorm = nn.LayerNorm(hidden_size, + eps=config.layer_norm_epsilon, + dtype=dtype, + device=device) else: - # e.g. 7b-instruct model - fused_qkv = fused_qkv.unflatten(-1, (-1, self.head_dim)) - split_shape = (fused_qkv.size(-2) - 2, 1, 1) - return fused_qkv.split(split_shape, dim=-2) + if config.num_ln_in_parallel_attn == 2: + # The layer norm before self-attention + self.ln_attn = nn.LayerNorm(hidden_size, + eps=config.layer_norm_epsilon, + dtype=dtype, + device=device) + # The layer norm before the MLP + self.ln_mlp = nn.LayerNorm(hidden_size, + eps=config.layer_norm_epsilon, + dtype=dtype, + device=device) + else: + self.input_layernorm = nn.LayerNorm( + hidden_size, + eps=config.layer_norm_epsilon, + dtype=dtype, + device=device) - def _contiguous_batching_forward( + def forward( self, hidden_states: torch.Tensor, - alibi: Optional[torch.Tensor], - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - output_attentions: bool = False, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], + past_key_value: Optional[List[torch.FloatTensor]], + attn_metadata: Any = None, ): - # prepare inputs for continuous batch forwarding - context = self.context.context - q_start_loc = context.q_start_loc - q_seq_length = context.q_seq_length - kv_seq_length = context.kv_seq_length - max_q_seq_length = context.max_q_seq_length - block_offsets = context.block_offsets - position_ids_1d = context.position_ids_1d - max_kv_seq_length = context.max_kv_seq_length - - def __maybe_rotary_fn(query_states, key_states, value_states): - scaling_factor = 1.0 - inv_freq = self.maybe_rotary.inv_freq - query_states, key_states = fused_rotary_emb( - query_states[None], - key_states[None], - position_ids_1d[None], - inv_freq=inv_freq, - scaling_factor=scaling_factor, - out_q=query_states[None], - out_k=key_states[None]) - return query_states[0], key_states[0], value_states - - def __rotary_emb_fn(query_states, key_states, value_states): - """rotary embedding func.""" - cos, sin = self.rotary_emb(value_states.transpose(0, 1), - max_kv_seq_length) - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin, context.position_ids, - position_ids_1d) - return query_states, key_states, value_states - - fused_qkv = self.query_key_value(hidden_states) - - # 3 x [batch_size, seq_length, num_heads, head_dim] - (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) - - query_layer = query_layer.flatten(0, 1) - key_layer = key_layer.flatten(0, 1) - value_layer = value_layer.flatten(0, 1) - if hasattr(self, 'maybe_rotary'): - query_layer, key_layer, value_layer = __maybe_rotary_fn( - query_layer, key_layer, value_layer) - elif hasattr(self, 'rotary_emb'): - query_layer, key_layer, value_layer = __rotary_emb_fn( - query_layer, key_layer, value_layer) - - past_key, past_value = layer_past - fill_kv_cache( - key_layer.contiguous(), - value_layer.contiguous(), - past_key, - past_value, - q_start_loc, - q_seq_length, - kv_seq_length=kv_seq_length, - max_q_seq_length=max_q_seq_length, - block_offsets=block_offsets, + + residual = hidden_states + if (self.config.new_decoder_architecture + and self.config.num_ln_in_parallel_attn == 2): + attention_layernorm_out = self.ln_attn(hidden_states) + mlp_layernorm_out = self.ln_mlp(hidden_states) + else: + attention_layernorm_out = self.input_layernorm(hidden_states) + + # Self Attention + attention_output = self.self_attention( + hidden_states=attention_layernorm_out, + rotary_pos_emb=rotary_pos_emb, + past_key_value=past_key_value, + attn_metadata=attn_metadata, ) - attn_output = query_layer + # Fully Connected + if not self.config.new_decoder_architecture: + if self.config.parallel_attn: + mlp_layernorm_out = attention_layernorm_out + else: + residual = attention_output + residual + mlp_layernorm_out = self.post_attention_layernorm(residual) - if not alibi: - paged_attention_fwd(q=query_layer, - k=past_key, - v=past_value, - o=attn_output, - block_offsets=block_offsets, - q_start_loc=q_start_loc, - q_seqlens=q_seq_length, - kv_seqlens=kv_seq_length, - max_seqlen=max_q_seq_length) + if (self.config.new_decoder_architecture and self.config.parallel_attn + and self.config.num_ln_in_parallel_attn == 1): + mlp_layernorm_out = attention_layernorm_out - else: - num_heads_full = self.num_heads - head_offset = 0 - if dist.is_initialized(): - world_size = dist.get_world_size() - rank = dist.get_rank() - head_offset = self.num_heads // world_size * rank - alibi_paged_attention_fwd(q=query_layer, - k=past_key, - v=past_value, - o=attn_output, - block_offsets=block_offsets, - b_start_loc=q_start_loc, - b_seq_len=q_seq_length, - b_kv_seq_len=kv_seq_length, - max_input_len=max_q_seq_length, - head_offset=head_offset, - num_heads=num_heads_full, - alibi_scale=self.inv_norm_factor) - - attn_output = attn_output[None].flatten(-2, -1) - output_tensor = self.dense(attn_output) - - if output_attentions: - return output_tensor, layer_past, None - else: - return output_tensor, layer_past + # MLP. + mlp_output = self.mlp(mlp_layernorm_out) - def forward(self, - hidden_states: torch.Tensor, - alibi: Optional[torch.Tensor], - attention_mask: torch.Tensor, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - head_mask: Optional[torch.Tensor] = None, - use_cache: bool = False, - output_attentions: bool = False, - **kwargs): - return self._contiguous_batching_forward(hidden_states, alibi, - layer_past) + if self.config.new_decoder_architecture or self.config.parallel_attn: + mlp_output += attention_output + outputs = mlp_output + residual + return outputs -class PatchedFalconMLP(nn.Module): - def _load_weights(self, loader, rank: int, world_size: int, - device: torch.device): - """load weights.""" - for mod_name in ['dense_h_to_4h']: - colwise_parallelize_linear(getattr(self, mod_name), - loader, - rank=rank, - world_size=world_size, - prefix=mod_name) - rowwise_parallelize_linear(self.dense_4h_to_h, - loader, - rank=rank, - world_size=world_size, - prefix='dense_4h_to_h') - - @classmethod - def _distribute_output_fn(cls, outputs, **kwargs): - """Distribution output hook.""" - dist.all_reduce(outputs) - return outputs +class FalconModel(nn.Module): + """falcon model.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.embed_dim = config.hidden_size + self.word_embeddings = nn.Embedding(config.vocab_size, + self.embed_dim, + dtype=dtype, + device=device) -class PatchedFalconModel(nn.Module): + # build all decode layers + self.h = nn.ModuleList([ + FalconDecoderLayer(config, layer_idx, dtype=dtype, device=device) + for layer_idx in range(config.num_hidden_layers) + ]) - def _contiguous_batching_forward( + self.ln_f = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_epsilon, + dtype=dtype, + device=device) + + scaling_factor = 1.0 + if not hasattr(config, 'rope_scaling'): + emb_type = EmbeddingType.LinearScaling + else: + rope_scaling = config.rope_scaling + rope_type = rope_scaling['rope_type'] + if rope_type == 'linear': + emb_type = EmbeddingType.LinearScaling + elif rope_type == 'dynamic': + emb_type = EmbeddingType.DynamicNTKScaling + else: + raise RuntimeError(f'Unsupported rope type: {rope_type}') + scaling_factor = rope_scaling.get('scaling_factor', scaling_factor) + + rope_dim = config.hidden_size // config.num_attention_heads + rope_max_pos_emb = getattr(config, 'max_position_embeddings', 2048) + rope_base = getattr(config, 'rope_base', 10000) + self.rotary_emb = build_rotary_embedding( + rope_dim, + rope_max_pos_emb, + rope_base, + scaling_factor, + emb_type=emb_type, + ) + + def forward( self, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], - ...]] = None, - head_mask: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.LongTensor] = None, - ) -> Union[Tuple[torch.Tensor, ...], - BaseModelOutputWithPastAndCrossAttentions]: - - output_attentions = False - use_cache = True - use_alibi = getattr(self, 'use_alibi', getattr(self, 'alibi', False)) + input_ids: torch.LongTensor = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + attn_metadata: Any = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ): + """Rewrite of LlamaModel.forward.""" + # token embedding if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) - head_mask = self.get_head_mask(head_mask, - self.config.num_hidden_layers) - hidden_states = inputs_embeds - # Compute alibi tensor: check build_alibi_tensor documentation + # rotary embedding + cos, sin = self.rotary_emb(hidden_states, position_ids) + cos, sin = cos[0], sin[0] + rotary_pos_emb = (cos, sin) - for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): - outputs = block( + # decoding + for idx, decoder_layer in enumerate(self.h): + past_key_value = past_key_values[idx] + hidden_states = decoder_layer( hidden_states, - layer_past=layer_past, - attention_mask=None, - head_mask=head_mask[i], - use_cache=use_cache, - output_attentions=output_attentions, - alibi=use_alibi, + rotary_pos_emb=rotary_pos_emb, + past_key_value=past_key_value, + attn_metadata=attn_metadata, ) - hidden_states = outputs[0] - # Add last hidden state + # norm hidden_states = self.ln_f(hidden_states) - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=past_key_values, - hidden_states=None, - attentions=None, - ) + return hidden_states + + def get_input_embeddings(self): + """get input embeddings.""" + return self.word_embeddings + + +class FalconForCausalLM(nn.Module): + """rewrote model of FalconForCausalLM.""" + + support_cuda_graph = True + + def __init__(self, + config: PretrainedConfig, + ctx_mgr: StepContextManager, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.ctx_mgr = ctx_mgr + # build LLamaModel + self.transformer = FalconModel(config, dtype=dtype, device=device) + # build lm_head + self.lm_head = build_rowwise_linear(config.hidden_size, + config.vocab_size, + bias=False, + dtype=dtype, + device=device) def forward( self, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], - ...]] = None, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple[torch.Tensor, ...], - BaseModelOutputWithPastAndCrossAttentions]: - return self._contiguous_batching_forward( - input_ids=input_ids, past_key_values=past_key_values) - - -class PatchedFalconForCausalLM(nn.Module): + input_ids: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: List[List[torch.Tensor]], + attn_metadata: Any = None, + inputs_embeds: torch.Tensor = None, + **kwargs, + ): + """model forward, return logits.""" + hidden_states = self.transformer( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + ) - def forward( + logits = self.lm_head(hidden_states) + logits = logits.float() + return logits + + def get_input_embeddings(self): + """get input embeddings.""" + return self.transformer.get_input_embeddings() + + def prepare_inputs_for_generation( self, - input_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], - ...]] = None, - return_dict: Optional[bool] = True, - output_attentions: Optional[bool] = False, - output_hidden_states: Optional[bool] = False, - use_origin: Optional[bool] = True, - ) -> Union[Tuple[torch.Tensor, ...], - BaseModelOutputWithPastAndCrossAttentions]: - """Forward function, patched to ignore position_ids.""" - - outputs = self.origin_mod(input_ids=input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict) - return outputs + past_key_values: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + context: StepContext = None, + ): + """prepare input.""" + # get input_ids, position_ids and attention metadatas + input_ids = context.input_ids + position_ids = context.position_ids + attn_metadata = context.attn_metadata + + # process vision embeddings + vision_embeddings = context.input_embeddings + vision_embedding_indexing = context.input_embedding_indexing + if vision_embeddings is not None and len(vision_embeddings) > 0: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + inputs_embeds[:, + vision_embedding_indexing, :] = vision_embeddings.to( + inputs_embeds) + + # inputs of forward + return dict( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """load weights.""" + # modify from vllm + + config = self.config + num_attention_heads = config.num_attention_heads + hidden_size = config.hidden_size + num_kv_heads = num_attention_heads + head_size = (hidden_size // num_attention_heads) + multi_query_attention = config.multi_query + if multi_query_attention: + num_kv_heads = 1 + qkv_section = [ + head_size * num_attention_heads, head_size * num_kv_heads, + head_size * num_kv_heads + ] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if 'rotary_pos_emb.inv_freq' in name: + continue + if ('rotary_pos_emb.cos_cached' in name + or 'rotary_pos_emb.sin_cached' in name): + continue + if (self.config.tie_word_embeddings + and 'output_layer.weight' in name): + continue + if '.query_key_value' in name: + q, k, v = loaded_weight.split(qkv_section) + param = params_dict[name] + load_weight(param, q, shard_id='q') + load_weight(param, k, shard_id='k') + load_weight(param, v, shard_id='v') + else: + param = params_dict[name] + load_weight(param, loaded_weight) diff --git a/lmdeploy/pytorch/models/functional.py b/lmdeploy/pytorch/models/functional.py deleted file mode 100644 index 04b8acb0f..000000000 --- a/lmdeploy/pytorch/models/functional.py +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import math -from typing import Any - -import torch - -from ..kernels import apply_rotary_pos_emb - -__all__ = ['apply_rotary_pos_emb'] - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., :x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2:] - return torch.cat((-x2, x1), dim=-1) - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """This is the equivalent of torch.repeat_interleave(x, dim=1, - repeats=n_rep). - - The hidden states go from (num_key_value_heads, seqlen, head_dim) to - (num_attention_heads, seqlen, head_dim) - """ - if n_rep == 1: - return hidden_states - num_key_value_heads, slen, head_dim = hidden_states.shape - hidden_states = hidden_states[:, - None, :, :].expand(num_key_value_heads, - n_rep, slen, head_dim) - return hidden_states.reshape(num_key_value_heads * n_rep, slen, head_dim) - - -def generate_batched_mask(q_lens, - k_lens, - max_q_len: int = None, - max_k_len: int = None, - device='cuda'): - """Generate batched mask.""" - if max_q_len is None: - max_q_len = max(q_lens) - - if max_k_len is None: - max_k_len = max(k_lens) - - q_range = torch.arange(max_q_len).to(device) - k_range = torch.arange(max_k_len).to(device) - - cross = k_range.unsqueeze(0) - q_range.unsqueeze(1) - cross = cross.unsqueeze(0) - - threshold = (k_lens - q_lens).view(-1, 1, 1) - mask = torch.where(cross <= threshold, 1, 0).to(device) - for idx, q_len in enumerate(q_lens): - mask[idx, q_len:, :] = 0 - return mask - - -def get_slopes(n: int): - """Get alibi slopes.""" - - def _get_interleave_power_of_2(n): - start = 2**(-(2**-(math.log2(n) - 3))) - ratio = start - return [start * ratio**i for i in range(n)] - - if math.log2(n).is_integer(): - return _get_interleave_power_of_2(n) - else: - closest_power_of_2 = 2**math.floor(math.log2(n)) - return ( - _get_interleave_power_of_2(closest_power_of_2) + - get_slopes(2 * closest_power_of_2)[0::2][:n - closest_power_of_2]) - - -@torch.no_grad() -def get_alibi_biases(n_heads: int, mask: torch.Tensor): - """Get alibi bias.""" - m = torch.tensor(get_slopes(n_heads)).to(mask.device) - distance = mask.cumsum(dim=-1) - 1 - return distance * m[None, :, None, None] - - -def quant_kv(key: torch.Tensor, value: torch.Tensor, out_type: torch.dtype): - """Quantize key and value of attention to `out_type`. - - Args: - key (torch.Tensor): Attention key. - value (torch.Tensor): Attention value. - out_type (torch.dtype): Output data type. - """ - assert out_type is torch.int8 - # quantize key and value - _min = torch.min(key, axis=-1).values - _max = torch.max(key, axis=-1).values - key_zp = (_min + _max) / 2 - key_scale = (_max - key_zp) / 127 - key_int8 = torch.round( - (key - key_zp[:, :, None]) / key_scale[:, :, None]).to(out_type) - - _min = torch.min(value, axis=-1).values - _max = torch.max(value, axis=-1).values - value_zp = (_min + _max) / 2 - value_scale = (_max - value_zp) / 127 - value_int8 = torch.round( - (value - value_zp[:, :, None]) / value_scale[:, :, None]).to(out_type) - - # wrap zp and scale to qparams - qparams = { - 'key_zp': key_zp, - 'key_scale': key_scale, - 'value_zp': value_zp, - 'value_scale': value_scale, - } - return key_int8, value_int8, qparams - - -def dequant_kv(context: Any, layer_id: str, key_int8: torch.Tensor, - value_int8: torch.Tensor, out_type: torch.dtype): - """Dequantize key and value of attention to `out_type`. - - Args: - context (Any): StepContext during inference. - layer_id (str): Layer object id. - key (torch.Tensor): Quantized attention key. - value (torch.Tensor): Quantized attention value. - out_type (torch.dtype): output data type. - """ - qparams = context.get_output(layer_id) - - key_scale = qparams['key_scale'] - key_zp = qparams['key_zp'] - key_float = (key_int8 * key_scale[:, :, None] + - key_zp[:, :, None]).to(out_type) - - value_scale = qparams['value_scale'] - value_zp = qparams['value_zp'] - value_float = (value_int8 * value_scale[:, :, None] + - value_zp[:, :, None]).to(out_type) - return key_float, value_float - - -def sync_qparam_to_context(context: Any, layer_id: str, qparams: dict): - """Merge quantization param to context. - - Args: - context (Any): StepContext during inference. - layer_id (str): Layer object id. - qparams (dict): Quantization param of current step. - """ - if context.inputs.meta is not None: - last_qparam = context.inputs.meta[layer_id] - for _k in last_qparam.keys(): - _v = torch.concat([last_qparam[_k], qparams[_k]], axis=0) - last_qparam[_k] = _v - context.set_output(layer_id, last_qparam) - else: - context.set_output(layer_id, qparams) diff --git a/lmdeploy/pytorch/models/internlm.py b/lmdeploy/pytorch/models/internlm.py deleted file mode 100644 index 4e4e14037..000000000 --- a/lmdeploy/pytorch/models/internlm.py +++ /dev/null @@ -1,138 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from typing import Optional, Tuple - -import torch -import torch.distributed as dist -from torch import nn - -from ..kernels import apply_rotary_pos_emb, fill_kv_cache, paged_attention_fwd -from ..weight_loader.dist_utils import (colwise_parallelize_linear, - rowwise_parallelize_linear) - - -class PatchedInternLMAttention(nn.Module): - - def _load_weights(self, loader, rank: int, world_size: int, - device: torch.device): - """load weights.""" - for mod_name in ['q_proj', 'k_proj', 'v_proj']: - colwise_parallelize_linear(getattr(self, mod_name), - loader, - rank=rank, - world_size=world_size, - prefix=mod_name) - rowwise_parallelize_linear(self.o_proj, - loader, - rank=rank, - world_size=world_size, - prefix='o_proj') - - @classmethod - def _distribute_output_fn(cls, outputs, **kwargs): - """Distribution output hook.""" - dist.all_reduce(outputs[0]) - return outputs - - def _contiguous_batching_forward_impl( - self, - hidden_states: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - world_size: int = 1, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - """Rewrite implementation of LlamaAttention.forward. - - Add continuous batching support. Add paged attention support. TP - support. - """ - context = self.context.context - q_start_loc = context.q_start_loc - kv_seq_length = context.kv_seq_length - q_seq_length = context.q_seq_length - block_offsets = context.block_offsets - max_q_seq_length = context.max_q_seq_length - max_kv_seq_length = context.max_kv_seq_length - - num_heads = self.num_heads // world_size - num_kv_heads = num_heads - head_dim = self.head_dim - hidden_size = num_heads * head_dim - - def __qkv_proj(hidden_states): - """qkv proj.""" - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - return query_states, key_states, value_states - - def __rotary_emb_fn(query_states, key_states, value_states): - """rotary embedding func.""" - cos, sin = self.rotary_emb(value_states, seq_len=max_kv_seq_length) - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin, position_ids, - context.position_ids_1d) - return query_states, key_states, value_states - - query_states, key_states, value_states = __qkv_proj(hidden_states) - - query_states = query_states.view(-1, num_heads, head_dim) - key_states = key_states.view(-1, num_kv_heads, head_dim) - value_states = value_states.view(-1, num_kv_heads, head_dim) - - query_states, key_states, value_states = __rotary_emb_fn( - query_states, key_states, value_states) - - fill_kv_cache( - key_states, - value_states, - past_key_value[0], - past_key_value[1], - q_start_loc, - q_seq_length, - kv_seq_length=kv_seq_length, - max_q_seq_length=max_q_seq_length, - block_offsets=block_offsets, - ) - - attn_output = query_states - paged_attention_fwd( - query_states, - past_key_value[0], - past_key_value[1], - attn_output, - block_offsets, - q_start_loc=q_start_loc, - q_seqlens=q_seq_length, - kv_seqlens=kv_seq_length, - max_seqlen=max_q_seq_length, - ) - attn_output = attn_output.reshape(*hidden_states.shape[:-1], - hidden_size) - - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - """forward.""" - world_size = 1 - if dist.is_initialized(): - world_size = dist.get_world_size() - return self._contiguous_batching_forward_impl( - hidden_states, - position_ids, - past_key_value, - world_size=world_size, - ) diff --git a/lmdeploy/pytorch/models/llama.py b/lmdeploy/pytorch/models/llama.py index 6f18f94bb..85ef0cac0 100644 --- a/lmdeploy/pytorch/models/llama.py +++ b/lmdeploy/pytorch/models/llama.py @@ -102,6 +102,7 @@ def forward( class LlamaMLP(nn.Module): + """llama mlp.""" def __init__(self, config: LlamaConfig, @@ -140,6 +141,7 @@ def forward(self, x): class LlamaDecoderLayer(nn.Module): + """llama decoder layer.""" def __init__(self, config: LlamaConfig, @@ -205,6 +207,7 @@ def forward( class LlamaModel(nn.Module): + """llama model.""" def __init__(self, config: LlamaConfig, diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index d2a6d6e9b..7b223d8d9 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -10,23 +10,8 @@ # Falcon Models in transformer / on hub MODULE_MAP.update({ - 'modeling_falcon.FalconAttention': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.falcon.PatchedFalconAttention', - 'modeling_falcon.FalconModel': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.falcon.PatchedFalconModel', - 'modeling_falcon.FalconMLP': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.falcon.PatchedFalconMLP', - 'modeling_falcon.FalconForCausalLM': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.falcon.PatchedFalconForCausalLM', - # for old implementations on hub - 'modelling_RW.Attention': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.falcon.PatchedFalconAttention', - 'modelling_RW.MLP': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.falcon.PatchedFalconMLP', - 'modelling_RW.RWModel': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.falcon.PatchedFalconModel', - 'modelling_RW.RotaryEmbedding': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.falcon.PatchedFalconRotaryEmbedding', + 'FalconForCausalLM': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.falcon.FalconForCausalLM', }) # baichuan @@ -43,14 +28,8 @@ # internlm MODULE_MAP.update({ - 'modeling_internlm.InternLMAttention': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.internlm.PatchedInternLMAttention', - 'modeling_internlm.InternLMModel': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaModel', - 'modeling_internlm.InternLMMLP': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaMLP', - 'modeling_internlm.InternLMRMSNorm': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaRMSNorm', + 'InternLMForCausalLM': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaForCausalLM', }) # internlm2 diff --git a/lmdeploy/pytorch/nn/linear.py b/lmdeploy/pytorch/nn/linear.py index 9ccd0ef2a..c5ca22185 100644 --- a/lmdeploy/pytorch/nn/linear.py +++ b/lmdeploy/pytorch/nn/linear.py @@ -695,8 +695,10 @@ def __init__( colwise: bool = True, is_tp: bool = False, all_reduce: bool = True, + tp_align_size: int = 1, ): super().__init__() + self.tp_align_size = tp_align_size if device is None: device = torch.device('cpu') if dtype is None: @@ -731,16 +733,23 @@ def _get_io_features(self, in_features: int, out_features: int, """get io features.""" world_size, rank = get_world_rank() if colwise: - out_features = get_distribute_size(out_features, world_size, rank) + out_features = get_distribute_size(out_features, + world_size, + rank, + align=self.tp_align_size) else: - in_features = get_distribute_size(in_features, world_size, rank) + in_features = get_distribute_size(in_features, + world_size, + rank, + align=self.tp_align_size) return in_features, out_features def _weight_loader_tp_colwise(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, rank: int, world_size: int): """weight loader for colwise linear.""" - weight = loaded_weight.chunk(world_size, 0)[rank] + weight = _chunk_align(loaded_weight, world_size, 0, + self.tp_align_size)[rank] return default_weight_loader(param, weight) def _weight_loader_tp_rowwise(self, param: torch.nn.Parameter, @@ -748,7 +757,8 @@ def _weight_loader_tp_rowwise(self, param: torch.nn.Parameter, world_size: int): """weight loader for rowwise linear.""" if loaded_weight.dim() == 2: - weight = loaded_weight.chunk(world_size, 1)[rank] + weight = _chunk_align(loaded_weight, world_size, 1, + self.tp_align_size)[rank] return default_weight_loader(param, weight) else: # bias @@ -908,6 +918,21 @@ def _update_all_out_features(self, all_out_features: List[int], """update all out features.""" return all_out_features + def weight_loader(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, shard_id: Any): + """weight loader.""" + world_size, rank = get_world_rank() + shard_idx = self.out_names_map[shard_id] + param_w = param.data.split(self.all_out_features, 0)[shard_idx] + if not self.replicate[shard_idx]: + if shard_idx in [0, 1]: + loaded_weight = _chunk_align(loaded_weight, world_size, 0, + self.head_size)[rank] + if shard_idx == 2: + loaded_weight = _chunk_align(loaded_weight, world_size, 0, + self.head_size_v)[rank] + param_w.copy_(loaded_weight) + def build_linear(in_features: int, out_features: int, @@ -917,7 +942,8 @@ def build_linear(in_features: int, colwise: bool = True, is_tp: bool = False, quant_config: Any = None, - all_reduce: bool = True) -> nn.Module: + all_reduce: bool = True, + tp_align_size: int = 1) -> nn.Module: """build linear.""" if is_tp: world_size, _ = get_world_rank() @@ -933,6 +959,7 @@ def build_linear(in_features: int, colwise=colwise, is_tp=is_tp, all_reduce=all_reduce, + tp_align_size=tp_align_size, ) quant_method = quant_config['quant_method'] @@ -971,6 +998,7 @@ def build_colwise_linear(in_features: int, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, is_tp: bool = False, + tp_align_size: int = 1, quant_config: Any = None) -> nn.Module: """build columnwise parallel linear layer.""" return build_linear(in_features=in_features, @@ -981,7 +1009,8 @@ def build_colwise_linear(in_features: int, colwise=True, is_tp=is_tp, quant_config=quant_config, - all_reduce=False) + all_reduce=False, + tp_align_size=tp_align_size) def build_rowwise_linear(in_features: int, @@ -990,6 +1019,7 @@ def build_rowwise_linear(in_features: int, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, is_tp: bool = False, + tp_align_size: int = 1, quant_config: Any = None, all_reduce: bool = True) -> nn.Module: """build rowwise parallel linear layer.""" @@ -1001,7 +1031,8 @@ def build_rowwise_linear(in_features: int, colwise=False, is_tp=is_tp, quant_config=quant_config, - all_reduce=all_reduce) + all_reduce=all_reduce, + tp_align_size=tp_align_size) def build_merged_colwise_linear( From 49e51cacf86e0184b93dd2f11e7d2f796a7001fd Mon Sep 17 00:00:00 2001 From: grimoire Date: Fri, 30 Aug 2024 11:38:56 +0800 Subject: [PATCH 51/89] add internlm2 and mistral --- .../backends/default/rotary_embedding.py | 2 +- lmdeploy/pytorch/models/internlm2.py | 762 +++++++++--------- lmdeploy/pytorch/models/mistral.py | 554 +++++++++---- lmdeploy/pytorch/models/module_map.py | 65 +- 4 files changed, 792 insertions(+), 591 deletions(-) diff --git a/lmdeploy/pytorch/backends/default/rotary_embedding.py b/lmdeploy/pytorch/backends/default/rotary_embedding.py index 4e7e0ba1c..fa151ea87 100644 --- a/lmdeploy/pytorch/backends/default/rotary_embedding.py +++ b/lmdeploy/pytorch/backends/default/rotary_embedding.py @@ -108,7 +108,7 @@ def forward(self, x: torch.Tensor, position_ids: torch.Tensor): cos, sin = _rotary_embedding_fwd(position_ids, inv_freq, - scaling_factor=self.scaling_factor, + scaling_factor=1.0, dtype=dtype, device_type=device_type) return cos, sin diff --git a/lmdeploy/pytorch/models/internlm2.py b/lmdeploy/pytorch/models/internlm2.py index 68f4ea7f2..2526f6709 100644 --- a/lmdeploy/pytorch/models/internlm2.py +++ b/lmdeploy/pytorch/models/internlm2.py @@ -1,442 +1,436 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import List, Optional, Tuple, Union +from typing import Any, Iterable, List, Optional, Tuple import torch -import torch.distributed as dist -from einops import rearrange from torch import nn -from transformers.modeling_outputs import BaseModelOutputWithPast - -from ..kernels import apply_rotary_pos_emb, fill_kv_cache, paged_attention_fwd -from ..weight_loader.dist_utils import (colwise_parallelize_linear, - rowwise_parallelize_linear) - - -class PatchedInternLM2Attention(nn.Module): - - def _load_weights(self, loader, rank: int, world_size: int, - device: torch.device): - """load weights.""" - for mod_name in ['wqkv']: - colwise_parallelize_linear(getattr(self, mod_name), - loader, - rank=rank, - world_size=world_size, - prefix=mod_name) - for mod_name in ['wo']: - rowwise_parallelize_linear(getattr(self, mod_name), - loader, - rank=rank, - world_size=world_size, - prefix=mod_name) - - @classmethod - def _distribute_output_fn(cls, outputs, **kwargs): - """Distribution output hook.""" - dist.all_reduce(outputs[0]) - return outputs - - def _contiguous_batching_forward_impl( - self, - hidden_states: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - world_size: int = 1, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - """Rewrite implementation of LlamaAttention.forward. - - Add continuous batching support. Add paged attention support. TP - support. - """ - context = self.context.context - q_start_loc = context.q_start_loc - q_seq_length = context.q_seq_length - kv_seq_length = context.kv_seq_length - block_offsets = context.block_offsets - max_q_seq_length = context.max_q_seq_length - position_ids_1d = context.position_ids_1d - max_kv_seq_length = context.max_kv_seq_length - - def __qkv_proj(hidden_states): - """qkv_proj.""" - qkv_states = self.wqkv(hidden_states) - qkv_states = rearrange( - qkv_states, - 'b q (h gs d) -> (b q) h gs d', - gs=2 + self.num_key_value_groups, - d=self.head_dim, - ) - query_states = qkv_states[..., :self.num_key_value_groups, :] - query_states = query_states.flatten(1, 2) - key_states = qkv_states[..., -2, :] - value_states = qkv_states[..., -1, :] - return query_states, key_states, value_states - - def __rotary_emb_fn(query_states, key_states, value_states): - """rotary embedding func.""" - # compat - if not hasattr(self, '_use_old_rotary_emb'): - import inspect - args = inspect.getargspec(self.rotary_emb.forward)[0] - self._use_old_rotary_emb = 'seq_len' in args - - if not hasattr(context, '_cos'): - if self._use_old_rotary_emb: - kwargs = dict(seq_len=max_kv_seq_length) - else: - kwargs = dict(position_ids=position_ids_1d[None]) - - cos, sin = self.rotary_emb(value_states.transpose(0, 1), - **kwargs) - context._cos = cos - context._sin = sin - else: - cos = context._cos - sin = context._sin - - if self._use_old_rotary_emb: - _position_ids_1d = position_ids_1d - else: - _position_ids_1d = torch.arange(0, - len(position_ids_1d), - device=query_states.device) - query_states, key_states = apply_rotary_pos_emb( - query_states, - key_states, - cos, - sin, - position_ids, - position_ids_1d=_position_ids_1d, - q_embed=query_states, - k_embed=key_states) - return query_states, key_states, value_states - - query_states, key_states, value_states = __qkv_proj(hidden_states) - - query_states, key_states, value_states = __rotary_emb_fn( - query_states, key_states, value_states) - - fill_kv_cache( - key_states, - value_states, - past_key_value[0], - past_key_value[1], - q_start_loc, - q_seq_length, - kv_seq_length=kv_seq_length, - max_q_seq_length=max_q_seq_length, - block_offsets=block_offsets, +from transformers.configuration_utils import PretrainedConfig + +from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, EmbeddingType, + RMSNorm, SiluAndMul, build_rotary_embedding) +from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear, + build_qkv_proj, build_rowwise_linear) +from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight + + +class InternLM2Attention(nn.Module): + """Rewrite module of InternLM2Attention.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + num_heads = config.num_attention_heads + num_key_value_heads = config.num_key_value_heads + hidden_size = config.hidden_size + head_dim = hidden_size // num_heads + + # packed qkv + self.wqkv = build_qkv_proj( + hidden_size, + num_q_heads=num_heads, + num_kv_heads=num_key_value_heads, + head_size=head_dim, + bias=config.bias, + quant_config=quantization_config, + dtype=dtype, + device=device, ) - attn_output = query_states - paged_attention_fwd( - query_states, - past_key_value[0], - past_key_value[1], - attn_output, - block_offsets, - q_start_loc=q_start_loc, - q_seqlens=q_seq_length, - kv_seqlens=kv_seq_length, - max_seqlen=max_q_seq_length, - ) - attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1) + # rotary embedding + self.apply_rotary_pos_emb = ApplyRotaryEmb() - attn_output = self.wo(attn_output) - - return attn_output, None, past_key_value - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - **kwargs - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - """forward.""" - world_size = 1 - if dist.is_initialized(): - world_size = dist.get_world_size() - return self._contiguous_batching_forward_impl( - hidden_states, - position_ids, - past_key_value, - output_attentions, - world_size=world_size, + # attention + self.attn_fwd = Attention( + num_heads, + head_dim, + num_kv_heads=num_key_value_heads, ) + # o_proj + self.wo = build_rowwise_linear(num_heads * head_dim, + hidden_size, + bias=config.bias, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) -class PatchedInternLM2AttentionAscend(nn.Module): - - def _load_weights(self, loader, rank: int, world_size: int, - device: torch.device): - """load weights.""" - for mod_name in ['wqkv']: - colwise_parallelize_linear( - getattr(self, mod_name), - loader, - rank=rank, - world_size=world_size, - prefix=mod_name, - ) - for mod_name in ['wo']: - rowwise_parallelize_linear( - getattr(self, mod_name), - loader, - rank=rank, - world_size=world_size, - prefix=mod_name, - ) - - @classmethod - def _distribute_output_fn(cls, outputs, **kwargs): - """Distribution output hook.""" - dist.all_reduce(outputs[0]) - return outputs - - def _contiguous_batching_forward_impl( + def forward( self, hidden_states: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - world_size: int = 1, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - """Rewrite implementation of LlamaAttention.forward. - - Add continuous batching support. Add paged attention support. TP - support. - """ - context = self.context.context - q_start_loc = context.q_start_loc - q_seq_length = context.q_seq_length - kv_seq_length = context.kv_seq_length - block_offsets = context.block_offsets - max_q_seq_length = context.max_q_seq_length - position_ids_1d = context.position_ids_1d - max_kv_seq_length = context.max_kv_seq_length - - def __qkv_proj(hidden_states): - """qkv_proj.""" - qkv_states = self.wqkv(hidden_states) - qkv_states = rearrange( - qkv_states, - 'b q (h gs d) -> (b q) h gs d', - gs=2 + self.num_key_value_groups, - d=self.head_dim, - ) - query_states = qkv_states[..., :self.num_key_value_groups, :] - query_states = query_states.flatten(1, 2) - key_states = qkv_states[..., -2, :] - value_states = qkv_states[..., -1, :] - return query_states, key_states, value_states - - def __rotary_emb_fn(query_states, key_states, value_states): - """rotary embedding func.""" - # compat - if not hasattr(self, '_use_old_rotary_emb'): - import inspect - - args = inspect.getargspec(self.rotary_emb.forward)[0] - self._use_old_rotary_emb = 'seq_len' in args - - if not hasattr(context, '_cos'): - if self._use_old_rotary_emb: - kwargs = dict(seq_len=max_kv_seq_length) - else: - kwargs = dict(position_ids=position_ids_1d[None]) - - cos, sin = self.rotary_emb(value_states.transpose(0, 1), - **kwargs) - context._cos = cos - context._sin = sin - else: - cos = context._cos - sin = context._sin - - if self._use_old_rotary_emb: - _position_ids_1d = position_ids_1d - else: - _position_ids_1d = torch.arange(0, - len(position_ids_1d), - device=query_states.device) - query_states, key_states = apply_rotary_pos_emb( - query_states, - key_states, - cos, - sin, - position_ids, - position_ids_1d=_position_ids_1d, - q_embed=query_states, - k_embed=key_states, - context=context, - ) - return query_states, key_states, value_states - - query_states, key_states, value_states = __qkv_proj(hidden_states) - - query_states, key_states, value_states = __rotary_emb_fn( - query_states, key_states, value_states) - - fill_kv_cache( + attn_metadata: Any = None, + ): + """Rewrite of InternLM2Attention.forward.""" + # qkv proj + qkv_states = self.wqkv(hidden_states) + # (-1, heads, head_dim) + qkv_states = qkv_states.flatten(0, -2) + query_states, key_states, value_states = self.wqkv.split_qkv( + qkv_states) + + # apply rotary embedding + cos, sin = rotary_pos_emb + query_states, key_states = self.apply_rotary_pos_emb( + query_states, key_states, - value_states, - past_key_value[0], - past_key_value[1], - q_start_loc, - q_seq_length, - kv_seq_length=kv_seq_length, - max_q_seq_length=max_q_seq_length, - block_offsets=block_offsets, - context=context, + cos, + sin, + inplace=True, ) - attn_output = query_states - paged_attention_fwd( + # attention + attn_output = self.attn_fwd( query_states, key_states, value_states, past_key_value[0], past_key_value[1], - attn_output, - block_offsets, - q_start_loc=q_start_loc, - q_seqlens=q_seq_length, - kv_seqlens=kv_seq_length, - max_seqlen=max_q_seq_length, - context=context, + attn_metadata, + inplace=True, ) attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1) + # o proj attn_output = self.wo(attn_output) + return attn_output + + +class InternLM2MLP(nn.Module): + """mlp.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + # gate up + self.gate_up_proj = build_merged_colwise_linear( + config.hidden_size, + [config.intermediate_size, config.intermediate_size], + bias=False, + dtype=dtype, + device=device, + quant_config=quantization_config, + is_tp=True, + ) + + # silu and mul + self.act_fn = SiluAndMul(inplace=True) + + # down + self.w2 = build_rowwise_linear(config.intermediate_size, + config.hidden_size, + bias=False, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) - return attn_output, None, past_key_value + def forward(self, x): + """forward.""" + gate_up = self.gate_up_proj(x) + act = self.act_fn(gate_up) + return self.w2(act) + + +class InternLM2DecoderLayer(nn.Module): + """decoder layer.""" + + def __init__(self, + config: PretrainedConfig, + layer_idx: int, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.layer_idx = layer_idx + quantization_config = getattr(config, 'quantization_config', None) + + # build attention layer + self.attention = InternLM2Attention(config, dtype=dtype, device=device) + + # builf MLP + self.feed_forward = InternLM2MLP(config, dtype=dtype, device=device) + + # build input layer norm + self.attention_norm = RMSNorm(config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) + + # build attention layer norm + self.ffn_norm = RMSNorm(config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - """forward.""" - world_size = 1 - if dist.is_initialized(): - world_size = dist.get_world_size() - return self._contiguous_batching_forward_impl( - hidden_states, - position_ids, - past_key_value, - output_attentions, - world_size=world_size, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], + past_key_value: Optional[List[torch.FloatTensor]], + residual: Optional[torch.Tensor] = None, + attn_metadata: Any = None, + ): + + if residual is None: + residual = hidden_states + hidden_states = self.attention_norm(hidden_states) + else: + hidden_states, residual = self.attention_norm( + hidden_states, residual) + + # Self Attention + hidden_states = self.attention( + hidden_states=hidden_states, + rotary_pos_emb=rotary_pos_emb, + past_key_value=past_key_value, + attn_metadata=attn_metadata, ) + # Fully Connected + hidden_states, residual = self.ffn_norm(hidden_states, residual) + hidden_states = self.feed_forward(hidden_states) -class PatchedInternLM2MLP(nn.Module): - - def _load_weights(self, loader, rank: int, world_size: int, - device: torch.device): - """load weights.""" - for mod_name in ['w1', 'w3']: - colwise_parallelize_linear(getattr(self, mod_name), - loader, - rank=rank, - world_size=world_size, - prefix=mod_name) - for mod_name in ['w2']: - rowwise_parallelize_linear(getattr(self, mod_name), - loader, - rank=rank, - world_size=world_size, - prefix=mod_name) - - @classmethod - def _distribute_output_fn(cls, outputs, **kwargs): - """Distribution output hook.""" - dist.all_reduce(outputs) + outputs = (hidden_states, residual) return outputs -class PatchedInternLM2Model(nn.Module): +class InternLM2Model(nn.Module): + """internlm2 model.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + quantization_config = getattr(config, 'quantization_config', None) + + self.tok_embeddings = nn.Embedding(config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=dtype, + device=device) + + # build all decode layers + self.layers = nn.ModuleList([ + InternLM2DecoderLayer(config, + layer_idx, + dtype=dtype, + device=device) + for layer_idx in range(config.num_hidden_layers) + ]) + + # build norm + self.norm = RMSNorm(config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) + + # build rotary embedding in Model + rope_scaling = config.rope_scaling + scaling_factor = 1.0 + emb_type = EmbeddingType.LinearScaling + if rope_scaling is not None: + scaling_factor = rope_scaling.get('factor', scaling_factor) + rope_type = rope_scaling['type'] + if rope_type == 'linear': + emb_type = EmbeddingType.LinearScaling + if rope_type == 'dynamic': + emb_type = EmbeddingType.DynamicNTKScaling + else: + raise RuntimeError(f'Unsupported rope type: {rope_type}') + rope_dim = config.hidden_size // config.num_attention_heads + rope_max_pos_emb = config.max_position_embeddings + rope_base = config.rope_theta + self.rotary_emb = build_rotary_embedding( + rope_dim, + rope_max_pos_emb, + rope_base, + scaling_factor, + emb_type=emb_type, + ) - def _continuous_batching_forward( + def forward( self, input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + attn_metadata: Any = None, inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - """Rewrite implementation of LlamaModel.forward.""" - context = self.context.context - # get inputs from context - vision_embeddings = context.input_embeddings - vision_embedding_indexing = context.input_embedding_indexing + ): + """Rewrite of forward.""" + # token embedding if inputs_embeds is None: inputs_embeds = self.tok_embeddings(input_ids) - if vision_embeddings is not None and len(vision_embeddings) > 0: - inputs_embeds[:, - vision_embedding_indexing, :] = vision_embeddings.to( - inputs_embeds) - - # Attention mask is not necessary in continuous batching - attention_mask = None hidden_states = inputs_embeds - # decoder layers + # rotary embedding + cos, sin = self.rotary_emb(hidden_states, position_ids) + cos, sin = cos[0], sin[0] + rotary_pos_emb = (cos, sin) + + # decoding + residual = None for idx, decoder_layer in enumerate(self.layers): - past_key_value = (past_key_values[idx] - if past_key_values is not None else None) - layer_outputs = decoder_layer( + past_key_value = past_key_values[idx] + hidden_states, residual = decoder_layer( hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, + rotary_pos_emb=rotary_pos_emb, past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, + residual=residual, + attn_metadata=attn_metadata, ) - hidden_states = layer_outputs[0] - hidden_states = self.norm(hidden_states) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=past_key_value, - hidden_states=None, - attentions=None, - ) + # norm + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + def get_input_embeddings(self): + """get input embeddings.""" + return self.tok_embeddings + + +class InternLM2ForCausalLM(nn.Module): + """rewrote model of InternLM2ForCausalLM.""" + + support_cuda_graph = True + + packed_modules_mapping = { + 'gate_up_proj': [ + 'w1', + 'w3', + ], + } + + def __init__(self, + config: PretrainedConfig, + ctx_mgr: StepContextManager, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.ctx_mgr = ctx_mgr + # build Model + self.model = InternLM2Model(config, dtype=dtype, device=device) + # build lm_head + self.output = build_rowwise_linear(config.hidden_size, + config.vocab_size, + bias=False, + dtype=dtype, + device=device) def forward( self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: List[List[torch.Tensor]], + attn_metadata: Any = None, + inputs_embeds: torch.Tensor = None, **kwargs, - ) -> Union[Tuple, BaseModelOutputWithPast]: - """Rewrite of LlamaModel.forward.""" - return self._continuous_batching_forward( - input_ids, - attention_mask, - position_ids, - past_key_values, - inputs_embeds, - use_cache, - output_attentions, + ): + """model forward, return logits.""" + hidden_states = self.model( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, ) + + logits = self.output(hidden_states) + logits = logits.float() + return logits + + def get_input_embeddings(self): + """get input embeddings.""" + return self.model.get_input_embeddings() + + def prepare_inputs_for_generation( + self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + context: StepContext = None, + ): + """prepare input.""" + # get input_ids, position_ids and attention metadatas + input_ids = context.input_ids + position_ids = context.position_ids + attn_metadata = context.attn_metadata + + # process vision embeddings + vision_embeddings = context.input_embeddings + vision_embedding_indexing = context.input_embedding_indexing + if vision_embeddings is not None and len(vision_embeddings) > 0: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + inputs_embeds[:, + vision_embedding_indexing, :] = vision_embeddings.to( + inputs_embeds) + + # inputs of forward + return dict( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """load weights.""" + # modify from vllm + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ('.gate_up_proj', '.w1', 0), + ('.gate_up_proj', '.w3', 1), + ] + + config = self.config + num_heads = config.num_attention_heads + num_key_value_heads = config.num_key_value_heads + hidden_size = config.hidden_size + head_dim = hidden_size // num_heads + num_key_value_groups = num_heads // num_key_value_heads + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if 'rotary_emb.inv_freq' in name: + continue + if ('rotary_emb.cos_cached' in name + or 'rotary_emb.sin_cached' in name): + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + load_weight(param, loaded_weight, shard_id=shard_id) + break + else: + if '.wqkv' in name: + loaded_weight = loaded_weight.unflatten( + 0, (-1, num_key_value_groups + 2, head_dim)) + q = loaded_weight[:, :num_key_value_groups].flatten(0, 2) + k = loaded_weight[:, -2].flatten(0, 1) + v = loaded_weight[:, -1].flatten(0, 1) + param = params_dict[name] + load_weight(param, q, shard_id='q') + load_weight(param, k, shard_id='k') + load_weight(param, v, shard_id='v') + else: + param = params_dict[name] + load_weight(param, loaded_weight) diff --git a/lmdeploy/pytorch/models/mistral.py b/lmdeploy/pytorch/models/mistral.py index 19c86d505..db06d7d29 100644 --- a/lmdeploy/pytorch/models/mistral.py +++ b/lmdeploy/pytorch/models/mistral.py @@ -1,176 +1,436 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Optional, Tuple +from typing import Any, Iterable, List, Optional, Tuple import torch -import torch.distributed as dist from torch import nn +from transformers.configuration_utils import PretrainedConfig -from ..kernels import apply_rotary_pos_emb -from ..kernels.fill_kv_cache import fill_kv_cache -from ..kernels.pagedattention import paged_attention_fwd -from ..weight_loader.dist_utils import (colwise_parallelize_linear, - rowwise_parallelize_linear) +from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, EmbeddingType, + RMSNorm, SiluAndMul, build_rotary_embedding) +from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear, + build_qkv_proj, build_rowwise_linear) +from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight -class MistralFlashAttention2(nn.Module): +class MistralAttention(nn.Module): + """Rewrite module of MistralAttention.""" - def _load_weights(self, loader, rank: int, world_size: int, - device: torch.device): - """load weights.""" - for mod_name in ['q_proj', 'k_proj', 'v_proj']: - colwise_parallelize_linear(getattr(self, mod_name), - loader, - rank=rank, - world_size=world_size, - prefix=mod_name) - for mod_name in ['o_proj']: - rowwise_parallelize_linear(getattr(self, mod_name), - loader, - rank=rank, - world_size=world_size, - prefix=mod_name) - - @classmethod - def _distribute_output_fn(cls, outputs, **kwargs): - """Distribution output hook.""" - dist.all_reduce(outputs[0]) - return outputs + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + num_heads = config.num_attention_heads + num_key_value_heads = config.num_key_value_heads + hidden_size = config.hidden_size + head_dim = getattr(config, 'head_dim', hidden_size // num_heads) + + # packed qkv + self.qkv_proj = build_qkv_proj( + hidden_size, + num_q_heads=num_heads, + num_kv_heads=num_key_value_heads, + head_size=head_dim, + bias=False, + quant_config=quantization_config, + dtype=dtype, + device=device, + ) + + # rotary embedding + self.apply_rotary_pos_emb = ApplyRotaryEmb() - def _contiguous_batching_forward_impl( + # attention + self.attn_fwd = Attention( + num_heads, + head_dim, + num_kv_heads=num_key_value_heads, + v_head_size=head_dim, + sliding_window=config.sliding_window, + ) + + # o_proj + self.o_proj = build_rowwise_linear(num_heads * head_dim, + hidden_size, + bias=False, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward( self, hidden_states: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - attention_mask: Optional[torch.Tensor] = None, - world_size: int = 1, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - """Rewrite implementation of forward. - - Add continuous batching support. Add paged attention support. TP - support. - """ - context = self.context.context - kv_seq_length = context.kv_seq_length - q_seq_length = context.q_seq_length - q_start_loc = context.q_start_loc - block_offsets = context.block_offsets - max_q_seq_length = context.max_q_seq_length - max_kv_seq_length = context.max_kv_seq_length - - num_heads = self.num_heads // world_size - num_kv_heads = self.num_key_value_heads // world_size - head_dim = self.head_dim - hidden_size = num_heads * head_dim - - position_ids_1d = context.position_ids_1d - - def __qkv_proj(hidden_states): - """qkv proj.""" - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - return query_states, key_states, value_states - - def __rotary_emb_fn(query_states, key_states, value_states): - if hasattr(self, 'rotary_emb'): - if not hasattr(self, '_use_old_rotary_emb'): - import inspect - args = inspect.getargspec(self.rotary_emb.forward)[0] - self._use_old_rotary_emb = 'seq_len' in args - if not hasattr(context, '_cos'): - if self._use_old_rotary_emb: - kwargs = dict(seq_len=max_kv_seq_length) - else: - kwargs = dict(position_ids=position_ids_1d[None]) - - cos, sin = self.rotary_emb(value_states, **kwargs) - - context._cos = cos - context._sin = sin - else: - cos = context._cos - sin = context._sin - - if self._use_old_rotary_emb: - _position_ids_1d = position_ids_1d - else: - _position_ids_1d = torch.arange(0, - len(position_ids_1d), - device=query_states.device) - query_states, key_states = apply_rotary_pos_emb( - query_states, - key_states, - cos, - sin, - position_ids, - position_ids_1d=_position_ids_1d) - return query_states, key_states, value_states - - query_states, key_states, value_states = __qkv_proj(hidden_states) - - query_states = query_states.view(-1, num_heads, head_dim) - key_states = key_states.view(-1, num_kv_heads, head_dim) - value_states = value_states.view(-1, num_kv_heads, head_dim) - - query_states, key_states, value_states = __rotary_emb_fn( - query_states, key_states, value_states) - - fill_kv_cache( + attn_metadata: Any = None, + ): + """Rewrite of LlamaAttention.forward.""" + # qkv proj + qkv_states = self.qkv_proj(hidden_states) + # (-1, heads, head_dim) + qkv_states = qkv_states.flatten(0, -2) + query_states, key_states, value_states = self.qkv_proj.split_qkv( + qkv_states) + + # apply rotary embedding + cos, sin = rotary_pos_emb + query_states, key_states = self.apply_rotary_pos_emb( + query_states, key_states, - value_states, - past_key_value[0], - past_key_value[1], - q_start_loc, - q_seq_length, - kv_seq_length=kv_seq_length, - max_q_seq_length=max_q_seq_length, - block_offsets=block_offsets, + cos, + sin, + inplace=True, ) - attn_output = query_states - window_size = self.config.sliding_window - paged_attention_fwd( + # attention + attn_output = self.attn_fwd( query_states, + key_states, + value_states, past_key_value[0], past_key_value[1], - attn_output, - block_offsets, - q_start_loc=q_start_loc, - q_seqlens=q_seq_length, - kv_seqlens=kv_seq_length, - max_seqlen=max_q_seq_length, - window_size=window_size, + attn_metadata, + inplace=True, ) - attn_output = attn_output.reshape(*hidden_states.shape[:-1], - hidden_size) + attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1) + # o proj attn_output = self.o_proj(attn_output) + return attn_output + + +class MistralMLP(nn.Module): + """mlp.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + # gate up + self.gate_up_proj = build_merged_colwise_linear( + config.hidden_size, + [config.intermediate_size, config.intermediate_size], + bias=False, + dtype=dtype, + device=device, + quant_config=quantization_config, + is_tp=True, + ) + + # silu and mul + self.act_fn = SiluAndMul(inplace=True) + + # down + self.down_proj = build_rowwise_linear(config.intermediate_size, + config.hidden_size, + bias=False, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward(self, x): + """forward.""" + gate_up = self.gate_up_proj(x) + act = self.act_fn(gate_up) + return self.down_proj(act) + + +class MistralDecoderLayer(nn.Module): + """llama decoder layer.""" + + def __init__(self, + config: PretrainedConfig, + layer_idx: int, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.layer_idx = layer_idx + quantization_config = getattr(config, 'quantization_config', None) + + # build attention layer + self.self_attn = MistralAttention(config, dtype=dtype, device=device) + + # builf MLP + self.mlp = MistralMLP(config, dtype=dtype, device=device) + + # build input layer norm + self.input_layernorm = RMSNorm(config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) - return attn_output, None, past_key_value + # build attention layer norm + self.post_attention_layernorm = RMSNorm( + config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], + past_key_value: Optional[List[torch.FloatTensor]], + residual: Optional[torch.Tensor] = None, + attn_metadata: Any = None, + ): + + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + # Self Attention + hidden_states = self.self_attn( + hidden_states=hidden_states, + rotary_pos_emb=rotary_pos_emb, + past_key_value=past_key_value, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + + outputs = (hidden_states, residual) + return outputs + + +class MistralModel(nn.Module): + """model.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + quantization_config = getattr(config, 'quantization_config', None) + + self.embed_tokens = nn.Embedding(config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=dtype, + device=device) + + # build all decode layers + self.layers = nn.ModuleList([ + MistralDecoderLayer(config, layer_idx, dtype=dtype, device=device) + for layer_idx in range(config.num_hidden_layers) + ]) + + # build norm + self.norm = RMSNorm(config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) + + # build rotary embedding + emb_type = EmbeddingType.LinearScaling + rope_dim = config.hidden_size // config.num_attention_heads + rope_max_pos_emb = config.max_position_embeddings + rope_base = config.rope_theta + self.rotary_emb = build_rotary_embedding( + rope_dim, + rope_max_pos_emb, + rope_base, + emb_type=emb_type, + ) + + def forward( + self, + input_ids: torch.LongTensor = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, + past_key_values: Optional[List[torch.FloatTensor]] = None, + attn_metadata: Any = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ): + """Rewrite of LlamaModel.forward.""" + + # token embedding + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + hidden_states = inputs_embeds + + # rotary embedding + cos, sin = self.rotary_emb(hidden_states, position_ids) + cos, sin = cos[0], sin[0] + rotary_pos_emb = (cos, sin) + + # decoding + residual = None + for idx, decoder_layer in enumerate(self.layers): + past_key_value = past_key_values[idx] + hidden_states, residual = decoder_layer( + hidden_states, + rotary_pos_emb=rotary_pos_emb, + past_key_value=past_key_value, + residual=residual, + attn_metadata=attn_metadata, + ) + + # norm + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + def get_input_embeddings(self): + """get input embeddings.""" + return self.embed_tokens + + +class MistralForCausalLM(nn.Module): + """ModelForCausalLM.""" + + support_cuda_graph = True + + packed_modules_mapping = { + 'qkv_proj': [ + 'q_proj', + 'k_proj', + 'v_proj', + ], + 'gate_up_proj': [ + 'gate_proj', + 'up_proj', + ], + } + + def __init__(self, + config: PretrainedConfig, + ctx_mgr: StepContextManager, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.ctx_mgr = ctx_mgr + # build model + self.model = MistralModel(config, dtype=dtype, device=device) + # build lm_head + self.lm_head = build_rowwise_linear(config.hidden_size, + config.vocab_size, + bias=False, + dtype=dtype, + device=device) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: List[List[torch.Tensor]], + attn_metadata: Any = None, + inputs_embeds: torch.Tensor = None, **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - """Rewrite of forward.""" - world_size = 1 - if dist.is_initialized(): - world_size = dist.get_world_size() - return self._contiguous_batching_forward_impl( - hidden_states, - position_ids, - past_key_value, - output_attentions, - attention_mask=attention_mask, - world_size=world_size, + ): + """model forward, return logits.""" + hidden_states = self.model( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + ) + + logits = self.lm_head(hidden_states) + logits = logits.float() + return logits + + def get_input_embeddings(self): + """get input embeddings.""" + return self.model.get_input_embeddings() + + def prepare_inputs_for_generation( + self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + context: StepContext = None, + ): + """prepare input.""" + # get input_ids, position_ids and attention metadatas + input_ids = context.input_ids + position_ids = context.position_ids + attn_metadata = context.attn_metadata + + # process vision embeddings + vision_embeddings = context.input_embeddings + vision_embedding_indexing = context.input_embedding_indexing + if vision_embeddings is not None and len(vision_embeddings) > 0: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + inputs_embeds[:, + vision_embedding_indexing, :] = vision_embeddings.to( + inputs_embeds) + + # inputs of forward + return dict( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """load weights.""" + # modify from vllm + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ('.qkv_proj', '.q_proj', 'q'), + ('.qkv_proj', '.k_proj', 'k'), + ('.qkv_proj', '.v_proj', 'v'), + ('.gate_up_proj', '.gate_proj', 0), + ('.gate_up_proj', '.up_proj', 1), + ] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if 'rotary_emb.inv_freq' in name: + continue + if ('rotary_emb.cos_cached' in name + or 'rotary_emb.sin_cached' in name): + continue + if self.config.tie_word_embeddings and 'lm_head.weight' in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + load_weight(param, loaded_weight, shard_id=shard_id) + break + else: + param = params_dict[name] + load_weight(param, loaded_weight) + + +class LlavaMistralForCausalLM(MistralForCausalLM): + """llava forcausallm.""" + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """load weights.""" + + new_weights = dict() + for key, val in weights: + if key.startswith('model.vision_tower'): + continue + if key.startswith('model.mm_projector'): + continue + if key.startswith('model.image_newline'): + continue + new_weights[key] = val + + super().load_weights(new_weights.items()) diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index 7b223d8d9..5fa231e03 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -34,32 +34,14 @@ # internlm2 MODULE_MAP.update({ - 'modeling_internlm2.InternLM2Attention': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.internlm2.PatchedInternLM2Attention', - 'modeling_internlm2.InternLM2FlashAttention2': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.internlm2.PatchedInternLM2Attention', - 'modeling_internlm2.InternLM2Model': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.internlm2.PatchedInternLM2Model', - 'modeling_internlm2.InternLM2MLP': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.internlm2.PatchedInternLM2MLP', - 'modeling_internlm2.InternLM2RMSNorm': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaRMSNorm', + 'InternLM2ForCausalLM': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.internlm2.InternLM2ForCausalLM', }) # mistral MODULE_MAP.update({ - 'transformers.models.mistral.modeling_mistral.MistralAttention': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mistral.MistralFlashAttention2', - 'transformers.models.mistral.modeling_mistral.MistralFlashAttention2': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mistral.MistralFlashAttention2', - 'transformers.models.mistral.modeling_mistral.MistralSdpaAttention': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mistral.MistralFlashAttention2', - 'transformers.models.mistral.modeling_mistral.MistralModel': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaModel', - 'transformers.models.mistral.modeling_mistral.MistralMLP': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaMLP', - 'transformers.models.mistral.modeling_mistral.MistralRMSNorm': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaRMSNorm', + 'MistralForCausalLM': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mistral.MistralForCausalLM', }) # gemma @@ -160,14 +142,6 @@ f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaRMSNorm', }) -# peft -MODULE_MAP.update({ - 'peft.tuners.lora.layer.Linear': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.peft.LoRALinear', - 'peft.tuners.lora.awq.AwqLoraLinear': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.peft.LoRALinear' -}) - # mixtral MODULE_MAP.update({ 'MixtralForCausalLM': @@ -245,8 +219,8 @@ { 'LlavaLlamaForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlavaLlamaForCausalLM', - 'llava.model.language_model.llava_mistral.LlavaMistralForCausalLM': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llava.PatchedLlavaLlamaForCausalLM', + 'LlavaMistralForCausalLM': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mistral.LlavaMistralForCausalLM', 'LlavaForConditionalGeneration': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llava.LlavaForConditionalGeneration', # noqa: E501 'LlavaNextForConditionalGeneration': # noqa: E501 @@ -259,12 +233,6 @@ f'{LMDEPLOY_PYTORCH_MODEL_PATH}.internvl.PatchedInternVLChatModel' }) -# awq -MODULE_MAP.update({ - 'awq.modules.linear.gemm.WQLinear_GEMM': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.awq_modules.PatchedWQLinear_GEMM' -}) - # phi3 vision MODULE_MAP.update({ 'modeling_phi3_v.Phi3Attention': @@ -286,27 +254,6 @@ DEVICE_SPECIAL_MODULE_MAP = dict(ascend=ASCEND_MODULE_MAP) -# ascend llama -ASCEND_MODULE_MAP.update({ - 'transformers.models.llama.modeling_llama.LlamaFlashAttention2': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaAttentionAscend', - 'transformers.models.llama.modeling_llama.LlamaSdpaAttention': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaAttentionAscend', - 'transformers.models.llama.modeling_llama.LlamaAttention': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaAttentionAscend', - # support modeling rewritten in lmdeploy - 'modeling_llama.LlamaAttention': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaAttentionAscend', -}) - -# ascend internlm2 -ASCEND_MODULE_MAP.update({ - 'modeling_internlm2.InternLM2Attention': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.internlm2.PatchedInternLM2AttentionAscend', - 'modeling_internlm2.InternLM2FlashAttention2': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.internlm2.PatchedInternLM2AttentionAscend', -}) - # phi-3.5-moe MODULE_MAP.update({ 'modeling_phimoe.PhiMoEAttention': From 8550f041585e7d0255f8cc9714aaa7b8af9cc192 Mon Sep 17 00:00:00 2001 From: grimoire Date: Fri, 30 Aug 2024 16:18:15 +0800 Subject: [PATCH 52/89] add gemma/gemma2 --- lmdeploy/pytorch/models/gemma.py | 627 ++++++++++++++++++-------- lmdeploy/pytorch/models/module_map.py | 28 +- 2 files changed, 439 insertions(+), 216 deletions(-) diff --git a/lmdeploy/pytorch/models/gemma.py b/lmdeploy/pytorch/models/gemma.py index c7bbbf1af..3c02e8c3c 100644 --- a/lmdeploy/pytorch/models/gemma.py +++ b/lmdeploy/pytorch/models/gemma.py @@ -1,241 +1,484 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import List, Optional, Tuple, Union +import math +from typing import Any, Iterable, List, Optional, Tuple import torch -import torch.distributed as dist from torch import nn -from transformers.modeling_outputs import BaseModelOutputWithPast - -from ..kernels import fill_kv_cache, fused_rotary_emb, paged_attention_fwd -from ..weight_loader.dist_utils import (colwise_parallelize_linear, - rowwise_parallelize_linear) - - -class PatchedGemmaRMSNorm(nn.Module): - """Rewrite RMSNorm.""" - - def forward(self, x): - """forward.""" - # torch.nn.functional.normalize based implementation might leads - # to wrong output - from ..kernels import rms_norm - ret = rms_norm(x.contiguous(), self.weight + 1, self.eps) - - return ret - - -def _make_inv_freq(self, device: torch.device): - if self.inv_freq is None: - self.inv_freq = 1.0 / (self.base**(torch.arange( - 0, self.dim, 2, dtype=torch.int64, device=device).float() / - self.dim)) - - -class PatchedGemmaAttention(nn.Module): +from transformers.configuration_utils import PretrainedConfig + +from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, EmbeddingType, + RMSNorm, build_rotary_embedding) +from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear, + build_qkv_proj, build_rowwise_linear) +from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight + + +class GemmaAttention(nn.Module): + """Rewrite module of GemmaAttention.""" + + def __init__(self, + config: PretrainedConfig, + layer_idx: int, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.layer_idx = layer_idx + quantization_config = getattr(config, 'quantization_config', None) + num_heads = config.num_attention_heads + num_key_value_heads = config.num_key_value_heads + hidden_size = config.hidden_size + head_dim = config.head_dim + + # packed qkv + self.qkv_proj = build_qkv_proj( + hidden_size, + num_q_heads=num_heads, + num_kv_heads=num_key_value_heads, + head_size=head_dim, + bias=config.attention_bias, + quant_config=quantization_config, + dtype=dtype, + device=device, + ) - def _load_weights(self, loader, rank: int, world_size: int, - device: torch.device): - """load weights.""" - for mod_name in ['q_proj', 'k_proj', 'v_proj']: - colwise_parallelize_linear(getattr(self, mod_name), - loader, - rank=rank, - world_size=world_size, - prefix=mod_name) - rowwise_parallelize_linear(self.o_proj, - loader, - rank=rank, - world_size=world_size, - prefix='o_proj') - - @classmethod - def _distribute_output_fn(cls, outputs, **kwargs): - """Distribution output hook.""" - dist.all_reduce(outputs[0]) - return outputs + # rotary embedding + self.apply_rotary_pos_emb = ApplyRotaryEmb() + + # attention + self.scaling = 1 / math.sqrt(config.head_dim) + if hasattr(config, 'query_pre_attn_scalar'): + self.scaling = config.query_pre_attn_scalar**-0.5 + self.sliding_window = (getattr(config, 'sliding_window', -1) + if not bool(layer_idx % 2) else -1) + logit_softcapping = getattr(config, 'attn_logit_softcapping', None) + self.attn_fwd = Attention(num_heads, + head_dim, + scale=self.scaling, + num_kv_heads=num_key_value_heads, + sliding_window=self.sliding_window, + logit_softcapping=logit_softcapping) + + # o_proj + self.o_proj = build_rowwise_linear(num_heads * head_dim, + hidden_size, + bias=config.attention_bias, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) - def _contiguous_batching_forward_impl( + def forward( self, hidden_states: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - attention_mask: Optional[torch.Tensor] = None, - world_size: int = 1, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - """Rewrite implementation of forward. - - Add continuous batching support. Add paged attention support. TP - support. - """ - context = self.context.context - kv_seq_length = context.kv_seq_length - q_seq_length = context.q_seq_length - q_start_loc = context.q_start_loc - block_offsets = context.block_offsets - max_q_seq_length = context.max_q_seq_length - - num_heads = self.num_heads // world_size - num_kv_heads = self.num_key_value_heads // world_size - head_dim = self.head_dim - hidden_size = num_heads * head_dim - - def __qkv_proj(hidden_states): - """qkv proj.""" - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - return query_states, key_states, value_states - - def __rotary_emb_fn(query_states, key_states, value_states): - scaling_factor = 1.0 - _make_inv_freq(self.rotary_emb, query_states.device) - inv_freq = self.rotary_emb.inv_freq - query_states, key_states = fused_rotary_emb( - query_states[None], - key_states[None], - context.position_ids_1d[None], - inv_freq=inv_freq, - scaling_factor=scaling_factor, - out_q=query_states[None], - out_k=key_states[None]) - return query_states[0], key_states[0], value_states - - query_states, key_states, value_states = __qkv_proj(hidden_states) - - query_states = query_states.view(-1, num_heads, head_dim) - key_states = key_states.view(-1, num_kv_heads, head_dim) - value_states = value_states.view(-1, num_kv_heads, head_dim) - - query_states, key_states, value_states = __rotary_emb_fn( - query_states, key_states, value_states) - - fill_kv_cache( + attn_metadata: Any = None, + ): + """Rewrite of LlamaAttention.forward.""" + # qkv proj + qkv_states = self.qkv_proj(hidden_states) + # (-1, heads, head_dim) + qkv_states = qkv_states.flatten(0, -2) + query_states, key_states, value_states = self.qkv_proj.split_qkv( + qkv_states) + + # apply rotary embedding + cos, sin = rotary_pos_emb + query_states, key_states = self.apply_rotary_pos_emb( + query_states, key_states, - value_states, - past_key_value[0], - past_key_value[1], - q_start_loc, - q_seq_length, - kv_seq_length=kv_seq_length, - max_q_seq_length=max_q_seq_length, - block_offsets=block_offsets, + cos, + sin, + inplace=True, ) - logit_softcapping = getattr(self.config, 'attn_logit_softcapping', - None) - window_size = getattr(self, 'sliding_window', None) - sm_scale = getattr(self, 'scaling', None) - attn_output = query_states - paged_attention_fwd( + # attention + attn_output = self.attn_fwd( query_states, + key_states, + value_states, past_key_value[0], past_key_value[1], - attn_output, - block_offsets, - q_start_loc=q_start_loc, - q_seqlens=q_seq_length, - kv_seqlens=kv_seq_length, - max_seqlen=max_q_seq_length, - window_size=window_size, - sm_scale=sm_scale, - logit_softcapping=logit_softcapping, + attn_metadata, + inplace=True, ) - attn_output = attn_output.reshape(*hidden_states.shape[:-1], - hidden_size) + attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1) + # o proj attn_output = self.o_proj(attn_output) + return attn_output + + +class GemmaMLP(nn.Module): + """mlp.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + # gate up + self.gate_up_proj = build_merged_colwise_linear( + config.hidden_size, + [config.intermediate_size, config.intermediate_size], + bias=False, + dtype=dtype, + device=device, + quant_config=quantization_config, + is_tp=True, + ) - return attn_output, None, past_key_value + hidden_activation = config.hidden_activation + if hidden_activation is None: + hidden_activation = 'gelu_pytorch_tanh' + assert hidden_activation == 'gelu_pytorch_tanh' + self.act_fn = nn.GELU(approximate='tanh') + + # down + self.down_proj = build_rowwise_linear(config.intermediate_size, + config.hidden_size, + bias=False, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward(self, x): + """forward.""" + gate_up = self.gate_up_proj(x) + gate, up = gate_up.chunk(2, -1) + act = self.act_fn(gate) * up + out = self.down_proj(act) + return out + + +class GemmaDecoderLayer(nn.Module): + """llama decoder layer.""" + + def __init__(self, + config: PretrainedConfig, + layer_idx: int, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.layer_idx = layer_idx + quantization_config = getattr(config, 'quantization_config', None) + + # build attention layer + self.self_attn = GemmaAttention(config, + layer_idx, + dtype=dtype, + device=device) + + # builf MLP + self.mlp = GemmaMLP(config, dtype=dtype, device=device) + + # build input layer norm + self.input_layernorm = RMSNorm(config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) + + # build attention layer norm + self.post_attention_layernorm = RMSNorm( + config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) + + self.model_type = config.model_type + if self.model_type == 'gemma2': + self.pre_feedforward_layernorm = RMSNorm( + config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) + self.post_feedforward_layernorm = RMSNorm( + config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - """Rewrite of forward.""" - world_size = 1 - if dist.is_initialized(): - world_size = dist.get_world_size() - return self._contiguous_batching_forward_impl( - hidden_states, - position_ids, - past_key_value, - output_attentions, - attention_mask=attention_mask, - world_size=world_size, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], + past_key_value: Optional[List[torch.FloatTensor]], + residual: Optional[torch.Tensor] = None, + attn_metadata: Any = None, + ): + + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + # Self Attention + hidden_states = self.self_attn( + hidden_states=hidden_states, + rotary_pos_emb=rotary_pos_emb, + past_key_value=past_key_value, + attn_metadata=attn_metadata, ) + # Fully Connected + + if self.model_type == 'gemma2': + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, residual = self.pre_feedforward_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + else: + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + + outputs = (hidden_states, residual) + return outputs -class PatchedGemmaModel(nn.Module): - def _continuous_batching_forward( +class GemmaModel(nn.Module): + """model.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + quantization_config = getattr(config, 'quantization_config', None) + + self.embed_tokens = nn.Embedding(config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=dtype, + device=device) + + # build all decode layers + self.layers = nn.ModuleList([ + GemmaDecoderLayer(config, layer_idx, dtype=dtype, device=device) + for layer_idx in range(config.num_hidden_layers) + ]) + + # build norm + self.norm = RMSNorm(config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) + + # build rotary embedding + rope_scaling = getattr(config, 'rope_scaling', None) + emb_type = EmbeddingType.LinearScaling + scaling_factor = 1.0 + if rope_scaling is not None: + rope_type = rope_scaling['rope_type'] + if rope_type == 'linear': + emb_type = EmbeddingType.LinearScaling + if rope_type == 'dynamic': + emb_type = EmbeddingType.DynamicNTKScaling + else: + raise RuntimeError(f'Unsupported rope type: {rope_type}') + scaling_factor = rope_scaling.get('scaling_factor', scaling_factor) + + rope_dim = config.head_dim + rope_max_pos_emb = config.max_position_embeddings + rope_base = config.rope_theta + self.rotary_emb = build_rotary_embedding( + rope_dim, + rope_max_pos_emb, + rope_base, + scaling_factor, + emb_type=emb_type, + ) + + def forward( self, input_ids: torch.LongTensor = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + attn_metadata: Any = None, inputs_embeds: Optional[torch.FloatTensor] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - """Rewrite implementation of LlamaModel.forward.""" - output_attentions = False - use_cache = True - # Attention mask is not necessary in continuous batching - attention_mask = None + ): + """Rewrite of LlamaModel.forward.""" + # token embedding if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) hidden_states = inputs_embeds - # This is Gemma only! hidden_states = hidden_states * (self.config.hidden_size**0.5) - for idx, decoder_layer in enumerate(self.layers): - past_key_value = (past_key_values[idx] - if past_key_values is not None else None) - layer_outputs = decoder_layer( + # rotary embedding + cos, sin = self.rotary_emb(hidden_states, position_ids) + cos, sin = cos[0], sin[0] + rotary_pos_emb = (cos, sin) + + # decoding + residual = None + for idx, decoder_layer in enumerate(self.layers): + past_key_value = past_key_values[idx] + hidden_states, residual = decoder_layer( hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, + rotary_pos_emb=rotary_pos_emb, past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, + residual=residual, + attn_metadata=attn_metadata, ) - hidden_states = layer_outputs[0] - hidden_states = self.norm(hidden_states) + # norm + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + def get_input_embeddings(self): + """get input embeddings.""" + return self.embed_tokens + + +class GemmaForCausalLM(nn.Module): + """ModelForCausalLM.""" + + support_cuda_graph = True + + packed_modules_mapping = { + 'qkv_proj': [ + 'q_proj', + 'k_proj', + 'v_proj', + ], + 'gate_up_proj': [ + 'gate_proj', + 'up_proj', + ], + } + + def __init__(self, + config: PretrainedConfig, + ctx_mgr: StepContextManager, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.ctx_mgr = ctx_mgr + # build model + self.model = GemmaModel(config, dtype=dtype, device=device) + # build lm_head + self.lm_head = build_rowwise_linear(config.hidden_size, + config.vocab_size, + bias=False, + dtype=dtype, + device=device) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: List[List[torch.Tensor]], + attn_metadata: Any = None, + inputs_embeds: torch.Tensor = None, + **kwargs, + ): + """model forward, return logits.""" + hidden_states = self.model( + input_ids=input_ids, + position_ids=position_ids, past_key_values=past_key_values, - hidden_states=None, - attentions=None, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, ) - def forward( + logits = self.lm_head(hidden_states) + logits = logits.float() + return logits + + def get_input_embeddings(self): + """get input embeddings.""" + return self.model.get_input_embeddings() + + def prepare_inputs_for_generation( self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - **kwargs, - ) -> Union[Tuple, BaseModelOutputWithPast]: - """Rewrite of LlamaModel.forward.""" - return self._continuous_batching_forward( - input_ids, - position_ids, - past_key_values, - inputs_embeds, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + context: StepContext = None, + ): + """prepare input.""" + # get input_ids, position_ids and attention metadatas + input_ids = context.input_ids + position_ids = context.position_ids + attn_metadata = context.attn_metadata + + # process vision embeddings + vision_embeddings = context.input_embeddings + vision_embedding_indexing = context.input_embedding_indexing + if vision_embeddings is not None and len(vision_embeddings) > 0: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + inputs_embeds[:, + vision_embedding_indexing, :] = vision_embeddings.to( + inputs_embeds) + + # inputs of forward + return dict( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, ) + + def update_weights(self): + """update weights.""" + self.lm_head.weight = self.model.embed_tokens.weight + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """load weights.""" + # modify from vllm + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ('.qkv_proj', '.q_proj', 'q'), + ('.qkv_proj', '.k_proj', 'k'), + ('.qkv_proj', '.v_proj', 'v'), + ('.gate_up_proj', '.gate_proj', 0), + ('.gate_up_proj', '.up_proj', 1), + ] + norm_layers = [ + '.norm', '.input_layernorm', '.post_attention_layernorm', + 'pre_feedforward_layernorm', 'post_feedforward_layernorm' + ] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if 'rotary_emb.inv_freq' in name: + continue + if ('rotary_emb.cos_cached' in name + or 'rotary_emb.sin_cached' in name): + continue + if 'lm_head' in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + load_weight(param, loaded_weight, shard_id=shard_id) + break + else: + for weight_name in norm_layers: + if weight_name not in name: + continue + loaded_weight += 1 + break + param = params_dict[name] + load_weight(param, loaded_weight) diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index 5fa231e03..fe609fd8a 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -46,34 +46,14 @@ # gemma MODULE_MAP.update({ - 'transformers.models.gemma.modeling_gemma.GemmaAttention': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gemma.PatchedGemmaAttention', - 'transformers.models.gemma.modeling_gemma.GemmaFlashAttention2': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gemma.PatchedGemmaAttention', - 'transformers.models.gemma.modeling_gemma.GemmaSdpaAttention': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gemma.PatchedGemmaAttention', - 'transformers.models.gemma.modeling_gemma.GemmaModel': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gemma.PatchedGemmaModel', - 'transformers.models.gemma.modeling_gemma.GemmaMLP': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaMLP', - 'transformers.models.gemma.modeling_gemma.GemmaRMSNorm': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gemma.PatchedGemmaRMSNorm', + 'GemmaForCausalLM': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gemma.GemmaForCausalLM', }) # gemma2 MODULE_MAP.update({ - 'transformers.models.gemma2.modeling_gemma2.Gemma2Attention': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gemma.PatchedGemmaAttention', - 'transformers.models.gemma2.modeling_gemma2.Gemma2FlashAttention2': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gemma.PatchedGemmaAttention', - 'transformers.models.gemma2.modeling_gemma2.Gemma2SdpaAttention': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gemma.PatchedGemmaAttention', - 'transformers.models.gemma2.modeling_gemma2.Gemma2Model': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gemma.PatchedGemmaModel', - 'transformers.models.gemma2.modeling_gemma2.Gemma2MLP': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaMLP', - 'transformers.models.gemma2.modeling_gemma2.Gemma2RMSNorm': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gemma.PatchedGemmaRMSNorm', + 'Gemma2ForCausalLM': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gemma.GemmaForCausalLM', }) # deepseek From 6889cc69fd1c28828ae31285ee7a89424530a93d Mon Sep 17 00:00:00 2001 From: grimoire Date: Fri, 30 Aug 2024 17:43:09 +0800 Subject: [PATCH 53/89] add deepseek, qwen1 --- lmdeploy/pytorch/models/deepseek.py | 785 ++++++++++++++++++-------- lmdeploy/pytorch/models/mistral.py | 2 +- lmdeploy/pytorch/models/module_map.py | 28 +- lmdeploy/pytorch/models/qwen.py | 602 +++++++++++++------- 4 files changed, 938 insertions(+), 479 deletions(-) diff --git a/lmdeploy/pytorch/models/deepseek.py b/lmdeploy/pytorch/models/deepseek.py index 331c68968..187d081e9 100644 --- a/lmdeploy/pytorch/models/deepseek.py +++ b/lmdeploy/pytorch/models/deepseek.py @@ -1,279 +1,574 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple import torch import torch.distributed as dist from torch import nn +from transformers.configuration_utils import PretrainedConfig + +from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, EmbeddingType, + RMSNorm, SiluAndMul, build_rotary_embedding) +from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear, + build_qkv_proj, build_rowwise_linear) +from lmdeploy.pytorch.nn.moe import FusedMoE, SoftmaxTopK +from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight + + +def get_world_rank(): + """get current world size and rank.""" + import torch.distributed as dist + world_size = 1 + rank = 0 + + if dist.is_initialized(): + world_size = dist.get_world_size() + rank = dist.get_rank() + + return world_size, rank + + +class DeepseekAttention(nn.Module): + """Rewrite module of MistralAttention.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + num_heads = config.num_attention_heads + num_key_value_heads = config.num_key_value_heads + hidden_size = config.hidden_size + head_dim = getattr(config, 'head_dim', hidden_size // num_heads) + + # packed qkv + self.qkv_proj = build_qkv_proj( + hidden_size, + num_q_heads=num_heads, + num_kv_heads=num_key_value_heads, + head_size=head_dim, + bias=False, + quant_config=quantization_config, + dtype=dtype, + device=device, + ) -from lmdeploy.pytorch.kernels.fused_moe import fused_moe - -from ..kernels import apply_rotary_pos_emb, fill_kv_cache, paged_attention_fwd -from ..weight_loader.dist_utils import (colwise_parallelize_linear, - rowwise_parallelize_linear) + # rotary embedding + self.apply_rotary_pos_emb = ApplyRotaryEmb() + # attention + self.attn_fwd = Attention( + num_heads, + head_dim, + num_kv_heads=num_key_value_heads, + ) -class PatchedDeepseekAttention(nn.Module): + # o_proj + self.o_proj = build_rowwise_linear(num_heads * head_dim, + hidden_size, + bias=False, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) - def _load_weights(self, loader, rank: int, world_size: int, - device: torch.device): - """load weights.""" - for mod_name in ['q_proj', 'k_proj', 'v_proj']: - colwise_parallelize_linear(getattr(self, mod_name), - loader, - rank=rank, - world_size=world_size, - prefix=mod_name) - for mod_name in ['o_proj']: - rowwise_parallelize_linear(getattr(self, mod_name), - loader, - rank=rank, - world_size=world_size, - prefix=mod_name) - - @classmethod - def _distribute_output_fn(cls, outputs, **kwargs): - """Distribution output hook.""" - dist.all_reduce(outputs[0]) - return outputs - - def _contiguous_batching_forward_impl( + def forward( self, hidden_states: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - world_size: int = 1, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - """Rewrite implementation of forward. - - Add continuous batching support. Add paged attention support. TP - support. - """ - context = self.context.context - kv_seq_length = context.kv_seq_length - q_seq_length = context.q_seq_length - q_start_loc = context.q_start_loc - block_offsets = context.block_offsets - max_q_seq_length = context.max_q_seq_length - max_kv_seq_length = context.max_kv_seq_length - - num_heads = self.num_heads // world_size - num_kv_heads = self.num_key_value_heads // world_size - head_dim = self.head_dim - hidden_size = num_heads * head_dim - - def __qkv_proj(hidden_states): - """qkv proj.""" - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - return query_states, key_states, value_states - - def __rotary_emb_fn(query_states, key_states, value_states): - if hasattr(self, 'rotary_emb'): - if not hasattr(context, '_cos'): - cos, sin = self.rotary_emb(value_states, - seq_len=max_kv_seq_length) - context._cos = cos - context._sin = sin - else: - cos = context._cos - sin = context._sin - query_states, key_states = apply_rotary_pos_emb( - query_states, - key_states, - cos, - sin, - position_ids, - context.position_ids_1d, - q_embed=query_states, - k_embed=key_states) - return query_states, key_states, value_states - - query_states, key_states, value_states = __qkv_proj(hidden_states) - - query_states = query_states.view(-1, num_heads, head_dim) - key_states = key_states.view(-1, num_kv_heads, head_dim) - value_states = value_states.view(-1, num_kv_heads, head_dim) - - query_states, key_states, value_states = __rotary_emb_fn( - query_states, key_states, value_states) - - fill_kv_cache( + attn_metadata: Any = None, + ): + """Rewrite of LlamaAttention.forward.""" + # qkv proj + qkv_states = self.qkv_proj(hidden_states) + # (-1, heads, head_dim) + qkv_states = qkv_states.flatten(0, -2) + query_states, key_states, value_states = self.qkv_proj.split_qkv( + qkv_states) + + # apply rotary embedding + cos, sin = rotary_pos_emb + query_states, key_states = self.apply_rotary_pos_emb( + query_states, key_states, - value_states, - past_key_value[0], - past_key_value[1], - q_start_loc, - q_seq_length, - kv_seq_length=kv_seq_length, - max_q_seq_length=max_q_seq_length, - block_offsets=block_offsets, + cos, + sin, + inplace=True, ) - attn_output = query_states - paged_attention_fwd( + # attention + attn_output = self.attn_fwd( query_states, + key_states, + value_states, past_key_value[0], past_key_value[1], - attn_output, - block_offsets, - q_start_loc=q_start_loc, - q_seqlens=q_seq_length, - kv_seqlens=kv_seq_length, - max_seqlen=max_q_seq_length, + attn_metadata, + inplace=True, ) - attn_output = attn_output.reshape(*hidden_states.shape[:-1], - hidden_size) + attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1) + # o proj attn_output = self.o_proj(attn_output) + return attn_output + + +class DeepseekMoE(nn.Module): + """Deepseek MoE.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.hidden_dim = config.hidden_size + self.ffn_dim = config.moe_intermediate_size + self.num_experts = config.n_routed_experts + self.top_k = config.num_experts_per_tok + self.norm_topk_prob = config.norm_topk_prob + self.renormalize = self.top_k > 1 and self.norm_topk_prob + + self.gate = build_rowwise_linear( + self.hidden_dim, + self.num_experts, + bias=False, + dtype=dtype, + device=device, + is_tp=False, + ) + + self.softmax_topk = SoftmaxTopK(self.top_k) + + self.experts = FusedMoE( + self.hidden_dim, + self.ffn_dim, + self.num_experts, + top_k=self.top_k, + renormalize=self.renormalize, + dtype=dtype, + device=device, + all_reduce=False, + ) + + self.shared_experts = None + if config.n_shared_experts is not None: + intermediate_size = (config.moe_intermediate_size * + config.n_shared_experts) + self.shared_experts = DeepseekMLP( + config=config, + intermediate_size=intermediate_size, + dtype=dtype, + device=device, + is_tp=True, + all_reduce=False, + ) + world_size, _ = get_world_rank() + if world_size > 1: + self._all_reduce = True + else: + self._all_reduce = False + + def forward(self, hidden_states: torch.Tensor): + """forward.""" + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + router_logits = self.gate(hidden_states) - return attn_output, None, past_key_value + topk_weights, topk_ids = self.softmax_topk(router_logits) + out_states = self.experts( + hidden_states, + topk_weights, + topk_ids, + ) + + if self.shared_experts is not None: + shared_states = self.shared_experts(hidden_states) + out_states += shared_states + out_states = out_states.reshape(batch_size, sequence_length, -1) + + if self._all_reduce: + dist.all_reduce(out_states) + + return out_states + + +class DeepseekMLP(nn.Module): + """Deepseek mlp.""" + + def __init__(self, + config: Any, + intermediate_size: int = None, + dtype: torch.dtype = None, + device: torch.device = None, + is_tp: bool = True, + all_reduce: bool = True): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + # gate up + if intermediate_size is None: + intermediate_size = config.intermediate_size + self.gate_up_proj = build_merged_colwise_linear( + config.hidden_size, + [intermediate_size, intermediate_size], + bias=False, + dtype=dtype, + device=device, + quant_config=quantization_config, + is_tp=is_tp, + ) + + # silu and mul + self.act_fn = SiluAndMul(inplace=True) + + # down + self.down_proj = build_rowwise_linear( + intermediate_size, + config.hidden_size, + bias=False, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=is_tp, + all_reduce=all_reduce, + ) + + def forward(self, x): + """forward.""" + gate_up = self.gate_up_proj(x) + act = self.act_fn(gate_up) + return self.down_proj(act) + + +class DeepseekDecoderLayer(nn.Module): + """llama decoder layer.""" + + def __init__(self, + config: PretrainedConfig, + layer_idx: int, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.layer_idx = layer_idx + quantization_config = getattr(config, 'quantization_config', None) + + # build attention layer + self.self_attn = DeepseekAttention(config, dtype=dtype, device=device) + + # builf MLP + self.mlp = (DeepseekMoE(config, dtype=dtype, device=device) if + (config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0) else + DeepseekMLP(config, dtype=dtype, device=device)) + + # build input layer norm + self.input_layernorm = RMSNorm(config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) + + # build attention layer norm + self.post_attention_layernorm = RMSNorm( + config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - """forward.""" - world_size = 1 - if dist.is_initialized(): - world_size = dist.get_world_size() - return self._contiguous_batching_forward_impl( - hidden_states, - position_ids, - past_key_value, - output_attentions, - world_size=world_size, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], + past_key_value: Optional[List[torch.FloatTensor]], + residual: Optional[torch.Tensor] = None, + attn_metadata: Any = None, + ): + + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + # Self Attention + hidden_states = self.self_attn( + hidden_states=hidden_states, + rotary_pos_emb=rotary_pos_emb, + past_key_value=past_key_value, + attn_metadata=attn_metadata, ) + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) -def _div_up(a, b): - """div up.""" - return (a + b - 1) // b + outputs = (hidden_states, residual) + return outputs -class PatchedDeepseekMoE(nn.Module): +class DeepseekModel(nn.Module): + """model.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + quantization_config = getattr(config, 'quantization_config', None) + + self.embed_tokens = nn.Embedding(config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=dtype, + device=device) + + # build all decode layers + self.layers = nn.ModuleList([ + DeepseekDecoderLayer(config, layer_idx, dtype=dtype, device=device) + for layer_idx in range(config.num_hidden_layers) + ]) + + # build norm + self.norm = RMSNorm(config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) + + # build rotary embedding + rope_scaling = getattr(config, 'rope_scaling', None) + emb_type = EmbeddingType.LinearScaling + scaling_factor = 1.0 + if rope_scaling is not None: + rope_type = rope_scaling['type'] + if rope_type == 'linear': + emb_type = EmbeddingType.LinearScaling + if rope_type == 'dynamic': + emb_type = EmbeddingType.DynamicNTKScaling + else: + raise RuntimeError(f'Unsupported rope type: {rope_type}') + scaling_factor = rope_scaling.get('factor', scaling_factor) + + rope_dim = config.hidden_size // config.num_attention_heads + rope_max_pos_emb = config.max_position_embeddings + rope_base = config.rope_theta + self.rotary_emb = build_rotary_embedding( + rope_dim, + rope_max_pos_emb, + rope_base, + scaling_factor, + emb_type=emb_type, + ) - def _load_weights(self, loader, rank: int, world_size: int, - device: torch.device): - """load weights.""" + def forward( + self, + input_ids: torch.LongTensor = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + attn_metadata: Any = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ): + """Rewrite of LlamaModel.forward.""" + + # token embedding + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + hidden_states = inputs_embeds + + # rotary embedding + cos, sin = self.rotary_emb(hidden_states, position_ids) + cos, sin = cos[0], sin[0] + rotary_pos_emb = (cos, sin) + + # decoding + residual = None + for idx, decoder_layer in enumerate(self.layers): + past_key_value = past_key_values[idx] + hidden_states, residual = decoder_layer( + hidden_states, + rotary_pos_emb=rotary_pos_emb, + past_key_value=past_key_value, + residual=residual, + attn_metadata=attn_metadata, + ) + + # norm + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + def get_input_embeddings(self): + """get input embeddings.""" + return self.embed_tokens + + +class DeepseekForCausalLM(nn.Module): + """ModelForCausalLM.""" + + support_cuda_graph = True + + packed_modules_mapping = { + 'qkv_proj': [ + 'q_proj', + 'k_proj', + 'v_proj', + ], + 'gate_up_proj': [ + 'gate_proj', + 'up_proj', + ], + } + + def __init__(self, + config: PretrainedConfig, + ctx_mgr: StepContextManager, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.ctx_mgr = ctx_mgr + # build model + self.model = DeepseekModel(config, dtype=dtype, device=device) + # build lm_head + self.lm_head = build_rowwise_linear(config.hidden_size, + config.vocab_size, + bias=False, + dtype=dtype, + device=device) - def __load_mlp(exp_id, exp): - """load mlp.""" - with loader.prefix_context(f'experts.{exp_id}'): - loader.load_model_weights( - exp, - rank=rank, - world_size=world_size, - device=device, - load_only=True, - ) - - def __drop_mlp(exp_id, exp): - """drop mlp.""" - for name, _ in exp.named_parameters(recurse=True): - loader.pop(f'experts.{exp_id}.{name}') - - num_experts = len(self.experts) - exp_per_rank = _div_up(num_experts, world_size) - first_exp = rank * exp_per_rank - last_exp = min(num_experts, first_exp + exp_per_rank) - for exp_id, exp in enumerate(self.experts): - if first_exp <= exp_id < last_exp: - __load_mlp(exp_id, exp) + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: List[List[torch.Tensor]], + attn_metadata: Any = None, + inputs_embeds: torch.Tensor = None, + **kwargs, + ): + """model forward, return logits.""" + hidden_states = self.model( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + ) + + logits = self.lm_head(hidden_states) + logits = logits.float() + return logits + + def get_input_embeddings(self): + """get input embeddings.""" + return self.model.get_input_embeddings() + + def prepare_inputs_for_generation( + self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + context: StepContext = None, + ): + """prepare input.""" + # get input_ids, position_ids and attention metadatas + input_ids = context.input_ids + position_ids = context.position_ids + attn_metadata = context.attn_metadata + + # process vision embeddings + vision_embeddings = context.input_embeddings + vision_embedding_indexing = context.input_embedding_indexing + if vision_embeddings is not None and len(vision_embeddings) > 0: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + inputs_embeds[:, + vision_embedding_indexing, :] = vision_embeddings.to( + inputs_embeds) + + # inputs of forward + return dict( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + ) + + def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, + params_dict: Dict[str, nn.Parameter], + expert_params_mapping: List): + """load weight experts.""" + for (param_name, weight_name, expert_id, + shard_id) in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + load_weight(param, + loaded_weight, + expert_id=expert_id, + shard_id=shard_id) + break + else: + param = params_dict[name] + load_weight(param, loaded_weight) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """load weights.""" + # modify from vllm + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ('.qkv_proj', '.q_proj', 'q'), + ('.qkv_proj', '.k_proj', 'k'), + ('.qkv_proj', '.v_proj', 'v'), + ('.gate_up_proj', '.gate_proj', 0), + ('.gate_up_proj', '.up_proj', 1), + ] + + num_experts = self.config.n_routed_experts + expert_params_mapping = [] + for exp_id in range(num_experts): + gate_param = ('.experts.gate_up_weights', + f'.experts.{exp_id}.gate_proj.weight', exp_id, + 'gate') + up_param = ('.experts.gate_up_weights', + f'.experts.{exp_id}.up_proj.weight', exp_id, 'up') + down_param = ('.experts.down_weights', + f'.experts.{exp_id}.down_proj.weight', exp_id, + 'down') + expert_params_mapping += [gate_param, up_param, down_param] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if 'rotary_emb.inv_freq' in name: + continue + if ('rotary_emb.cos_cached' in name + or 'rotary_emb.sin_cached' in name): + continue + if self.config.tie_word_embeddings and 'lm_head.weight' in name: + continue + if '.experts' in name: + self._load_weight_experts( + name, + loaded_weight, + params_dict, + expert_params_mapping=expert_params_mapping) else: - __drop_mlp(exp_id, exp) - self.experts = self.experts[first_exp:last_exp] - with loader.prefix_context('gate'): - loader.load_model_weights(self.gate, - rank=rank, - world_size=world_size, - device=device) - - if self.config.n_shared_experts is not None: - with loader.prefix_context('shared_experts'): - loader.load_model_weights(self.shared_experts, - rank=rank, - world_size=world_size, - device=device) - - def _update_model_fn(self): - """update model.""" - num_experts = len(self.experts) - - def __get_meta(): - exp = self.experts[0] - ffn_dim = exp.gate_proj.weight.size(0) - hidden_dim = exp.down_proj.weight.size(0) - dtype = exp.gate_proj.weight.dtype - device = exp.gate_proj.weight.device - return ffn_dim, hidden_dim, dtype, device - - def __copy_assign_param(param, weight): - """copy assign.""" - weight.copy_(param.data) - param.data = weight - - ffn_dim, hidden_dim, dtype, device = __get_meta() - - gate_up_weights = torch.empty(num_experts, - ffn_dim * 2, - hidden_dim, - device=device, - dtype=dtype) - down_weights = torch.empty(num_experts, - hidden_dim, - ffn_dim, - device=device, - dtype=dtype) - - for exp_id, exp in enumerate(self.experts): - __copy_assign_param(exp.gate_proj.weight, - gate_up_weights[exp_id, :ffn_dim]) - __copy_assign_param(exp.up_proj.weight, gate_up_weights[exp_id, - ffn_dim:]) - __copy_assign_param(exp.down_proj.weight, down_weights[exp_id]) - - torch.cuda.empty_cache() - - self.register_buffer('gate_up_weights', gate_up_weights) - self.register_buffer('down_weights', down_weights) - - def forward(self, hidden_states): - """forward.""" - world_size = 1 - rank = 0 - if dist.is_initialized(): - world_size = dist.get_world_size() - rank = dist.get_rank() - exp_per_rank = self.gate_up_weights.size(0) - expert_offset = rank * exp_per_rank - - identity = hidden_states - orig_shape = hidden_states.shape - topk_idx, topk_weight, _ = self.gate(hidden_states) - hidden_states = hidden_states.flatten(0, 1) - flat_topk_idx = topk_idx.flatten() - y = fused_moe(hidden_states, - self.gate_up_weights, - self.down_weights, - topk_weight, - flat_topk_idx, - topk=self.num_experts_per_tok, - expert_offset=expert_offset, - num_experts=world_size * exp_per_rank, - renormalize=False).view(*orig_shape) - if self.config.n_shared_experts is not None: - y = y + self.shared_experts.forward(identity) - if dist.is_initialized(): - dist.all_reduce(y) - return y + for (param_name, weight_name, + shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + load_weight(param, loaded_weight, shard_id=shard_id) + break + else: + param = params_dict[name] + load_weight(param, loaded_weight) diff --git a/lmdeploy/pytorch/models/mistral.py b/lmdeploy/pytorch/models/mistral.py index db06d7d29..03240b7b9 100644 --- a/lmdeploy/pytorch/models/mistral.py +++ b/lmdeploy/pytorch/models/mistral.py @@ -142,7 +142,7 @@ def forward(self, x): class MistralDecoderLayer(nn.Module): - """llama decoder layer.""" + """decoder layer.""" def __init__(self, config: PretrainedConfig, diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index fe609fd8a..67a3fcfb5 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -58,34 +58,14 @@ # deepseek MODULE_MAP.update({ - 'modeling_deepseek.DeepseekAttention': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.deepseek.PatchedDeepseekAttention', - 'modeling_deepseek.DeepseekFlashAttention2': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.deepseek.PatchedDeepseekAttention', - 'modeling_deepseek.DeepseekSdpaAttention': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.deepseek.PatchedDeepseekAttention', - 'modeling_deepseek.DeepseekModel': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaModel', - 'modeling_deepseek.DeepseekMLP': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaMLP', - 'modeling_deepseek.DeepseekRMSNorm': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaRMSNorm', - 'modeling_deepseek.DeepseekMoE': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.deepseek.PatchedDeepseekMoE', + 'DeepseekForCausalLM': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.deepseek.DeepseekForCausalLM', }) # qwen MODULE_MAP.update({ - 'modeling_qwen.QWenAttention': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen.PatchedQWenAttention', - 'modeling_qwen.FlashSelfAttention': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen.PatchedQWenAttention', - 'modeling_qwen.QWenModel': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen.PatchedQWenModel', - 'modeling_qwen.QWenMLP': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen.PatchedQWenMLP', - 'modeling_qwen.RMSNorm': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen.PatchedRMSNorm', + 'QWenLMHeadModel': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen.QWenLMHeadModel', }) # qwen1.5 diff --git a/lmdeploy/pytorch/models/qwen.py b/lmdeploy/pytorch/models/qwen.py index 631bd4d78..2357251fb 100644 --- a/lmdeploy/pytorch/models/qwen.py +++ b/lmdeploy/pytorch/models/qwen.py @@ -1,254 +1,438 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import List, Optional, Tuple, Union +from typing import Any, Iterable, List, Optional, Tuple import torch -import torch.distributed as dist from torch import nn -from transformers.modeling_outputs import BaseModelOutputWithPast - -from ..kernels import apply_rotary_pos_emb, fill_kv_cache, paged_attention_fwd -from ..weight_loader.dist_utils import (colwise_parallelize_linear, - colwise_split_parallelize_linear, - rowwise_parallelize_linear) - - -class PatchedQWenAttention(nn.Module): - - def _load_weights(self, loader, rank: int, world_size: int, - device: torch.device): - for mod_name in ['c_attn']: - w_pack_out = self.c_attn.out_features - sections = [w_pack_out // 3] * 3 - colwise_split_parallelize_linear(getattr(self, mod_name), - sections, - loader, - rank=rank, - world_size=world_size, - prefix=mod_name) - for mod_name in ['c_proj']: - rowwise_parallelize_linear(getattr(self, mod_name), - loader, - rank=rank, - world_size=world_size, - prefix=mod_name) - - @classmethod - def _distribute_output_fn(cls, outputs, **kwargs): - """Distribution output hook.""" - dist.all_reduce(outputs[0]) - return outputs +from transformers.configuration_utils import PretrainedConfig + +from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, EmbeddingType, + RMSNorm, SiluAndMul, build_rotary_embedding) +from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear, + build_qkv_proj, build_rowwise_linear) +from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight + + +class QWenAttention(torch.nn.Module): + """Parallel self-attention layer abstract class. + + Self-attention layer takes input with size [s, b, h] and returns output of + the same size. + """ + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + + self.hidden_size = config.hidden_size + self.split_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.projection_size = config.kv_channels * config.num_attention_heads + self.num_attention_heads = config.num_attention_heads + self.num_kv_heads = self.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.c_attn = build_qkv_proj( + config.hidden_size, + num_q_heads=self.num_attention_heads, + num_kv_heads=self.num_kv_heads, + head_size=self.head_dim, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + ) - def _contiguous_batching_forward_impl( + # apply rotary + self.apply_rotary_pos_emb = ApplyRotaryEmb() + + # attention + self.attn_fwd = Attention( + self.num_attention_heads, + self.head_dim, + num_kv_heads=self.num_kv_heads, + ) + + # o_proj + self.c_proj = build_rowwise_linear(self.projection_size, + config.hidden_size, + bias=not config.no_bias, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward( self, hidden_states: torch.Tensor, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], past_key_value: Optional[Tuple[torch.Tensor]] = None, - rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None, - world_size: int = 1, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - """Rewrite implementation of QWenAttention.forward. - - Add continuous batching support. Add paged attention support. TP - support. - """ - context = self.context.context - block_offsets = context.block_offsets - position_ids = context.position_ids - position_ids_1d = context.position_ids_1d - max_kv_seq_length = context.max_kv_seq_length - max_q_seq_length = context.max_q_seq_length - kv_seq_length = context.kv_seq_length - q_start_loc = context.q_start_loc - q_seq_length = context.q_seq_length - - def __qkv_proj(hidden_states): - """qkv_proj.""" - qkv_states = self.c_attn(hidden_states) - b, seq_len, _ = qkv_states.size() - query_states, key_states, value_states = qkv_states.chunk(3, dim=2) - num_heads = self.num_heads // world_size - query_states = query_states.view(b, seq_len, num_heads, - self.head_dim) - key_states = key_states.view(b, seq_len, num_heads, self.head_dim) - value_states = value_states.view(b, seq_len, num_heads, - self.head_dim) - return query_states, key_states, value_states - - def __rotary_emb_fn(query_states, key_states, value_states): - """rotary embedding func.""" - assert len(rotary_pos_emb_list) == 1, 'do not support dynamic ntk' - cos, sin = rotary_pos_emb_list[0] - query_states, key_states = apply_rotary_pos_emb( - query_states, - key_states, - cos, - sin, - position_ids=position_ids, - position_ids_1d=position_ids_1d) - - return query_states, key_states, value_states - - query_states, key_states, value_states = __qkv_proj(hidden_states) - if rotary_pos_emb_list is not None: - query_states, key_states, value_states = __rotary_emb_fn( - query_states, key_states, value_states) - if max_kv_seq_length > self.seq_length and self.use_logn_attn: - if self.logn_tensor.device != query_states.device or \ - self.logn_tensor.dtype != query_states.dtype: - self.logn_tensor = self.logn_tensor.to( - query_states.device).type_as(query_states) - logn_tensor = self.logn_tensor[:, position_ids_1d, :, :] - query_states = query_states * logn_tensor.expand_as(query_states) - - query_states = query_states.flatten(0, 1) - key_states = key_states.flatten(0, 1) - value_states = value_states.flatten(0, 1) - - fill_kv_cache(key_states, - value_states, - past_key_value[0], - past_key_value[1], - q_start_loc, - q_seq_length, - kv_seq_length=kv_seq_length, - max_q_seq_length=max_q_seq_length, - block_offsets=block_offsets) - - attn_output = query_states - - paged_attention_fwd( + attn_metadata: Any = None, + ): + """Rewrite of LlamaAttention.forward.""" + # qkv proj + qkv_states = self.c_attn(hidden_states) + # (-1, heads, head_dim) + qkv_states = qkv_states.flatten(0, -2) + (query_states, key_states, + value_states) = self.c_attn.split_qkv(qkv_states) + + # apply rotary embedding + cos, sin = rotary_pos_emb + query_states, key_states = self.apply_rotary_pos_emb( + query_states, + key_states, + cos, + sin, + inplace=True, + ) + + # attention + attn_output = self.attn_fwd( query_states, + key_states, + value_states, past_key_value[0], past_key_value[1], - attn_output, - block_offsets, - q_start_loc=q_start_loc, - q_seqlens=q_seq_length, - kv_seqlens=kv_seq_length, - max_seqlen=max_q_seq_length, + attn_metadata, + inplace=True, ) - attn_output = attn_output.flatten(1, 2) - attn_output = self.c_proj(attn_output) attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1) - return attn_output, None, past_key_value + + # o proj + attn_output = self.c_proj(attn_output) + return attn_output + + +class QWenMLP(nn.Module): + """mlp.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + ff_dim_in = config.intermediate_size // 2 + # gate up + self.gate_up_proj = build_merged_colwise_linear( + config.hidden_size, + [ff_dim_in, ff_dim_in], + bias=not config.no_bias, + dtype=dtype, + device=device, + quant_config=quantization_config, + is_tp=True, + ) + + # silu and mul + self.act_fn = SiluAndMul(inplace=True) + + # down + self.c_proj = build_rowwise_linear(ff_dim_in, + config.hidden_size, + bias=not config.no_bias, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward(self, x): + """forward.""" + gate_up = self.gate_up_proj(x) + act = self.act_fn(gate_up) + return self.c_proj(act) + + +class QWenBlock(torch.nn.Module): + """A single transformer layer. + + Transformer layer takes input with size [s, b, h] and returns an output of + the same size. + """ + + def __init__(self, + config: PretrainedConfig, + layer_number: int, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.layer_number = layer_number + hidden_size = config.hidden_size + self.bf16 = config.bf16 + + quantization_config = getattr(config, 'quantization_config', None) + + # build attention layer + self.attn = QWenAttention(config, dtype=dtype, device=device) + + # builf MLP + self.mlp = QWenMLP(config, dtype=dtype, device=device) + + # build input layer norm + self.ln_1 = RMSNorm(hidden_size, + config.layer_norm_epsilon, + quant_config=quantization_config, + dtype=dtype, + device=device) + + # build attention layer norm + self.ln_2 = RMSNorm(hidden_size, + config.layer_norm_epsilon, + quant_config=quantization_config, + dtype=dtype, + device=device) def forward( self, - hidden_states: Optional[Tuple[torch.FloatTensor]], - rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None, - layer_past: Optional[Tuple[torch.Tensor]] = None, - **kwargs - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - """forward.""" - world_size = 1 - if dist.is_initialized(): - world_size = dist.get_world_size() - return self._contiguous_batching_forward_impl( - hidden_states, - past_key_value=layer_past, - rotary_pos_emb_list=rotary_pos_emb_list, - world_size=world_size, + hidden_states: torch.Tensor, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], + past_key_value: Optional[List[torch.FloatTensor]], + residual: Optional[torch.Tensor] = None, + attn_metadata: Any = None, + ): + + if residual is None: + residual = hidden_states + layernorm_output = self.ln_1(hidden_states) + else: + layernorm_output, residual = self.ln_1(hidden_states, residual) + + # Self Attention + layernorm_input = self.attn( + hidden_states=layernorm_output, + rotary_pos_emb=rotary_pos_emb, + past_key_value=past_key_value, + attn_metadata=attn_metadata, ) + # Fully Connected + layernorm_output, residual = self.ln_2(layernorm_input, residual) + mlp_output = self.mlp(layernorm_output) -class PatchedQWenMLP(nn.Module): - - def _load_weights(self, loader, rank: int, world_size: int, - device: torch.device): - """load weights.""" - for mod_name in ['w1', 'w2']: - colwise_parallelize_linear(getattr(self, mod_name), - loader, - rank=rank, - world_size=world_size, - prefix=mod_name) - for mod_name in ['c_proj']: - rowwise_parallelize_linear(getattr(self, mod_name), - loader, - rank=rank, - world_size=world_size, - prefix=mod_name) - - @classmethod - def _distribute_output_fn(cls, outputs, **kwargs): - """Distribution output hook.""" - dist.all_reduce(outputs) + outputs = (mlp_output, residual) return outputs -class PatchedQWenModel(nn.Module): +class QWenModel(nn.Module): + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + self.vocab_size = config.vocab_size + self.embed_dim = config.hidden_size + self.wte = nn.Embedding(self.vocab_size, + self.embed_dim, + dtype=dtype, + device=device) + + # build all decode layers + self.h = nn.ModuleList([ + QWenBlock(config, layer_idx, dtype=dtype, device=device) + for layer_idx in range(config.num_hidden_layers) + ]) + + # build rotary embedding + emb_type = EmbeddingType.LinearScaling + if config.rotary_pct == 1.0: + self.rotary_ndims = None + else: + assert config.rotary_pct < 1 + self.rotary_ndims = int(config.kv_channels * config.rotary_pct) + rope_dim = (self.rotary_ndims + if self.rotary_ndims is not None else config.kv_channels) + rope_max_pos_emb = getattr(config, 'max_position_embeddings', 4096) + rope_base = config.rotary_emb_base + self.rotary_emb = build_rotary_embedding( + rope_dim, + rope_max_pos_emb, + rope_base, + emb_type=emb_type, + ) + + self.ln_f = RMSNorm(self.embed_dim, + eps=config.layer_norm_epsilon, + quant_config=quantization_config, + dtype=dtype, + device=device) - def _continuous_batching_forward( + def forward( self, input_ids: torch.LongTensor = None, + position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + attn_metadata: Any = None, inputs_embeds: Optional[torch.FloatTensor] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - """Rewrite implementation of LlamaModel.forward.""" + ): + """forward.""" + + # token embedding if inputs_embeds is None: inputs_embeds = self.wte(input_ids) - # Attention mask is not necessary in continuous batching hidden_states = inputs_embeds - context = self.context.context - max_kv_seq_length = context.max_kv_seq_length - # do not support use_dynamic_ntk - ntk_alpha_list = [1.0] - self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list - rotary_pos_emb_list = [ - self.rotary_emb(max_kv_seq_length, ntk_alpha=ntk_alpha) - for ntk_alpha in ntk_alpha_list - ] - for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): - outputs = block( + # rotary embedding + cos, sin = self.rotary_emb(hidden_states, position_ids) + cos, sin = cos[0], sin[0] + rotary_pos_emb = (cos, sin) + + # decoding + residual = None + for idx, decoder_layer in enumerate(self.h): + past_key_value = past_key_values[idx] + hidden_states, residual = decoder_layer( hidden_states, - layer_past=layer_past, - rotary_pos_emb_list=rotary_pos_emb_list, + rotary_pos_emb=rotary_pos_emb, + past_key_value=past_key_value, + residual=residual, + attn_metadata=attn_metadata, ) - hidden_states = outputs[0] - hidden_states = self.ln_f(hidden_states) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=past_key_values, - hidden_states=None, - attentions=None, - ) + # norm + hidden_states, residual = self.ln_f(hidden_states, residual) + + return hidden_states + + def get_input_embeddings(self): + """get input embeddings.""" + return self.wte + + +class QWenLMHeadModel(nn.Module): + """rewrote model.""" + + support_cuda_graph = True + + packed_modules_mapping = { + 'gate_up_proj': [ + 'w2', + 'w1', + ], + } + + def __init__(self, + config: PretrainedConfig, + ctx_mgr: StepContextManager, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.ctx_mgr = ctx_mgr + # build Model + self.transformer = QWenModel(config, dtype=dtype, device=device) + + # output_layers + self.lm_head = build_rowwise_linear(config.hidden_size, + config.vocab_size, + bias=False, + dtype=dtype, + device=device) def forward( self, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - """Rewrite of LlamaModel.forward.""" - return self._continuous_batching_forward( - input_ids, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: List[List[torch.Tensor]], + attn_metadata: Any = None, + inputs_embeds: torch.Tensor = None, + **kwargs, + ): + """model forward, return logits.""" + hidden_states = self.transformer( + input_ids=input_ids, + position_ids=position_ids, past_key_values=past_key_values, + attn_metadata=attn_metadata, inputs_embeds=inputs_embeds, ) + logits = self.lm_head(hidden_states) + logits = logits.float() + return logits -class PatchedRMSNorm(nn.Module): - """Rewrite RMSNorm.""" + def get_input_embeddings(self): + """get input embeddings.""" + return self.transformer.get_input_embeddings() - def forward(self, hidden_states): - """forward.""" - # torch.nn.functional.normalize based implementation might leads - # to wrong output - from ..kernels import rms_norm - ret = rms_norm(hidden_states, self.weight, self.eps) + def prepare_inputs_for_generation( + self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + context: StepContext = None, + ): + """prepare input.""" + # get input_ids, position_ids and attention metadatas + input_ids = context.input_ids + position_ids = context.position_ids + attn_metadata = context.attn_metadata + + # process vision embeddings + vision_embeddings = context.input_embeddings + vision_embedding_indexing = context.input_embedding_indexing + if vision_embeddings is not None and len(vision_embeddings) > 0: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + inputs_embeds[:, + vision_embedding_indexing, :] = vision_embeddings.to( + inputs_embeds) + + # inputs of forward + return dict( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """load weights.""" + # modify from vllm + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ('.gate_up_proj', '.w2', 0), + ('.gate_up_proj', '.w1', 1), + ] + + config = self.config + num_attention_heads = config.num_attention_heads + projection_size = config.kv_channels * num_attention_heads + num_kv_heads = num_attention_heads + head_size = (projection_size // num_attention_heads) + qkv_section = [ + head_size * num_attention_heads, head_size * num_kv_heads, + head_size * num_kv_heads + ] - return ret + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if 'rotary_pos_emb.inv_freq' in name: + continue + if ('rotary_pos_emb.cos_cached' in name + or 'rotary_pos_emb.sin_cached' in name): + continue + if (self.config.tie_word_embeddings and 'lm_head.weight' in name): + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + load_weight(param, loaded_weight, shard_id=shard_id) + break + else: + if '.c_attn' in name: + q, k, v = loaded_weight.split(qkv_section) + param = params_dict[name] + load_weight(param, q, shard_id='q') + load_weight(param, k, shard_id='k') + load_weight(param, v, shard_id='v') + else: + param = params_dict[name] + load_weight(param, loaded_weight) From 481182d5a9f5d5c596f83279e1e6649744ac13a4 Mon Sep 17 00:00:00 2001 From: grimoire Date: Fri, 30 Aug 2024 18:16:47 +0800 Subject: [PATCH 54/89] remove request timeout --- lmdeploy/pytorch/engine/request.py | 28 +++++------- lmdeploy/pytorch/models/module_map.py | 65 ++++++++++++++------------- 2 files changed, 43 insertions(+), 50 deletions(-) diff --git a/lmdeploy/pytorch/engine/request.py b/lmdeploy/pytorch/engine/request.py index 1a1c0d424..18bd2193d 100644 --- a/lmdeploy/pytorch/engine/request.py +++ b/lmdeploy/pytorch/engine/request.py @@ -281,7 +281,7 @@ def send_async(self, req_type: RequestType, data: Any) -> int: """send request asynchronize.""" return self.batched_send_async(req_types=[req_type], data=[data])[0] - async def async_recv_any(self, que_timeout: float = None) -> Response: + async def async_recv_any(self) -> Response: """receive any response.""" self._prefetch_resps() for req_id in self.resp_dict: @@ -290,9 +290,9 @@ async def async_recv_any(self, que_timeout: float = None) -> Response: return ret return await self._async_resp_get() - def recv_any(self, que_timeout: float = None) -> Response: + def recv_any(self) -> Response: """receive any response.""" - coro = self.async_recv_any(que_timeout) + coro = self.async_recv_any() return self.run_until_complete(coro) def recv_all(self, req_id: int, block: bool = True): @@ -301,9 +301,7 @@ def recv_all(self, req_id: int, block: bool = True): resps = self.resp_dict.pop(req_id, []) return resps - async def async_recv(self, - req_id: int, - que_timeout: float = None) -> Response: + async def async_recv(self, req_id: int) -> Response: """receive response of given request id async.""" ret = self._pop_resp(req_id, default=None) if ret is not None: @@ -317,13 +315,13 @@ async def async_recv(self, else: return resp - def recv(self, req_id: int, que_timeout: float = None) -> Response: + def recv(self, req_id: int) -> Response: """receive response of given request id. Different behavior in threadsafe mode. """ if not self.is_thread_safe(): - coro = self.async_recv(req_id, que_timeout) + coro = self.async_recv(req_id) return self.run_until_complete(coro) ret = self._pop_resp(req_id, default=None) @@ -338,21 +336,15 @@ def recv(self, req_id: int, que_timeout: float = None) -> Response: else: return resp - async def async_send(self, - req_type: RequestType, - data: Any, - que_timeout: float = None): + async def async_send(self, req_type: RequestType, data: Any): """send and receive synchronize.""" req_id = await self.async_send_async(req_type, data) - return await self.async_recv(req_id, que_timeout=que_timeout) + return await self.async_recv(req_id) - def send(self, - req_type: RequestType, - data: Any, - que_timeout: float = None) -> Response: + def send(self, req_type: RequestType, data: Any) -> Response: """send and receive synchronize.""" req_id = self.send_async(req_type, data) - return self.recv(req_id, que_timeout=que_timeout) + return self.recv(req_id) def response_callback(self, resp: Response): """response callback.""" diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index 67a3fcfb5..385d0eba5 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -2,11 +2,17 @@ LMDEPLOY_PYTORCH_MODEL_PATH = 'lmdeploy.pytorch.models' +# ascend module +MODULE_MAP = dict() +ASCEND_MODULE_MAP = dict() + +DEVICE_SPECIAL_MODULE_MAP = dict(ascend=ASCEND_MODULE_MAP) + # llama -MODULE_MAP = { +MODULE_MAP.update({ 'LlamaForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaForCausalLM', -} +}) # Falcon Models in transformer / on hub MODULE_MAP.update({ @@ -44,6 +50,12 @@ f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mistral.MistralForCausalLM', }) +# mixtral +MODULE_MAP.update({ + 'MixtralForCausalLM': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mixtral.MixtralForCausalLM', +}) + # gemma MODULE_MAP.update({ 'GemmaForCausalLM': @@ -62,6 +74,25 @@ f'{LMDEPLOY_PYTORCH_MODEL_PATH}.deepseek.DeepseekForCausalLM', }) +# deepseek-v2 +MODULE_MAP.update({ + 'DeepseekV2ForCausalLM': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.deepseek_v2.DeepseekV2ForCausalLM' +}) + +# llava +MODULE_MAP.update( + { + 'LlavaLlamaForCausalLM': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlavaLlamaForCausalLM', + 'LlavaMistralForCausalLM': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mistral.LlavaMistralForCausalLM', + 'LlavaForConditionalGeneration': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llava.LlavaForConditionalGeneration', # noqa: E501 + 'LlavaNextForConditionalGeneration': # noqa: E501 + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llava.LlavaForConditionalGeneration' + }) + # qwen MODULE_MAP.update({ 'QWenLMHeadModel': @@ -102,12 +133,6 @@ f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaRMSNorm', }) -# mixtral -MODULE_MAP.update({ - 'MixtralForCausalLM': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mixtral.MixtralForCausalLM', -}) - # dbrx MODULE_MAP.update({ 'modeling_dbrx.DbrxAttention': @@ -154,12 +179,6 @@ f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaRMSNorm', }) -# deepseek-v2 -MODULE_MAP.update({ - 'DeepseekV2ForCausalLM': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.deepseek_v2.DeepseekV2ForCausalLM' -}) - # cogvlm MODULE_MAP.update({ 'modeling_cogvlm.RMSNorm': @@ -174,19 +193,6 @@ f'{LMDEPLOY_PYTORCH_MODEL_PATH}.cogvlm.PatchedCogVLMModel', }) -# llava -MODULE_MAP.update( - { - 'LlavaLlamaForCausalLM': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlavaLlamaForCausalLM', - 'LlavaMistralForCausalLM': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mistral.LlavaMistralForCausalLM', - 'LlavaForConditionalGeneration': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llava.LlavaForConditionalGeneration', # noqa: E501 - 'LlavaNextForConditionalGeneration': # noqa: E501 - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llava.LlavaForConditionalGeneration' - }) - # internvl MODULE_MAP.update({ 'modeling_internvl_chat.InternVLChatModel': @@ -209,11 +215,6 @@ f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaRMSNorm', }) -# ascend module -ASCEND_MODULE_MAP = dict() - -DEVICE_SPECIAL_MODULE_MAP = dict(ascend=ASCEND_MODULE_MAP) - # phi-3.5-moe MODULE_MAP.update({ 'modeling_phimoe.PhiMoEAttention': From 71d55a43551a5ee850f6f87beed18ff09e2f4306 Mon Sep 17 00:00:00 2001 From: grimoire Date: Mon, 2 Sep 2024 11:11:44 +0800 Subject: [PATCH 55/89] add qwen2, qwen-moe --- lmdeploy/pytorch/models/module_map.py | 30 +- lmdeploy/pytorch/models/qwen2.py | 505 ++++++++++++++++----- lmdeploy/pytorch/models/qwen2_moe.py | 629 ++++++++++++++++++++++---- 3 files changed, 924 insertions(+), 240 deletions(-) diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index 385d0eba5..0d3baaae0 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -101,36 +101,14 @@ # qwen1.5 MODULE_MAP.update({ - 'transformers.models.qwen2.modeling_qwen2.Qwen2Attention': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen2.PatchedQwen2Attention', - 'transformers.models.qwen2.modeling_qwen2.Qwen2FlashAttention2': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen2.PatchedQwen2Attention', - 'transformers.models.qwen2.modeling_qwen2.Qwen2SdpaAttention': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen2.PatchedQwen2Attention', - 'transformers.models.qwen2.modeling_qwen2.Qwen2Model': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaModel', - 'transformers.models.qwen2.modeling_qwen2.Qwen2MLP': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaMLP', - 'transformers.models.qwen2.modeling_qwen2.Qwen2RMSNorm': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaRMSNorm', + 'Qwen2ForCausalLM': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen2.Qwen2ForCausalLM', }) # qwen2 moe MODULE_MAP.update({ - 'transformers.models.qwen2_moe.modeling_qwen2_moe.Qwen2MoeAttention': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen2.PatchedQwen2Attention', - 'transformers.models.qwen2_moe.modeling_qwen2_moe.Qwen2MoeFlashAttention2': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen2.PatchedQwen2Attention', - 'transformers.models.qwen2_moe.modeling_qwen2_moe.Qwen2MoeSdpaAttention': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen2.PatchedQwen2Attention', - 'transformers.models.qwen2_moe.modeling_qwen2_moe.Qwen2MoeModel': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen2_moe.PatchedQwen2MoeModel', - 'transformers.models.qwen2_moe.modeling_qwen2_moe.Qwen2MoeMLP': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaMLP', - 'transformers.models.qwen2_moe.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen2_moe.PatchedQwen2MoeSparseMoeBlock', - 'transformers.models.qwen2_moe.modeling_qwen2_moe.Qwen2MoeRMSNorm': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaRMSNorm', + 'Qwen2MoeForCausalLM': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen2_moe.Qwen2MoeForCausalLM', }) # dbrx diff --git a/lmdeploy/pytorch/models/qwen2.py b/lmdeploy/pytorch/models/qwen2.py index 02185c703..2001c365d 100644 --- a/lmdeploy/pytorch/models/qwen2.py +++ b/lmdeploy/pytorch/models/qwen2.py @@ -1,148 +1,417 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Optional, Tuple +from typing import Any, Iterable, List, Optional, Tuple import torch -import torch.distributed as dist from torch import nn +from transformers.configuration_utils import PretrainedConfig -from ..kernels import apply_rotary_pos_emb, fill_kv_cache, paged_attention_fwd -from ..weight_loader.dist_utils import (colwise_parallelize_linear, - rowwise_parallelize_linear) +from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, EmbeddingType, + RMSNorm, SiluAndMul, build_rotary_embedding) +from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear, + build_qkv_proj, build_rowwise_linear) +from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight -class PatchedQwen2Attention(nn.Module): +class Qwen2Attention(nn.Module): + """Rewrite module of Qwen2Attention.""" - def _load_weights(self, loader, rank: int, world_size: int, - device: torch.device): - """load weights.""" - for mod_name in ['q_proj', 'k_proj', 'v_proj']: - colwise_parallelize_linear(getattr(self, mod_name), - loader, - rank=rank, - world_size=world_size, - prefix=mod_name) - for mod_name in ['o_proj']: - rowwise_parallelize_linear(getattr(self, mod_name), - loader, - rank=rank, - world_size=world_size, - prefix=mod_name) - - @classmethod - def _distribute_output_fn(cls, outputs, **kwargs): - """Distribution output hook.""" - dist.all_reduce(outputs[0]) - return outputs + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + num_heads = config.num_attention_heads + num_key_value_heads = config.num_key_value_heads + hidden_size = config.hidden_size + head_dim = getattr(config, 'head_dim', hidden_size // num_heads) + + # packed qkv + self.qkv_proj = build_qkv_proj( + hidden_size, + num_q_heads=num_heads, + num_kv_heads=num_key_value_heads, + head_size=head_dim, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + ) + + # rotary embedding + self.apply_rotary_pos_emb = ApplyRotaryEmb() - def _contiguous_batching_forward_impl( + # attention + self.attn_fwd = Attention( + num_heads, + head_dim, + num_kv_heads=num_key_value_heads, + v_head_size=head_dim, + sliding_window=config.sliding_window, + ) + + # o_proj + self.o_proj = build_rowwise_linear(num_heads * head_dim, + hidden_size, + bias=False, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward( self, hidden_states: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], past_key_value: Optional[Tuple[torch.Tensor]] = None, - world_size: int = 1, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - """Rewrite implementation of forward. - - Add continuous batching support. Add paged attention support. TP - support. - """ - context = self.context.context - kv_seq_length = context.kv_seq_length - q_seq_length = context.q_seq_length - q_start_loc = context.q_start_loc - block_offsets = context.block_offsets - max_q_seq_length = context.max_q_seq_length - max_kv_seq_length = context.max_kv_seq_length - - num_heads = self.num_heads // world_size - num_kv_heads = self.num_key_value_heads // world_size - head_dim = self.head_dim - hidden_size = num_heads * head_dim - - def __qkv_proj(hidden_states): - """qkv proj.""" - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - return query_states, key_states, value_states - - def __rotary_emb_fn(query_states, key_states, value_states): - if hasattr(self, 'rotary_emb'): - cos, sin = self.rotary_emb(value_states, - seq_len=max_kv_seq_length) - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin, position_ids, - context.position_ids_1d) - return query_states, key_states, value_states - - query_states, key_states, value_states = __qkv_proj(hidden_states) - - query_states = query_states.view(-1, num_heads, head_dim) - key_states = key_states.view(-1, num_kv_heads, head_dim) - value_states = value_states.view(-1, num_kv_heads, head_dim) - - query_states, key_states, value_states = __rotary_emb_fn( - query_states, key_states, value_states) - - fill_kv_cache( + attn_metadata: Any = None, + ): + """Rewrite of LlamaAttention.forward.""" + # qkv proj + qkv_states = self.qkv_proj(hidden_states) + # (-1, heads, head_dim) + qkv_states = qkv_states.flatten(0, -2) + query_states, key_states, value_states = self.qkv_proj.split_qkv( + qkv_states) + + # apply rotary embedding + cos, sin = rotary_pos_emb + query_states, key_states = self.apply_rotary_pos_emb( + query_states, + key_states, + cos, + sin, + inplace=True, + ) + + # attention + attn_output = self.attn_fwd( + query_states, key_states, value_states, past_key_value[0], past_key_value[1], - q_start_loc, - q_seq_length, - kv_seq_length=kv_seq_length, - max_q_seq_length=max_q_seq_length, - block_offsets=block_offsets, + attn_metadata, + inplace=True, ) + attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1) - attn_output = query_states + # o proj + attn_output = self.o_proj(attn_output) + return attn_output - use_sliding_windows = (getattr(self.config, 'sliding_window', None) - is not None and self.config.use_sliding_window) - window_size = self.config.sliding_window - if not use_sliding_windows: - window_size = -1 - paged_attention_fwd( - query_states, - past_key_value[0], - past_key_value[1], - attn_output, - block_offsets, - q_start_loc=q_start_loc, - q_seqlens=q_seq_length, - kv_seqlens=kv_seq_length, - max_seqlen=max_q_seq_length, - window_size=window_size, + +class Qwen2MLP(nn.Module): + """mlp.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + # gate up + self.gate_up_proj = build_merged_colwise_linear( + config.hidden_size, + [config.intermediate_size, config.intermediate_size], + bias=False, + dtype=dtype, + device=device, + quant_config=quantization_config, + is_tp=True, ) - attn_output = attn_output.reshape(*hidden_states.shape[:-1], - hidden_size) + # silu and mul + self.act_fn = SiluAndMul(inplace=True) - attn_output = self.o_proj(attn_output) + # down + self.down_proj = build_rowwise_linear(config.intermediate_size, + config.hidden_size, + bias=False, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward(self, x): + """forward.""" + gate_up = self.gate_up_proj(x) + act = self.act_fn(gate_up) + return self.down_proj(act) + + +class Qwen2DecoderLayer(nn.Module): + """decoder layer.""" + + def __init__(self, + config: PretrainedConfig, + layer_idx: int, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.layer_idx = layer_idx + quantization_config = getattr(config, 'quantization_config', None) + + # build attention layer + self.self_attn = Qwen2Attention(config, dtype=dtype, device=device) - return attn_output, None, past_key_value + # builf MLP + self.mlp = Qwen2MLP(config, dtype=dtype, device=device) + + # build input layer norm + self.input_layernorm = RMSNorm(config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) + + # build attention layer norm + self.post_attention_layernorm = RMSNorm( + config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], + past_key_value: Optional[List[torch.FloatTensor]], + residual: Optional[torch.Tensor] = None, + attn_metadata: Any = None, + ): + + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + # Self Attention + hidden_states = self.self_attn( + hidden_states=hidden_states, + rotary_pos_emb=rotary_pos_emb, + past_key_value=past_key_value, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + + outputs = (hidden_states, residual) + return outputs + + +class Qwen2Model(nn.Module): + """model.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + quantization_config = getattr(config, 'quantization_config', None) + + self.embed_tokens = nn.Embedding(config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=dtype, + device=device) + + # build all decode layers + self.layers = nn.ModuleList([ + Qwen2DecoderLayer(config, layer_idx, dtype=dtype, device=device) + for layer_idx in range(config.num_hidden_layers) + ]) + + # build norm + self.norm = RMSNorm(config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) + + # build rotary embedding + emb_type = EmbeddingType.LinearScaling + rope_dim = config.hidden_size // config.num_attention_heads + rope_max_pos_emb = config.max_position_embeddings + rope_base = config.rope_theta + self.rotary_emb = build_rotary_embedding( + rope_dim, + rope_max_pos_emb, + rope_base, + emb_type=emb_type, + ) + + def forward( + self, + input_ids: torch.LongTensor = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, + past_key_values: Optional[List[torch.FloatTensor]] = None, + attn_metadata: Any = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ): + """Rewrite of LlamaModel.forward.""" + + # token embedding + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + hidden_states = inputs_embeds + + # rotary embedding + cos, sin = self.rotary_emb(hidden_states, position_ids) + cos, sin = cos[0], sin[0] + rotary_pos_emb = (cos, sin) + + # decoding + residual = None + for idx, decoder_layer in enumerate(self.layers): + past_key_value = past_key_values[idx] + hidden_states, residual = decoder_layer( + hidden_states, + rotary_pos_emb=rotary_pos_emb, + past_key_value=past_key_value, + residual=residual, + attn_metadata=attn_metadata, + ) + + # norm + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + def get_input_embeddings(self): + """get input embeddings.""" + return self.embed_tokens + + +class Qwen2ForCausalLM(nn.Module): + """ModelForCausalLM.""" + + support_cuda_graph = True + + packed_modules_mapping = { + 'qkv_proj': [ + 'q_proj', + 'k_proj', + 'v_proj', + ], + 'gate_up_proj': [ + 'gate_proj', + 'up_proj', + ], + } + + def __init__(self, + config: PretrainedConfig, + ctx_mgr: StepContextManager, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.ctx_mgr = ctx_mgr + # build model + self.model = Qwen2Model(config, dtype=dtype, device=device) + # build lm_head + self.lm_head = build_rowwise_linear(config.hidden_size, + config.vocab_size, + bias=False, + dtype=dtype, + device=device) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: List[List[torch.Tensor]], + attn_metadata: Any = None, + inputs_embeds: torch.Tensor = None, **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - """Rewrite of forward.""" - world_size = 1 - if dist.is_initialized(): - world_size = dist.get_world_size() - return self._contiguous_batching_forward_impl( - hidden_states, - position_ids, - past_key_value, - world_size=world_size, + ): + """model forward, return logits.""" + hidden_states = self.model( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + ) + + logits = self.lm_head(hidden_states) + logits = logits.float() + return logits + + def get_input_embeddings(self): + """get input embeddings.""" + return self.model.get_input_embeddings() + + def prepare_inputs_for_generation( + self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + context: StepContext = None, + ): + """prepare input.""" + # get input_ids, position_ids and attention metadatas + input_ids = context.input_ids + position_ids = context.position_ids + attn_metadata = context.attn_metadata + + # process vision embeddings + vision_embeddings = context.input_embeddings + vision_embedding_indexing = context.input_embedding_indexing + if vision_embeddings is not None and len(vision_embeddings) > 0: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + inputs_embeds[:, + vision_embedding_indexing, :] = vision_embeddings.to( + inputs_embeds) + + # inputs of forward + return dict( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """load weights.""" + # modify from vllm + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ('.qkv_proj', '.q_proj', 'q'), + ('.qkv_proj', '.k_proj', 'k'), + ('.qkv_proj', '.v_proj', 'v'), + ('.gate_up_proj', '.gate_proj', 0), + ('.gate_up_proj', '.up_proj', 1), + ] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if 'rotary_emb.inv_freq' in name: + continue + if ('rotary_emb.cos_cached' in name + or 'rotary_emb.sin_cached' in name): + continue + if self.config.tie_word_embeddings and 'lm_head.weight' in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + load_weight(param, loaded_weight, shard_id=shard_id) + break + else: + param = params_dict[name] + load_weight(param, loaded_weight) diff --git a/lmdeploy/pytorch/models/qwen2_moe.py b/lmdeploy/pytorch/models/qwen2_moe.py index 5db3b68d5..c46104f5e 100644 --- a/lmdeploy/pytorch/models/qwen2_moe.py +++ b/lmdeploy/pytorch/models/qwen2_moe.py @@ -1,141 +1,578 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import List, Optional +from typing import Any, Dict, Iterable, List, Optional, Tuple import torch +import torch.distributed as dist import torch.nn.functional as F -from torch import distributed as dist from torch import nn +from transformers.configuration_utils import PretrainedConfig -from lmdeploy.pytorch.kernels.fused_moe import fused_moe +from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, EmbeddingType, + RMSNorm, SiluAndMul, build_rotary_embedding) +from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear, + build_qkv_proj, build_rowwise_linear) +from lmdeploy.pytorch.nn.moe import FusedMoE, SoftmaxTopK +from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight -class PatchedQwen2MoeSparseMoeBlock(nn.Module): +def get_world_rank(): + """get current world size and rank.""" + import torch.distributed as dist + world_size = 1 + rank = 0 - def _update_model_fn(self): - """update model.""" - num_experts = len(self.experts) + if dist.is_initialized(): + world_size = dist.get_world_size() + rank = dist.get_rank() - def __get_meta(): - exp = self.experts[0] - ffn_dim = exp.gate_proj.weight.size(0) - hidden_dim = exp.down_proj.weight.size(0) - dtype = exp.gate_proj.weight.dtype - device = exp.gate_proj.weight.device - return ffn_dim, hidden_dim, dtype, device + return world_size, rank - def __copy_assign_param(param, weight): - """copy assign.""" - weight.copy_(param.data) - param.data = weight - ffn_dim, hidden_dim, dtype, device = __get_meta() +class Qwen2MoeAttention(nn.Module): + """Rewrite module of Qwen2MoeAttention.""" - gate_up_weights = torch.empty(num_experts, - ffn_dim * 2, - hidden_dim, - device=device, - dtype=dtype) - down_weights = torch.empty(num_experts, - hidden_dim, - ffn_dim, - device=device, - dtype=dtype) + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + num_heads = config.num_attention_heads + num_key_value_heads = config.num_key_value_heads + hidden_size = config.hidden_size + head_dim = getattr(config, 'head_dim', hidden_size // num_heads) - for exp_id, exp in enumerate(self.experts): - __copy_assign_param(exp.gate_proj.weight, - gate_up_weights[exp_id, :ffn_dim]) - __copy_assign_param(exp.up_proj.weight, gate_up_weights[exp_id, - ffn_dim:]) - __copy_assign_param(exp.down_proj.weight, down_weights[exp_id]) + # packed qkv + self.qkv_proj = build_qkv_proj( + hidden_size, + num_q_heads=num_heads, + num_kv_heads=num_key_value_heads, + head_size=head_dim, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + ) + + # rotary embedding + self.apply_rotary_pos_emb = ApplyRotaryEmb() + + # attention + self.attn_fwd = Attention( + num_heads, + head_dim, + num_kv_heads=num_key_value_heads, + v_head_size=head_dim, + sliding_window=config.sliding_window, + ) + + # o_proj + self.o_proj = build_rowwise_linear(num_heads * head_dim, + hidden_size, + bias=False, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward( + self, + hidden_states: torch.Tensor, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attn_metadata: Any = None, + ): + """Rewrite of LlamaAttention.forward.""" + # qkv proj + qkv_states = self.qkv_proj(hidden_states) + # (-1, heads, head_dim) + qkv_states = qkv_states.flatten(0, -2) + query_states, key_states, value_states = self.qkv_proj.split_qkv( + qkv_states) + + # apply rotary embedding + cos, sin = rotary_pos_emb + query_states, key_states = self.apply_rotary_pos_emb( + query_states, + key_states, + cos, + sin, + inplace=True, + ) - torch.cuda.empty_cache() + # attention + attn_output = self.attn_fwd( + query_states, + key_states, + value_states, + past_key_value[0], + past_key_value[1], + attn_metadata, + inplace=True, + ) + attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1) + + # o proj + attn_output = self.o_proj(attn_output) + return attn_output + + +class Qwen2MoeMLP(nn.Module): + """mlp.""" + + def __init__(self, + config: PretrainedConfig, + intermediate_size: int = None, + dtype: torch.dtype = None, + device: torch.device = None, + is_tp: bool = True, + all_reduce: bool = True): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + if intermediate_size is None: + intermediate_size = config.intermediate_size + # gate up + self.gate_up_proj = build_merged_colwise_linear( + config.hidden_size, + [intermediate_size, intermediate_size], + bias=False, + dtype=dtype, + device=device, + quant_config=quantization_config, + is_tp=is_tp, + ) - self.register_buffer('gate_up_weights', gate_up_weights) - self.register_buffer('down_weights', down_weights) + # silu and mul + self.act_fn = SiluAndMul(inplace=True) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - """moe forward.""" + # down + self.down_proj = build_rowwise_linear(intermediate_size, + config.hidden_size, + bias=False, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=is_tp, + all_reduce=all_reduce) - _, sequence_length, hidden_dim = hidden_states.shape + def forward(self, x): + """forward.""" + gate_up = self.gate_up_proj(x) + act = self.act_fn(gate_up) + return self.down_proj(act) + + +class Qwen2MoeSparseMoeBlock(nn.Module): + """moe block.""" + + def __init__(self, + config: PretrainedConfig, + layer_idx: int, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.layer_idx = layer_idx + self.hidden_dim = config.hidden_size + self.ffn_dim = config.moe_intermediate_size + self.num_experts = config.num_experts + self.top_k = config.num_experts_per_tok + self.norm_topk_prob = config.norm_topk_prob + self.renormalize = self.norm_topk_prob + + self.gate = build_rowwise_linear( + self.hidden_dim, + self.num_experts, + bias=False, + dtype=dtype, + device=device, + is_tp=False, + ) + + self.softmax_topk = SoftmaxTopK(self.top_k) + + self.experts = FusedMoE( + self.hidden_dim, + self.ffn_dim, + self.num_experts, + top_k=self.top_k, + renormalize=self.renormalize, + dtype=dtype, + device=device, + all_reduce=False, + ) + + intermediate_size = config.shared_expert_intermediate_size + self.shared_expert = Qwen2MoeMLP( + config=config, + intermediate_size=intermediate_size, + dtype=dtype, + device=device, + is_tp=True, + all_reduce=False, + ) + self.shared_expert_gate = build_rowwise_linear(config.hidden_size, + 1, + bias=False, + dtype=dtype, + device=device, + all_reduce=False) + world_size, _ = get_world_rank() + if world_size > 1: + self._all_reduce = True + else: + self._all_reduce = False + + def forward(self, hidden_states: torch.Tensor): + """forward.""" + batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) - # router_logits: (batch * sequence_length, n_experts) router_logits = self.gate(hidden_states) - routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - routing_weights, selected_experts = torch.topk(routing_weights, - self.top_k, - dim=-1, - sorted=False) - if self.norm_topk_prob: - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - out_states = fused_moe(hidden_states, - self.gate_up_weights, - self.down_weights, - routing_weights, - selected_experts, - topk=self.top_k, - renormalize=False) - - # all reduce of shared_expert is not necessary - shared_expert_output = self.shared_expert.forward(hidden_states) - shared_expert_output = F.sigmoid( - self.shared_expert_gate(hidden_states)) * shared_expert_output - - out_states = out_states + shared_expert_output - out_states = out_states.unflatten(0, (-1, sequence_length)) - - if dist.is_initialized(): + topk_weights, topk_ids = self.softmax_topk(router_logits) + + out_states = self.experts( + hidden_states, + topk_weights, + topk_ids, + ) + + shared_states = self.shared_expert(hidden_states) + shared_states = F.sigmoid( + self.shared_expert_gate(hidden_states)) * shared_states + out_states += shared_states + out_states = out_states.reshape(batch_size, sequence_length, -1) + + if self._all_reduce: dist.all_reduce(out_states) - return out_states, router_logits + return out_states + + +class Qwen2MoeDecoderLayer(nn.Module): + """decoder layer.""" + + def __init__(self, + config: PretrainedConfig, + layer_idx: int, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.layer_idx = layer_idx + quantization_config = getattr(config, 'quantization_config', None) + + # build attention layer + self.self_attn = Qwen2MoeAttention(config, dtype=dtype, device=device) + + # builf MLP + if (layer_idx not in config.mlp_only_layers) and ( + config.num_experts > 0) and ((layer_idx + 1) % + config.decoder_sparse_step == 0): + self.mlp = Qwen2MoeSparseMoeBlock(config, + layer_idx=layer_idx, + dtype=dtype, + device=device) + else: + self.mlp = Qwen2MoeMLP(config, + intermediate_size=config.intermediate_size, + dtype=dtype, + device=device) + + # build input layer norm + self.input_layernorm = RMSNorm(config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) + + # build attention layer norm + self.post_attention_layernorm = RMSNorm( + config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) + + def forward( + self, + hidden_states: torch.Tensor, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], + past_key_value: Optional[List[torch.FloatTensor]], + residual: Optional[torch.Tensor] = None, + attn_metadata: Any = None, + ): + + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + # Self Attention + hidden_states = self.self_attn( + hidden_states=hidden_states, + rotary_pos_emb=rotary_pos_emb, + past_key_value=past_key_value, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + + outputs = (hidden_states, residual) + return outputs + +class Qwen2MoeModel(nn.Module): + """model.""" -class PatchedQwen2MoeModel(nn.Module): + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + quantization_config = getattr(config, 'quantization_config', None) + + self.embed_tokens = nn.Embedding(config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=dtype, + device=device) + + # build all decode layers + self.layers = nn.ModuleList([ + Qwen2MoeDecoderLayer(config, layer_idx, dtype=dtype, device=device) + for layer_idx in range(config.num_hidden_layers) + ]) + + # build norm + self.norm = RMSNorm(config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) + + # build rotary embedding + emb_type = EmbeddingType.LinearScaling + rope_dim = config.hidden_size // config.num_attention_heads + rope_max_pos_emb = config.max_position_embeddings + rope_base = config.rope_theta + self.rotary_emb = build_rotary_embedding( + rope_dim, + rope_max_pos_emb, + rope_base, + emb_type=emb_type, + ) - def _continuous_batching_forward( - self, - input_ids: torch.LongTensor = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None): - """Rewrite implementation of Qwen2MoeModel.forward.""" - from transformers.modeling_outputs import MoeModelOutputWithPast + def forward( + self, + input_ids: torch.LongTensor = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + attn_metadata: Any = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ): + """Rewrite of LlamaModel.forward.""" + # token embedding if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) hidden_states = inputs_embeds - for idx, decoder_layer in enumerate(self.layers): + # rotary embedding + cos, sin = self.rotary_emb(hidden_states, position_ids) + cos, sin = cos[0], sin[0] + rotary_pos_emb = (cos, sin) + # decoding + residual = None + for idx, decoder_layer in enumerate(self.layers): past_key_value = past_key_values[idx] - layer_outputs = decoder_layer( + hidden_states, residual = decoder_layer( hidden_states, - position_ids=position_ids, + rotary_pos_emb=rotary_pos_emb, past_key_value=past_key_value, + residual=residual, + attn_metadata=attn_metadata, ) - hidden_states = layer_outputs[0] - hidden_states = self.norm(hidden_states) + # norm + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + def get_input_embeddings(self): + """get input embeddings.""" + return self.embed_tokens - return MoeModelOutputWithPast(last_hidden_state=hidden_states, - past_key_values=past_key_values, - hidden_states=None, - attentions=None, - router_logits=None) + +class Qwen2MoeForCausalLM(nn.Module): + """ModelForCausalLM.""" + + support_cuda_graph = True + + packed_modules_mapping = { + 'qkv_proj': [ + 'q_proj', + 'k_proj', + 'v_proj', + ], + 'gate_up_proj': [ + 'gate_proj', + 'up_proj', + ], + } + + def __init__(self, + config: PretrainedConfig, + ctx_mgr: StepContextManager, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.ctx_mgr = ctx_mgr + # build model + self.model = Qwen2MoeModel(config, dtype=dtype, device=device) + # build lm_head + self.lm_head = build_rowwise_linear(config.hidden_size, + config.vocab_size, + bias=False, + dtype=dtype, + device=device) def forward( self, - input_ids: torch.LongTensor = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: List[List[torch.Tensor]], + attn_metadata: Any = None, + inputs_embeds: torch.Tensor = None, **kwargs, ): - """Rewrite of Qwen2MoeModel.forward.""" - return self._continuous_batching_forward( - input_ids, - position_ids, - past_key_values, - inputs_embeds, + """model forward, return logits.""" + hidden_states = self.model( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, ) + + logits = self.lm_head(hidden_states) + logits = logits.float() + return logits + + def get_input_embeddings(self): + """get input embeddings.""" + return self.model.get_input_embeddings() + + def prepare_inputs_for_generation( + self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + context: StepContext = None, + ): + """prepare input.""" + # get input_ids, position_ids and attention metadatas + input_ids = context.input_ids + position_ids = context.position_ids + attn_metadata = context.attn_metadata + + # process vision embeddings + vision_embeddings = context.input_embeddings + vision_embedding_indexing = context.input_embedding_indexing + if vision_embeddings is not None and len(vision_embeddings) > 0: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + inputs_embeds[:, + vision_embedding_indexing, :] = vision_embeddings.to( + inputs_embeds) + + # inputs of forward + return dict( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + ) + + def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, + params_dict: Dict[str, nn.Parameter], + expert_params_mapping: List): + """load weight experts.""" + for (param_name, weight_name, expert_id, + shard_id) in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + load_weight(param, + loaded_weight, + expert_id=expert_id, + shard_id=shard_id) + break + else: + param = params_dict[name] + load_weight(param, loaded_weight) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """load weights.""" + # modify from vllm + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ('.qkv_proj', '.q_proj', 'q'), + ('.qkv_proj', '.k_proj', 'k'), + ('.qkv_proj', '.v_proj', 'v'), + ('.gate_up_proj', '.gate_proj', 0), + ('.gate_up_proj', '.up_proj', 1), + ] + + # expert map + num_experts = self.config.num_experts + expert_params_mapping = [] + for exp_id in range(num_experts): + gate_param = ('.experts.gate_up_weights', + f'.experts.{exp_id}.gate_proj.weight', exp_id, + 'gate') + up_param = ('.experts.gate_up_weights', + f'.experts.{exp_id}.up_proj.weight', exp_id, 'up') + down_param = ('.experts.down_weights', + f'.experts.{exp_id}.down_proj.weight', exp_id, + 'down') + expert_params_mapping += [gate_param, up_param, down_param] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if 'rotary_emb.inv_freq' in name: + continue + if ('rotary_emb.cos_cached' in name + or 'rotary_emb.sin_cached' in name): + continue + if self.config.tie_word_embeddings and 'lm_head.weight' in name: + continue + + if '.experts' in name: + self._load_weight_experts( + name, + loaded_weight, + params_dict, + expert_params_mapping=expert_params_mapping) + else: + for (param_name, weight_name, + shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + load_weight(param, loaded_weight, shard_id=shard_id) + break + else: + param = params_dict[name] + load_weight(param, loaded_weight) From 90a4a63a7387f0e0ae495fb24ca97142a06216d9 Mon Sep 17 00:00:00 2001 From: grimoire Date: Mon, 2 Sep 2024 15:54:48 +0800 Subject: [PATCH 56/89] add starcoder2 phi-3 phi-3 vision --- .../backends/default/rotary_embedding.py | 66 +- lmdeploy/pytorch/backends/rotary_embedding.py | 12 + lmdeploy/pytorch/check_env/__init__.py | 2 +- lmdeploy/pytorch/config.py | 4 + .../pytorch/configurations/phi3_vision.py | 18 - lmdeploy/pytorch/models/module_map.py | 40 +- lmdeploy/pytorch/models/phi3.py | 611 +++++++++++------- lmdeploy/pytorch/models/starcoder2.py | 538 ++++++++++----- lmdeploy/pytorch/nn/rotary_embedding.py | 6 +- 9 files changed, 869 insertions(+), 428 deletions(-) delete mode 100644 lmdeploy/pytorch/configurations/phi3_vision.py diff --git a/lmdeploy/pytorch/backends/default/rotary_embedding.py b/lmdeploy/pytorch/backends/default/rotary_embedding.py index fa151ea87..0abd6e195 100644 --- a/lmdeploy/pytorch/backends/default/rotary_embedding.py +++ b/lmdeploy/pytorch/backends/default/rotary_embedding.py @@ -4,8 +4,9 @@ import torch from torch import nn -from ..rotary_embedding import (EmbeddingType, RotaryEmbeddingBuilder, - RotaryEmbeddingImpl, YarnParameters) +from ..rotary_embedding import (EmbeddingType, LongRoPEScalingParameters, + RotaryEmbeddingBuilder, RotaryEmbeddingImpl, + YarnParameters) def _rotary_embedding_fwd(position_ids: torch.Tensor, @@ -247,6 +248,59 @@ def forward(self, x: torch.Tensor, position_ids: torch.Tensor): device_type=device_type) +class LongRoPEScalingRotaryEmbeddingImpl(RotaryEmbeddingImpl): + """yarn rotary embedding implementation.""" + + def __init__( + self, + dim: int, + base: int = 10000, + max_position_embeddings: int = 4096, + longrope_params: LongRoPEScalingParameters = None, + ): + super().__init__(dim, base) + short_factor = torch.tensor(longrope_params.short_factor, + dtype=torch.float32) + long_factor = torch.tensor(longrope_params.long_factor, + dtype=torch.float32) + self.register_buffer('short_factor', short_factor, persistent=False) + self.register_buffer('long_factor', long_factor, persistent=False) + self.original_max_position_embeddings = \ + longrope_params.original_max_position_embeddings + scale = (max_position_embeddings / + self.original_max_position_embeddings) + if scale <= 1.0: + self.mscale = 1.0 + else: + self.mscale = math.sqrt( + 1 + math.log(scale) / + math.log(self.original_max_position_embeddings)) + + def forward(self, x: torch.Tensor, position_ids: torch.Tensor): + """rope forward.""" + dtype = x.dtype + device = position_ids.device + if self.short_factor.device != device: + self.register_buffer('short_factor', + self.short_factor.to(device), + persistent=False) + self.register_buffer('long_factor', + self.long_factor.to(device), + persistent=False) + + max_pos_ids = position_ids.max() + 1 + ext_factors = torch.where( + max_pos_ids > self.original_max_position_embeddings, + self.long_factor, self.short_factor) + inv_freq = self.inv_freq * (1.0 / ext_factors) + return _rotary_embedding_fwd(position_ids, + inv_freq, + scaling_factor=1.0, + mscale=self.mscale, + dtype=dtype, + device_type=device) + + class DefaultRotaryEmbeddingBuilder(RotaryEmbeddingBuilder): """rotary embedding builder.""" @@ -259,6 +313,7 @@ def build( low_freq_factor: float = 1.0, high_freq_factor: float = 4.0, yarn_params: YarnParameters = None, + longrope_params: LongRoPEScalingParameters = None, emb_type: EmbeddingType = EmbeddingType.Default, ): """build.""" @@ -277,6 +332,13 @@ def build( scaling_factor, max_position_embeddings, yarn_params=yarn_params) + elif emb_type == EmbeddingType.LongRoPEScaling: + return LongRoPEScalingRotaryEmbeddingImpl( + dim, + base, + max_position_embeddings=max_position_embeddings, + longrope_params=longrope_params, + ) else: raise NotImplementedError( f'Unsupported embedding type: {emb_type}') diff --git a/lmdeploy/pytorch/backends/rotary_embedding.py b/lmdeploy/pytorch/backends/rotary_embedding.py index 2dfae1e78..bb4918e2c 100644 --- a/lmdeploy/pytorch/backends/rotary_embedding.py +++ b/lmdeploy/pytorch/backends/rotary_embedding.py @@ -2,6 +2,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum, auto +from typing import List class EmbeddingType(Enum): @@ -11,6 +12,7 @@ class EmbeddingType(Enum): DynamicNTKScaling = auto() Llama3 = auto() Yarn = auto() + LongRoPEScaling = auto() @dataclass @@ -22,6 +24,14 @@ class YarnParameters: mscale_all_dim: int = 0 +@dataclass +class LongRoPEScalingParameters: + """Long Ropescaling parameters.""" + short_factor: List[int] + long_factor: List[int] + original_max_position_embeddings: int + + class RotaryEmbeddingImpl(ABC): """rotary embedding implementation api.""" @@ -43,6 +53,8 @@ def build( scaling_factor: float = 1.0, low_freq_factor: float = 1.0, high_freq_factor: float = 4.0, + yarn_params: YarnParameters = None, + longrope_params: LongRoPEScalingParameters = None, emb_type: EmbeddingType = EmbeddingType.Default, ): """build.""" diff --git a/lmdeploy/pytorch/check_env/__init__.py b/lmdeploy/pytorch/check_env/__init__.py index 7675956ab..3f7e3d840 100644 --- a/lmdeploy/pytorch/check_env/__init__.py +++ b/lmdeploy/pytorch/check_env/__init__.py @@ -98,7 +98,7 @@ def check_env(device_type: str): MIN_TRANSFORMERS_VERSION = '4.33.0' -MAX_TRANSFORMERS_VERSION = '4.41.2' +MAX_TRANSFORMERS_VERSION = '4.44.1' def check_awq(hf_config): diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index 38384c0e6..dc23bcd6a 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -111,6 +111,10 @@ def from_pretrained(cls, from transformers import AutoConfig hf_config = AutoConfig.from_pretrained( pretrained_model_name_or_path, trust_remote_code=trust_remote_code) + if getattr(hf_config, 'model_type', None) == 'phi3': + # phi3 + trust_remote_code leads to error when tp. + hf_config = AutoConfig.from_pretrained( + pretrained_model_name_or_path) return cls.from_hf_config(hf_config, pretrained_model_name_or_path) @classmethod diff --git a/lmdeploy/pytorch/configurations/phi3_vision.py b/lmdeploy/pytorch/configurations/phi3_vision.py deleted file mode 100644 index 21f76f21a..000000000 --- a/lmdeploy/pytorch/configurations/phi3_vision.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .builder import AutoModelConfigBuilder -from .default import DefaultModelConfigBuilder - - -class Phi3VisionModelConfigBuilder(AutoModelConfigBuilder): - - @classmethod - def condition(cls, hf_config): - """config.""" - return hf_config.model_type == 'phi3_v' - - @classmethod - def build(cls, hf_config, model_path: str = None): - """build.""" - cfg = DefaultModelConfigBuilder.build(hf_config) - # cfg.unused_modules = ['model.vision_embed_tokens'] - return cfg diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index 0d3baaae0..d62371595 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -129,32 +129,14 @@ # starcoder2 MODULE_MAP.update({ - 'transformers.models.starcoder2.modeling_starcoder2.Starcoder2Attention': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.starcoder2.PatchedStarcoder2Attention', - 'transformers.models.starcoder2.modeling_starcoder2.Starcoder2FlashAttention2': # noqa: E501 - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.starcoder2.PatchedStarcoder2Attention', - 'transformers.models.starcoder2.modeling_starcoder2.Starcoder2SdpaAttention': # noqa: E501 - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.starcoder2.PatchedStarcoder2Attention', - 'transformers.models.starcoder2.modeling_starcoder2.Starcoder2Model': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaModel', - 'transformers.models.starcoder2.modeling_starcoder2.Starcoder2MLP': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.starcoder2.PatchedStarcoder2MLP', + 'Starcoder2ForCausalLM': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.starcoder2.Starcoder2ForCausalLM', }) # phi-3 MODULE_MAP.update({ - 'modeling_phi3.Phi3Attention': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.phi3.PatchedPhi3Attention', - 'modeling_phi3.Phi3FlashAttention2': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.phi3.PatchedPhi3Attention', - 'modeling_phi3.Phi3SdpaAttention': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.phi3.PatchedPhi3Attention', - 'modeling_phi3.Phi3Model': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.phi3.PatchedPhi3Model', - 'modeling_phi3.Phi3MLP': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.phi3.PatchedPhi3MLP', - 'modeling_phi3.Phi3RMSNorm': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaRMSNorm', + 'Phi3ForCausalLM': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.phi3.Phi3ForCausalLM', }) # cogvlm @@ -179,18 +161,8 @@ # phi3 vision MODULE_MAP.update({ - 'modeling_phi3_v.Phi3Attention': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.phi3.PatchedPhi3Attention', - 'modeling_phi3_v.Phi3FlashAttention2': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.phi3.PatchedPhi3Attention', - 'modeling_phi3_v.Phi3SdpaAttention': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.phi3.PatchedPhi3Attention', - 'modeling_phi3_v.Phi3VModel': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.phi3.PatchedPhi3Model', - 'modeling_phi3_v.Phi3MLP': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.phi3.PatchedPhi3MLP', - 'modeling_phi3_v.Phi3RMSNorm': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaRMSNorm', + 'Phi3VForCausalLM': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.phi3.Phi3VForCausalLM', }) # phi-3.5-moe diff --git a/lmdeploy/pytorch/models/phi3.py b/lmdeploy/pytorch/models/phi3.py index 5a4329d69..56110908e 100644 --- a/lmdeploy/pytorch/models/phi3.py +++ b/lmdeploy/pytorch/models/phi3.py @@ -1,267 +1,446 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import List, Optional, Tuple, Union +from typing import Any, Iterable, List, Optional, Tuple import torch -import torch.distributed as dist from torch import nn -from transformers.cache_utils import Cache -from transformers.modeling_outputs import BaseModelOutputWithPast - -from ..kernels import apply_rotary_pos_emb, fill_kv_cache, paged_attention_fwd -from ..weight_loader.dist_utils import (colwise_split_parallelize_linear, - rowwise_parallelize_linear) - +from transformers.configuration_utils import PretrainedConfig + +from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, EmbeddingType, + RMSNorm, SiluAndMul) +from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear, + build_qkv_proj, build_rowwise_linear) +from lmdeploy.pytorch.nn.rotary_embedding import (LongRoPEScalingParameters, + build_rotary_embedding) +from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight + + +class Phi3Attention(nn.Module): + """Rewrite module of Phi3Attention.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + num_heads = config.num_attention_heads + num_key_value_heads = config.num_key_value_heads + hidden_size = config.hidden_size + head_dim = getattr(config, 'head_dim', hidden_size // num_heads) + + # packed qkv + self.qkv_proj = build_qkv_proj( + hidden_size, + num_q_heads=num_heads, + num_kv_heads=num_key_value_heads, + head_size=head_dim, + bias=False, + quant_config=quantization_config, + dtype=dtype, + device=device, + ) -class PatchedPhi3Attention(nn.Module): + # rotary embedding + self.apply_rotary_pos_emb = ApplyRotaryEmb() + + # attention + sliding_window = getattr(config, 'sliding_window', None) + self.attn_fwd = Attention( + num_heads, + head_dim, + num_kv_heads=num_key_value_heads, + v_head_size=head_dim, + sliding_window=sliding_window, + ) - def _load_weights(self, loader, rank: int, world_size: int, - device: torch.device): - """load weights.""" - sections = [ - self.num_heads * self.head_dim, - self.num_key_value_heads * self.head_dim, - self.num_key_value_heads * self.head_dim, - ] - for mod_name in ['qkv_proj']: - colwise_split_parallelize_linear(getattr(self, mod_name), - sections, - loader, - rank=rank, - world_size=world_size, - prefix=mod_name) - for mod_name in ['o_proj']: - rowwise_parallelize_linear(getattr(self, mod_name), - loader, - rank=rank, - world_size=world_size, - prefix=mod_name) - - @classmethod - def _distribute_output_fn(cls, outputs, **kwargs): - """Distribution output hook.""" - dist.all_reduce(outputs[0]) - return outputs + # o_proj + self.o_proj = build_rowwise_linear(num_heads * head_dim, + hidden_size, + bias=False, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) - def _contiguous_batching_forward_impl( + def forward( self, hidden_states: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], past_key_value: Optional[Tuple[torch.Tensor]] = None, - world_size: int = 1, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - """Rewrite implementation of LlamaAttention.forward. - - Add continuous batching support. Add paged attention support. TP - support. - """ - context = self.context.context - q_start_loc = context.q_start_loc - kv_seq_length = context.kv_seq_length - q_seq_length = context.q_seq_length - block_offsets = context.block_offsets - max_q_seq_length = context.max_q_seq_length - max_kv_seq_length = context.max_kv_seq_length - position_ids_1d = context.position_ids_1d - - num_heads = self.num_heads // world_size - num_kv_heads = self.num_key_value_heads // world_size - head_dim = self.head_dim - hidden_size = num_heads * head_dim - - def __qkv_proj(hidden_states): - """qkv proj.""" - qkv_states = self.qkv_proj(hidden_states) - query_states, key_states, value_states = qkv_states.split( - (num_heads * head_dim, num_kv_heads * head_dim, - num_kv_heads * head_dim), - dim=-1, - ) - - return query_states, key_states, value_states - - def __rotary_emb_fn(query_states, key_states, value_states): - """rotary embedding func.""" - if not hasattr(context, '_cos'): - cos, sin = self.rotary_emb( - value_states, - position_ids=position_ids_1d[None, :], - seq_len=max_kv_seq_length) - context._cos = cos - context._sin = sin - cos = context._cos - sin = context._sin - query_states, key_states = apply_rotary_pos_emb( - query_states, - key_states, - cos[0], - sin[0], - position_ids, - torch.arange(0, - len(position_ids_1d), - device=query_states.device), - q_embed=query_states, - k_embed=key_states) - return query_states, key_states, value_states - - query_states, key_states, value_states = __qkv_proj(hidden_states) - - query_states = query_states.view(-1, num_heads, head_dim) - key_states = key_states.view(-1, num_kv_heads, head_dim) - value_states = value_states.view(-1, num_kv_heads, head_dim) - - # inplace rotary - query_states, key_states, value_states = __rotary_emb_fn( - query_states, key_states, value_states) - fill_kv_cache( + attn_metadata: Any = None, + ): + """Rewrite of LlamaAttention.forward.""" + # qkv proj + qkv_states = self.qkv_proj(hidden_states) + # (-1, heads, head_dim) + qkv_states = qkv_states.flatten(0, -2) + query_states, key_states, value_states = self.qkv_proj.split_qkv( + qkv_states) + + # apply rotary embedding + cos, sin = rotary_pos_emb + query_states, key_states = self.apply_rotary_pos_emb( + query_states, key_states, - value_states, - past_key_value[0], - past_key_value[1], - q_start_loc, - q_seq_length, - kv_seq_length=kv_seq_length, - max_q_seq_length=max_q_seq_length, - block_offsets=block_offsets, + cos, + sin, + inplace=True, ) - attn_output = query_states - paged_attention_fwd( + # attention + attn_output = self.attn_fwd( query_states, + key_states, + value_states, past_key_value[0], past_key_value[1], - attn_output, - block_offsets, - q_start_loc=q_start_loc, - q_seqlens=q_seq_length, - kv_seqlens=kv_seq_length, - max_seqlen=max_q_seq_length, - window_size=self.config.sliding_window, + attn_metadata, + inplace=True, ) - attn_output = attn_output.reshape(*hidden_states.shape[:-1], - hidden_size) + attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1) + # o proj attn_output = self.o_proj(attn_output) - return attn_output, None, past_key_value + return attn_output + + +class Phi3MLP(nn.Module): + """mlp.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + # gate up + self.gate_up_proj = build_merged_colwise_linear( + config.hidden_size, + [config.intermediate_size, config.intermediate_size], + bias=False, + dtype=dtype, + device=device, + quant_config=quantization_config, + is_tp=True, + ) + + # silu and mul + self.act_fn = SiluAndMul(inplace=True) + + # down + self.down_proj = build_rowwise_linear(config.intermediate_size, + config.hidden_size, + bias=False, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward(self, x): + """forward.""" + gate_up = self.gate_up_proj(x) + act = self.act_fn(gate_up) + return self.down_proj(act) + + +class Phi3DecoderLayer(nn.Module): + """decoder layer.""" + + def __init__(self, + config: PretrainedConfig, + layer_idx: int, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.layer_idx = layer_idx + quantization_config = getattr(config, 'quantization_config', None) + + # build attention layer + self.self_attn = Phi3Attention(config, dtype=dtype, device=device) + + # builf MLP + self.mlp = Phi3MLP(config, dtype=dtype, device=device) + + # build input layer norm + self.input_layernorm = RMSNorm(config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) + + # build attention layer norm + self.post_attention_layernorm = RMSNorm( + config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - """rewrite of forward.""" - world_size = 1 - if dist.is_initialized(): - world_size = dist.get_world_size() - return self._contiguous_batching_forward_impl( - hidden_states, - position_ids, - past_key_value, - world_size=world_size, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], + past_key_value: Optional[List[torch.FloatTensor]], + residual: Optional[torch.Tensor] = None, + attn_metadata: Any = None, + ): + + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + # Self Attention + hidden_states = self.self_attn( + hidden_states=hidden_states, + rotary_pos_emb=rotary_pos_emb, + past_key_value=past_key_value, + attn_metadata=attn_metadata, ) + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) -class PatchedPhi3MLP(nn.Module): - - def _load_weights(self, loader, rank: int, world_size: int, - device: torch.device): - """load weights.""" - for mod_name in ['gate_up_proj']: - out_size = self.gate_up_proj.out_features - sections = [out_size // 2] * 2 - colwise_split_parallelize_linear(getattr(self, mod_name), - sections, - loader, - rank=rank, - world_size=world_size, - prefix=mod_name) - for mod_name in ['down_proj']: - rowwise_parallelize_linear(getattr(self, mod_name), - loader, - rank=rank, - world_size=world_size, - prefix=mod_name) - - @classmethod - def _distribute_output_fn(cls, outputs, **kwargs): - """Distribution output hook.""" - dist.all_reduce(outputs) + outputs = (hidden_states, residual) return outputs -class PatchedPhi3Model(nn.Module): +class Phi3Model(nn.Module): + """model.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + quantization_config = getattr(config, 'quantization_config', None) + + self.embed_tokens = nn.Embedding(config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=dtype, + device=device) + + # build all decode layers + self.layers = nn.ModuleList([ + Phi3DecoderLayer(config, layer_idx, dtype=dtype, device=device) + for layer_idx in range(config.num_hidden_layers) + ]) + + # build norm + self.norm = RMSNorm(config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) + + # build rotary embedding + emb_type = EmbeddingType.LinearScaling + rope_dim = config.hidden_size // config.num_attention_heads + rope_max_pos_emb = config.max_position_embeddings + rope_base = config.rope_theta + rope_scaling = config.rope_scaling + if rope_scaling is not None: + scaling_type = rope_scaling['type'] + assert scaling_type in ['longrope', 'su'] + emb_type = EmbeddingType.LongRoPEScaling + ori_pos_emb = getattr(config, 'original_max_position_embeddings', + rope_max_pos_emb) + longrope_params = LongRoPEScalingParameters( + short_factor=rope_scaling['short_factor'], + long_factor=rope_scaling['long_factor'], + original_max_position_embeddings=ori_pos_emb) + self.rotary_emb = build_rotary_embedding( + rope_dim, + rope_max_pos_emb, + rope_base, + longrope_params=longrope_params, + emb_type=emb_type, + ) + else: + self.rotary_emb = build_rotary_embedding( + rope_dim, + rope_max_pos_emb, + rope_base, + emb_type=emb_type, + ) - def _continuous_batching_forward( + def forward( self, input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + attn_metadata: Any = None, inputs_embeds: Optional[torch.FloatTensor] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - """Rewrite implementation of LlamaModel.forward.""" - output_attentions = True - use_cache = True - context = self.context.context - # get inputs from context - vision_embeddings = context.input_embeddings - vision_embedding_indexing = context.input_embedding_indexing + ): + """Rewrite of LlamaModel.forward.""" + # token embedding if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if vision_embeddings is not None and len(vision_embeddings) > 0: - inputs_embeds[:, - vision_embedding_indexing, :] = vision_embeddings.to( - inputs_embeds) - - # Attention mask is not necessary in continuous batching - attention_mask = None hidden_states = inputs_embeds - # decoder layers + # rotary embedding + cos, sin = self.rotary_emb(hidden_states, position_ids) + cos, sin = cos[0], sin[0] + rotary_pos_emb = (cos, sin) + + # decoding + residual = None for idx, decoder_layer in enumerate(self.layers): - past_key_value = (past_key_values[idx] - if past_key_values is not None else None) - layer_outputs = decoder_layer( + past_key_value = past_key_values[idx] + hidden_states, residual = decoder_layer( hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, + rotary_pos_emb=rotary_pos_emb, past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, + residual=residual, + attn_metadata=attn_metadata, ) - hidden_states = layer_outputs[0] - - hidden_states = self.norm(hidden_states) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=past_key_value, - hidden_states=None, - attentions=None, - ) + # norm + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + def get_input_embeddings(self): + """get input embeddings.""" + return self.embed_tokens + + +class Phi3ForCausalLM(nn.Module): + """ModelForCausalLM.""" + + support_cuda_graph = True + + packed_modules_mapping = { + 'gate_up_proj': [ + 'gate_proj', + 'up_proj', + ], + } + + def __init__(self, + config: PretrainedConfig, + ctx_mgr: StepContextManager, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.ctx_mgr = ctx_mgr + # build model + self.model = Phi3Model(config, dtype=dtype, device=device) + # build lm_head + self.lm_head = build_rowwise_linear(config.hidden_size, + config.vocab_size, + bias=False, + dtype=dtype, + device=device) def forward( self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: List[List[torch.Tensor]], + attn_metadata: Any = None, + inputs_embeds: torch.Tensor = None, **kwargs, - ) -> Union[Tuple, BaseModelOutputWithPast]: - """rewrite of forward.""" - return self._continuous_batching_forward( - input_ids, - attention_mask, - position_ids, - past_key_values, - inputs_embeds, + ): + """model forward, return logits.""" + hidden_states = self.model( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, ) + + logits = self.lm_head(hidden_states) + logits = logits.float() + return logits + + def get_input_embeddings(self): + """get input embeddings.""" + return self.model.get_input_embeddings() + + def prepare_inputs_for_generation( + self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + context: StepContext = None, + ): + """prepare input.""" + # get input_ids, position_ids and attention metadatas + input_ids = context.input_ids + position_ids = context.position_ids + attn_metadata = context.attn_metadata + + # process vision embeddings + vision_embeddings = context.input_embeddings + vision_embedding_indexing = context.input_embedding_indexing + if vision_embeddings is not None and len(vision_embeddings) > 0: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + inputs_embeds[:, + vision_embedding_indexing, :] = vision_embeddings.to( + inputs_embeds) + + # inputs of forward + return dict( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """load weights.""" + # modify from vllm + + config = self.config + num_heads = config.num_attention_heads + num_key_value_heads = config.num_key_value_heads + hidden_size = config.hidden_size + head_dim = getattr(config, 'head_dim', hidden_size // num_heads) + qkv_section = [ + head_dim * num_heads, head_dim * num_key_value_heads, + head_dim * num_key_value_heads + ] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if 'rotary_emb.inv_freq' in name: + continue + if ('rotary_emb.cos_cached' in name + or 'rotary_emb.sin_cached' in name): + continue + if self.config.tie_word_embeddings and 'lm_head.weight' in name: + continue + if 'vision_embed_tokens' in name: + continue + if '.qkv_proj' in name: + q, k, v = loaded_weight.split(qkv_section) + param = params_dict[name] + load_weight(param, q, shard_id='q') + load_weight(param, k, shard_id='k') + load_weight(param, v, shard_id='v') + elif '.gate_up_proj' in name: + gate, up = loaded_weight.chunk(2) + param = params_dict[name] + load_weight(param, gate, shard_id=0) + load_weight(param, up, shard_id=1) + else: + param = params_dict[name] + load_weight(param, loaded_weight) + + +class Phi3VForCausalLM(Phi3ForCausalLM): + ... diff --git a/lmdeploy/pytorch/models/starcoder2.py b/lmdeploy/pytorch/models/starcoder2.py index 1a2d9d748..ea1a7d526 100644 --- a/lmdeploy/pytorch/models/starcoder2.py +++ b/lmdeploy/pytorch/models/starcoder2.py @@ -1,186 +1,412 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Optional, Tuple +from typing import Any, Iterable, List, Optional, Tuple import torch -import torch.distributed as dist from torch import nn +from transformers.configuration_utils import PretrainedConfig -from ..kernels import apply_rotary_pos_emb, fill_kv_cache, paged_attention_fwd -from ..weight_loader.dist_utils import (colwise_parallelize_linear, - rowwise_parallelize_linear) +from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, EmbeddingType, + build_rotary_embedding) +from lmdeploy.pytorch.nn.linear import (build_colwise_linear, build_qkv_proj, + build_rowwise_linear) +from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight -class PatchedStarcoder2MLP(nn.Module): +class Starcoder2Attention(nn.Module): + """Rewrite module of Starcoder2Attention.""" - def _load_weights(self, loader, rank: int, world_size: int, - device: torch.device): - """load weights.""" - for mod_name in ['c_fc']: - colwise_parallelize_linear(getattr(self, mod_name), - loader, - rank=rank, - world_size=world_size, - prefix=mod_name) - for mod_name in ['c_proj']: - rowwise_parallelize_linear(getattr(self, mod_name), - loader, - rank=rank, - world_size=world_size, - prefix=mod_name) - - @classmethod - def _distribute_output_fn(cls, outputs, **kwargs): - """Distribution output hook.""" - dist.all_reduce(outputs) - return outputs - - def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: - """rewrite of starcoder2mlp forward.""" - hidden_states = self.c_fc(hidden_states) - hidden_states = self.act(hidden_states) - hidden_states = self.c_proj(hidden_states) - return hidden_states + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + num_heads = config.num_attention_heads + num_key_value_heads = config.num_key_value_heads + hidden_size = config.hidden_size + head_dim = getattr(config, 'head_dim', hidden_size // num_heads) + # packed qkv + self.qkv_proj = build_qkv_proj( + hidden_size, + num_q_heads=num_heads, + num_kv_heads=num_key_value_heads, + head_size=head_dim, + bias=config.use_bias, + quant_config=quantization_config, + dtype=dtype, + device=device, + ) -class PatchedStarcoder2Attention(nn.Module): + # rotary embedding + self.apply_rotary_pos_emb = ApplyRotaryEmb() - def _load_weights(self, loader, rank: int, world_size: int, - device: torch.device): - """load weights.""" - for mod_name in ['q_proj', 'k_proj', 'v_proj']: - colwise_parallelize_linear(getattr(self, mod_name), - loader, - rank=rank, - world_size=world_size, - prefix=mod_name) - for mod_name in ['o_proj']: - rowwise_parallelize_linear(getattr(self, mod_name), - loader, - rank=rank, - world_size=world_size, - prefix=mod_name) - - @classmethod - def _distribute_output_fn(cls, outputs, **kwargs): - """Distribution output hook.""" - dist.all_reduce(outputs[0]) - return outputs - - def _contiguous_batching_forward_impl( + # attention + sliding_window = getattr(config, 'sliding_window', None) + self.attn_fwd = Attention( + num_heads, + head_dim, + num_kv_heads=num_key_value_heads, + v_head_size=head_dim, + sliding_window=sliding_window, + ) + + # o_proj + self.o_proj = build_rowwise_linear(num_heads * head_dim, + hidden_size, + bias=config.use_bias, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward( self, hidden_states: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - world_size: int = 1, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - """Rewrite implementation of LlamaAttention.forward. - - Add continuous batching support. Add paged attention support. TP - support. - """ - context = self.context.context - q_start_loc = context.q_start_loc - q_seq_length = context.q_seq_length - kv_seq_length = context.kv_seq_length - block_offsets = context.block_offsets - max_q_seq_length = context.max_q_seq_length - max_kv_seq_length = context.max_kv_seq_length - - num_heads = self.num_heads // world_size - num_kv_heads = self.num_key_value_heads // world_size - head_dim = self.head_dim - hidden_size = num_heads * head_dim - - def __qkv_proj(hidden_states): - """qkv proj.""" - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - return query_states, key_states, value_states - - def __rotary_emb_fn(query_states, key_states, value_states): - """rotary embedding func.""" - if not hasattr(context, '_cos'): - cos, sin = self.rotary_emb(value_states.transpose(0, 1), - seq_len=max_kv_seq_length) - context._cos = cos - context._sin = sin - else: - cos = context._cos - sin = context._sin - query_states, key_states = apply_rotary_pos_emb( - query_states, - key_states, - cos, - sin, - position_ids, - context.position_ids_1d, - q_embed=query_states, - k_embed=key_states) - return query_states, key_states, value_states - - query_states, key_states, value_states = __qkv_proj(hidden_states) - - query_states = query_states.view(-1, num_heads, head_dim) - key_states = key_states.view(-1, num_kv_heads, head_dim) - value_states = value_states.view(-1, num_kv_heads, head_dim) - - query_states, key_states, value_states = __rotary_emb_fn( - query_states, key_states, value_states) - - fill_kv_cache( + attn_metadata: Any = None, + ): + """Rewrite of LlamaAttention.forward.""" + # qkv proj + qkv_states = self.qkv_proj(hidden_states) + # (-1, heads, head_dim) + qkv_states = qkv_states.flatten(0, -2) + query_states, key_states, value_states = self.qkv_proj.split_qkv( + qkv_states) + + # apply rotary embedding + cos, sin = rotary_pos_emb + query_states, key_states = self.apply_rotary_pos_emb( + query_states, key_states, - value_states, - past_key_value[0], - past_key_value[1], - q_start_loc, - q_seq_length, - kv_seq_length=kv_seq_length, - max_q_seq_length=max_q_seq_length, - block_offsets=block_offsets, + cos, + sin, + inplace=True, ) - attn_output = query_states - paged_attention_fwd( + # attention + attn_output = self.attn_fwd( query_states, + key_states, + value_states, past_key_value[0], past_key_value[1], - attn_output, - block_offsets, - q_start_loc=q_start_loc, - q_seqlens=q_seq_length, - kv_seqlens=kv_seq_length, - max_seqlen=max_q_seq_length, + attn_metadata, + inplace=True, ) - attn_output = attn_output.reshape(*hidden_states.shape[:-1], - hidden_size) + attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1) + # o proj attn_output = self.o_proj(attn_output) + return attn_output + - return attn_output, None, past_key_value +class Starcoder2MLP(nn.Module): + """mlp.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + # gate up + self.c_fc = build_colwise_linear( + config.hidden_size, + config.intermediate_size, + bias=config.use_bias, + dtype=dtype, + device=device, + quant_config=quantization_config, + is_tp=True, + ) + + # silu and mul + hidden_act = config.hidden_act + if hidden_act is None: + hidden_act = 'gelu_pytorch_tanh' + assert hidden_act == 'gelu_pytorch_tanh' + self.act_fn = nn.GELU(approximate='tanh') + + # down + self.c_proj = build_rowwise_linear(config.intermediate_size, + config.hidden_size, + bias=config.use_bias, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward(self, x): + """forward.""" + gate_up = self.c_fc(x) + act = self.act_fn(gate_up) + return self.c_proj(act) + + +class Starcoder2DecoderLayer(nn.Module): + """decoder layer.""" + + def __init__(self, + config: PretrainedConfig, + layer_idx: int, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.layer_idx = layer_idx + + # build attention layer + self.self_attn = Starcoder2Attention(config, + dtype=dtype, + device=device) + + # builf MLP + self.mlp = Starcoder2MLP(config, dtype=dtype, device=device) + + # build input layer norm + self.input_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.norm_epsilon, + dtype=dtype, + device=device) + + # build attention layer norm + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.norm_epsilon, + dtype=dtype, + device=device) def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], + past_key_value: Optional[List[torch.FloatTensor]], + attn_metadata: Any = None, + ): + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states = self.self_attn( + hidden_states=hidden_states, + rotary_pos_emb=rotary_pos_emb, + past_key_value=past_key_value, + attn_metadata=attn_metadata, + ) + hidden_states = residual + hidden_states + + # Fully Connected + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class Starcoder2Model(nn.Module): + """model.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=dtype, + device=device) + + # build all decode layers + self.layers = nn.ModuleList([ + Starcoder2DecoderLayer(config, + layer_idx, + dtype=dtype, + device=device) + for layer_idx in range(config.num_hidden_layers) + ]) + + # build norm + self.norm = nn.LayerNorm(config.hidden_size, + eps=config.norm_epsilon, + dtype=dtype, + device=device) + + # build rotary embedding + emb_type = EmbeddingType.LinearScaling + rope_dim = config.hidden_size // config.num_attention_heads + rope_max_pos_emb = config.max_position_embeddings + rope_base = config.rope_theta + self.rotary_emb = build_rotary_embedding( + rope_dim, + rope_max_pos_emb, + rope_base, + emb_type=emb_type, + ) + + def forward( + self, + input_ids: torch.LongTensor = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - """forward.""" - world_size = 1 - if dist.is_initialized(): - world_size = dist.get_world_size() - return self._contiguous_batching_forward_impl( - hidden_states, - position_ids, - past_key_value, - output_attentions, - world_size=world_size, + past_key_values: Optional[List[torch.FloatTensor]] = None, + attn_metadata: Any = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ): + """Rewrite of LlamaModel.forward.""" + + # token embedding + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + hidden_states = inputs_embeds + + # rotary embedding + cos, sin = self.rotary_emb(hidden_states, position_ids) + cos, sin = cos[0], sin[0] + rotary_pos_emb = (cos, sin) + + # decoding + for idx, decoder_layer in enumerate(self.layers): + past_key_value = past_key_values[idx] + hidden_states = decoder_layer( + hidden_states, + rotary_pos_emb=rotary_pos_emb, + past_key_value=past_key_value, + attn_metadata=attn_metadata, + ) + + # norm + hidden_states = self.norm(hidden_states) + + return hidden_states + + def get_input_embeddings(self): + """get input embeddings.""" + return self.embed_tokens + + +class Starcoder2ForCausalLM(nn.Module): + """ModelForCausalLM.""" + + support_cuda_graph = True + + packed_modules_mapping = { + 'qkv_proj': [ + 'q_proj', + 'k_proj', + 'v_proj', + ], + } + + def __init__(self, + config: PretrainedConfig, + ctx_mgr: StepContextManager, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.ctx_mgr = ctx_mgr + # build model + self.model = Starcoder2Model(config, dtype=dtype, device=device) + # build lm_head + self.lm_head = build_rowwise_linear(config.hidden_size, + config.vocab_size, + bias=False, + dtype=dtype, + device=device) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: List[List[torch.Tensor]], + attn_metadata: Any = None, + inputs_embeds: torch.Tensor = None, + **kwargs, + ): + """model forward, return logits.""" + hidden_states = self.model( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, ) + + logits = self.lm_head(hidden_states) + logits = logits.float() + return logits + + def update_weights(self): + """update weights.""" + self.lm_head.weight = self.model.embed_tokens.weight + + def get_input_embeddings(self): + """get input embeddings.""" + return self.model.get_input_embeddings() + + def prepare_inputs_for_generation( + self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + context: StepContext = None, + ): + """prepare input.""" + # get input_ids, position_ids and attention metadatas + input_ids = context.input_ids + position_ids = context.position_ids + attn_metadata = context.attn_metadata + + # process vision embeddings + vision_embeddings = context.input_embeddings + vision_embedding_indexing = context.input_embedding_indexing + if vision_embeddings is not None and len(vision_embeddings) > 0: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + inputs_embeds[:, + vision_embedding_indexing, :] = vision_embeddings.to( + inputs_embeds) + + # inputs of forward + return dict( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """load weights.""" + # modify from vllm + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ('.qkv_proj', '.q_proj', 'q'), + ('.qkv_proj', '.k_proj', 'k'), + ('.qkv_proj', '.v_proj', 'v'), + ] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if 'rotary_emb.inv_freq' in name: + continue + if ('rotary_emb.cos_cached' in name + or 'rotary_emb.sin_cached' in name): + continue + if self.config.tie_word_embeddings and 'lm_head.weight' in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + load_weight(param, loaded_weight, shard_id=shard_id) + break + else: + param = params_dict[name] + load_weight(param, loaded_weight) diff --git a/lmdeploy/pytorch/nn/rotary_embedding.py b/lmdeploy/pytorch/nn/rotary_embedding.py index c39eb21a5..9d035c961 100644 --- a/lmdeploy/pytorch/nn/rotary_embedding.py +++ b/lmdeploy/pytorch/nn/rotary_embedding.py @@ -2,7 +2,9 @@ from torch import Tensor, nn from ..backends import LayerType, get_backend -from ..backends.rotary_embedding import EmbeddingType, YarnParameters +from ..backends.rotary_embedding import (EmbeddingType, + LongRoPEScalingParameters, + YarnParameters) def build_rotary_embedding( @@ -13,6 +15,7 @@ def build_rotary_embedding( low_freq_factor: float = 1.0, high_freq_factor: float = 4.0, yarn_params: YarnParameters = None, + longrope_params: LongRoPEScalingParameters = None, emb_type: EmbeddingType = EmbeddingType.Default) -> nn.Module: """build rotary embedding op.""" backend = get_backend() @@ -25,6 +28,7 @@ def build_rotary_embedding( low_freq_factor=low_freq_factor, high_freq_factor=high_freq_factor, yarn_params=yarn_params, + longrope_params=longrope_params, emb_type=emb_type) From 5f772ab0e1864a2ff117692ca369ebd62ba06bcb Mon Sep 17 00:00:00 2001 From: grimoire Date: Mon, 2 Sep 2024 19:56:54 +0800 Subject: [PATCH 57/89] support phi3 moe --- .../backends/default/rotary_embedding.py | 38 +- lmdeploy/pytorch/backends/rotary_embedding.py | 2 + lmdeploy/pytorch/config.py | 2 +- lmdeploy/pytorch/models/module_map.py | 16 +- lmdeploy/pytorch/models/phi3_moe.py | 510 +++++++++++++++--- lmdeploy/pytorch/nn/linear.py | 6 + 6 files changed, 469 insertions(+), 105 deletions(-) diff --git a/lmdeploy/pytorch/backends/default/rotary_embedding.py b/lmdeploy/pytorch/backends/default/rotary_embedding.py index 0abd6e195..1003e678d 100644 --- a/lmdeploy/pytorch/backends/default/rotary_embedding.py +++ b/lmdeploy/pytorch/backends/default/rotary_embedding.py @@ -12,7 +12,7 @@ def _rotary_embedding_fwd(position_ids: torch.Tensor, inv_freq: torch.Tensor, scaling_factor: float, - mscale: float = 1.0, + mscale: float = None, dtype: torch.dtype = None, device_type: torch.device = None): """rotary embedding forward.""" @@ -36,7 +36,7 @@ def _rotary_embedding_fwd(position_ids: torch.Tensor, cos = emb.cos() sin = emb.sin() - if mscale != 1.0: + if mscale is not None: cos = cos * mscale sin = sin * mscale @@ -233,6 +233,8 @@ def __init__(self, self.mscale = float( yarn_get_mscale(self.scaling_factor, self.mscale) / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)) + if self.mscale == 1.0: + self.mscale = None def forward(self, x: torch.Tensor, position_ids: torch.Tensor): """forward.""" @@ -267,14 +269,18 @@ def __init__( self.register_buffer('long_factor', long_factor, persistent=False) self.original_max_position_embeddings = \ longrope_params.original_max_position_embeddings - scale = (max_position_embeddings / - self.original_max_position_embeddings) - if scale <= 1.0: - self.mscale = 1.0 - else: - self.mscale = math.sqrt( - 1 + math.log(scale) / - math.log(self.original_max_position_embeddings)) + self.mscale = None + self.short_mscale = longrope_params.short_mscale + self.long_mscale = longrope_params.long_mscale + if self.short_mscale is None and self.long_mscale is None: + scale = (max_position_embeddings / + self.original_max_position_embeddings) + if scale <= 1.0: + self.mscale = 1.0 + else: + self.mscale = math.sqrt( + 1 + math.log(scale) / + math.log(self.original_max_position_embeddings)) def forward(self, x: torch.Tensor, position_ids: torch.Tensor): """rope forward.""" @@ -289,14 +295,18 @@ def forward(self, x: torch.Tensor, position_ids: torch.Tensor): persistent=False) max_pos_ids = position_ids.max() + 1 - ext_factors = torch.where( - max_pos_ids > self.original_max_position_embeddings, - self.long_factor, self.short_factor) + mask = max_pos_ids > self.original_max_position_embeddings + ext_factors = torch.where(mask, self.long_factor, self.short_factor) + + mscale = self.mscale + if mscale is None: + mscale = torch.where(mask, self.long_mscale, self.short_mscale) + inv_freq = self.inv_freq * (1.0 / ext_factors) return _rotary_embedding_fwd(position_ids, inv_freq, scaling_factor=1.0, - mscale=self.mscale, + mscale=mscale, dtype=dtype, device_type=device) diff --git a/lmdeploy/pytorch/backends/rotary_embedding.py b/lmdeploy/pytorch/backends/rotary_embedding.py index bb4918e2c..ce85b1846 100644 --- a/lmdeploy/pytorch/backends/rotary_embedding.py +++ b/lmdeploy/pytorch/backends/rotary_embedding.py @@ -30,6 +30,8 @@ class LongRoPEScalingParameters: short_factor: List[int] long_factor: List[int] original_max_position_embeddings: int + long_mscale: float = None + short_mscale: float = None class RotaryEmbeddingImpl(ABC): diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index dc23bcd6a..de5aeb0b1 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -111,7 +111,7 @@ def from_pretrained(cls, from transformers import AutoConfig hf_config = AutoConfig.from_pretrained( pretrained_model_name_or_path, trust_remote_code=trust_remote_code) - if getattr(hf_config, 'model_type', None) == 'phi3': + if getattr(hf_config, 'model_type', None) in ['phi3']: # phi3 + trust_remote_code leads to error when tp. hf_config = AutoConfig.from_pretrained( pretrained_model_name_or_path) diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index d62371595..c692e79f3 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -167,20 +167,8 @@ # phi-3.5-moe MODULE_MAP.update({ - 'modeling_phimoe.PhiMoEAttention': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mixtral.PatchedMixtralAttention', - 'modeling_phimoe.PhiMoEFlashAttention2': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mixtral.PatchedMixtralAttention', - 'modeling_phimoe.PhiMoESdpaAttention': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mixtral.PatchedMixtralAttention', - 'modeling_phimoe.PhiMoEModel': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mixtral.PatchedMixtralModel', - 'modeling_phimoe.PhiMoEBlockSparseTop2MLP': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mixtral.PatchedMixtralBLockSparseTop2MLP', - 'modeling_phimoe.PhiMoEBLockSparseTop2MLP': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mixtral.PatchedMixtralBLockSparseTop2MLP', - 'modeling_phimoe.PhiMoESparseMoeBlock': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.phi3_moe.PatchedPhiMoESparseMoeBlock', + 'PhiMoEForCausalLM': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.phi3_moe.PhiMoEForCausalLM', }) CUSTOM_MODULE_MAP = dict() diff --git a/lmdeploy/pytorch/models/phi3_moe.py b/lmdeploy/pytorch/models/phi3_moe.py index 6fb07272d..4b00a6161 100644 --- a/lmdeploy/pytorch/models/phi3_moe.py +++ b/lmdeploy/pytorch/models/phi3_moe.py @@ -1,16 +1,21 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Any, Iterable, List, Optional, Tuple + import torch -from torch import distributed as dist from torch import nn -from lmdeploy.pytorch.kernels.fused_moe import fused_moe +from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager +from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, EmbeddingType +from lmdeploy.pytorch.nn.linear import build_qkv_proj, build_rowwise_linear +from lmdeploy.pytorch.nn.moe import FusedMoE +from lmdeploy.pytorch.nn.rotary_embedding import (LongRoPEScalingParameters, + build_rotary_embedding) +from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight -# from https://huggingface.co/microsoft/Phi-3.5-MoE-instruct/blob/482a9ba0eb0e1fa1671e3560e009d7cec2e5147c/modeling_phimoe.py#L883 # noqa: E501 def sparsemixer(scores, top_k, jitter_eps): assert top_k == 2 - final_multipliers = scores.new_empty((scores.shape[0], top_k)) - final_experts = torch.empty_like(final_multipliers) + with torch.no_grad(): # compute mask for sparsity mask_logits_threshold, max_ind = scores.max(dim=-1, keepdim=True) @@ -22,11 +27,11 @@ def sparsemixer(scores, top_k, jitter_eps): masked_gates = scores.masked_fill(mask_logits_threshold, float('-inf')) selected_experts = max_ind - final_experts[:, 0:1] = max_ind # compute scores for gradients masked_gates = torch.softmax(masked_gates, dim=-1) - final_multipliers[:, 0:1] = masked_gates.gather(dim=-1, - index=selected_experts) + multiplier_o = masked_gates.gather(dim=-1, index=selected_experts) + multiplier = multiplier_o + # masked out first expert masked_scores = torch.scatter( scores, @@ -42,85 +47,438 @@ def sparsemixer(scores, top_k, jitter_eps): mask_logits_threshold = ( (mask_logits_threshold - scores) / factor) > (2 * jitter_eps) - final_experts[:, 1:2] = max_ind # apply mask masked_gates_top2 = masked_scores.masked_fill(mask_logits_threshold, float('-inf')) selected_experts_top2 = max_ind # compute scores for gradients masked_gates_top2 = torch.softmax(masked_gates_top2, dim=-1) - final_multipliers[:, 1:2] = masked_gates_top2.gather( - dim=-1, index=selected_experts_top2) - return final_multipliers, final_experts - - -class PatchedPhiMoESparseMoeBlock(nn.Module): - - def _update_model_fn(self): - """update model.""" - num_experts = self.num_experts - - def __get_meta(): - exp = self.experts[0] - ffn_dim = exp.w1.weight.size(0) - hidden_dim = exp.w2.weight.size(0) - dtype = exp.w1.weight.dtype - device = exp.w1.weight.device - return ffn_dim, hidden_dim, dtype, device - - def __copy_assign_param(param, weight): - """copy assign.""" - weight.copy_(param.data) - param.data = weight - - ffn_dim, hidden_dim, dtype, device = __get_meta() - - gate_up_weights = torch.empty(num_experts, - ffn_dim * 2, - hidden_dim, - device=device, - dtype=dtype) - down_weights = torch.empty(num_experts, - hidden_dim, - ffn_dim, - device=device, - dtype=dtype) - for exp_id, exp in enumerate(self.experts): - __copy_assign_param(exp.w1.weight, - gate_up_weights[exp_id, :ffn_dim]) - __copy_assign_param(exp.w3.weight, gate_up_weights[exp_id, - ffn_dim:]) - __copy_assign_param(exp.w2.weight, down_weights[exp_id]) - - torch.cuda.empty_cache() - - self.register_buffer('gate_up_weights', gate_up_weights) - self.register_buffer('down_weights', down_weights) - - @classmethod - def _distribute_output_fn(cls, outputs, **kwargs): - """Distribution output hook.""" - dist.all_reduce(outputs[0]) - return outputs - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - """rewrite moe forward.""" + multiplier_top2_o = masked_gates_top2.gather(dim=-1, + index=selected_experts_top2) + + multiplier_top2 = multiplier_top2_o + + multiplier = torch.concat((multiplier, multiplier_top2), dim=-1) + selected_experts = torch.concat((selected_experts, selected_experts_top2), + dim=-1) + + return ( + multiplier, + selected_experts, + ) + + +class PhiMoEAttention(nn.Module): + """PhiMoE attention.""" + + def __init__(self, + config: Any, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + quantization_config = None + + num_heads = config.num_attention_heads + num_key_value_heads = config.num_key_value_heads + hidden_size = config.hidden_size + head_dim = hidden_size // num_heads + + # qkv + self.qkv_proj = build_qkv_proj( + hidden_size, + num_q_heads=num_heads, + num_kv_heads=num_key_value_heads, + head_size=head_dim, + bias=config.attention_bias, + quant_config=quantization_config, + dtype=dtype, + device=device, + ) + + self.apply_rotary_pos_emb = ApplyRotaryEmb() + + # attention + self.sliding_window = getattr(config, 'sliding_window', None) + self.attn_fwd = Attention( + num_heads, + head_dim, + num_kv_heads=num_key_value_heads, + sliding_window=self.sliding_window, + ) + + # o_proj + self.o_proj = build_rowwise_linear(num_heads * head_dim, + hidden_size, + bias=config.attention_bias, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + def forward( + self, + hidden_states: torch.Tensor, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attn_metadata: Any = None, + ): + """Rewrite of LlamaAttention.forward.""" + qkv_states = self.qkv_proj(hidden_states) + # (-1, heads, head_dim) + qkv_states = qkv_states.flatten(0, -2) + query_states, key_states, value_states = self.qkv_proj.split_qkv( + qkv_states) + + cos, sin = rotary_pos_emb + query_states, key_states = self.apply_rotary_pos_emb( + query_states, + key_states, + cos, + sin, + inplace=True, + ) + attn_output = self.attn_fwd( + query_states, + key_states, + value_states, + past_key_value[0], + past_key_value[1], + attn_metadata, + inplace=True, + ) + attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1) + + attn_output = self.o_proj(attn_output) + return attn_output + + +class PhiMoESparseMoeBlock(nn.Module): + """PhiMoE sparse moe block.""" + + def __init__(self, + config: Any, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok + + self.gate = build_rowwise_linear( + self.hidden_dim, + self.num_experts, + bias=False, + dtype=dtype, + device=device, + is_tp=False, + ) + + self.experts = FusedMoE( + self.hidden_dim, + self.ffn_dim, + self.num_experts, + top_k=2, + renormalize=False, + dtype=dtype, + device=device, + all_reduce=True, + ) + + self.router_jitter_noise = config.router_jitter_noise + self.input_jitter_noise = config.input_jitter_noise + + def forward(self, hidden_states: torch.Tensor): + """forward.""" batch_size, sequence_length, hidden_dim = hidden_states.shape + if self.input_jitter_noise > 0: + hidden_states *= torch.empty_like(hidden_states).uniform_( + 1.0 - self.input_jitter_noise, 1.0 + self.input_jitter_noise) hidden_states = hidden_states.view(-1, hidden_dim) router_logits = self.gate(hidden_states) - routing_weights, selected_experts = sparsemixer( - router_logits, - top_k=2, - jitter_eps=self.router_jitter_noise, + + topk_weights, topk_ids = sparsemixer( + router_logits, top_k=2, jitter_eps=self.router_jitter_noise) + out_states = self.experts( + hidden_states, + topk_weights, + topk_ids, ) - out_states = fused_moe(hidden_states, - self.gate_up_weights, - self.down_weights, - routing_weights, - selected_experts, - topk=2, - renormalize=False) out_states = out_states.reshape(batch_size, sequence_length, -1) return out_states, router_logits + + +class PhiMoEDecoderLayer(nn.Module): + """PhiMoE decoder layer.""" + + def __init__(self, + config: Any, + layer_idx: int, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.layer_idx = layer_idx + + # build attention layer + self.self_attn = PhiMoEAttention(config, dtype=dtype, device=device) + self.block_sparse_moe = PhiMoESparseMoeBlock(config, + dtype=dtype, + device=device) + + # build input layer norm + self.input_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.rms_norm_eps, + elementwise_affine=True, + dtype=dtype, + device=device) + + # build attention layer norm + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.rms_norm_eps, + elementwise_affine=True, + dtype=dtype, + device=device) + + def forward( + self, + hidden_states: torch.Tensor, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], + past_key_value: Optional[List[torch.FloatTensor]], + attn_metadata: Any = None, + ): + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states = self.self_attn( + hidden_states=hidden_states, + rotary_pos_emb=rotary_pos_emb, + past_key_value=past_key_value, + attn_metadata=attn_metadata, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, _ = self.block_sparse_moe(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class PhiMoEModel(nn.Module): + """PhiMoE model.""" + + def __init__(self, + config: Any, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = nn.Embedding(config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=dtype, + device=device) + self.layers = nn.ModuleList([ + PhiMoEDecoderLayer(config, layer_idx, dtype=dtype, device=device) + for layer_idx in range(config.num_hidden_layers) + ]) + + # build norm + self.norm = nn.LayerNorm(config.hidden_size, + eps=config.rms_norm_eps, + elementwise_affine=True, + dtype=dtype, + device=device) + + # build rotary embedding + emb_type = EmbeddingType.LinearScaling + rope_dim = config.hidden_size // config.num_attention_heads + rope_max_pos_emb = config.max_position_embeddings + rope_base = config.rope_theta + rope_scaling = config.rope_scaling + if rope_scaling is not None: + scaling_type = rope_scaling['type'] + assert scaling_type in ['longrope', 'su'] + emb_type = EmbeddingType.LongRoPEScaling + ori_pos_emb = getattr(config, 'original_max_position_embeddings', + rope_max_pos_emb) + longrope_params = LongRoPEScalingParameters( + short_factor=rope_scaling['short_factor'], + long_factor=rope_scaling['long_factor'], + original_max_position_embeddings=ori_pos_emb, + short_mscale=rope_scaling['short_mscale'], + long_mscale=rope_scaling['long_mscale']) + self.rotary_emb = build_rotary_embedding( + rope_dim, + rope_max_pos_emb, + rope_base, + longrope_params=longrope_params, + emb_type=emb_type, + ) + else: + self.rotary_emb = build_rotary_embedding( + rope_dim, + rope_max_pos_emb, + rope_base, + emb_type=emb_type, + ) + + def forward( + self, + input_ids: torch.LongTensor = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + attn_metadata: Any = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ): + """forward.""" + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + hidden_states = inputs_embeds + cos, sin = self.rotary_emb(hidden_states, position_ids) + cos, sin = cos[0], sin[0] + rotary_pos_emb = (cos, sin) + for idx, decoder_layer in enumerate(self.layers): + + past_key_value = past_key_values[idx] + hidden_states = decoder_layer( + hidden_states, + rotary_pos_emb=rotary_pos_emb, + past_key_value=past_key_value, + attn_metadata=attn_metadata, + ) + + hidden_states = self.norm(hidden_states) + + return hidden_states + + def get_input_embeddings(self): + """get input embeddings.""" + return self.embed_tokens + + +class PhiMoEForCausalLM(nn.Module): + """mixture model for causalLM.""" + + support_cuda_graph = True + + def __init__(self, + config: Any, + ctx_mgr: StepContextManager, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.ctx_mgr = ctx_mgr + self.model = PhiMoEModel(config, dtype=dtype, device=device) + # build lm_head + self.lm_head = build_rowwise_linear(config.hidden_size, + config.vocab_size, + bias=config.lm_head_bias, + dtype=dtype, + device=device) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: List[List[torch.Tensor]], + attn_metadata: Any = None, + inputs_embeds: torch.Tensor = None, + **kwargs, + ): + hidden_states = self.model( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + ) + + logits = self.lm_head(hidden_states) + logits = logits.float() + return logits + + def get_input_embeddings(self): + """get input embeddings.""" + return self.model.get_input_embeddings() + + def prepare_inputs_for_generation( + self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + context: StepContext = None, + ): + """prepare input.""" + input_ids = context.input_ids + position_ids = context.position_ids + attn_metadata = context.attn_metadata + + return dict( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """load weights.""" + # modify from vllm + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ('.qkv_proj', '.q_proj', 'q'), + ('.qkv_proj', '.k_proj', 'k'), + ('.qkv_proj', '.v_proj', 'v'), + ] + + num_experts = self.config.num_local_experts + expert_params_mapping = [] + for exp_id in range(num_experts): + gate_param = ('.experts.gate_up_weights', + f'.experts.{exp_id}.w1.weight', exp_id, 'gate') + up_param = ('.experts.gate_up_weights', + f'.experts.{exp_id}.w3.weight', exp_id, 'up') + down_param = ('.experts.down_weights', + f'.experts.{exp_id}.w2.weight', exp_id, 'down') + expert_params_mapping += [gate_param, up_param, down_param] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if 'rotary_emb.inv_freq' in name: + continue + if ('rotary_emb.cos_cached' in name + or 'rotary_emb.sin_cached' in name): + continue + if self.config.tie_word_embeddings and 'lm_head.weight' in name: + continue + for (param_name, weight_name, expert_id, + shard_id) in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + load_weight(param, + loaded_weight, + expert_id=expert_id, + shard_id=shard_id) + break + else: + for (param_name, weight_name, + shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + load_weight(param, loaded_weight, shard_id=shard_id) + break + else: + param = params_dict[name] + load_weight(param, loaded_weight) diff --git a/lmdeploy/pytorch/nn/linear.py b/lmdeploy/pytorch/nn/linear.py index c5ca22185..53f9f717d 100644 --- a/lmdeploy/pytorch/nn/linear.py +++ b/lmdeploy/pytorch/nn/linear.py @@ -212,6 +212,8 @@ def _weight_loader_tp_rowwise(self, param: torch.nn.Parameter, """weight loader for rowwise linear.""" if loaded_weight.dim() == 1: # bias + if rank == 0: + loaded_weight = torch.zeros_like(loaded_weight) return default_weight_loader(param, loaded_weight) if loaded_weight.size(0) == self.in_features: @@ -516,6 +518,8 @@ def _weight_loader_tp_rowwise(self, param: torch.nn.Parameter, return default_weight_loader(param, weight) else: # bias + if rank == 0: + loaded_weight = torch.zeros_like(loaded_weight) return default_weight_loader(param, loaded_weight) def weight_loader(self, param: torch.nn.Parameter, @@ -762,6 +766,8 @@ def _weight_loader_tp_rowwise(self, param: torch.nn.Parameter, return default_weight_loader(param, weight) else: # bias + if rank != 0: + loaded_weight = torch.zeros_like(loaded_weight) return default_weight_loader(param, loaded_weight) def weight_loader(self, param: torch.nn.Parameter, From 1d3b27d9b3615bad0c98621cc8ba6146de992d98 Mon Sep 17 00:00:00 2001 From: grimoire Date: Tue, 3 Sep 2024 11:05:03 +0800 Subject: [PATCH 58/89] support dbrx --- lmdeploy/pytorch/models/dbrx.py | 743 +++++++++++++++++--------- lmdeploy/pytorch/models/module_map.py | 14 +- lmdeploy/pytorch/models/peft.py | 280 ---------- 3 files changed, 504 insertions(+), 533 deletions(-) delete mode 100644 lmdeploy/pytorch/models/peft.py diff --git a/lmdeploy/pytorch/models/dbrx.py b/lmdeploy/pytorch/models/dbrx.py index 729dfc917..c9b790a6c 100644 --- a/lmdeploy/pytorch/models/dbrx.py +++ b/lmdeploy/pytorch/models/dbrx.py @@ -1,296 +1,557 @@ # Copyright (c) OpenMMLab. All rights reserved. - -from typing import Any, Optional, Tuple, Union +from typing import Any, Iterable, List, Optional, Tuple import torch -import torch.distributed as dist -import torch.nn as nn -import torch.utils.checkpoint -from transformers.cache_utils import Cache -from transformers.modeling_outputs import MoeModelOutputWithPast - -from lmdeploy.pytorch.kernels.fused_moe import fused_moe +from torch import nn +from transformers.configuration_utils import PretrainedConfig + +from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, EmbeddingType, + build_rotary_embedding) +from lmdeploy.pytorch.nn.linear import build_qkv_proj, build_rowwise_linear +from lmdeploy.pytorch.nn.moe import FusedMoE, SoftmaxTopK +from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight + + +class DbrxAttention(nn.Module): + """Rewrite module of DbrxAttention.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + attn_config = config.attn_config + quantization_config = getattr(config, 'quantization_config', None) + hidden_size = config.d_model + num_heads = config.n_heads + num_key_value_heads = attn_config.kv_n_heads + head_dim = getattr(config, 'head_dim', hidden_size // num_heads) + + # packed qkv + self.Wqkv = build_qkv_proj( + hidden_size, + num_q_heads=num_heads, + num_kv_heads=num_key_value_heads, + head_size=head_dim, + bias=False, + quant_config=quantization_config, + dtype=dtype, + device=device, + ) -from ..kernels import fill_kv_cache, fused_rotary_emb, paged_attention_fwd -from ..weight_loader.dist_utils import (colwise_split_parallelize_linear, - rowwise_parallelize_linear) + # rotary embedding + self.apply_rotary_pos_emb = ApplyRotaryEmb() + # attention + self.attn_fwd = Attention( + num_heads, + head_dim, + num_kv_heads=num_key_value_heads, + ) -class PatchedDbrxAttention(nn.Module): + # o_proj + self.out_proj = build_rowwise_linear(hidden_size, + hidden_size, + bias=False, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) - def _load_weights(self, loader, rank: int, world_size: int, - device: torch.device): - """load weights.""" - sections = [ - self.num_heads * self.head_dim, - self.num_key_value_heads * self.head_dim, - self.num_key_value_heads * self.head_dim, - ] - colwise_split_parallelize_linear(self.Wqkv, - sections, - loader, - rank=rank, - world_size=world_size, - prefix='Wqkv') - rowwise_parallelize_linear(self.out_proj, - loader, - rank=rank, - world_size=world_size, - prefix='out_proj') - - @classmethod - def _distribute_output_fn(cls, outputs, **kwargs): - """Distribution output hook.""" - dist.all_reduce(outputs[0]) - return outputs - - def _contiguous_batching_forward_impl( + def forward( self, hidden_states: torch.Tensor, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], past_key_value: Optional[Tuple[torch.Tensor]] = None, - world_size: int = 1, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - """Implement of attention forward.""" - context = self.context.context - q_start_loc = context.q_start_loc - q_seq_length = context.q_seq_length - kv_seq_length = context.kv_seq_length - block_offsets = context.block_offsets - max_q_seq_length = context.max_q_seq_length - - num_heads = self.num_heads // world_size - num_kv_heads = self.num_key_value_heads // world_size - head_dim = self.head_dim - - def __qkv_proj(hidden_states): - """qkv_proj.""" - qkv_states = self.Wqkv(hidden_states) - if self.clip_qkv is not None: - qkv_states = qkv_states.clamp(min=-self.clip_qkv, - max=self.clip_qkv) - - query_states, key_states, value_states = qkv_states.split( - [ - num_heads * head_dim, - num_kv_heads * head_dim, - num_kv_heads * head_dim, - ], - dim=-1, - ) - - query_states = query_states.view(-1, num_heads, head_dim) - key_states = key_states.view(-1, num_kv_heads, head_dim) - value_states = value_states.view(-1, num_kv_heads, head_dim) - return query_states, key_states, value_states - - def __rotary_emb_fn(query_states, key_states, value_states): - scaling_factor = 1.0 - rotary_emb = self.rotary_emb - if rotary_emb.inv_freq is None: - rotary_emb.inv_freq = 1.0 / (rotary_emb.base**(torch.arange( - 0, - rotary_emb.dim, - 2, - dtype=torch.int64, - device=query_states.device).float() / rotary_emb.dim)) - inv_freq = rotary_emb.inv_freq - query_states, key_states = fused_rotary_emb( - query_states[None], - key_states[None], - context.position_ids_1d[None], - inv_freq=inv_freq, - scaling_factor=scaling_factor, - out_q=query_states[None], - out_k=key_states[None]) - return query_states[0], key_states[0], value_states - - query_states, key_states, value_states = __qkv_proj(hidden_states) - - query_states, key_states, value_states = __rotary_emb_fn( - query_states, key_states, value_states) - - fill_kv_cache( + attn_metadata: Any = None, + ): + """Rewrite of forward.""" + # qkv proj + qkv_states = self.Wqkv(hidden_states) + # (-1, heads, head_dim) + qkv_states = qkv_states.flatten(0, -2) + query_states, key_states, value_states = self.Wqkv.split_qkv( + qkv_states) + + # apply rotary embedding + cos, sin = rotary_pos_emb + query_states, key_states = self.apply_rotary_pos_emb( + query_states, key_states, - value_states, - past_key_value[0], - past_key_value[1], - q_start_loc, - q_seq_length, - kv_seq_length=kv_seq_length, - max_q_seq_length=max_q_seq_length, - block_offsets=block_offsets, + cos, + sin, + inplace=True, ) - attn_output = query_states - paged_attention_fwd( + # attention + attn_output = self.attn_fwd( query_states, + key_states, + value_states, past_key_value[0], past_key_value[1], - attn_output, - block_offsets, - q_start_loc=q_start_loc, - q_seqlens=q_seq_length, - kv_seqlens=kv_seq_length, - max_seqlen=max_q_seq_length, + attn_metadata, + inplace=True, ) attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1) + # o proj attn_output = self.out_proj(attn_output) + return attn_output + + +class DbrxRouter(nn.Module): + """router.""" + + def __init__(self, + hidden_size: int, + moe_num_experts: int, + moe_top_k: int, + moe_normalize_expert_weights: Optional[float], + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + + self.layer = build_rowwise_linear( + hidden_size, + moe_num_experts, + bias=False, + dtype=dtype, + device=device, + is_tp=False, + ) + + self.softmax_topk = SoftmaxTopK(moe_top_k) - return attn_output, None, past_key_value + self.moe_normalize_expert_weights = moe_normalize_expert_weights + + def forward(self, hidden_states: torch.Tensor): + """forward.""" + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + weights = self.layer(hidden_states) + top_weights, top_experts = self.softmax_topk(weights) + + top_weights_scale = (torch.norm(top_weights, + p=self.moe_normalize_expert_weights, + dim=-1, + keepdim=True) + if self.moe_normalize_expert_weights is not None + else 1.0) + top_weights = top_weights / top_weights_scale + + return top_weights, top_experts + + +class DbrxExperts(nn.Module): + """experts.""" + + def __init__(self, + hidden_size: int, + ffn_hidden_size: int, + moe_num_experts: int, + ffn_act_fn: dict, + moe_top_k: int, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + + act_fn_name = ffn_act_fn.get('name', None) + assert act_fn_name == 'silu' + + self.mlp = FusedMoE( + hidden_size, + ffn_hidden_size, + moe_num_experts, + top_k=moe_top_k, + renormalize=False, + dtype=dtype, + device=device, + all_reduce=True, + ) + + def forward(self, hidden_states: torch.Tensor, top_weights: torch.Tensor, + top_experts: torch.Tensor): + """forward.""" + batch_size = hidden_states.size(0) + hidden_states = hidden_states.flatten(0, 1) + out_states = self.mlp( + hidden_states, + top_weights, + top_experts, + ) + out_states = out_states.unflatten(0, (batch_size, -1)) + + return out_states + + +class DbrxFFN(nn.Module): + """mlp.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + ffn_config = config.ffn_config + self.router = DbrxRouter( + hidden_size=config.d_model, + moe_num_experts=ffn_config.moe_num_experts, + moe_top_k=ffn_config.moe_top_k, + moe_normalize_expert_weights=ffn_config. + moe_normalize_expert_weights, + dtype=dtype, + device=device, + ) + + self.experts = DbrxExperts( + hidden_size=config.d_model, + ffn_hidden_size=ffn_config.ffn_hidden_size, + moe_num_experts=ffn_config.moe_num_experts, + ffn_act_fn=ffn_config.ffn_act_fn, + moe_top_k=ffn_config.moe_top_k, + dtype=dtype, + device=device, + ) + + def forward(self, x): + """forward.""" + top_weights, top_experts = self.router(x) + out = self.experts(x, top_weights, top_experts) + return out + + +class DbrxNormAttentionNorm(nn.Module): + """dbrx norm attention norm.""" + + def __init__(self, + config: PretrainedConfig, + layer_idx: int, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.layer_idx = layer_idx + + self.norm_1 = nn.LayerNorm( + config.d_model, + bias=False, + dtype=dtype, + device=device, + ) + self.attn = DbrxAttention( + config=config, + dtype=dtype, + device=device, + ) + self.norm_2 = nn.LayerNorm( + config.d_model, + bias=False, + dtype=dtype, + device=device, + ) def forward( self, hidden_states: torch.Tensor, - position_ids: torch.LongTensor, - attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - **kwargs: Any, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attn_metadata: Any = None, + ): """forward.""" - world_size = 1 - if dist.is_initialized(): - world_size = dist.get_world_size() - return self._contiguous_batching_forward_impl( + residual_states = hidden_states + hidden_states = self.norm_1(hidden_states).to(hidden_states.dtype) + + hidden_states = self.attn( hidden_states, + rotary_pos_emb, past_key_value, - world_size=world_size, + attn_metadata, ) + hidden_states = hidden_states + residual_states -class PatchedDbrxExpertGLU(nn.Module): + residual_states = hidden_states + hidden_states = self.norm_2(hidden_states).to(hidden_states.dtype) + return hidden_states, residual_states - def _load_weights(self, loader, rank: int, world_size: int, - device: torch.device): - """load weights.""" - def __partition(name, param): - param = loader.pop(name) - param = param.unflatten(0, (self.moe_num_experts, -1)) - param = param.chunk(world_size, 1)[rank] - param = param.to(device) - dtype = param.dtype - if param.dtype != dtype: - param = param.to(dtype) - param = torch.nn.Parameter(param.flatten(0, 1)) - self.register_parameter(name, param) - - __partition('w1', self.w1) - __partition('v1', self.v1) - __partition('w2', self.w2) - - def _update_model_fn(self): - """update model.""" - ffn_hidden_size = self.w1.size(0) // self.moe_num_experts - gate_up_weights = self.w1.new_empty(self.moe_num_experts, - ffn_hidden_size * 2, - self.w1.size(1)) - gate_up_weights[:, :ffn_hidden_size].copy_( - self.w1.unflatten(0, (self.moe_num_experts, -1))) - gate_up_weights[:, ffn_hidden_size:].copy_( - self.v1.unflatten(0, (self.moe_num_experts, -1))) - delattr(self, 'w1') - delattr(self, 'v1') - down_weights = self.w2.data.unflatten( - 0, (self.moe_num_experts, -1)).transpose(1, 2).contiguous() - delattr(self, 'w2') - torch.cuda.empty_cache() - - self.register_buffer('gate_up_weights', gate_up_weights) - self.register_buffer('down_weights', down_weights) - - -class PatchedDbrxExperts(nn.Module): - - @classmethod - def _distribute_output_fn(cls, outputs, **kwargs): - """Distribution output hook.""" - dist.all_reduce(outputs) - return outputs - - def forward(self, x: torch.Tensor, weights: torch.Tensor, - top_weights: torch.Tensor, - top_experts: torch.LongTensor) -> torch.Tensor: - """moe forward.""" - q_len = x.size(1) - x = x.flatten(0, 1) - out_states = fused_moe(x, - self.mlp.gate_up_weights, - self.mlp.down_weights, - top_weights, - top_experts, - topk=top_weights.size(1), - renormalize=False) - - out_states = out_states.unflatten(0, (-1, q_len)) - return out_states +class DbrxBlock(nn.Module): + """decoder layer.""" + + def __init__(self, + config: PretrainedConfig, + layer_idx: int, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.layer_idx = layer_idx + # build attention layer + self.norm_attn_norm = DbrxNormAttentionNorm(config, + layer_idx, + dtype=dtype, + device=device) -class PatchedDbrxModel(nn.Module): + # builf MLP + self.ffn = DbrxFFN(config, dtype=dtype, device=device) - def _continuous_batching_forward( + def forward( self, - input_ids: Optional[torch.LongTensor], - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache] = None, - ) -> Union[Tuple, MoeModelOutputWithPast]: - """forward impl.""" - output_attentions = False - use_cache = True - output_router_logits = False + hidden_states: torch.Tensor, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], + past_key_value: Optional[List[torch.FloatTensor]], + attn_metadata: Any = None, + ): + # Self Attention + hidden_states, resid_states = self.norm_attn_norm( + hidden_states=hidden_states, + rotary_pos_emb=rotary_pos_emb, + past_key_value=past_key_value, + attn_metadata=attn_metadata, + ) - inputs_embeds = self.wte(input_ids) + # Fully Connected + hidden_states = self.ffn(hidden_states) + hidden_states = resid_states + hidden_states + + return hidden_states + + +class DbrxModel(nn.Module): + """model.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.wte = nn.Embedding(config.vocab_size, + config.d_model, + self.padding_idx, + dtype=dtype, + device=device) + + # build all decode layers + self.blocks = nn.ModuleList([ + DbrxBlock(config, layer_idx, dtype=dtype, device=device) + for layer_idx in range(config.n_layers) + ]) + + # build norm + self.norm_f = nn.LayerNorm(config.d_model, + bias=False, + dtype=dtype, + device=device) + + # build rotary embedding + emb_type = EmbeddingType.LinearScaling + rope_dim = config.d_model // config.n_heads + rope_max_pos_emb = config.max_seq_len + rope_base = config.attn_config.rope_theta + self.rotary_emb = build_rotary_embedding( + rope_dim, + rope_max_pos_emb, + rope_base, + emb_type=emb_type, + ) + + def forward( + self, + input_ids: torch.LongTensor = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + attn_metadata: Any = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ): + """Rewrite of LlamaModel.forward.""" - # Attention mask is not necessary in continuous batching - attention_mask = None - cache_position = None + # token embedding + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) hidden_states = inputs_embeds - for idx, block in enumerate(self.blocks): + # rotary embedding + cos, sin = self.rotary_emb(hidden_states, position_ids) + cos, sin = cos[0], sin[0] + rotary_pos_emb = (cos, sin) + + # decoding + for idx, decoder_layer in enumerate(self.blocks): past_key_value = past_key_values[idx] - block_outputs = block( + hidden_states = decoder_layer( hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, + rotary_pos_emb=rotary_pos_emb, past_key_value=past_key_value, - output_attentions=output_attentions, - output_router_logits=output_router_logits, - use_cache=use_cache, - cache_position=cache_position, + attn_metadata=attn_metadata, ) - hidden_states = block_outputs[0] + # norm hidden_states = self.norm_f(hidden_states) - return MoeModelOutputWithPast( - last_hidden_state=hidden_states, + return hidden_states + + def get_input_embeddings(self): + """get input embeddings.""" + return self.wte + + +class DbrxForCausalLM(nn.Module): + """ModelForCausalLM.""" + + support_cuda_graph = True + + packed_modules_mapping = { + 'qkv_proj': [ + 'q_proj', + 'k_proj', + 'v_proj', + ], + 'gate_up_proj': [ + 'gate_proj', + 'up_proj', + ], + } + + def __init__(self, + config: PretrainedConfig, + ctx_mgr: StepContextManager, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.ctx_mgr = ctx_mgr + # build model + self.transformer = DbrxModel(config, dtype=dtype, device=device) + # build lm_head + self.lm_head = build_rowwise_linear(config.hidden_size, + config.vocab_size, + bias=False, + dtype=dtype, + device=device) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: List[List[torch.Tensor]], + attn_metadata: Any = None, + inputs_embeds: torch.Tensor = None, + **kwargs, + ): + """model forward, return logits.""" + hidden_states = self.transformer( + input_ids=input_ids, + position_ids=position_ids, past_key_values=past_key_values, - hidden_states=None, - attentions=None, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, ) - def forward( + logits = self.lm_head(hidden_states) + logits = logits.float() + return logits + + def get_input_embeddings(self): + """get input embeddings.""" + return self.transformer.get_input_embeddings() + + def prepare_inputs_for_generation( self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache] = None, + past_key_values: List[List[torch.Tensor]], inputs_embeds: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_router_logits: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, MoeModelOutputWithPast]: - """Rewrite of LlamaModel.forward.""" - return self._continuous_batching_forward( - input_ids, - position_ids, - past_key_values, + context: StepContext = None, + ): + """prepare input.""" + # get input_ids, position_ids and attention metadatas + input_ids = context.input_ids + position_ids = context.position_ids + attn_metadata = context.attn_metadata + + # process vision embeddings + vision_embeddings = context.input_embeddings + vision_embedding_indexing = context.input_embedding_indexing + if vision_embeddings is not None and len(vision_embeddings) > 0: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + inputs_embeds[:, + vision_embedding_indexing, :] = vision_embeddings.to( + inputs_embeds) + + # inputs of forward + return dict( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """load weights.""" + # modify from vllm + config = self.config + attn_config = config.attn_config + hidden_size = config.d_model + num_heads = config.n_heads + num_key_value_heads = attn_config.kv_n_heads + head_dim = getattr(config, 'head_dim', hidden_size // num_heads) + qkv_section = [ + head_dim * num_heads, head_dim * num_key_value_heads, + head_dim * num_key_value_heads + ] + + ffn_config = config.ffn_config + num_experts = ffn_config.moe_num_experts + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if 'rotary_emb.inv_freq' in name: + continue + if ('rotary_emb.cos_cached' in name + or 'rotary_emb.sin_cached' in name): + continue + if self.config.tie_word_embeddings and 'lm_head.weight' in name: + continue + + if '.experts' in name: + loaded_weight = loaded_weight.unflatten(0, (num_experts, -1)) + if '.w1' in name: + name = name.replace('.w1', '.gate_up_weights') + param = params_dict[name] + for exp_id in range(num_experts): + weight = loaded_weight[exp_id] + load_weight(param, + weight, + expert_id=exp_id, + shard_id='gate') + elif '.v1' in name: + name = name.replace('.v1', '.gate_up_weights') + param = params_dict[name] + for exp_id in range(num_experts): + weight = loaded_weight[exp_id] + load_weight(param, + weight, + expert_id=exp_id, + shard_id='up') + elif '.w2' in name: + name = name.replace('.w2', '.down_weights') + param = params_dict[name] + for exp_id in range(num_experts): + weight = loaded_weight[exp_id].t() + load_weight(param, + weight, + expert_id=exp_id, + shard_id='down') + elif '.Wqkv' in name: + q, k, v = loaded_weight.split(qkv_section) + param = params_dict[name] + load_weight(param, q, shard_id='q') + load_weight(param, k, shard_id='k') + load_weight(param, v, shard_id='v') + else: + param = params_dict[name] + load_weight(param, loaded_weight) diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index c692e79f3..086921528 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -113,18 +113,8 @@ # dbrx MODULE_MAP.update({ - 'modeling_dbrx.DbrxAttention': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.dbrx.PatchedDbrxAttention', - 'modeling_dbrx.DbrxFlashAttention2': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.dbrx.PatchedDbrxAttention', - 'modeling_dbrx.DbrxSdpaAttention': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.dbrx.PatchedDbrxAttention', - 'modeling_dbrx.DbrxModel': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.dbrx.PatchedDbrxModel', - 'modeling_dbrx.DbrxExpertGLU': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.dbrx.PatchedDbrxExpertGLU', - 'modeling_dbrx.DbrxExperts': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.dbrx.PatchedDbrxExperts', + 'DbrxForCausalLM': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.dbrx.DbrxForCausalLM', }) # starcoder2 diff --git a/lmdeploy/pytorch/models/peft.py b/lmdeploy/pytorch/models/peft.py deleted file mode 100644 index ed414c179..000000000 --- a/lmdeploy/pytorch/models/peft.py +++ /dev/null @@ -1,280 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from dataclasses import dataclass - -import torch -import torch.distributed as dist - -from ..kernels.mbgmm import mbgmm_a, mbgmm_b -from ..kernels.mbgmv import mbgmv_a, mbgmv_b -from ..kernels.rearange_all_gather import rearange_all_gather - - -@dataclass -class PackedLoRAInput: - x: torch.Tensor - a_cache: torch.Tensor - b_cache: torch.Tensor - q_start_loc: torch.Tensor - q_seqlens: torch.Tensor - adapter_ids: torch.Tensor - scaling: torch.Tensor - rank_offset: torch.Tensor - ranks: torch.Tensor - max_seq_len: int - max_rank: int - is_decoding: bool - - -class LoRALinear(torch.nn.Module): - - def _make_packed_lora_input(self, x): - context = self.context.context - adapter_info = context.adapter_params[self.target_name] - - # adapter cache - layer_idx = self.layer_idx - ranks = adapter_info.ranks - scaling = adapter_info.scalings - rank_offset = adapter_info.rank_offsets - max_rank = adapter_info.max_rank - k_cache, v_cache = context.kv_caches[layer_idx] - cache_len = k_cache.size(0) - a_cache = k_cache.view(cache_len, -1) - b_cache = v_cache.view(cache_len, -1) - - return PackedLoRAInput(x=x.flatten(0, -2).contiguous(), - a_cache=a_cache, - b_cache=b_cache, - q_start_loc=context.q_start_loc, - q_seqlens=context.q_seq_length, - adapter_ids=context.local_adapter_ids, - scaling=scaling, - rank_offset=rank_offset, - ranks=ranks, - max_seq_len=context.max_q_seq_length, - max_rank=max_rank, - is_decoding=context.is_decoding) - - def _lora_forward_local(self, x): - """lora forward no tp.""" - - lora_input = self._make_packed_lora_input(x) - - if hasattr(self.base_layer, 'weight'): - out_size = self.base_layer.weight.size(0) - else: - out_size = self.base_layer.scales.size(1) - if not lora_input.is_decoding: - xa = mbgmm_a(lora_input.x, - lora_input.a_cache, - q_start_loc=lora_input.q_start_loc, - q_seqlens=lora_input.q_seqlens, - adapter_ids=lora_input.adapter_ids, - rank_offset=lora_input.rank_offset, - ranks=lora_input.ranks, - max_seq_len=lora_input.max_seq_len, - max_rank=lora_input.max_rank) - lora_out = mbgmm_b(xa, - lora_input.b_cache, - q_start_loc=lora_input.q_start_loc, - q_seqlens=lora_input.q_seqlens, - adapter_ids=lora_input.adapter_ids, - scaling=lora_input.scaling, - rank_offset=lora_input.rank_offset, - ranks=lora_input.ranks, - max_seq_len=lora_input.max_seq_len, - max_rank=lora_input.max_rank, - out_size=out_size) - else: - xa = mbgmv_a(lora_input.x, - lora_input.a_cache, - adapter_ids=lora_input.adapter_ids, - rank_offset=lora_input.rank_offset, - ranks=lora_input.ranks, - max_rank=lora_input.max_rank) - lora_out = mbgmv_b(xa, - lora_input.b_cache, - adapter_ids=lora_input.adapter_ids, - scaling=lora_input.scaling, - rank_offset=lora_input.rank_offset, - ranks=lora_input.ranks, - max_rank=lora_input.max_rank, - out_size=out_size) - - base_out = self.base_layer(x) - lora_out = lora_out.reshape(base_out.shape) - output = base_out + lora_out - - return output - - def _lora_forward_tp_rowwise(self, x): - """lora forward tp rowwise.""" - - lora_input = self._make_packed_lora_input(x) - rank = dist.get_rank() - world_size = dist.get_world_size() - if hasattr(self.base_layer, 'weight'): - out_size = self.base_layer.weight.size(0) - else: - out_size = self.base_layer.scales.size(1) - out_size = out_size // world_size - if not lora_input.is_decoding: - xa = mbgmm_a(lora_input.x, - lora_input.a_cache, - q_start_loc=lora_input.q_start_loc, - q_seqlens=lora_input.q_seqlens, - adapter_ids=lora_input.adapter_ids, - rank_offset=lora_input.rank_offset, - ranks=lora_input.ranks, - max_seq_len=lora_input.max_seq_len, - max_rank=lora_input.max_rank) - lora_out = mbgmm_b(xa, - lora_input.b_cache, - q_start_loc=lora_input.q_start_loc, - q_seqlens=lora_input.q_seqlens, - adapter_ids=lora_input.adapter_ids, - scaling=lora_input.scaling, - rank_offset=lora_input.rank_offset, - ranks=lora_input.ranks, - max_seq_len=lora_input.max_seq_len, - max_rank=lora_input.max_rank, - out_size=out_size) - else: - xa = mbgmv_a(lora_input.x, - lora_input.a_cache, - adapter_ids=lora_input.adapter_ids, - rank_offset=lora_input.rank_offset, - ranks=lora_input.ranks, - max_rank=lora_input.max_rank) - lora_out = mbgmv_b(xa, - lora_input.b_cache, - adapter_ids=lora_input.adapter_ids, - scaling=lora_input.scaling, - rank_offset=lora_input.rank_offset, - ranks=lora_input.ranks, - max_rank=lora_input.max_rank, - out_size=out_size) - - base_out = self.base_layer(x) - out_shape = base_out.shape - base_out = base_out.flatten(0, -2) - - slice_start = rank * out_size - slice_end = slice_start + out_size - base_out[:, slice_start:slice_end] += lora_out - base_out = base_out.reshape(out_shape) - - return base_out - - def _lora_forward_tp_colwise(self, x): - """lora forward tp colwise.""" - - def __gather_xa(xa): - """gather xa.""" - gathered_xa = xa.new_empty(world_size, xa.size(0), xa.size(1)) - dist.all_gather_into_tensor(gathered_xa, xa) - # TODO: gather would failed when adapters have different ranks. - gathered_xa = gathered_xa.permute(1, 0, 2).flatten(-2, -1) - return gathered_xa - - lora_input = self._make_packed_lora_input(x) - world_size = dist.get_world_size() - if hasattr(self.base_layer, 'weight'): - out_size = self.base_layer.weight.size(0) - else: - out_size = self.base_layer.scales.size(1) - if not lora_input.is_decoding: - xa = mbgmm_a(lora_input.x, - lora_input.a_cache, - q_start_loc=lora_input.q_start_loc, - q_seqlens=lora_input.q_seqlens, - adapter_ids=lora_input.adapter_ids, - rank_offset=lora_input.rank_offset, - ranks=lora_input.ranks, - max_seq_len=lora_input.max_seq_len, - max_rank=lora_input.max_rank, - rank_step=world_size) - gathered_xa = __gather_xa(xa) - if len(lora_input.ranks) > 1: - gathered_xa = rearange_all_gather( - gathered_xa, - b_start_loc=lora_input.q_start_loc, - b_seq_lens=lora_input.q_seqlens, - adapter_ids=lora_input.adapter_ids, - ranks=lora_input.ranks, - world_size=world_size, - max_seq_len=lora_input.max_seq_len, - output=gathered_xa) - lora_out = mbgmm_b(gathered_xa, - lora_input.b_cache, - q_start_loc=lora_input.q_start_loc, - q_seqlens=lora_input.q_seqlens, - adapter_ids=lora_input.adapter_ids, - scaling=lora_input.scaling, - rank_offset=lora_input.rank_offset, - ranks=lora_input.ranks, - max_seq_len=lora_input.max_seq_len, - max_rank=lora_input.max_rank, - out_size=out_size) - else: - xa = mbgmv_a(lora_input.x, - lora_input.a_cache, - adapter_ids=lora_input.adapter_ids, - rank_offset=lora_input.rank_offset, - ranks=lora_input.ranks, - max_rank=lora_input.max_rank, - rank_step=world_size) - gathered_xa = __gather_xa(xa) - if len(lora_input.ranks) > 1: - gathered_xa = rearange_all_gather( - gathered_xa, - b_start_loc=lora_input.q_start_loc, - b_seq_lens=lora_input.q_seqlens, - adapter_ids=lora_input.adapter_ids, - ranks=lora_input.ranks, - world_size=world_size, - max_seq_len=lora_input.max_seq_len, - output=gathered_xa) - lora_out = mbgmv_b(gathered_xa, - lora_input.b_cache, - adapter_ids=lora_input.adapter_ids, - scaling=lora_input.scaling, - rank_offset=lora_input.rank_offset, - ranks=lora_input.ranks, - max_rank=lora_input.max_rank, - out_size=out_size) - - base_out = self.base_layer(x) - lora_out = lora_out.reshape(base_out.shape) - output = base_out + lora_out - - return output - - def _lora_forward_tp(self, x): - """lora forward tp.""" - tp_mode = getattr(self, '_tp_mode', None) - if tp_mode == 'rowwise': - return self._lora_forward_tp_rowwise(x) - elif tp_mode == 'colwise': - return self._lora_forward_tp_colwise(x) - else: - assert tp_mode is None, 'tp_mode == None failed.' - return self._lora_forward_local(x) - - def _lora_forward(self, x): - """lora forward.""" - if dist.is_initialized(): - return self._lora_forward_tp(x) - else: - return self._lora_forward_local(x) - - def forward(self, x): - """forward.""" - context = self.context.context - adapter_info = context.adapter_params[self.target_name] - max_rank = adapter_info.max_rank - - if max_rank == 0: - return self.origin_mod.forward(x) - else: - return self._lora_forward(x) From 5607566143e3251c26ca535eccde65b195411615 Mon Sep 17 00:00:00 2001 From: grimoire Date: Tue, 3 Sep 2024 11:40:04 +0800 Subject: [PATCH 59/89] support internvl --- lmdeploy/pytorch/models/internvl.py | 96 ++++++++++++++++++++++----- lmdeploy/pytorch/models/module_map.py | 4 +- 2 files changed, 83 insertions(+), 17 deletions(-) diff --git a/lmdeploy/pytorch/models/internvl.py b/lmdeploy/pytorch/models/internvl.py index d02b3a6c5..d62317275 100644 --- a/lmdeploy/pytorch/models/internvl.py +++ b/lmdeploy/pytorch/models/internvl.py @@ -1,23 +1,89 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import List, Optional, Tuple, Union +from typing import Any, Iterable, List, Tuple import torch from torch import nn -from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.configuration_utils import PretrainedConfig +from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -class PatchedInternVLChatModel(nn.Module): +from .patch import build_model_from_hf_config - def forward(self, - pixel_values: torch.FloatTensor = None, - input_ids: torch.LongTensor = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - **kwargs) -> Union[Tuple, CausalLMOutputWithPast]: - outputs = self.language_model(input_ids=input_ids, - past_key_values=past_key_values, - position_ids=position_ids) - return CausalLMOutputWithPast(logits=outputs.logits) +class InternVLChatModel(nn.Module): + + support_cuda_graph = True + + def __init__(self, + config: PretrainedConfig, + ctx_mgr: StepContextManager, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.ctx_mgr = ctx_mgr + llm_config = config.llm_config + self.language_model = build_model_from_hf_config(llm_config, + dtype=dtype, + device=device) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: List[List[torch.Tensor]], + attn_metadata: Any = None, + inputs_embeds: torch.Tensor = None, + **kwargs, + ): + return self.language_model.forward(input_ids=input_ids, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + position_ids=position_ids, + attn_metadata=attn_metadata) + + def get_input_embeddings(self): + """get input embeddings.""" + return self.language_model.get_input_embeddings() + + def prepare_inputs_for_generation( + self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: torch.Tensor = None, + context: StepContext = None, + ): + """prepare input.""" + input_ids = context.input_ids + position_ids = context.position_ids + attn_metadata = context.attn_metadata + # get inputs from context + vision_embeddings = context.input_embeddings + vision_embedding_indexing = context.input_embedding_indexing + + if vision_embeddings is not None and len(vision_embeddings) > 0: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + inputs_embeds[:, + vision_embedding_indexing, :] = vision_embeddings.to( + inputs_embeds) + + return dict( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """load weights.""" + + prefix_length = len('language_model.') + new_weights = dict() + for key, val in weights: + if not key.startswith('language_model.'): + continue + new_key = key[prefix_length:] + new_weights[new_key] = val + + self.language_model.load_weights(new_weights.items()) diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index 086921528..b3eec4234 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -145,8 +145,8 @@ # internvl MODULE_MAP.update({ - 'modeling_internvl_chat.InternVLChatModel': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.internvl.PatchedInternVLChatModel' + 'InternVLChatModel': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.internvl.InternVLChatModel' }) # phi3 vision From 494649de808215cf8091a9e1b689e61b5d3d3c95 Mon Sep 17 00:00:00 2001 From: grimoire Date: Tue, 3 Sep 2024 12:07:07 +0800 Subject: [PATCH 60/89] support merged awq weight --- lmdeploy/pytorch/models/chatglm2.py | 17 +------- lmdeploy/pytorch/models/dbrx.py | 11 +---- lmdeploy/pytorch/models/falcon.py | 15 +------ lmdeploy/pytorch/models/phi3.py | 14 +------ lmdeploy/pytorch/models/qwen.py | 12 +----- lmdeploy/pytorch/nn/linear.py | 62 +++++++++++++++++++++++++++++ 6 files changed, 69 insertions(+), 62 deletions(-) diff --git a/lmdeploy/pytorch/models/chatglm2.py b/lmdeploy/pytorch/models/chatglm2.py index e32c3c7b6..b66f83a05 100644 --- a/lmdeploy/pytorch/models/chatglm2.py +++ b/lmdeploy/pytorch/models/chatglm2.py @@ -468,19 +468,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): """load weights.""" # modify from vllm - config = self.config - num_attention_heads = config.num_attention_heads - projection_size = config.kv_channels * num_attention_heads - num_kv_heads = num_attention_heads - head_size = (projection_size // num_attention_heads) - multi_query_attention = config.multi_query_attention - if multi_query_attention: - num_kv_heads = config.multi_query_group_num - qkv_section = [ - head_size * num_attention_heads, head_size * num_kv_heads, - head_size * num_kv_heads - ] - params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if 'rotary_pos_emb.inv_freq' in name: @@ -492,14 +479,14 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): and 'output_layer.weight' in name): continue if '.query_key_value' in name: - q, k, v = loaded_weight.split(qkv_section) param = params_dict[name] + q, k, v = param.weight_spliter(loaded_weight) load_weight(param, q, shard_id='q') load_weight(param, k, shard_id='k') load_weight(param, v, shard_id='v') elif '.dense_h_to_4h' in name: - gate, up = loaded_weight.chunk(2) param = params_dict[name] + gate, up = param.weight_spliter(loaded_weight) load_weight(param, gate, shard_id=0) load_weight(param, up, shard_id=1) else: diff --git a/lmdeploy/pytorch/models/dbrx.py b/lmdeploy/pytorch/models/dbrx.py index c9b790a6c..056796fcb 100644 --- a/lmdeploy/pytorch/models/dbrx.py +++ b/lmdeploy/pytorch/models/dbrx.py @@ -494,15 +494,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): """load weights.""" # modify from vllm config = self.config - attn_config = config.attn_config - hidden_size = config.d_model - num_heads = config.n_heads - num_key_value_heads = attn_config.kv_n_heads - head_dim = getattr(config, 'head_dim', hidden_size // num_heads) - qkv_section = [ - head_dim * num_heads, head_dim * num_key_value_heads, - head_dim * num_key_value_heads - ] ffn_config = config.ffn_config num_experts = ffn_config.moe_num_experts @@ -547,8 +538,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): expert_id=exp_id, shard_id='down') elif '.Wqkv' in name: - q, k, v = loaded_weight.split(qkv_section) param = params_dict[name] + q, k, v = param.weight_spliter(loaded_weight) load_weight(param, q, shard_id='q') load_weight(param, k, shard_id='k') load_weight(param, v, shard_id='v') diff --git a/lmdeploy/pytorch/models/falcon.py b/lmdeploy/pytorch/models/falcon.py index 145c765d2..d50d97fac 100644 --- a/lmdeploy/pytorch/models/falcon.py +++ b/lmdeploy/pytorch/models/falcon.py @@ -432,19 +432,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): """load weights.""" # modify from vllm - config = self.config - num_attention_heads = config.num_attention_heads - hidden_size = config.hidden_size - num_kv_heads = num_attention_heads - head_size = (hidden_size // num_attention_heads) - multi_query_attention = config.multi_query - if multi_query_attention: - num_kv_heads = 1 - qkv_section = [ - head_size * num_attention_heads, head_size * num_kv_heads, - head_size * num_kv_heads - ] - params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if 'rotary_pos_emb.inv_freq' in name: @@ -456,8 +443,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): and 'output_layer.weight' in name): continue if '.query_key_value' in name: - q, k, v = loaded_weight.split(qkv_section) param = params_dict[name] + q, k, v = param.weight_spliter(loaded_weight) load_weight(param, q, shard_id='q') load_weight(param, k, shard_id='k') load_weight(param, v, shard_id='v') diff --git a/lmdeploy/pytorch/models/phi3.py b/lmdeploy/pytorch/models/phi3.py index 56110908e..c04bb235a 100644 --- a/lmdeploy/pytorch/models/phi3.py +++ b/lmdeploy/pytorch/models/phi3.py @@ -405,16 +405,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): """load weights.""" # modify from vllm - config = self.config - num_heads = config.num_attention_heads - num_key_value_heads = config.num_key_value_heads - hidden_size = config.hidden_size - head_dim = getattr(config, 'head_dim', hidden_size // num_heads) - qkv_section = [ - head_dim * num_heads, head_dim * num_key_value_heads, - head_dim * num_key_value_heads - ] - params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if 'rotary_emb.inv_freq' in name: @@ -427,14 +417,14 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if 'vision_embed_tokens' in name: continue if '.qkv_proj' in name: - q, k, v = loaded_weight.split(qkv_section) param = params_dict[name] + q, k, v = param.weight_spliter(loaded_weight) load_weight(param, q, shard_id='q') load_weight(param, k, shard_id='k') load_weight(param, v, shard_id='v') elif '.gate_up_proj' in name: - gate, up = loaded_weight.chunk(2) param = params_dict[name] + gate, up = param.weight_spliter(loaded_weight) load_weight(param, gate, shard_id=0) load_weight(param, up, shard_id=1) else: diff --git a/lmdeploy/pytorch/models/qwen.py b/lmdeploy/pytorch/models/qwen.py index 2357251fb..0d33e8268 100644 --- a/lmdeploy/pytorch/models/qwen.py +++ b/lmdeploy/pytorch/models/qwen.py @@ -400,16 +400,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ('.gate_up_proj', '.w1', 1), ] - config = self.config - num_attention_heads = config.num_attention_heads - projection_size = config.kv_channels * num_attention_heads - num_kv_heads = num_attention_heads - head_size = (projection_size // num_attention_heads) - qkv_section = [ - head_size * num_attention_heads, head_size * num_kv_heads, - head_size * num_kv_heads - ] - params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if 'rotary_pos_emb.inv_freq' in name: @@ -428,8 +418,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): break else: if '.c_attn' in name: - q, k, v = loaded_weight.split(qkv_section) param = params_dict[name] + q, k, v = param.weight_spliter(loaded_weight) load_weight(param, q, shard_id='q') load_weight(param, k, shard_id='k') load_weight(param, v, shard_id='v') diff --git a/lmdeploy/pytorch/nn/linear.py b/lmdeploy/pytorch/nn/linear.py index 53f9f717d..bace20bf2 100644 --- a/lmdeploy/pytorch/nn/linear.py +++ b/lmdeploy/pytorch/nn/linear.py @@ -320,6 +320,13 @@ def __init__(self, out_names: Optional[List[int]] = None): if replicate is None: replicate = tuple(False for _ in all_out_features) + + self.split_section_s = all_out_features + elem_per_int = 32 // w_bit + self.split_section_wz = [ + size // elem_per_int for size in all_out_features + ] + all_out_features = self._update_all_out_features( all_out_features, w_bit, group_size, replicate) self.all_out_features = all_out_features @@ -339,13 +346,17 @@ def __init__(self, colwise=True, is_tp=is_tp) self.qweight.weight_loader = self.weight_loader + self.qweight.weight_spliter = self.weight_spliter_wz self.qweight._weight_type = 'qweight' self.scales.weight_loader = self.weight_loader + self.scales.weight_spliter = self.weight_spliter_s self.scales._weight_type = 'scales' self.qzeros.weight_loader = self.weight_loader + self.qzeros.weight_spliter = self.weight_spliter_wz self.qzeros._weight_type = 'qzeros' if self.bias is not None: self.bias.weight_loader = self.weight_loader + self.bias.weight_spliter = self.weight_spliter_s self.bias._weight_type = 'bias' def _get_io_features(self, in_features: int, out_features: int, w_bit: int, @@ -400,6 +411,14 @@ def weight_loader(self, param: torch.nn.Parameter, weight = _chunk_align(loaded_weight, world_size, 1, align)[rank] param_w.copy_(weight) + def weight_spliter_wz(self, loaded_weight: torch.Tensor): + """weight spliter.""" + return loaded_weight.split(self.split_section_wz, dim=1) + + def weight_spliter_s(self, loaded_weight: torch.Tensor): + """weight spliter.""" + return loaded_weight.split(self.split_section_s, dim=-1) + class QKVAwqLinear(MergedAwqLinear, QKVMixin): """qkv awq linear.""" @@ -416,6 +435,14 @@ def __init__(self, bias: bool = False, device: Optional[torch.device] = None, is_tp: bool = True): + + self.qkv_split_section_s = self._get_qkv_out_features( + num_q_heads, num_kv_heads, head_size, head_size_v) + elem_per_int = 32 // w_bit + self.qkv_split_section_wz = [ + size // elem_per_int for size in self.qkv_split_section_s + ] + num_q_heads, num_kv_heads = self._update_num_heads( num_q_heads, num_kv_heads, replicate_kv) all_out_features = self._get_qkv_out_features(num_q_heads, @@ -443,6 +470,14 @@ def _update_all_out_features(self, all_out_features: List[int], w_bit: int, """update all out features.""" return all_out_features + def weight_spliter_wz(self, loaded_weight: torch.Tensor): + """weight spliter.""" + return loaded_weight.split(self.qkv_split_section_wz, dim=1) + + def weight_spliter_s(self, loaded_weight: torch.Tensor): + """weight spliter.""" + return loaded_weight.split(self.qkv_split_section_s, dim=-1) + class W8A8Linear(nn.Module): """w8a8 linear.""" @@ -596,6 +631,7 @@ def __init__(self, out_names: Optional[List[int]] = None): if replicate is None: replicate = tuple(False for _ in all_out_features) + self.split_section = all_out_features all_out_features = self._update_all_out_features( all_out_features, replicate) self.all_out_features = all_out_features @@ -615,8 +651,11 @@ def __init__(self, is_tp=is_tp) self.weight.weight_loader = self.weight_loader self.scale.weight_loader = self.weight_loader + self.weight.weight_spliter = self.weight_spliter + self.scale.weight_spliter = self.weight_spliter if self.bias is not None: self.bias.weight_loader = self.weight_loader + self.bias.weight_spliter = self.weight_spliter def _get_io_features(self, in_features: int, out_features: int, colwise: bool): @@ -645,6 +684,10 @@ def weight_loader(self, param: torch.nn.Parameter, loaded_weight = loaded_weight.chunk(world_size, 0)[rank] param_w.copy_(loaded_weight) + def weight_spliter(self, loaded_weight: torch.Tensor): + """weight spliter.""" + return loaded_weight.split(self.split_section, dim=0) + class QKVW8A8Linear(MergedW8A8Linear, QKVMixin): """qkv w8a8 linear.""" @@ -660,6 +703,8 @@ def __init__(self, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, is_tp: bool = True): + self.qkv_split_section = self._get_qkv_out_features( + num_q_heads, num_kv_heads, head_size, head_size_v) num_q_heads, num_kv_heads = self._update_num_heads( num_q_heads, num_kv_heads, replicate_kv) all_out_features = self._get_qkv_out_features(num_q_heads, @@ -685,6 +730,10 @@ def _update_all_out_features(self, all_out_features: List[int], """update all out features.""" return all_out_features + def weight_spliter(self, loaded_weight: torch.Tensor): + """weight spliter.""" + return loaded_weight.split(self.qkv_split_section, dim=0) + class BaseLinear(nn.Module): """linear layer.""" @@ -836,6 +885,7 @@ def __init__(self, out_names: Optional[List[int]] = None): if replicate is None: replicate = tuple(False for _ in all_out_features) + self.split_section = all_out_features all_out_features = self._update_all_out_features( all_out_features, replicate) self.all_out_features = all_out_features @@ -854,8 +904,10 @@ def __init__(self, colwise=True, is_tp=is_tp) self.weight.weight_loader = self.weight_loader + self.weight.weight_spliter = self.weight_spliter if self.bias is not None: self.bias.weight_loader = self.weight_loader + self.bias.weight_spliter = self.weight_spliter def _get_io_features(self, in_features: int, out_features: int, colwise: bool): @@ -884,6 +936,10 @@ def weight_loader(self, param: torch.nn.Parameter, loaded_weight = loaded_weight.chunk(world_size, 0)[rank] param_w.copy_(loaded_weight) + def weight_spliter(self, loaded_weight: torch.Tensor): + """weight spliter.""" + return loaded_weight.split(self.split_section, dim=0) + class QKVBaseLinear(MergedBaseLinear, QKVMixin): """qkv base linear.""" @@ -899,6 +955,8 @@ def __init__(self, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, is_tp: bool = True): + self.qkv_split_section = self._get_qkv_out_features( + num_q_heads, num_kv_heads, head_size, head_size_v) num_q_heads, num_kv_heads = self._update_num_heads( num_q_heads, num_kv_heads, replicate_kv) all_out_features = self._get_qkv_out_features(num_q_heads, @@ -939,6 +997,10 @@ def weight_loader(self, param: torch.nn.Parameter, self.head_size_v)[rank] param_w.copy_(loaded_weight) + def weight_spliter(self, loaded_weight: torch.Tensor): + """weight spliter.""" + return loaded_weight.split(self.qkv_split_section, dim=0) + def build_linear(in_features: int, out_features: int, From 5f33ccfc52cea3606da88295bacb1f06a67b1057 Mon Sep 17 00:00:00 2001 From: grimoire Date: Tue, 3 Sep 2024 16:23:30 +0800 Subject: [PATCH 61/89] add cogvlm --- .../pytorch/backends/cuda/graph_runner.py | 19 +- lmdeploy/pytorch/model_inputs.py | 2 + lmdeploy/pytorch/models/cogvlm.py | 799 ++++++++++++------ lmdeploy/pytorch/models/module_map.py | 12 +- 4 files changed, 552 insertions(+), 280 deletions(-) diff --git a/lmdeploy/pytorch/backends/cuda/graph_runner.py b/lmdeploy/pytorch/backends/cuda/graph_runner.py index 131494d35..87738bd2e 100644 --- a/lmdeploy/pytorch/backends/cuda/graph_runner.py +++ b/lmdeploy/pytorch/backends/cuda/graph_runner.py @@ -224,18 +224,7 @@ def check_enable_graph(self): if self.backend_config.eager_mode: return False - if not getattr(self.model, 'support_cuda_graph', False): - return False - - # TODO: should we enable cudagraph awq now? - # hf_config = self.model_config.hf_config - # quantization_config = getattr(hf_config, - # 'quantization_config', dict()) - # quant_method = quantization_config.get('quant_method', None) - # if quant_method == 'awq': - # logger.warning('AWQ model does not support cuda graph yet.') - # return False - return True + return getattr(self.model, 'support_cuda_graph', False) def get_graph_key(self, input_ids: torch.Tensor, position_ids: torch.Tensor, past_key_values: List, @@ -250,7 +239,11 @@ def get_graph_key(self, input_ids: torch.Tensor, def __call__(self, **kwargs): """call.""" - if not self.enable_graph: + enable_graph = self.enable_graph + if callable(enable_graph): + enable_graph = enable_graph(**kwargs) + + if not enable_graph: return self.model(**kwargs) graph_key = self.get_graph_key(**kwargs) diff --git a/lmdeploy/pytorch/model_inputs.py b/lmdeploy/pytorch/model_inputs.py index dcd805412..27ed28def 100644 --- a/lmdeploy/pytorch/model_inputs.py +++ b/lmdeploy/pytorch/model_inputs.py @@ -208,6 +208,7 @@ class StepContext: adapter_params: Dict[str, AdapterInfo] = None input_embeddings: torch.Tensor = None input_embedding_indexing: torch.Tensor = None + vision_inputs: VisionModelInputs = None attn_metadata: Any = None _outputs: Dict = field(default_factory=dict) @@ -269,6 +270,7 @@ def new( is_decoding=inputs.is_decoding, world_size=world_size, local_adapter_ids=inputs.local_adapter_ids, + vision_inputs=inputs.vision_inputs, ) ret = get_backend().update_step_context(ret) diff --git a/lmdeploy/pytorch/models/cogvlm.py b/lmdeploy/pytorch/models/cogvlm.py index 45d7a4d01..9e73c538c 100644 --- a/lmdeploy/pytorch/models/cogvlm.py +++ b/lmdeploy/pytorch/models/cogvlm.py @@ -1,290 +1,431 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import List, Optional, Tuple, Union +from typing import Any, Iterable, List, Optional, Tuple import torch import torch.distributed as dist from torch import nn -from transformers.modeling_outputs import BaseModelOutputWithPast - -from ..kernels import fill_kv_cache, fused_rotary_emb, paged_attention_fwd -from ..weight_loader.dist_utils import (colwise_split_parallelize_linear, - rowwise_parallelize_linear) - -LANGUAGE_TOKEN_TYPE = 0 -VISION_TOKEN_TYPE = 1 - -# flake8: noqa: F821 - - -def get_vision_expert_mask( - token_type_ids: 'torch.LongTensor(B, L)' -) -> '[torch.BoolTensor(B, L), torch.BoolTensor(B, L)]': - vision_token_mask = torch.zeros_like(token_type_ids, dtype=torch.bool) - vision_token_mask[:, :-1] = (token_type_ids[:, :-1] - == VISION_TOKEN_TYPE) & (token_type_ids[:, 1:] - == VISION_TOKEN_TYPE) - language_token_mask = ~vision_token_mask - return vision_token_mask, language_token_mask - - -class PatchedVisionExpertMLP(nn.Module): - - def forward(self, hidden_states: 'torch.Tensor(B, L, D)', - token_type_ids: 'torch.LongTensor(B, L)'): - context = self.context.context - only_has_language = context.is_decoding - if not context.is_decoding: - # for embedding splitting - if hasattr(context, 'vision_token_mask') and hasattr( - context, 'language_token_mask'): - vision_token_mask = context.vision_token_mask - language_token_mask = context.language_token_mask - only_has_language = vision_token_mask.numel() == 0 - else: - only_has_language = True - - if only_has_language: - output = self.language_mlp(hidden_states) - else: - output = torch.empty_like(hidden_states) - output[:, vision_token_mask, :] = self.vision_mlp( - hidden_states[:, vision_token_mask, :]) - output[:, language_token_mask, :] = self.language_mlp( - hidden_states[:, language_token_mask, :]) - return output +from transformers.configuration_utils import PretrainedConfig + +from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, EmbeddingType, + RMSNorm, SiluAndMul, build_rotary_embedding) +from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear, + build_qkv_proj, build_rowwise_linear) +from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight + + +def get_world_rank(): + """get current world size and rank.""" + import torch.distributed as dist + world_size = 1 + rank = 0 + + if dist.is_initialized(): + world_size = dist.get_world_size() + rank = dist.get_rank() + + return world_size, rank + + +class VisionExpertAttention(nn.Module): + """Rewrite module of VisionExpertAttention.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + num_heads = config.num_attention_heads + num_key_value_heads = num_heads + hidden_size = config.hidden_size + head_dim = getattr(config, 'head_dim', hidden_size // num_heads) + self.hidden_size = hidden_size + + # packed qkv + self.vision_expert_query_key_value = build_qkv_proj( + hidden_size, + num_q_heads=num_heads, + num_kv_heads=num_key_value_heads, + head_size=head_dim, + bias=False, + quant_config=quantization_config, + dtype=dtype, + device=device, + ) + self.language_expert_query_key_value = build_qkv_proj( + hidden_size, + num_q_heads=num_heads, + num_kv_heads=num_key_value_heads, + head_size=head_dim, + bias=False, + quant_config=quantization_config, + dtype=dtype, + device=device, + ) + # rotary embedding + self.apply_rotary_pos_emb = ApplyRotaryEmb() -class PatchedVisionExpertAttention(nn.Module): + # attention + self.attn_fwd = Attention( + num_heads, + head_dim, + num_kv_heads=num_key_value_heads, + ) - def _load_weights(self, loader, rank: int, world_size: int, - device: torch.device): - """load weights.""" - num_heads = self.config.num_attention_heads - num_kv_heads = getattr(self.config, 'num_multi_query_heads', num_heads) - head_dim = self.config.hidden_size // num_heads - sections = [ - self.config.hidden_size, num_kv_heads * head_dim, - num_kv_heads * head_dim - ] - for name in [ - 'vision_expert_query_key_value', - 'language_expert_query_key_value' - ]: - colwise_split_parallelize_linear(getattr(self, name), - sections, - loader, - rank=rank, - world_size=world_size, - prefix=name) - for name in ['vision_expert_dense', 'language_expert_dense']: - rowwise_parallelize_linear(getattr(self, name), - loader, - rank=rank, - world_size=world_size, - prefix=name) - - @classmethod - def _distribute_output_fn(cls, outputs, **kwargs): - """Distribution output hook.""" - dist.all_reduce(outputs[0]) - return outputs + # o_proj + self.vision_expert_dense = build_rowwise_linear( + hidden_size, + hidden_size, + bias=False, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True, + all_reduce=False) + self.language_expert_dense = build_rowwise_linear( + hidden_size, + hidden_size, + bias=False, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True, + all_reduce=False) + world_size, _ = get_world_rank() + self.world_size = world_size + self.all_reduce = world_size > 1 - def _contiguous_batching_forward_impl( + def forward( self, hidden_states: torch.Tensor, - token_type_ids: torch.LongTensor = None, - position_ids: Optional[torch.LongTensor] = None, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], past_key_value: Optional[Tuple[torch.Tensor]] = None, - world_size: int = 1, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - """Rewrite implementation of Attention.forward. - - Add continuous batching support. Add paged attention support. - """ - context = self.context.context - q_start_loc = context.q_start_loc - q_seq_length = context.q_seq_length - kv_seq_length = context.kv_seq_length - block_offsets = context.block_offsets - max_q_seq_length = context.max_q_seq_length - num_heads = self.config.num_attention_heads // world_size - num_kv_heads = getattr(self.config, 'num_multi_query_heads', - self.config.num_attention_heads) // world_size - - head_dim = self.config.hidden_size // self.config.num_attention_heads - hidden_size = num_heads * head_dim - only_has_language = context.is_decoding - if not context.is_decoding: - # for embedding splitting - if hasattr(context, 'vision_token_mask') and hasattr( - context, 'language_token_mask'): - vision_token_mask = context.vision_token_mask - language_token_mask = context.language_token_mask - only_has_language = vision_token_mask.numel() == 0 - else: - only_has_language = True + attn_metadata: Any = None, + lang_ids: torch.LongTensor = None, + vision_ids: torch.LongTensor = None, + ): + """Rewrite of LlamaAttention.forward.""" + bsz, seqlen, _ = hidden_states.size() + hidden_size = self.hidden_size // self.world_size + + # qkv proj + if lang_ids is None and vision_ids is None: + qkv_states = self.language_expert_query_key_value(hidden_states) + else: + qkv_states = hidden_states.new_empty(bsz, seqlen, hidden_size * 3) + if lang_ids is not None: + qkv_states[:, lang_ids] = self.language_expert_query_key_value( + hidden_states[:, lang_ids]) + if vision_ids is not None: + qkv_states[:, vision_ids] = self.vision_expert_query_key_value( + hidden_states[:, vision_ids]) + # (-1, heads, head_dim) + qkv_states = qkv_states.flatten(0, -2) + query_states, key_states, value_states = \ + self.language_expert_query_key_value.split_qkv(qkv_states) + + # apply rotary embedding + cos, sin = rotary_pos_emb + query_states, key_states = self.apply_rotary_pos_emb( + query_states, + key_states, + cos, + sin, + inplace=True, + ) - def __qkv_proj(hidden_states): - """qkv_proj.""" - if only_has_language: - mixed_raw_layer = self.language_expert_query_key_value( - hidden_states) - else: - shape = list(hidden_states.shape) - shape[-1] = hidden_size + head_dim * num_kv_heads * 2 - mixed_raw_layer = torch.empty(shape, - dtype=hidden_states.dtype, - device=hidden_states.device) - - mixed_raw_layer[:, - vision_token_mask, :] = self.vision_expert_query_key_value( - hidden_states[:, vision_token_mask, :]) - mixed_raw_layer[:, - language_token_mask, :] = self.language_expert_query_key_value( - hidden_states[:, language_token_mask, :]) - query_states, key_states, value_states = torch.split( - mixed_raw_layer, [ - hidden_size, head_dim * num_kv_heads, - head_dim * num_kv_heads - ], - dim=-1) - return query_states, key_states, value_states - - def __rotary_emb_fn(query_states, key_states, value_states): - """rotary embedding func.""" - scaling_factor = getattr(self.rotary_emb, 'scaling_factor', 1.0) - inv_freq = self.rotary_emb.inv_freq - - query_states, key_states = fused_rotary_emb( - query_states[None], - key_states[None], - position_ids[None], - inv_freq=inv_freq, - scaling_factor=scaling_factor, - out_q=query_states[None], - out_k=key_states[None]) - return query_states[0], key_states[0], value_states - - query_states, key_states, value_states = __qkv_proj(hidden_states) - - query_states = query_states.view(-1, num_heads, head_dim) - key_states = key_states.view(-1, num_kv_heads, head_dim) - value_states = value_states.view(-1, num_kv_heads, head_dim) - - query_states, key_states, value_states = __rotary_emb_fn( - query_states, key_states, value_states) - - fill_kv_cache( + # attention + attn_output = self.attn_fwd( + query_states, key_states, value_states, past_key_value[0], past_key_value[1], - q_start_loc, - q_seq_length, - kv_seq_length=kv_seq_length, - max_q_seq_length=max_q_seq_length, - block_offsets=block_offsets, + attn_metadata, + inplace=True, ) + attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1) - context_layer = query_states - paged_attention_fwd( - query_states, - past_key_value[0], - past_key_value[1], - context_layer, - block_offsets, - q_start_loc=q_start_loc, - q_seqlens=q_seq_length, - kv_seqlens=kv_seq_length, - max_seqlen=max_q_seq_length, + # o proj + if lang_ids is None and vision_ids is None: + attn_output = self.language_expert_dense(attn_output) + else: + new_attn_output = torch.empty_like(hidden_states) + if lang_ids is not None: + new_attn_output[:, lang_ids] = self.language_expert_dense( + attn_output[:, lang_ids]) + if vision_ids is not None: + new_attn_output[:, vision_ids] = self.vision_expert_dense( + attn_output[:, vision_ids]) + attn_output = new_attn_output + + if self.all_reduce: + dist.all_reduce(attn_output) + return attn_output + + +class MLP(nn.Module): + """mlp.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + assert config.hidden_act == 'silu' + + quantization_config = getattr(config, 'quantization_config', None) + + # gate up + self.gate_up_proj = build_merged_colwise_linear( + config.hidden_size, + [config.intermediate_size, config.intermediate_size], + bias=False, + dtype=dtype, + device=device, + quant_config=quantization_config, + is_tp=True, ) - context_layer = context_layer.reshape(*hidden_states.shape[:-1], -1) - if only_has_language: - attn_output = self.language_expert_dense(context_layer) + # silu and mul + self.act_fn = SiluAndMul(inplace=True) + + # down + self.down_proj = build_rowwise_linear(config.intermediate_size, + config.hidden_size, + bias=False, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True, + all_reduce=False) + + def forward(self, x): + """forward.""" + gate_up = self.gate_up_proj(x) + act = self.act_fn(gate_up) + return self.down_proj(act) + + +class VisionExpertMLP(nn.Module): + """vision expert mlp.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.language_mlp = MLP(config, dtype=dtype, device=device) + self.vision_mlp = MLP(config, dtype=dtype, device=device) + world_size, _ = get_world_rank() + self.all_reduce = world_size > 1 + + def forward( + self, + hidden_states: torch.Tensor, + lang_ids: torch.LongTensor = None, + vision_ids: torch.LongTensor = None, + ): + """forward.""" + if lang_ids is None and vision_ids is None: + output = self.language_mlp(hidden_states) else: - ctx_shape = list(context_layer.shape) - ctx_shape[-1] *= world_size - attn_output = torch.empty(ctx_shape, - dtype=hidden_states.dtype, - device=hidden_states.device) + output = torch.empty_like(hidden_states) + if lang_ids is not None: + output[:, + lang_ids] = self.language_mlp(hidden_states[:, + lang_ids]) + if vision_ids is not None: + output[:, + vision_ids] = self.vision_mlp(hidden_states[:, + vision_ids]) + if self.all_reduce: + dist.all_reduce(output) + return output - attn_output[:, vision_token_mask, :] = self.vision_expert_dense( - context_layer[:, vision_token_mask, :]) - attn_output[:, - language_token_mask, :] = self.language_expert_dense( - context_layer[:, language_token_mask, :]) - return attn_output, None, past_key_value +class CogVLMDecoderLayer(nn.Module): + """decoder layer.""" + + def __init__(self, + config: PretrainedConfig, + layer_idx: int, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.layer_idx = layer_idx + quantization_config = getattr(config, 'quantization_config', None) + + # build attention layer + self.self_attn = VisionExpertAttention(config, + dtype=dtype, + device=device) + + # builf MLP + self.mlp = VisionExpertMLP(config, dtype=dtype, device=device) + + # build input layer norm + self.input_layernorm = RMSNorm(config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) + + # build attention layer norm + self.post_attention_layernorm = RMSNorm( + config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) def forward( self, hidden_states: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - """Rewrite of forward.""" - world_size = 1 - if dist.is_initialized(): - world_size = dist.get_world_size() - return self._contiguous_batching_forward_impl( - hidden_states, - position_ids=position_ids, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], + past_key_value: Optional[List[torch.FloatTensor]], + residual: Optional[torch.Tensor] = None, + attn_metadata: Any = None, + lang_ids: torch.LongTensor = None, + vision_ids: torch.LongTensor = None, + ): + + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + # Self Attention + hidden_states = self.self_attn( + hidden_states=hidden_states, + rotary_pos_emb=rotary_pos_emb, past_key_value=past_key_value, - world_size=world_size, + attn_metadata=attn_metadata, + lang_ids=lang_ids, + vision_ids=vision_ids, ) + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp( + hidden_states, + lang_ids=lang_ids, + vision_ids=vision_ids, + ) + + outputs = (hidden_states, residual) + return outputs -class PatchedCogVLMModel(nn.Module): + +class CogVLMModel(nn.Module): + """model.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + quantization_config = getattr(config, 'quantization_config', None) + + self.embed_tokens = nn.Embedding(config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=dtype, + device=device) + + # build all decode layers + self.layers = nn.ModuleList([ + CogVLMDecoderLayer(config, layer_idx, dtype=dtype, device=device) + for layer_idx in range(config.num_hidden_layers) + ]) + + # build norm + self.norm = RMSNorm(config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) + + # build rotary embedding + emb_type = EmbeddingType.LinearScaling + rope_dim = config.hidden_size // config.num_attention_heads + rope_max_pos_emb = 2048 + rope_base = 10000 + self.rotary_emb = build_rotary_embedding( + rope_dim, + rope_max_pos_emb, + rope_base, + emb_type=emb_type, + ) def forward( self, input_ids: torch.LongTensor = None, + position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, - **kwargs, - ) -> Union[Tuple, BaseModelOutputWithPast]: - # not allow for inputs_embeds, because we want to process image feature - assert input_ids is not None - context = self.context.context - # get inputs from context - vision_embeddings = context.input_embeddings - vision_embedding_indexing = context.input_embedding_indexing + attn_metadata: Any = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + lang_ids: torch.LongTensor = None, + vision_ids: torch.LongTensor = None, + ): + """Rewrite of LlamaModel.forward.""" - inputs_embeds = self.embed_tokens(input_ids) - position_ids = _get_cogvlm_position_ids(context) + # token embedding + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) - if vision_embeddings is not None and len(vision_embeddings) > 0: - # multi-modality - inputs_embeds[:, - vision_embedding_indexing, :] = vision_embeddings.to( - inputs_embeds) hidden_states = inputs_embeds + # rotary embedding + cos, sin = self.rotary_emb(hidden_states, position_ids) + cos, sin = cos[0], sin[0] + rotary_pos_emb = (cos, sin) + + # decoding + residual = None for idx, decoder_layer in enumerate(self.layers): past_key_value = past_key_values[idx] - layer_outputs = decoder_layer( + hidden_states, residual = decoder_layer( hidden_states, - token_type_ids=None, - position_ids=position_ids, + rotary_pos_emb=rotary_pos_emb, past_key_value=past_key_value, + residual=residual, + attn_metadata=attn_metadata, + lang_ids=lang_ids, + vision_ids=vision_ids, ) - hidden_states = layer_outputs[0] - hidden_states = self.norm(hidden_states) + # norm + hidden_states, _ = self.norm(hidden_states, residual) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=past_key_value, - hidden_states=None, - attentions=None, - ) + return hidden_states + + def get_input_embeddings(self): + """get input embeddings.""" + return self.embed_tokens + + +LANGUAGE_TOKEN_TYPE = 0 +VISION_TOKEN_TYPE = 1 + + +def get_vision_expert_mask(token_type_ids: torch.LongTensor): + vision_token_mask = torch.zeros_like(token_type_ids, dtype=torch.bool) + vision_token_mask[:, :-1] = (token_type_ids[:, :-1] + == VISION_TOKEN_TYPE) & (token_type_ids[:, 1:] + == VISION_TOKEN_TYPE) + language_token_mask = ~vision_token_mask + return vision_token_mask, language_token_mask -def build_position_ids( - x: 'torch.BoolTensor(B, L)') -> 'torch.LongTensor(B, L)': +def build_position_ids(x: torch.BoolTensor) -> torch.LongTensor: tmp = x.clone() # image boi eoi token as LANGUAGE_TOKEN_TYPE is_boi_eoi = torch.zeros_like(x, dtype=torch.bool) @@ -306,23 +447,26 @@ def build_position_ids( def _get_cogvlm_position_ids(context): """get cogvlm position_ids.""" - inputs = context.inputs - q_seq_length = inputs.seq_length - vision_input_info = inputs.vision_inputs - position_id_offsets = vision_input_info.history_image_token_lengths - vision_input_info.history_image_nums * 3 - if inputs.is_decoding: - position_ids = inputs.history_lengths - position_id_offsets + q_seqlens = context.q_seqlens + history_lengths = context.kv_seqlens - q_seqlens + vision_input_info = context.vision_inputs + position_id_offsets = (vision_input_info.history_image_token_lengths - + vision_input_info.history_image_nums * 3) + lang_ids = None + vis_ids = None + if context.is_decoding: + position_ids = history_lengths - position_id_offsets else: if vision_input_info.input_embeddings is not None and len( vision_input_info.input_embeddings) > 0: - starts = inputs.history_lengths - vision_input_info.history_lengths - ends = starts + q_seq_length + starts = history_lengths - vision_input_info.history_lengths + ends = starts + q_seqlens token_type_ids = vision_input_info.input_embedding_indexing.to( torch.int) - history_position_lengths = vision_input_info.history_lengths - position_id_offsets - position_ids_all = history_position_lengths[:, - None] + build_position_ids( - token_type_ids) + history_position_lengths = (vision_input_info.history_lengths - + position_id_offsets) + position_ids_all = (history_position_lengths[:, None] + + build_position_ids(token_type_ids)) position_ids = torch.cat([ pids[s:e] for (pids, s, e) in zip(position_ids_all, starts, ends) @@ -334,20 +478,161 @@ def _get_cogvlm_position_ids(context): ]) mask_indexing = torch.arange(vision_token_mask.shape[-1], device=vision_token_mask.device) - vision_token_mask_new = mask_indexing[vision_token_mask] - language_token_mask_new = mask_indexing[~vision_token_mask] + vis_ids = mask_indexing[vision_token_mask] + lang_ids = mask_indexing[~vision_token_mask] - context.vision_token_mask = vision_token_mask_new - context.language_token_mask = language_token_mask_new else: position_ids = context.attention_mask.long().cumsum(-1) - 1 - position_ids += (inputs.history_lengths - + position_ids += (history_lengths - position_id_offsets).unsqueeze(-1) device = position_ids.device position_ids_1d = [ - ids[:l] - for ids, l in zip(position_ids.cpu(), q_seq_length.cpu()) + ids[:l] for ids, l in zip(position_ids.cpu(), q_seqlens.cpu()) ] position_ids = torch.cat(position_ids_1d).to(device) - return position_ids + return position_ids, lang_ids, vis_ids + + +class CogVLMForCausalLM(nn.Module): + """ModelForCausalLM.""" + + packed_modules_mapping = { + 'gate_up_proj': [ + 'gate_proj', + 'up_proj', + ], + } + + def __init__(self, + config: PretrainedConfig, + ctx_mgr: StepContextManager, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.ctx_mgr = ctx_mgr + # build model + self.model = CogVLMModel(config, dtype=dtype, device=device) + # build lm_head + self.lm_head = build_rowwise_linear(config.hidden_size, + config.vocab_size, + bias=False, + dtype=dtype, + device=device) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: List[List[torch.Tensor]], + attn_metadata: Any = None, + inputs_embeds: torch.Tensor = None, + lang_ids: torch.LongTensor = None, + vision_ids: torch.LongTensor = None, + **kwargs, + ): + """model forward, return logits.""" + hidden_states = self.model( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + lang_ids=lang_ids, + vision_ids=vision_ids, + ) + + logits = self.lm_head(hidden_states) + logits = logits.float() + return logits + + def support_cuda_graph( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: List[List[torch.Tensor]], + attn_metadata: Any = None, + inputs_embeds: torch.Tensor = None, + lang_ids: torch.LongTensor = None, + vision_ids: torch.LongTensor = None, + **kwargs, + ): + """support cudagraph.""" + return inputs_embeds is None + + def get_input_embeddings(self): + """get input embeddings.""" + return self.model.get_input_embeddings() + + def prepare_inputs_for_generation( + self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + context: StepContext = None, + ): + """prepare input.""" + # get input_ids, position_ids and attention metadatas + input_ids = context.input_ids + position_ids, lang_ids, vis_ids = _get_cogvlm_position_ids(context) + position_ids = position_ids[None] + attn_metadata = context.attn_metadata + + # process vision embeddings + vision_embeddings = context.input_embeddings + vision_embedding_indexing = context.input_embedding_indexing + if vision_embeddings is not None and len(vision_embeddings) > 0: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + inputs_embeds[:, + vision_embedding_indexing, :] = vision_embeddings.to( + inputs_embeds) + + # inputs of forward + return dict( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + lang_ids=lang_ids, + vision_ids=vis_ids, + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """load weights.""" + # modify from vllm + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ('.gate_up_proj', '.gate_proj', 0), + ('.gate_up_proj', '.up_proj', 1), + ] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if 'model.vision' in name: + continue + if 'rotary_emb.inv_freq' in name: + continue + if ('rotary_emb.cos_cached' in name + or 'rotary_emb.sin_cached' in name): + continue + if self.config.tie_word_embeddings and 'lm_head.weight' in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + load_weight(param, loaded_weight, shard_id=shard_id) + break + else: + if '_expert_query_key_value' in name: + param = params_dict[name] + q, k, v = param.weight_spliter(loaded_weight) + load_weight(param, q, shard_id='q') + load_weight(param, k, shard_id='k') + load_weight(param, v, shard_id='v') + else: + param = params_dict[name] + load_weight(param, loaded_weight) diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index b3eec4234..9840367b6 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -131,16 +131,8 @@ # cogvlm MODULE_MAP.update({ - 'modeling_cogvlm.RMSNorm': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaRMSNorm', - 'modeling_cogvlm.MLP': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaMLP', - 'modeling_cogvlm.VisionExpertMLP': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.cogvlm.PatchedVisionExpertMLP', - 'modeling_cogvlm.VisionExpertAttention': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.cogvlm.PatchedVisionExpertAttention', - 'modeling_cogvlm.CogVLMModel': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.cogvlm.PatchedCogVLMModel', + 'CogVLMForCausalLM': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.cogvlm.CogVLMForCausalLM', }) # internvl From 17abf917da406072ddd6c38e33ebd6036b60e9e6 Mon Sep 17 00:00:00 2001 From: grimoire Date: Tue, 3 Sep 2024 17:22:25 +0800 Subject: [PATCH 62/89] update docs --- docs/en/advance/pytorch_new_model.md | 473 +++++++----------------- docs/en/inference/pytorch.md | 22 -- docs/zh_cn/advance/pytorch_new_model.md | 465 +++++++---------------- docs/zh_cn/inference/pytorch.md | 22 -- 4 files changed, 266 insertions(+), 716 deletions(-) diff --git a/docs/en/advance/pytorch_new_model.md b/docs/en/advance/pytorch_new_model.md index d61d96e23..8d6c11dd9 100644 --- a/docs/en/advance/pytorch_new_model.md +++ b/docs/en/advance/pytorch_new_model.md @@ -1,386 +1,181 @@ -# How to support new model in lmdeploy.pytorch +# lmdeploy.pytorch New Model Support -lmdeploy.pytorch is designed to ease new model deployment and prototype verification. If you are willing to use our engine, here is the tutorial. +lmdeploy.pytorch is designed to simplify the support for new models and the development of prototypes. Users can adapt new models according to their own needs. -## Support New Model +## Model Support -Let's begin with Llama. +### Configuration Loading (Optional) -Before delving into the details, it's essential to acquaint ourselves with the input specifications of the model. In order to accommodate new features within our engine, there are some deviations from the typical transformer inputs. - -1. To circumvent the need for batch padding, continuous batching is employed. Consequently, the `input_ids` now represents the concatenation of all input sequences in the batch, followed by a `unsqueeze(0)` operation to align with the original `input_ids` dimension. - -2. In an effort to optimize memory usage for the key/value cache, we implement paged attention. This transforms the `past_key_value` into a substantial tensor with dimensions `[num_blocks, block_size, num_heads, head_dim]`. Here, `num_blocks` denotes the number of page blocks, and `block_size` indicates the size of each block. - -3. Accompanying these changes, additional inputs are imperative to support the modified inputs described above. These include the block table and history length. It's important to note that these supplementary inputs are not explicitly listed as arguments in the original forward method. Instead, a context object is utilized to furnish this essential information. - -Due to the alterations in the input structure mentioned earlier, the forward methods for both `LlamaModel` and `LlamaAttention` modules need to be adjusted. Below are the modified implementations: - -For `LlamaModel`: +lmdeploy.pytorch initializes the engine based on the model's config file. If the parameter naming of the model to be integrated differs from common models in transformers, parsing errors may occur. A custom ConfigBuilder can be added to parse the configuration. ```python -# lmdeploy/pytorch/models/llama.py +# lmdeploy/pytorch/configurations/gemma.py -class LlamaModel(nn.Module): - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - """Rewrite implementation of LlamaModel.forward.""" - inputs_embeds = self.embed_tokens(input_ids) - hidden_states = inputs_embeds - - # decoder layers - for idx, decoder_layer in enumerate(self.layers): - past_key_value = past_key_values[idx] - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - hidden_states = layer_outputs[0] - hidden_states = self.norm(hidden_states) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=past_key_values, - hidden_states=None, - attentions=None, - ) -``` - -For LlamaAttention: - -```python -# lmdeploy/pytorch/models/llama.py -from lmdeploy.pytorch.kernels import apply_rotary_pos_emb, fill_kv_cache, paged_attention_fwd - -class LlamaAttention(nn.Module): - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - """Rewrite of LlamaAttention.forward.""" - context = self.context.context - history_lengths = context.history_lengths - position_ids_1d = context.position_ids_1d - block_offsets = context.block_offsets - - # qkv proj - query_states = q_proj(hidden_states) - key_states = k_proj(hidden_states) - value_states = v_proj(hidden_states) - query_states = query_states.view(-1, num_heads, head_dim) - key_states = key_states.view(-1, num_kv_heads, head_dim) - value_states = value_states.view(-1, num_kv_heads, head_dim) - - # rotary embedding - max_seq_len = position_ids.size(-1) - kv_seq_len = max_seq_len + max(history_lengths) - if kv_seq_len >= self.rotary_emb.max_seq_len_cached: - cos, sin = self.rotary_emb(value_states, - seq_len=kv_seq_len + 128) - query_states, key_states = apply_rotary_pos_emb( - query_states, - key_states, - self.rotary_emb.cos_cached, - self.rotary_emb.sin_cached, - position_ids, - position_ids_1d, - q_embed=query_states, - k_embed=key_states) - - # fill kv cache - kv_seq_length = context.kv_seq_length - q_seq_length = context.q_seq_length - q_start_loc = context.q_start_loc - fill_kv_cache(key_states, - value_states, - past_key_value[0], - past_key_value[1], - q_start_loc, - q_seq_length, - block_offsets=block_offsets, - history_lengths=history_lengths, - context=context) - - # attention - attn_output = query_states - block_size = past_key_value[0].size(1) - paged_attention_fwd( - query_states, - past_key_value[0], - past_key_value[1], - attn_output, - block_offsets, - q_start_loc=q_start_loc, - q_seqlens=q_seq_length, - kv_seqlens=kv_seq_length, - max_seqlen=max_seq_len, - ) - hidden_size = num_heads * head_dim - attn_output = attn_output.reshape(*hidden_states.shape[:-1], hidden_size) - - # o proj - attn_output = o_proj(attn_output) - return attn_output, None, past_key_value -``` +from lmdeploy.pytorch.config import ModelConfig -Note: The additional arguments like `history_lengths` and `block_offsets` are accessed from the `context` object, which acts as a container for the necessary inputs required by continuous batching and paged attention. Refer to the [context info](#context-info) for more detail about `context` object. +from .builder import AutoModelConfigBuilder -We have replaced certain operations with our custom Triton kernel for two reasons: -1. The custom Triton kernel allows us to incorporate new features, such as `paged_attention_fwd`. -2. Fused kernels offer superior performance compared to the pure PyTorch implementation. +class GemmaModelConfigBuilder(AutoModelConfigBuilder): -Now that we have the updated implementations for the two modules, let's register them in `lmdeploy/pytorch/models/module_map.py`. - -```python -# lmdeploy/pytorch/models/module_map.py -MODEL_MAP.update({ - 'transformers.models.llama.modeling_llama.LlamaAttention': - 'lmdeploy.pytorch.models.llama.LlamaAttention', - 'transformers.models.llama.modeling_llama.LlamaModel': - 'lmdeploy.pytorch.models.llama.LlamaModel' -}) -``` - -In this mapping, the revised modules are associated with their original counterparts. When creating an `Engine`, the `ModelAgent` will automatically patch the model. Subsequently, we can conduct inference using these updated implementations. - -## Support Tensor Parallelism - -If we aim to enable tensor parallelism (TP), it is necessary to partition the weights in the model. Let's build upon the previously mentioned modifications to accommodate TP in the Llama model: - -In Llama (as well as in most Language Model models), the weight partition primarily affects the Linear layers. Specifically, for the following components: - -- In `LlamaAttention`: `q_proj`, `k_proj`, `v_proj` require column-wise partitioning, while `o_proj` necessitates row-wise partitioning. -- In `LlamaMLP`: `gate_proj` and `up_proj` require column-wise partitioning, while `down_proj` requires row-wise partitioning. - -We can implement the \_distribution_partition_fn in each of the rewritten modules: - -```python -# lmdeploy/pytorch/models/llama.py -from ..dist_utils import (colwise_parallelize_linear_fn, - rowwise_parallelize_linear_fn) - -class LlamaAttention(nn.Module): @classmethod - def _distribute_partition_fn(cls, mod_name: str, mod: nn.Module, - device_mesh: DeviceMesh): - """Distribution partition callback.""" - if mod_name in ['q_proj', 'k_proj', 'v_proj']: - colwise_parallelize_linear_fn(mod, - device_mesh=device_mesh, - to_local=True) - elif mod_name in ['o_proj']: - rowwise_parallelize_linear_fn(mod, - device_mesh=device_mesh, - to_local=True) + def condition(cls, hf_config): + # Check if hf_config is suitable for this builder + return hf_config.model_type in ['gemma', 'gemma2'] -class LlamaMLP(nn.Module): @classmethod - def _distribute_partition_fn(cls, mod_name: str, mod: nn.Module, - device_mesh: DeviceMesh): - """Distribution partition callback.""" - if mod_name in ['gate_proj', 'up_proj']: - colwise_parallelize_linear_fn(mod, - device_mesh=device_mesh, - to_local=True) - elif mod_name in ['down_proj']: - rowwise_parallelize_linear_fn(mod, - device_mesh=device_mesh, - to_local=True) - + def build(cls, hf_config, model_path: str = None): + # Use the hf_config loaded by transformers + # Construct the ModelConfig for the pytorch engine + return ModelConfig(hidden_size=hf_config.hidden_size, + num_layers=hf_config.num_hidden_layers, + num_attention_heads=hf_config.num_attention_heads, + num_key_value_heads=hf_config.num_key_value_heads, + bos_token_id=hf_config.bos_token_id, + eos_token_id=hf_config.eos_token_id, + head_dim=hf_config.head_dim, + vocab_size=hf_config.vocab_size) ``` -In the process of loading model weights, the `_distribute_partition_fn` is called to distribute the weights of specific modules across different devices. Following the weight partitioning, it becomes necessary to perform `all_reduce` on the output tensors of `o_proj` and `down_proj`. While one option is to include `all_reduce` directly in the forward method, an alternative approach is to introduce the `_distribute_output_fn` call: +The `lmdeploy.pytorch.check_env.check_model` function can be used to verify if the configuration can be parsed correctly. -```python -# lmdeploy/pytorch/models/llama.py -import torch.distributed as dist +### Implementing the Model -class LlamaAttention(nn.Module): - @classmethod - def _distribute_output_fn(cls, outputs, device_mesh: DeviceMesh): - """Distribution output hook.""" - dist.all_reduce(outputs[0]) - return outputs - -class LlamaMLP(nn.Module): - @classmethod - def _distribute_output_fn(cls, outputs, device_mesh: DeviceMesh): - """Distribution output hook.""" - dist.all_reduce(outputs) - return outputs -``` - -It is essential to remember to add `LlamaMLP` to the `module_map`: +After ensuring that the configuration can be parsed correctly, you can start implementing the model logic. Taking the implementation of llama as an example, we need to create the model using the configuration file from transformers. ```python -# lmdeploy/pytorch/models/module_map.py -MODEL_MAP.update({ - 'transformers.models.llama.modeling_llama.LlamaMLP': - 'lmdeploy.pytorch.models.llama.LlamaMLP' -}) -``` - -With these adjustments, the model is now capable of utilizing multiple GPUs for deploying Large Language Models (LLM). This enables efficient distribution of computations across different devices in a parallelized manner. - -## Debug Module - -When the output of the model does not meet expectations, we would like to debug a specific module to determine if the added rewrite is correct. `lmdeploy.pytorch` provides some tools to assist with accuracy alignment. Let’s take `LlamaAttention` module as an example. - -First, create an instance of the module that we want to debug: +class LlamaForCausalLM(nn.Module): + + # Constructor, builds the model with the given config + # ctx_mgr is the context manager, which can be used to pass engine configurations or additional parameters + def __init__(self, + config: LlamaConfig, + ctx_mgr: StepContextManager, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.ctx_mgr = ctx_mgr + # build LLamaModel + self.model = LlamaModel(config, dtype=dtype, device=device) + # build lm_head + self.lm_head = build_rowwise_linear(config.hidden_size, + config.vocab_size, + bias=False, + dtype=dtype, + device=device) + + # Model inference function + # It is recommended to use the same parameters as below + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: List[List[torch.Tensor]], + attn_metadata: Any = None, + inputs_embeds: torch.Tensor = None, + **kwargs, + ): + hidden_states = self.model( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + ) -```python -import torch -from transformers import AutoModelForCausalLM - -# get module -model_path = 'meta-llama/Llama-2-7b-chat-hf' -dtype = torch.float16 -model = AutoModelForCausalLM.from_pretrained(model_path).to(torch.float16).cuda() -self_attn = model.model.layers[0].self_attn + logits = self.lm_head(hidden_states) + logits = logits.float() + return logits ``` -Extract the inputs/outputs with `ModuleIOExtractor`. +In addition to these, the following content needs to be added: ```python -from lmdeploy.pytorch.tools.make_inputs import ModuleIOExtractor - -# extract module input/output -input_ids = torch.tensor([[1, 2, 3, 4, 5]]).cuda() -extractor = ModuleIOExtractor(model, self_attn) -attn_args, attn_kwargs, attn_output = extractor.extract(input_ids) -``` - -The inputs of rewrite module are different from the inputs of origin module: +class LlamaForCausalLM(nn.Module): -1. Module requires some special inputs, which are passed through `StepContext`. We can create one with `make_step_context`. -2. `input_ids`, `hidden_states` should be continuous. We can use `continuous_tensor` to do the process. -3. `past_key_value` should be paged to meet the demand of paged attention. + ... -Based on the reason above, the input should be updated: + # Indicates whether the model supports cudagraph + # Can be a callable object, receiving forward inputs + # Dynamically determines if cudagraph is supported + support_cuda_graph = True -```python -from lmdeploy.pytorch.tools.make_inputs import make_step_context -from lmdeploy.pytorch.tools.layout_convert import continuous_tensor - -# create patched input/output -context = make_step_context(input_ids, - kv_cache_dtype=dtype, - num_key_value_heads=32) -seq_length = context.q_seq_length -attn_kwargs['hidden_states'] = continuous_tensor( - attn_kwargs['hidden_states'], - seq_length) -attn_kwargs['past_key_value'] = context.kv_caches[0] + # Builds model inputs + # Returns a dictionary, the keys of which must be inputs to forward + def prepare_inputs_for_generation( + self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + context: StepContext = None, + ): + ... + + # Loads weights + # The model's inputs are key-value pairs of the state dict + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + ... ``` -Then you can start the rewrite and compare the correctness of the results. +We have encapsulated many fused operators to simplify the model construction. These operators better support various functions such as tensor parallelism and quantization. We encourage developers to use these ops as much as possible. ```python -from lmdeploy.pytorch.models import patch - -# patch and test -patched_self_attn = patch(self_attn, extra_args=['context']) -with torch.inference_mode(): - patched_output = patched_self_attn.patched_forward(*attn_args, - **attn_kwargs, - context=context) -torch.testing.assert_close(patched_output[0], - continuous_tensor(attn_output[0], seq_length)) -``` - -Adjust the rewrite module until the output can be aligned. - -## Appendix +# Using predefined build_merged_colwise_linear, SiluAndMul, build_rowwise_linear +# Helps us build the model faster and without worrying about tensor concurrency, quantization, etc. +class LlamaMLP(nn.Module): -### context info + def __init__(self, + config: LlamaConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + # gate up + self.gate_up_proj = build_merged_colwise_linear( + config.hidden_size, + [config.intermediate_size, config.intermediate_size], + bias=config.mlp_bias, + dtype=dtype, + device=device, + quant_config=quantization_config, + is_tp=True, + ) -```python -@dataclass -class StepContext: - """context of Model. - """ - inputs: ModelInputs - block_offsets: torch.LongTensor - position_ids: torch.LongTensor - position_ids_1d: torch.LongTensor - q_start_loc: torch.LongTensor - history_lengths: torch.LongTensor - seq_length: torch.LongTensor - max_seq_length: int - kv_seq_length: torch.LongTensor - kv_caches: List - is_decoding: bool - world_size: int = 1 - json_config: Dict = None - local_adapter_ids: torch.LongTensor = None - global_adapter_ids: torch.LongTensor = None - adapter_offsets: torch.LongTensor = None - max_rank: int = 0 + # silu and mul + self.act_fn = SiluAndMul(inplace=True) + + # down + self.down_proj = build_rowwise_linear(config.intermediate_size, + config.hidden_size, + bias=config.mlp_bias, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward(self, x): + """forward.""" + gate_up = self.gate_up_proj(x) + act = self.act_fn(gate_up) + return self.down_proj(act) ``` -### FAQ - -- **How to invoke the original forward method?** +### Model Registration -A common approach is to add hooks to a method rather than performing a complete rewrite. To access the unpatched module, you can utilize self.origin_mod within the rewritten method. - -- **How to register modules in remote code?** - -For modules located in remote code, pinpointing them via `qualname` might be challenging. `lmdeploy.pytorch` facilitates registration using abbreviations for such modules:n: +To ensure that the developed model implementation can be used normally, we also need to register the model in `lmdeploy/pytorch/models/module_map.py` ```python MODULE_MAP.update({ - 'modeling_internlm.InternLMAttention': - 'lmdeploy.pytorch.models.internlm.PatchedInternLMAttention', + 'LlamaForCausalLM': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaForCausalLM', }) ``` -> \[!NOTE\] -> -> Although abbreviations are supported, they tend to have lower priority. It is advisable to register modules using their complete `qualname` for more robust and accurate mapping. - -- **How to support different modules with the same name?** +If you do not wish to modify the model source code, you can also pass a custom module map from the outside, making it easier to integrate into other projects. -You can accommodate multiple modules with the same name within a single rewrite module by providing distinct implementations based on their attributes. For instance, consider `baichuan2` 7b/13b: - -```python -class BaichuanModel(nn.Module): - def forward(self, ...): - if self.config.num_hidden_layers == 32: - return forward_7b(...) - else: - return forward_default(...) ``` +from lmdeploy import PytorchEngineConfig, pipeline -- **How to perform post-initialization for a rewrite module?** - -To execute tasks after model weight loading, introduce a `_update_model_fn` method in your rewrite module. This method will be automatically called post-initialization: - -```python -class LlamaAttention: - def _update_model_fn(self): - # ADD YOUR CODE HERE +backend_config = PytorchEngineConfig(custom_module_map='/path/to/custom/module_map.py') +generator = pipeline(model_path, backend_config=backend_config) ``` - -Here, you can include any additional post-initialization steps or configurations needed for your specific use case. diff --git a/docs/en/inference/pytorch.md b/docs/en/inference/pytorch.md index 80323a371..dc724f576 100644 --- a/docs/en/inference/pytorch.md +++ b/docs/en/inference/pytorch.md @@ -47,28 +47,6 @@ ModelAgent consists of two components: 1. \`**patched_model**: : This is the transformer model after patching. In comparison to the original model, the patched model incorporates additional features such as Tensor Parallelism, quantization, and high-performance kernels. 2. **cache_engine**: This component manages the caches. It receives commands from the Scheduler and performs host-device page swaps. Only GPU blocks are utilized for caching key/value pairs and adapters. -## Patching - -In order to facilitate the deployment of a new model, we have developed a tool to patch the modules. - -For example, if we want to reimplement the forward method of `LlamaAttention`: - -```python -class CustomLlamaAttention(nn.Module): - def forward(self, ...): - # custom forward -``` - -We register the implementation above into `lmdeploy.pytorch.models.module_map`: - -```python -MODULE_MAP.update({ -'transformers.models.llama.modeling_llama.LlamaAttention': -'qualname.to.CustomLlamaAttention'}) -``` - -`ModelAgent` would then load and patch `LlamaAttention` with `CustomLlamaAttention` while leaving everything else unchanged. You can perform inference with the new implementation. For more detail about model patching, please refer to [support new model](../advance/pytorch_new_model.md) . - ## Features `lmdeploy.pytorch` supports new features including: diff --git a/docs/zh_cn/advance/pytorch_new_model.md b/docs/zh_cn/advance/pytorch_new_model.md index 26b05aac8..a60801575 100644 --- a/docs/zh_cn/advance/pytorch_new_model.md +++ b/docs/zh_cn/advance/pytorch_new_model.md @@ -1,382 +1,181 @@ # lmdeploy.pytorch 新模型支持 -lmdeploy.pytorch 被设计用来简化新模型的支持以及原型的开发,新模型的支持依赖于 patch 机制,对原模型做修改以及功能添加,以期可以最大程度上复用模型的原始实现,减少工作量。 +lmdeploy.pytorch 被设计用来简化新模型的支持以及原型的开发,用户可以根据自己的需求适配新的模型。 ## 模型支持 -我们以 transformers 中的 llama 实现来介绍模型支持的流程 +### 配置加载(可选) -在开始之前,我们首先要了解一下模型的输入。lmdeploy.pytorch 的输入与标准 transformers 模型的输入略有不同,差异主要体现在如下方面: - -1. 由于支持了 continuous batching,一个 batch 的输入 `input_ids` 会被拼接成一维的长序列,然后 `unsqueeze(0)` 来保证输入维度与 transformers 中相同。这样的输入不会影响 MLP 以及 RMSNorm 等模块的计算。 -2. 由于添加了对 paged attention 的支持,`past_key_value` 不再是原来的大小,而是一组形状为 `[num_blocks, block_size, num_heads, head_dim]` 的 cache 块,num_blocks 为总 block 数量,由可用显存大小决定,block_size 为预设的块大小。这样的输入改变会影响到 LlamaModel 和 LlamaAttention 的计算,因此要对这两个模块的实现进行修改。 -3. 由于上述输入的改变,模型中需要一些额外的输入来支持推理,比如 batch 中的序列起始位置和长度,kv cache 的 block table 等。这些输入并不在模块的 forward 参数列表中,我们需要维护一个上下文以获得这些输入。 - -上面的输入改动会影响 LlamaModel 和 LlamaAttention,首先我们来实现新的 LlamaModel,这是对原始实现的简化,我们删除了很多检查代码,以避免由于输入改变造成的断言失败,仅保留了最小程度的代码: - -```python -# lmdeploy/pytorch/models/llama.py - -class LlamaModel(nn.Module): - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - """Rewrite implementation of LlamaModel.forward.""" - inputs_embeds = self.embed_tokens(input_ids) - hidden_states = inputs_embeds - - # decoder layers - for idx, decoder_layer in enumerate(self.layers): - past_key_value = past_key_values[idx] - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - hidden_states = layer_outputs[0] - hidden_states = self.norm(hidden_states) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=past_key_values, - hidden_states=None, - attentions=None, - ) -``` - -然后是对 LlamaAttention 模块的改写。按顺序实现如下操作: - -1. kqv proj -2. rotary embedding -3. 填充 kv cache -4. MHA 计算 -5. o proj - -continuous batching 和 kv cache 的改动对该模块的影响比较大 +lmdeploy.pytorch 会根据模型的参数初始化引擎,如果需要接入的模型的参数命名与 transformers 中常见模型不同,可能存在解析错误的情况。可以添加自定义的 ConfigBuilder 来解析配置 ```python -# lmdeploy/pytorch/models/llama.py -from lmdeploy.pytorch.kernels import apply_rotary_pos_emb, fill_kv_cache, paged_attention_fwd +# lmdeploy/pytorch/configurations/gemma.py -class LlamaAttention(nn.Module): - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - """Rewrite of LlamaAttention.forward.""" - context = self.context.context - history_lengths = context.history_lengths - position_ids_1d = context.position_ids_1d - block_offsets = context.block_offsets - - # qkv proj - query_states = q_proj(hidden_states) - key_states = k_proj(hidden_states) - value_states = v_proj(hidden_states) - query_states = query_states.view(-1, num_heads, head_dim) - key_states = key_states.view(-1, num_kv_heads, head_dim) - value_states = value_states.view(-1, num_kv_heads, head_dim) - - # rotary embedding - max_seq_len = position_ids.size(-1) - kv_seq_len = max_seq_len + max(history_lengths) - if kv_seq_len >= self.rotary_emb.max_seq_len_cached: - cos, sin = self.rotary_emb(value_states, - seq_len=kv_seq_len + 128) - query_states, key_states = apply_rotary_pos_emb( - query_states, - key_states, - self.rotary_emb.cos_cached, - self.rotary_emb.sin_cached, - position_ids, - position_ids_1d, - q_embed=query_states, - k_embed=key_states) - - # fill kv cache - kv_seq_length = context.kv_seq_length - q_seq_length = context.q_seq_length - q_start_loc = context.q_start_loc - fill_kv_cache(key_states, - value_states, - past_key_value[0], - past_key_value[1], - q_start_loc, - q_seq_length, - block_offsets=block_offsets, - history_lengths=history_lengths, - context=context) - - # attention - attn_output = query_states - block_size = past_key_value[0].size(1) - paged_attention_fwd( - query_states, - past_key_value[0], - past_key_value[1], - attn_output, - block_offsets, - q_start_loc=q_start_loc, - q_seqlens=q_seq_length, - kv_seqlens=kv_seq_length, - max_seqlen=max_seq_len, - ) - hidden_size = num_heads * head_dim - attn_output = attn_output.reshape(*hidden_states.shape[:-1], hidden_size) - - # o proj - attn_output = o_proj(attn_output) - return attn_output, None, past_key_value -``` +from lmdeploy.pytorch.config import ModelConfig -上面的代码有几处值得注意的地方,首先是 context 对象。我们需要 history_lengths、block_offsets 等参数辅助运算,这些参数无法通过模型的 forward 函数传递进来。因此我们维护了一个 context 对象,把几乎所有可能用到的输入参数都保存在其中,方便在各个模块间共享。context 对象可以通过 `self.context.context` 来访问,结构可以参考 [context-结构](#context-结构)。 +from .builder import AutoModelConfigBuilder -另一个值得注意的地方就是自定义 kernel,由于输入形式的改变,原来的 LlamaAttention 实现变得不再适用,为了保证推理的速度和正确性,我们在 lmdeploy.pytorch.kernels 中实现了许多自定义的 triton kernel,上面的模块中就用到了 `apply_rotary_pos_emb`,`fill_kv_cache` 和 `paged_attention_fwd` ,分别负责实现 rotary embedding,填充 kv cache 还有 attention 的计算。 -有了上述的两个模块后,还需要将他们注册到 `lmdeploy/pytorch/models/module_map.py` 中,进行原模块与 patch 模块的映射 +class GemmaModelConfigBuilder(AutoModelConfigBuilder): -```python -# lmdeploy/pytorch/models/module_map.py -MODEL_MAP.update({ - 'transformers.models.llama.modeling_llama.LlamaAttention': - 'lmdeploy.pytorch.models.llama.LlamaAttention', - 'transformers.models.llama.modeling_llama.LlamaModel': - 'lmdeploy.pytorch.models.llama.LlamaModel' -}) -``` - -完成注册后,Engine 在启动时就会将这两个模块 patch 成新的实现,完成后续的部署任务。 - -## Tensor 并发支持 - -为了支持 Tensor 并发,需要对模型的权重做切分。让我们试着为上面接入的 Llama 模型添加 TP 的支持。 - -Llama 中涉及到 Tensor 并发的模块是 LlamaAttention 中的 qkvo proj 和 LlamaMLP 中的 gate,up 和 down proj。其中 o_proj 和 down_proj 需要按行切分,剩下的按列切分。我们可以在对应的模块中实现 `_distribution_partition_fn` 函数: - -```python -# lmdeploy/pytorch/models/llama.py -from ..dist_utils import (colwise_parallelize_linear_fn, - rowwise_parallelize_linear_fn) - -class LlamaAttention(nn.Module): @classmethod - def _distribute_partition_fn(cls, mod_name: str, mod: nn.Module, - device_mesh: DeviceMesh): - """Distribution partition callback.""" - if mod_name in ['q_proj', 'k_proj', 'v_proj']: - colwise_parallelize_linear_fn(mod, - device_mesh=device_mesh, - to_local=True) - elif mod_name in ['o_proj']: - rowwise_parallelize_linear_fn(mod, - device_mesh=device_mesh, - to_local=True) + def condition(cls, hf_config): + # 判断 hf_config 是否适配该 builder + return hf_config.model_type in ['gemma', 'gemma2'] -class LlamaMLP(nn.Module): @classmethod - def _distribute_partition_fn(cls, mod_name: str, mod: nn.Module, - device_mesh: DeviceMesh): - """Distribution partition callback.""" - if mod_name in ['gate_proj', 'up_proj']: - colwise_parallelize_linear_fn(mod, - device_mesh=device_mesh, - to_local=True) - elif mod_name in ['down_proj']: - rowwise_parallelize_linear_fn(mod, - device_mesh=device_mesh, - to_local=True) - + def build(cls, hf_config, model_path: str = None): + # 使用 transformers 加载的 hf_config + # 构造 pytorch engine 的 ModelConfig + return ModelConfig(hidden_size=hf_config.hidden_size, + num_layers=hf_config.num_hidden_layers, + num_attention_heads=hf_config.num_attention_heads, + num_key_value_heads=hf_config.num_key_value_heads, + bos_token_id=hf_config.bos_token_id, + eos_token_id=hf_config.eos_token_id, + head_dim=hf_config.head_dim, + vocab_size=hf_config.vocab_size) ``` -`_distribute_partition_fn` 会在加载模型权重时被调用,对应的权重会被按照特定的形式分配到对应的设备中。 - -按照目前的方案切分后的权重,需要对 o_proj 和 down_proj 的结果进行 all_reduce 操作才能得到正确的结果。可以选择将 all_reduce 放在模型的 forward 函数中,也可以选择另一种方案,添加 `_distribute_output_fn` 函数: +可以使用 `lmdeploy.pytorch.check_env.check_model` 函数验证配置是否能够正确解析 -```python -# lmdeploy/pytorch/models/llama.py -import torch.distributed as dist +### 实现模型 -class LlamaAttention(nn.Module): - @classmethod - def _distribute_output_fn(cls, outputs, device_mesh: DeviceMesh): - """Distribution output hook.""" - dist.all_reduce(outputs[0]) - return outputs - -class LlamaMLP(nn.Module): - @classmethod - def _distribute_output_fn(cls, outputs, device_mesh: DeviceMesh): - """Distribution output hook.""" - dist.all_reduce(outputs) - return outputs -``` - -最后别忘了将 LlamaMLP 也注册进 module_map 中 +在确保能够正确解析配置后,就可以开始实现模型逻辑。以 llama 的实现为例,我们需要通过 transformers 的配置文件创建模型 ```python -# lmdeploy/pytorch/models/module_map.py -MODEL_MAP.update({ - 'transformers.models.llama.modeling_llama.LlamaMLP': - 'lmdeploy.pytorch.models.llama.LlamaMLP' -}) -``` - -这样就可以利用多卡的优势,让更大的模型部署成为可能 - -## 模块调试 - -当模型的输出不符合预期时,我们会希望调试某个特定模块以确定添加的重写是否正确。`lmdeploy.pytorch` 提供了一些工具以帮助进行精度对齐。还是以上面提到的 `LlamaAttention` 模块为例。 - -首先,我们通过 transformers 的 API 得到想要调试的子模块的一个实例: +class LlamaForCausalLM(nn.Module): + + # 构造函数,通过传入的 config 搭建模型 + # ctx_mgr 是上下文管理器,可以通过它传入引擎配置或额外参数 + def __init__(self, + config: LlamaConfig, + ctx_mgr: StepContextManager, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.ctx_mgr = ctx_mgr + # build LLamaModel + self.model = LlamaModel(config, dtype=dtype, device=device) + # build lm_head + self.lm_head = build_rowwise_linear(config.hidden_size, + config.vocab_size, + bias=False, + dtype=dtype, + device=device) + + # 模型推理函数 + # 推荐尽可能使用与下面相同的参数 + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: List[List[torch.Tensor]], + attn_metadata: Any = None, + inputs_embeds: torch.Tensor = None, + **kwargs, + ): + hidden_states = self.model( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + ) -```python -import torch -from transformers import AutoModelForCausalLM - -# get module -model_path = 'meta-llama/Llama-2-7b-chat-hf' -dtype = torch.float16 -model = AutoModelForCausalLM.from_pretrained(model_path).to(torch.float16).cuda() -self_attn = model.model.layers[0].self_attn + logits = self.lm_head(hidden_states) + logits = logits.float() + return logits ``` -然后,使用 `ModuleIOExtractor` 工具可以生成该模块的一组输入输出 +除了这些以外,还有如下内容需要添加 ```python -from lmdeploy.pytorch.tools.make_inputs import ModuleIOExtractor - -# extract module input/output -input_ids = torch.tensor([[1, 2, 3, 4, 5]]).cuda() -extractor = ModuleIOExtractor(model, self_attn) -attn_args, attn_kwargs, attn_output = extractor.extract(input_ids) -``` - -重写模块的输入与原模块略有不同,主要体现在三方面: +class LlamaForCausalLM(nn.Module): -1. 模型需要一些特殊输入输出,他们以 `StepContext` 的形式传入,可以使用 `make_step_context` 生成。 -2. `input_ids`,`hidden_states` 等数据都被 continuous 化,可以使用 `continuous_tensor` 进行处理。 -3. 由于 paged caching 的需要, `past_key_value` 需要被 page 化处理。 + ... -基于以上原因,我们要对提取的输入进行加工: + # 标注该模型是否支持 cudagraph + # 可以是一个 callable 对象,接收 forward 输入 + # 动态判断是否支持 cudagraph + support_cuda_graph = True -```python -from lmdeploy.pytorch.tools.make_inputs import make_step_context -from lmdeploy.pytorch.tools.layout_convert import continuous_tensor - -# create patched input/output -context = make_step_context(input_ids, - kv_cache_dtype=dtype, - num_key_value_heads=32) -seq_length = context.q_seq_length -attn_kwargs['hidden_states'] = continuous_tensor( - attn_kwargs['hidden_states'], - seq_length) -attn_kwargs['past_key_value'] = context.kv_caches[0] + # 构建模型输入 + # 返回词典,词典的 key 必须是 forward 的输入 + def prepare_inputs_for_generation( + self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + context: StepContext = None, + ): + ... + + # 加载权重 + # 模型的输入是 state dict 的 key value 对 + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + ... ``` -然后就可以启动重写,并比较结果正确性了。(注意输出也要 continuous 化后进行比较) +我们封装了许多融合算子以简化模型的搭建。这些算子能够更好的支持 tensor 并行、量化等各种功能,我们鼓励开发者尽可能使用这些 op 进行开发。 ```python -from lmdeploy.pytorch.models import patch - -# patch and test -patched_self_attn = patch(self_attn, extra_args=['context']) -with torch.inference_mode(): - patched_output = patched_self_attn.patched_forward(*attn_args, - **attn_kwargs, - context=context) -torch.testing.assert_close(patched_output[0], - continuous_tensor(attn_output[0], seq_length)) -``` - -可以通过上述方法调试重写模块,直到精度满足预期。 - -## 附录 +# 使用预定义的 build_merged_colwise_linear, SiluAndMul, build_rowwise_linear +# 可以帮助我们更快搭建模型,并且不用关心 tensor 并发、量化等细节 +class LlamaMLP(nn.Module): -### context 结构 + def __init__(self, + config: LlamaConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + # gate up + self.gate_up_proj = build_merged_colwise_linear( + config.hidden_size, + [config.intermediate_size, config.intermediate_size], + bias=config.mlp_bias, + dtype=dtype, + device=device, + quant_config=quantization_config, + is_tp=True, + ) -```python -@dataclass -class StepContext: - """context of Model. - """ - inputs: ModelInputs - block_offsets: torch.LongTensor - position_ids: torch.LongTensor - position_ids_1d: torch.LongTensor - q_start_loc: torch.LongTensor - history_lengths: torch.LongTensor - seq_length: torch.LongTensor - max_seq_length: int - kv_seq_length: torch.LongTensor - kv_caches: List - is_decoding: bool - world_size: int = 1 - json_config: Dict = None - local_adapter_ids: torch.LongTensor = None - global_adapter_ids: torch.LongTensor = None - adapter_offsets: torch.LongTensor = None - max_rank: int = 0 + # silu and mul + self.act_fn = SiluAndMul(inplace=True) + + # down + self.down_proj = build_rowwise_linear(config.intermediate_size, + config.hidden_size, + bias=config.mlp_bias, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward(self, x): + """forward.""" + gate_up = self.gate_up_proj(x) + act = self.act_fn(gate_up) + return self.down_proj(act) ``` -### FAQ - -- **如何访问 patch 前的模块?** +### 模型注册 -有时我们只希望在函数前后加一个 hook 代码,不希望大段的拷贝函数,可以通过 `self.origin_mod` 访问 patch 前的模块。 - -- **非 transformers 官方的模型该如何注册?** - -一些模型的实现代码可能是以 remote code 的形式添加的,这样的模块无法通过完整的 qualname 来定位。lmdeploy.pytorch 支持使用缩写的模块名进行注册: +为了能够让开发的模型实现可以正常使用,我们还需要在 `lmdeploy/pytorch/models/module_map.py` 中注册该模型 ```python MODULE_MAP.update({ - 'modeling_internlm.InternLMAttention': - 'lmdeploy.pytorch.models.internlm.PatchedInternLMAttention', + 'LlamaForCausalLM': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaForCausalLM', }) ``` -> \[!NOTE\] -> -> 缩写的优先级会更低,有条件的话还是鼓励使用完整的 qualname 进行注册。 - -- **模块出现同名但不同实现怎么处理?** - -目前推荐的做法是同名就映射到同一个实现中,然后在实现内部根据模块的固有参数来判断模型该使用的类型,以 baichuan2 7b/13b 为例: +如果你不希望修改模型源码,也可以从外部传入自定义的 module map,方便整合进其他项目中 -```python -class BaichuanModel(nn.Module): - def forward(self, ...): - if self.config.num_hidden_layers == 32: - return forward_7b(...) - else: - return forward_default(...) ``` +from lmdeploy import PytorchEngineConfig, pipeline -- **如果希望在推理前对模块进行初始化?** - -可以实现模块的 `_update_model_fn` 函数,它会在模块的权重都加载完,完成 TP 权重切分后被调用 - -```python -class LlamaAttention: - def _update_model_fn(self): - # ADD YOUR CODE HERE +backend_config = PytorchEngineConfig(custom_module_map='/path/to/custom/module_map.py') +generator = pipeline(model_path, backend_config=backend_config) ``` diff --git a/docs/zh_cn/inference/pytorch.md b/docs/zh_cn/inference/pytorch.md index 982153f94..410d55301 100644 --- a/docs/zh_cn/inference/pytorch.md +++ b/docs/zh_cn/inference/pytorch.md @@ -47,28 +47,6 @@ ModelAgent 有两个重要组件: 1. patched_model 是更新后的 transformer 模型,更新后的模型添加了各种功能的支持,包括更高性能的子模块实现、TP、量化等等 2. cache_engine 是缓存的分配与交换模块。它接收来自 scheduler 的交换请求,执行 host-device 间显存交换,adapter 加载等工作 -## Patching - -为了降低接入模型的门槛,我们实现了一套简单的 patch 机制来简化实现的替换。 - -以 Llama 模型的 LlamaAttention.forward 为例,我们可以重新写一个 forward 的实现: - -```python -class CustomLlamaAttention(nn.Module): - def forward(self, ...): - # custom forward -``` - -然后在 `lmdeploy.pytorch.models.module_map` 中注册模块的映射 - -```python -MODULE_MAP.update({ -'transformers.models.llama.modeling_llama.LlamaAttention': -'qualname.to.CustomLlamaAttention'}) -``` - -经过 patch 后的模型就会使用新的 forward 实现。TP、量化等功能也依赖 patch 机制,请阅读 [lmdeploy.pytorch 新模型支持](../advance/pytorch_new_model.md) 了解更多细节。 - ## 特性 - **Continuous Batching**: 由于输入序列的长度不一样,batching 通常需要对输入进行 padding,这种 padding 会导致后续运算的计算量增加、影响速度,也会使得显存的占用大幅增加。遵循许多其他成熟框架的方案,lmdeploy.pytorch 采用了 continuous batching 的方式对输入做了连续化处理,避免了多余的资源占用。 From c6824d5d8f4fa628c9c8c2b064b99e9e70fb6c76 Mon Sep 17 00:00:00 2001 From: grimoire Date: Tue, 3 Sep 2024 18:35:06 +0800 Subject: [PATCH 63/89] fused layernorm --- lmdeploy/pytorch/backends/base.py | 1 + .../pytorch/backends/default/layer_backend.py | 3 + lmdeploy/pytorch/backends/default/norm.py | 39 ++++++++++++- lmdeploy/pytorch/backends/norm.py | 23 ++++++++ lmdeploy/pytorch/models/dbrx.py | 42 ++++++++------ lmdeploy/pytorch/models/llama.py | 2 - lmdeploy/pytorch/models/phi3_moe.py | 56 ++++++++++--------- lmdeploy/pytorch/models/starcoder2.py | 50 +++++++++-------- lmdeploy/pytorch/nn/__init__.py | 2 +- lmdeploy/pytorch/nn/norm.py | 55 +++++++++++++++++- 10 files changed, 201 insertions(+), 72 deletions(-) diff --git a/lmdeploy/pytorch/backends/base.py b/lmdeploy/pytorch/backends/base.py index 9e813782d..457f9ecda 100644 --- a/lmdeploy/pytorch/backends/base.py +++ b/lmdeploy/pytorch/backends/base.py @@ -18,6 +18,7 @@ class LayerType(Enum): ApplyRotaryEmb = auto() SiluAndMul = auto() RMSNorm = auto() + LayerNorm = auto() SLoRA = auto() LinearW8A8 = auto() RMSNormW8A8 = auto() diff --git a/lmdeploy/pytorch/backends/default/layer_backend.py b/lmdeploy/pytorch/backends/default/layer_backend.py index 587c2be32..c090e60f6 100644 --- a/lmdeploy/pytorch/backends/default/layer_backend.py +++ b/lmdeploy/pytorch/backends/default/layer_backend.py @@ -30,6 +30,9 @@ def get_layer_impl_builder(cls, layer_type: LayerType): elif layer_type == LayerType.RMSNorm: from .norm import DefaultRMSNormBuilder return DefaultRMSNormBuilder + elif layer_type == LayerType.LayerNorm: + from .norm import DefaultLayerNormBuilder + return DefaultLayerNormBuilder elif layer_type == LayerType.MultinomialSampling: from .multinomial_sampling import DefaultMultinomialSamplingBuilder return DefaultMultinomialSamplingBuilder diff --git a/lmdeploy/pytorch/backends/default/norm.py b/lmdeploy/pytorch/backends/default/norm.py index 0870f2943..7ae3bd824 100644 --- a/lmdeploy/pytorch/backends/default/norm.py +++ b/lmdeploy/pytorch/backends/default/norm.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch -from ..norm import RMSNormBuilder, RMSNormImpl +from ..norm import LayerNormBuilder, LayerNormImpl, RMSNormBuilder, RMSNormImpl class DefaultRMSNormImpl(RMSNormImpl): @@ -36,3 +36,40 @@ class DefaultRMSNormBuilder(RMSNormBuilder): def build(hidden_size: int, eps: float = 1e-6, inplace: bool = False): """build.""" return DefaultRMSNormImpl(hidden_size, eps) + + +class DefaultLayerNormImpl(LayerNormImpl): + """RMS norm implementation api.""" + + def __init__(self, normalized_shape: int, eps: float = 1e-6): + if isinstance(normalized_shape, int): + normalized_shape = (normalized_shape, ) + self.normalized_shape = normalized_shape + self.eps = eps + + def forward(self, + x: torch.Tensor, + weight: torch.Tensor = None, + bias: torch.Tensor = None, + residual: torch.Tensor = None): + """forward.""" + if residual is not None: + x = x + residual + residual = x + x = torch.nn.functional.layer_norm(x, + self.normalized_shape, + weight=weight, + bias=bias, + eps=self.eps) + if residual is None: + return x + return x, residual + + +class DefaultLayerNormBuilder(LayerNormBuilder): + """RMS norm implementation builder.""" + + @staticmethod + def build(normalized_shape: int, eps: float = 1e-6): + """build.""" + return DefaultLayerNormImpl(normalized_shape, eps) diff --git a/lmdeploy/pytorch/backends/norm.py b/lmdeploy/pytorch/backends/norm.py index de555c402..494793acd 100644 --- a/lmdeploy/pytorch/backends/norm.py +++ b/lmdeploy/pytorch/backends/norm.py @@ -24,3 +24,26 @@ class RMSNormBuilder(ABC): def build(hidden_size: int, eps: float = 1e-6, inplace: bool = False): """build.""" raise NotImplementedError + + +class LayerNormImpl(ABC): + """Layer norm implementation api.""" + + @abstractmethod + def forward(self, + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor = None, + residual: torch.Tensor = None): + """forward.""" + raise NotImplementedError + + +class LayerNormBuilder(ABC): + """layer norm implementation builder.""" + + @staticmethod + @abstractmethod + def build(normalized_shape: int, eps: float = 1e-6): + """build.""" + raise NotImplementedError diff --git a/lmdeploy/pytorch/models/dbrx.py b/lmdeploy/pytorch/models/dbrx.py index 056796fcb..919d3c552 100644 --- a/lmdeploy/pytorch/models/dbrx.py +++ b/lmdeploy/pytorch/models/dbrx.py @@ -7,7 +7,7 @@ from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, EmbeddingType, - build_rotary_embedding) + LayerNorm, build_rotary_embedding) from lmdeploy.pytorch.nn.linear import build_qkv_proj, build_rowwise_linear from lmdeploy.pytorch.nn.moe import FusedMoE, SoftmaxTopK from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight @@ -232,7 +232,7 @@ def __init__(self, super().__init__() self.layer_idx = layer_idx - self.norm_1 = nn.LayerNorm( + self.norm_1 = LayerNorm( config.d_model, bias=False, dtype=dtype, @@ -243,7 +243,7 @@ def __init__(self, dtype=dtype, device=device, ) - self.norm_2 = nn.LayerNorm( + self.norm_2 = LayerNorm( config.d_model, bias=False, dtype=dtype, @@ -255,11 +255,16 @@ def forward( hidden_states: torch.Tensor, rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], past_key_value: Optional[Tuple[torch.Tensor]] = None, + residual_states: Optional[torch.Tensor] = None, attn_metadata: Any = None, ): """forward.""" - residual_states = hidden_states - hidden_states = self.norm_1(hidden_states).to(hidden_states.dtype) + if residual_states is None: + residual_states = hidden_states + hidden_states = self.norm_1(hidden_states) + else: + hidden_states, residual_states = self.norm_1( + hidden_states, residual_states) hidden_states = self.attn( hidden_states, @@ -268,10 +273,8 @@ def forward( attn_metadata, ) - hidden_states = hidden_states + residual_states - - residual_states = hidden_states - hidden_states = self.norm_2(hidden_states).to(hidden_states.dtype) + hidden_states, residual_states = self.norm_2(hidden_states, + residual_states) return hidden_states, residual_states @@ -300,21 +303,22 @@ def forward( hidden_states: torch.Tensor, rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], past_key_value: Optional[List[torch.FloatTensor]], + residual: Optional[torch.Tensor] = None, attn_metadata: Any = None, ): # Self Attention - hidden_states, resid_states = self.norm_attn_norm( + hidden_states, residual = self.norm_attn_norm( hidden_states=hidden_states, rotary_pos_emb=rotary_pos_emb, past_key_value=past_key_value, + residual_states=residual, attn_metadata=attn_metadata, ) # Fully Connected hidden_states = self.ffn(hidden_states) - hidden_states = resid_states + hidden_states - return hidden_states + return hidden_states, residual class DbrxModel(nn.Module): @@ -341,10 +345,10 @@ def __init__(self, ]) # build norm - self.norm_f = nn.LayerNorm(config.d_model, - bias=False, - dtype=dtype, - device=device) + self.norm_f = LayerNorm(config.d_model, + bias=False, + dtype=dtype, + device=device) # build rotary embedding emb_type = EmbeddingType.LinearScaling @@ -380,17 +384,19 @@ def forward( rotary_pos_emb = (cos, sin) # decoding + residual = None for idx, decoder_layer in enumerate(self.blocks): past_key_value = past_key_values[idx] - hidden_states = decoder_layer( + hidden_states, residual = decoder_layer( hidden_states, rotary_pos_emb=rotary_pos_emb, past_key_value=past_key_value, + residual=residual, attn_metadata=attn_metadata, ) # norm - hidden_states = self.norm_f(hidden_states) + hidden_states, _ = self.norm_f(hidden_states, residual) return hidden_states diff --git a/lmdeploy/pytorch/models/llama.py b/lmdeploy/pytorch/models/llama.py index 85ef0cac0..fd7d514e6 100644 --- a/lmdeploy/pytorch/models/llama.py +++ b/lmdeploy/pytorch/models/llama.py @@ -216,7 +216,6 @@ def __init__(self, super().__init__() self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - quantization_config = getattr(config, 'quantization_config', None) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, @@ -233,7 +232,6 @@ def __init__(self, # build norm self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, - quant_config=quantization_config, dtype=dtype, device=device) diff --git a/lmdeploy/pytorch/models/phi3_moe.py b/lmdeploy/pytorch/models/phi3_moe.py index 4b00a6161..51e0cfe5b 100644 --- a/lmdeploy/pytorch/models/phi3_moe.py +++ b/lmdeploy/pytorch/models/phi3_moe.py @@ -5,7 +5,8 @@ from torch import nn from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, EmbeddingType +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, EmbeddingType, + LayerNorm) from lmdeploy.pytorch.nn.linear import build_qkv_proj, build_rowwise_linear from lmdeploy.pytorch.nn.moe import FusedMoE from lmdeploy.pytorch.nn.rotary_embedding import (LongRoPEScalingParameters, @@ -227,29 +228,31 @@ def __init__(self, device=device) # build input layer norm - self.input_layernorm = nn.LayerNorm(config.hidden_size, - eps=config.rms_norm_eps, - elementwise_affine=True, - dtype=dtype, - device=device) + self.input_layernorm = LayerNorm(config.hidden_size, + eps=config.rms_norm_eps, + dtype=dtype, + device=device) # build attention layer norm - self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, - eps=config.rms_norm_eps, - elementwise_affine=True, - dtype=dtype, - device=device) + self.post_attention_layernorm = LayerNorm(config.hidden_size, + eps=config.rms_norm_eps, + dtype=dtype, + device=device) def forward( self, hidden_states: torch.Tensor, rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], past_key_value: Optional[List[torch.FloatTensor]], + residual: Optional[torch.Tensor] = None, attn_metadata: Any = None, ): - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) # Self Attention hidden_states = self.self_attn( @@ -258,15 +261,14 @@ def forward( past_key_value=past_key_value, attn_metadata=attn_metadata, ) - hidden_states = residual + hidden_states # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) hidden_states, _ = self.block_sparse_moe(hidden_states) - hidden_states = residual + hidden_states - return hidden_states + outputs = (hidden_states, residual) + return outputs class PhiMoEModel(nn.Module): @@ -290,11 +292,10 @@ def __init__(self, ]) # build norm - self.norm = nn.LayerNorm(config.hidden_size, - eps=config.rms_norm_eps, - elementwise_affine=True, - dtype=dtype, - device=device) + self.norm = LayerNorm(config.hidden_size, + eps=config.rms_norm_eps, + dtype=dtype, + device=device) # build rotary embedding emb_type = EmbeddingType.LinearScaling @@ -345,17 +346,18 @@ def forward( cos, sin = self.rotary_emb(hidden_states, position_ids) cos, sin = cos[0], sin[0] rotary_pos_emb = (cos, sin) + residual = None for idx, decoder_layer in enumerate(self.layers): - past_key_value = past_key_values[idx] - hidden_states = decoder_layer( + hidden_states, residual = decoder_layer( hidden_states, rotary_pos_emb=rotary_pos_emb, past_key_value=past_key_value, + residual=residual, attn_metadata=attn_metadata, ) - hidden_states = self.norm(hidden_states) + hidden_states, _ = self.norm(hidden_states, residual) return hidden_states diff --git a/lmdeploy/pytorch/models/starcoder2.py b/lmdeploy/pytorch/models/starcoder2.py index ea1a7d526..7cb97be2b 100644 --- a/lmdeploy/pytorch/models/starcoder2.py +++ b/lmdeploy/pytorch/models/starcoder2.py @@ -7,7 +7,7 @@ from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, EmbeddingType, - build_rotary_embedding) + LayerNorm, build_rotary_embedding) from lmdeploy.pytorch.nn.linear import (build_colwise_linear, build_qkv_proj, build_rowwise_linear) from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight @@ -166,27 +166,31 @@ def __init__(self, self.mlp = Starcoder2MLP(config, dtype=dtype, device=device) # build input layer norm - self.input_layernorm = nn.LayerNorm(config.hidden_size, - eps=config.norm_epsilon, - dtype=dtype, - device=device) + self.input_layernorm = LayerNorm(config.hidden_size, + eps=config.norm_epsilon, + dtype=dtype, + device=device) # build attention layer norm - self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, - eps=config.norm_epsilon, - dtype=dtype, - device=device) + self.post_attention_layernorm = LayerNorm(config.hidden_size, + eps=config.norm_epsilon, + dtype=dtype, + device=device) def forward( self, hidden_states: torch.Tensor, rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], past_key_value: Optional[List[torch.FloatTensor]], + residual: Optional[torch.Tensor] = None, attn_metadata: Any = None, ): - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) # Self Attention hidden_states = self.self_attn( @@ -195,14 +199,14 @@ def forward( past_key_value=past_key_value, attn_metadata=attn_metadata, ) - hidden_states = residual + hidden_states # Fully Connected - hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - return hidden_states + outputs = (hidden_states, residual) + return outputs class Starcoder2Model(nn.Module): @@ -232,10 +236,10 @@ def __init__(self, ]) # build norm - self.norm = nn.LayerNorm(config.hidden_size, - eps=config.norm_epsilon, - dtype=dtype, - device=device) + self.norm = LayerNorm(config.hidden_size, + eps=config.norm_epsilon, + dtype=dtype, + device=device) # build rotary embedding emb_type = EmbeddingType.LinearScaling @@ -271,17 +275,19 @@ def forward( rotary_pos_emb = (cos, sin) # decoding + residual = None for idx, decoder_layer in enumerate(self.layers): past_key_value = past_key_values[idx] - hidden_states = decoder_layer( + hidden_states, residual = decoder_layer( hidden_states, rotary_pos_emb=rotary_pos_emb, past_key_value=past_key_value, + residual=residual, attn_metadata=attn_metadata, ) # norm - hidden_states = self.norm(hidden_states) + hidden_states, _ = self.norm(hidden_states, residual) return hidden_states diff --git a/lmdeploy/pytorch/nn/__init__.py b/lmdeploy/pytorch/nn/__init__.py index 8303c27a6..2809e8ec1 100644 --- a/lmdeploy/pytorch/nn/__init__.py +++ b/lmdeploy/pytorch/nn/__init__.py @@ -3,7 +3,7 @@ # https://github.com/vllm-project/vllm/blob/main/vllm/attention/ from .activation import SiluAndMul # noqa: F401 from .attention import Attention # noqa: F401 -from .norm import RMSNorm # noqa: F401 +from .norm import LayerNorm, RMSNorm # noqa: F401 from .rotary_embedding import ApplyRotaryEmb # noqa: F401 from .rotary_embedding import EmbeddingType # noqa: F401 from .rotary_embedding import YarnParameters # noqa: F401 diff --git a/lmdeploy/pytorch/nn/norm.py b/lmdeploy/pytorch/nn/norm.py index 5394e1892..f7b40b48a 100644 --- a/lmdeploy/pytorch/nn/norm.py +++ b/lmdeploy/pytorch/nn/norm.py @@ -9,7 +9,14 @@ def _is_w8a8(quant_config: Any): """is w8a8.""" - return False + if quant_config is None: + return False + else: + quant_method = quant_config['quant_method'] + if quant_method == 'w8a8': + return True + else: + return False class RMSNorm(nn.Module): @@ -49,3 +56,49 @@ def create_weight(hidden_size: int, def forward(self, x: torch.Tensor, residual: torch.Tensor = None): """forward.""" return self.impl.forward(x, self.weight, residual) + + +class LayerNorm(nn.Module): + """Layer Norm with add residual.""" + + def __init__(self, + hidden_size: int, + eps: float = 1e-6, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + backend = get_backend() + builder = backend.get_layer_impl_builder(LayerType.LayerNorm) + weight, bias = self.create_weight(hidden_size, bias, dtype, device) + self.register_parameter('weight', weight) + self.register_parameter('bias', bias) + self.impl = builder.build(hidden_size, eps) + + @staticmethod + def create_weight(hidden_size: int, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None): + """create weight.""" + if dtype is None: + dtype = torch.float16 + if device is None: + device = 'cuda' + weight = torch.nn.Parameter(torch.ones(hidden_size, + dtype=dtype, + device=device), + requires_grad=False) + if bias: + bias = torch.nn.Parameter(torch.ones(hidden_size, + dtype=dtype, + device=device), + requires_grad=False) + else: + bias = None + + return weight, bias + + def forward(self, x: torch.Tensor, residual: torch.Tensor = None): + """forward.""" + return self.impl.forward(x, self.weight, self.bias, residual) From 985c769337c683df08c55de541d9dd2f51015d05 Mon Sep 17 00:00:00 2001 From: grimoire Date: Tue, 3 Sep 2024 18:43:59 +0800 Subject: [PATCH 64/89] add gelu and mul --- lmdeploy/pytorch/backends/activation.py | 19 +++++++++++++++ lmdeploy/pytorch/backends/base.py | 1 + .../pytorch/backends/default/activation.py | 24 ++++++++++++++++++- .../pytorch/backends/default/layer_backend.py | 3 +++ lmdeploy/pytorch/models/gemma.py | 7 +++--- lmdeploy/pytorch/nn/__init__.py | 2 +- lmdeploy/pytorch/nn/activation.py | 14 +++++++++++ 7 files changed, 64 insertions(+), 6 deletions(-) diff --git a/lmdeploy/pytorch/backends/activation.py b/lmdeploy/pytorch/backends/activation.py index 327dcbfcc..b38bca4a4 100644 --- a/lmdeploy/pytorch/backends/activation.py +++ b/lmdeploy/pytorch/backends/activation.py @@ -19,3 +19,22 @@ class SiluAndMulBuilder(ABC): def build(inplace: bool = False): """build.""" raise NotImplementedError + + +class GeluAndMulImpl(ABC): + """gelu + multiple residual fused implementation.""" + + @abstractmethod + def forward(self, x): + """forward.""" + raise NotImplementedError + + +class GeluAndMulBuilder(ABC): + """gelu and mul implementation builder.""" + + @staticmethod + @abstractmethod + def build(approximate: str = 'none'): + """build.""" + raise NotImplementedError diff --git a/lmdeploy/pytorch/backends/base.py b/lmdeploy/pytorch/backends/base.py index 457f9ecda..a690558f8 100644 --- a/lmdeploy/pytorch/backends/base.py +++ b/lmdeploy/pytorch/backends/base.py @@ -17,6 +17,7 @@ class LayerType(Enum): RotaryEmbedding = auto() ApplyRotaryEmb = auto() SiluAndMul = auto() + GeluAndMul = auto() RMSNorm = auto() LayerNorm = auto() SLoRA = auto() diff --git a/lmdeploy/pytorch/backends/default/activation.py b/lmdeploy/pytorch/backends/default/activation.py index e5f03000e..a0eda376c 100644 --- a/lmdeploy/pytorch/backends/default/activation.py +++ b/lmdeploy/pytorch/backends/default/activation.py @@ -1,7 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. from torch import nn -from ..activation import SiluAndMulBuilder, SiluAndMulImpl +from ..activation import (GeluAndMulBuilder, GeluAndMulImpl, SiluAndMulBuilder, + SiluAndMulImpl) class DefaultSiluAndMulImpl(SiluAndMulImpl): @@ -24,3 +25,24 @@ class DefaultSiluAndMulBuilder(SiluAndMulBuilder): def build(inplace: bool = False): """build.""" return DefaultSiluAndMulImpl(inplace) + + +class DefaultGeluAndMulImpl(GeluAndMulImpl): + """gelu + multiple residual fused implementation.""" + + def __init__(self, approximate: str = 'none'): + self.act = nn.GELU(approximate=approximate) + + def forward(self, x): + """forward.""" + gate, up = x.chunk(2, -1) + return self.act(gate) * up + + +class DefaultGeluAndMulBuilder(GeluAndMulBuilder): + """gelu and mul implementation builder.""" + + @staticmethod + def build(approximate: str = 'none'): + """build.""" + return DefaultGeluAndMulImpl(approximate) diff --git a/lmdeploy/pytorch/backends/default/layer_backend.py b/lmdeploy/pytorch/backends/default/layer_backend.py index c090e60f6..3bfb377d1 100644 --- a/lmdeploy/pytorch/backends/default/layer_backend.py +++ b/lmdeploy/pytorch/backends/default/layer_backend.py @@ -27,6 +27,9 @@ def get_layer_impl_builder(cls, layer_type: LayerType): elif layer_type == LayerType.SiluAndMul: from .activation import DefaultSiluAndMulBuilder return DefaultSiluAndMulBuilder + elif layer_type == LayerType.GeluAndMul: + from .activation import DefaultGeluAndMulBuilder + return DefaultGeluAndMulBuilder elif layer_type == LayerType.RMSNorm: from .norm import DefaultRMSNormBuilder return DefaultRMSNormBuilder diff --git a/lmdeploy/pytorch/models/gemma.py b/lmdeploy/pytorch/models/gemma.py index 3c02e8c3c..a7538387f 100644 --- a/lmdeploy/pytorch/models/gemma.py +++ b/lmdeploy/pytorch/models/gemma.py @@ -8,7 +8,7 @@ from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, EmbeddingType, - RMSNorm, build_rotary_embedding) + GeluAndMul, RMSNorm, build_rotary_embedding) from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear) from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight @@ -134,7 +134,7 @@ def __init__(self, if hidden_activation is None: hidden_activation = 'gelu_pytorch_tanh' assert hidden_activation == 'gelu_pytorch_tanh' - self.act_fn = nn.GELU(approximate='tanh') + self.act_fn = GeluAndMul(approximate='tanh') # down self.down_proj = build_rowwise_linear(config.intermediate_size, @@ -148,8 +148,7 @@ def __init__(self, def forward(self, x): """forward.""" gate_up = self.gate_up_proj(x) - gate, up = gate_up.chunk(2, -1) - act = self.act_fn(gate) * up + act = self.act_fn(gate_up) out = self.down_proj(act) return out diff --git a/lmdeploy/pytorch/nn/__init__.py b/lmdeploy/pytorch/nn/__init__.py index 2809e8ec1..e90899bb8 100644 --- a/lmdeploy/pytorch/nn/__init__.py +++ b/lmdeploy/pytorch/nn/__init__.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. # attention module is modified from: # https://github.com/vllm-project/vllm/blob/main/vllm/attention/ -from .activation import SiluAndMul # noqa: F401 +from .activation import GeluAndMul, SiluAndMul # noqa: F401 from .attention import Attention # noqa: F401 from .norm import LayerNorm, RMSNorm # noqa: F401 from .rotary_embedding import ApplyRotaryEmb # noqa: F401 diff --git a/lmdeploy/pytorch/nn/activation.py b/lmdeploy/pytorch/nn/activation.py index 9e278d556..6c73d9752 100644 --- a/lmdeploy/pytorch/nn/activation.py +++ b/lmdeploy/pytorch/nn/activation.py @@ -16,3 +16,17 @@ def __init__(self, inplace: bool = True): def forward(self, x: Tensor): """forward.""" return self.impl.forward(x) + + +class GeluAndMul(nn.Module): + """Gelu and elementwise multiple.""" + + def __init__(self, approximate: str = 'none'): + super().__init__() + backend = get_backend() + builder = backend.get_layer_impl_builder(LayerType.GeluAndMul) + self.impl = builder.build(approximate) + + def forward(self, x: Tensor): + """forward.""" + return self.impl.forward(x) From 5de7cd9308f5ff2ff112ea4ca4bb801dc5f2d9f7 Mon Sep 17 00:00:00 2001 From: grimoire Date: Wed, 4 Sep 2024 10:42:12 +0800 Subject: [PATCH 65/89] support triton==3.0.0 --- lmdeploy/pytorch/check_env/__init__.py | 4 +-- .../pytorch/kernels/cuda/pagedattention.py | 30 ++++++++++++++----- 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/lmdeploy/pytorch/check_env/__init__.py b/lmdeploy/pytorch/check_env/__init__.py index 3f7e3d840..aa4423964 100644 --- a/lmdeploy/pytorch/check_env/__init__.py +++ b/lmdeploy/pytorch/check_env/__init__.py @@ -65,8 +65,8 @@ def check_env_triton(): import triton if version.parse( triton.__version__) > version.parse(MAX_TRITON_VERSION): - logger.warning(f'Install triton<={MAX_TRITON_VERSION}' - ' if you want to get better performance.') + logger.warning( + f'Engine has not been tested on triton>{MAX_TRITON_VERSION}.') from .triton_custom_add import custom_add a = torch.tensor([1, 2], device='cuda') diff --git a/lmdeploy/pytorch/kernels/cuda/pagedattention.py b/lmdeploy/pytorch/kernels/cuda/pagedattention.py index ebfb65cf7..d727841c9 100644 --- a/lmdeploy/pytorch/kernels/cuda/pagedattention.py +++ b/lmdeploy/pytorch/kernels/cuda/pagedattention.py @@ -16,6 +16,20 @@ assert TRITON_VERSION >= version.parse('2.1.0') +if TRITON_VERSION >= version.parse('3.0.0'): + + @triton.jit + def tanh(x): + """tanh.""" + return 2 * tl.sigmoid(2 * x) - 1 + + fast_expf = tl.math.exp + fast_dividef = tl.math.fdiv +else: + tanh = tl.math.tanh + fast_expf = tl.math.fast_expf + fast_dividef = tl.math.fast_dividef + @triton.autotune(configs=[ triton.Config({}, num_stages=2, num_warps=16), @@ -191,7 +205,7 @@ def _fwd_grouped_split_kernel( qk *= sm_scale if logit_softcapping > 0.0: qk = qk / logit_softcapping - qk = tl.math.tanh(qk) + qk = tanh(qk) qk = qk * logit_softcapping # NOTE: inf - inf = nan, and nan will leads to error if start_n + BLOCK_N > history_len or window_size > 0: @@ -206,8 +220,8 @@ def _fwd_grouped_split_kernel( # -- compute p, m_i and l_i m_i_new = tl.maximum(m_i, tl.max(qk, 1)) - p = tl.math.fast_expf(qk - m_i_new[:, None]) - alpha = tl.math.fast_expf(m_i - m_i_new) + p = fast_expf(qk - m_i_new[:, None]) + alpha = fast_expf(m_i - m_i_new) l_i_new = alpha * l_i + tl.sum(p, 1) # -- update output accumulator -- @@ -280,7 +294,7 @@ def _reduce_split_kernel( l_k = tl.load(Acc + offs_mi + 1) m_max = tl.max(m_k, 0) - alpha = tl.math.fast_expf(m_k - m_max) + alpha = fast_expf(m_k - m_max) acc_k = acc_k * alpha[:, None] l_k = l_k * alpha @@ -446,7 +460,7 @@ def _fwd_kernel( qk *= sm_scale if logit_softcapping > 0.0: qk = qk / logit_softcapping - qk = tl.math.tanh(qk) + qk = tanh(qk) qk = qk * logit_softcapping # NOTE: inf - inf = nan, and nan will leads to error if start_n + BLOCK_N > history_len or window_size > 0: @@ -463,8 +477,8 @@ def _fwd_kernel( # -- compute p, m_i and l_i m_i_new = tl.maximum(m_i, tl.max(qk, 1)) - p = tl.math.fast_expf(qk - m_i_new[:, None]) - alpha = tl.math.fast_expf(m_i - m_i_new) + p = fast_expf(qk - m_i_new[:, None]) + alpha = fast_expf(m_i - m_i_new) l_i_new = alpha * l_i + tl.sum(p, 1) # -- update output accumulator -- # scale acc @@ -477,7 +491,7 @@ def _fwd_kernel( l_i = l_i_new m_i = m_i_new - acc = tl.math.fast_dividef(acc, l_i[:, None]) + acc = fast_dividef(acc, l_i[:, None]) # initialize pointers to output off_o = ((q_start_loc + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_dv[None, :] * stride_od) From 7716147a69bb4b676ece8862b1d3a70f1b4825e7 Mon Sep 17 00:00:00 2001 From: grimoire Date: Wed, 4 Sep 2024 16:32:08 +0800 Subject: [PATCH 66/89] update names --- .../pytorch/kernels/cuda/pagedattention.py | 20 +++++++++++++------ lmdeploy/pytorch/nn/linear.py | 6 +++--- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/lmdeploy/pytorch/kernels/cuda/pagedattention.py b/lmdeploy/pytorch/kernels/cuda/pagedattention.py index d727841c9..c5d4d4db6 100644 --- a/lmdeploy/pytorch/kernels/cuda/pagedattention.py +++ b/lmdeploy/pytorch/kernels/cuda/pagedattention.py @@ -111,9 +111,9 @@ def _fwd_grouped_split_kernel( BLOCK_DMODEL1: tl.constexpr, ): """first step kernel of split k attention.""" - cur_batch = tl.program_id(0) - cur_kv_head = tl.program_id(1) - split_k_id = tl.program_id(2) + cur_batch = tl.program_id(2) + cur_kv_head = tl.program_id(0) + split_k_id = tl.program_id(1) if BLOCK_H < kv_group_num: HEAD_PER_CTA: tl.constexpr = BLOCK_H @@ -184,9 +184,11 @@ def _fwd_grouped_split_kernel( kv_min_loc = tl.maximum(history_len - window_size, 0) loop_start = start_block_id * BLOCK_N + block_offset_ptrs += start_block_id for start_n in range(loop_start, loop_end, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) - b_offset = tl.load(block_offset_ptrs + start_n // BLOCK_N) + b_offset = tl.load(block_offset_ptrs) + block_offset_ptrs += 1 # -- compute qk ---- k = tl.load(k_ptrs + b_offset * stride_kp) @@ -439,9 +441,11 @@ def _fwd_kernel( start_block_id = tl.maximum(history_len - window_size, 0) // BLOCK_N kv_min_loc = tl.maximum(history_len + offs_m - window_size, 0) kv_start_loc = start_block_id * BLOCK_N + block_offset_ptrs += start_block_id for start_n in range(kv_start_loc, kv_seqlen, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) - b_offset = tl.load(block_offset_ptrs + start_n // BLOCK_N) + b_offset = tl.load(block_offset_ptrs) + block_offset_ptrs += 1 # -- compute qk ---- k = tl.load(k_ptrs + b_offset * stride_kp) @@ -623,7 +627,11 @@ def _get_block_d(Lk): p2_kv_group_num = triton.next_power_of_2(kv_group_num) BLOCK_H = max(16, min(BLOCK, p2_kv_group_num)) grid_1 = triton.cdiv(head, min(BLOCK_H, kv_group_num)) - grid = (batch, grid_1, SPLIT_K) + grid = ( + grid_1, + SPLIT_K, + batch, + ) _fwd_grouped_split_kernel[grid](q, k, v, diff --git a/lmdeploy/pytorch/nn/linear.py b/lmdeploy/pytorch/nn/linear.py index bace20bf2..b9daf0140 100644 --- a/lmdeploy/pytorch/nn/linear.py +++ b/lmdeploy/pytorch/nn/linear.py @@ -1045,7 +1045,7 @@ def build_linear(in_features: int, is_tp=is_tp, all_reduce=all_reduce, ) - if quant_method == 'w8a8': + if quant_method == 'smooth_quant': return W8A8Linear( in_features, out_features, @@ -1142,7 +1142,7 @@ def build_merged_colwise_linear( device=device, is_tp=is_tp, ) - if quant_method == 'w8a8': + if quant_method == 'smooth_quant': return MergedW8A8Linear( in_features=in_features, all_out_features=all_out_features, @@ -1206,7 +1206,7 @@ def build_qkv_proj(in_features: int, device=device, is_tp=is_tp, ) - if quant_method == 'w8a8': + if quant_method == 'smooth_quant': return QKVW8A8Linear( in_features=in_features, num_q_heads=num_q_heads, From 9daa0d17f4e8af65214f11f4051ae952d6d131eb Mon Sep 17 00:00:00 2001 From: grimoire Date: Wed, 4 Sep 2024 18:27:01 +0800 Subject: [PATCH 67/89] fix --- lmdeploy/pytorch/adapter/adapter.py | 2 +- lmdeploy/pytorch/backends/cuda/norm.py | 2 +- lmdeploy/pytorch/backends/default/norm.py | 2 +- lmdeploy/pytorch/backends/norm.py | 2 +- lmdeploy/pytorch/configurations/internvl.py | 1 - lmdeploy/pytorch/models/chatglm2.py | 88 ++++++++++++++++++++- 6 files changed, 90 insertions(+), 7 deletions(-) diff --git a/lmdeploy/pytorch/adapter/adapter.py b/lmdeploy/pytorch/adapter/adapter.py index da7149cb9..3e39cea6b 100644 --- a/lmdeploy/pytorch/adapter/adapter.py +++ b/lmdeploy/pytorch/adapter/adapter.py @@ -59,7 +59,7 @@ def find_all_target(model: torch.nn.Module, target_name: str): def get_max_ranks_per_block(block_numel: int, rank_stride: int): assert block_numel >= rank_stride, ( - 'LoRA Adapter raquires larger block_size.') + 'LoRA Adapter requires larger block_size.') return block_numel // rank_stride diff --git a/lmdeploy/pytorch/backends/cuda/norm.py b/lmdeploy/pytorch/backends/cuda/norm.py index 9d5100978..defab7928 100644 --- a/lmdeploy/pytorch/backends/cuda/norm.py +++ b/lmdeploy/pytorch/backends/cuda/norm.py @@ -30,6 +30,6 @@ class TritonRMSNormBuilder(RMSNormBuilder): """triton RMS norm implementation builder.""" @staticmethod - def build(weight: torch.Tensor, eps: float = 1e-6, inplace: bool = False): + def build(weight: torch.Tensor, eps: float = 1e-6): """build.""" return TritonRMSNormImpl(weight, eps) diff --git a/lmdeploy/pytorch/backends/default/norm.py b/lmdeploy/pytorch/backends/default/norm.py index 7ae3bd824..590ad50f4 100644 --- a/lmdeploy/pytorch/backends/default/norm.py +++ b/lmdeploy/pytorch/backends/default/norm.py @@ -33,7 +33,7 @@ class DefaultRMSNormBuilder(RMSNormBuilder): """RMS norm implementation builder.""" @staticmethod - def build(hidden_size: int, eps: float = 1e-6, inplace: bool = False): + def build(hidden_size: int, eps: float = 1e-6): """build.""" return DefaultRMSNormImpl(hidden_size, eps) diff --git a/lmdeploy/pytorch/backends/norm.py b/lmdeploy/pytorch/backends/norm.py index 494793acd..2d717849f 100644 --- a/lmdeploy/pytorch/backends/norm.py +++ b/lmdeploy/pytorch/backends/norm.py @@ -21,7 +21,7 @@ class RMSNormBuilder(ABC): @staticmethod @abstractmethod - def build(hidden_size: int, eps: float = 1e-6, inplace: bool = False): + def build(hidden_size: int, eps: float = 1e-6): """build.""" raise NotImplementedError diff --git a/lmdeploy/pytorch/configurations/internvl.py b/lmdeploy/pytorch/configurations/internvl.py index 929b928dd..76b4187c5 100644 --- a/lmdeploy/pytorch/configurations/internvl.py +++ b/lmdeploy/pytorch/configurations/internvl.py @@ -14,6 +14,5 @@ def condition(cls, hf_config): def build(cls, hf_config, model_path: str = None): """build llava hf.""" cfg = DefaultModelConfigBuilder.build(hf_config.llm_config) - # cfg.unused_modules = ['InternVisionModel'] cfg.hf_config = hf_config return cfg diff --git a/lmdeploy/pytorch/models/chatglm2.py b/lmdeploy/pytorch/models/chatglm2.py index b66f83a05..287a90ffa 100644 --- a/lmdeploy/pytorch/models/chatglm2.py +++ b/lmdeploy/pytorch/models/chatglm2.py @@ -12,6 +12,87 @@ build_qkv_proj, build_rowwise_linear) from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight +LANGUAGE_TOKEN_TYPE = 0 +VISION_TOKEN_TYPE = 1 + + +def get_vision_expert_mask(token_type_ids: torch.LongTensor): + vision_token_mask = torch.zeros_like(token_type_ids, dtype=torch.bool) + vision_token_mask[:, :-1] = (token_type_ids[:, :-1] + == VISION_TOKEN_TYPE) & (token_type_ids[:, 1:] + == VISION_TOKEN_TYPE) + language_token_mask = ~vision_token_mask + return vision_token_mask, language_token_mask + + +def build_position_ids(x: torch.BoolTensor) -> torch.LongTensor: + tmp = x.clone() + # image boi eoi token as LANGUAGE_TOKEN_TYPE + is_boi_eoi = torch.zeros_like(x, dtype=torch.bool) + is_boi_eoi[:, 1:] |= (tmp[:, 1:] == VISION_TOKEN_TYPE) & ( + tmp[:, :-1] == LANGUAGE_TOKEN_TYPE) + is_boi_eoi[:, 0] |= (tmp[:, 0] == VISION_TOKEN_TYPE) + is_boi_eoi[:, :-1] |= (tmp[:, :-1] == VISION_TOKEN_TYPE) & ( + tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) + is_boi_eoi[:, -1] |= (tmp[:, -1] == VISION_TOKEN_TYPE) + tmp[is_boi_eoi] = LANGUAGE_TOKEN_TYPE + # final position ids + y = torch.zeros_like(x, dtype=torch.long) + y[:, 1:] = (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) | ( + (tmp[:, 1:] == VISION_TOKEN_TYPE) & + (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE)) + y = y.cumsum(dim=-1) + return y + + +def _get_cogvlm_position_ids(context): + """get cogvlm position_ids.""" + q_seqlens = context.q_seqlens + history_lengths = context.kv_seqlens - q_seqlens + vision_input_info = context.vision_inputs + position_id_offsets = (vision_input_info.history_image_token_lengths - + vision_input_info.history_image_nums * 3) + lang_ids = None + vis_ids = None + if context.is_decoding: + position_ids = history_lengths - position_id_offsets + else: + if vision_input_info.input_embeddings is not None and len( + vision_input_info.input_embeddings) > 0: + starts = history_lengths - vision_input_info.history_lengths + ends = starts + q_seqlens + token_type_ids = vision_input_info.input_embedding_indexing.to( + torch.int) + history_position_lengths = (vision_input_info.history_lengths - + position_id_offsets) + position_ids_all = (history_position_lengths[:, None] + + build_position_ids(token_type_ids)) + position_ids = torch.cat([ + pids[s:e] + for (pids, s, e) in zip(position_ids_all, starts, ends) + ]) + vision_token_mask_all, _ = get_vision_expert_mask(token_type_ids) + vision_token_mask = torch.cat([ + masks[s:e] + for (masks, s, e) in zip(vision_token_mask_all, starts, ends) + ]) + mask_indexing = torch.arange(vision_token_mask.shape[-1], + device=vision_token_mask.device) + vis_ids = mask_indexing[vision_token_mask] + lang_ids = mask_indexing[~vision_token_mask] + + else: + position_ids = context.attention_mask.long().cumsum(-1) - 1 + position_ids += (history_lengths - + position_id_offsets).unsqueeze(-1) + device = position_ids.device + position_ids_1d = [ + ids[:l] for ids, l in zip(position_ids.cpu(), q_seqlens.cpu()) + ] + position_ids = torch.cat(position_ids_1d).to(device) + + return position_ids, lang_ids, vis_ids + class SelfAttention(torch.nn.Module): """Parallel self-attention layer abstract class. @@ -58,8 +139,7 @@ def __init__(self, # o_proj self.dense = build_rowwise_linear(self.projection_size, config.hidden_size, - bias=config.add_bias_linear - or config.add_qkv_bias, + bias=config.add_bias_linear, quant_config=quantization_config, dtype=dtype, device=device, @@ -444,6 +524,8 @@ def prepare_inputs_for_generation( input_ids = context.input_ids position_ids = context.position_ids attn_metadata = context.attn_metadata + if context.vision_inputs is not None: + position_ids = _get_cogvlm_position_ids(context)[0][None] # process vision embeddings vision_embeddings = context.input_embeddings @@ -470,6 +552,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: + if 'transformer.vision' in name: + continue if 'rotary_pos_emb.inv_freq' in name: continue if ('rotary_pos_emb.cos_cached' in name From 7d8ac693caa018847f45bed6162ddc5c4cffe76b Mon Sep 17 00:00:00 2001 From: grimoire Date: Wed, 4 Sep 2024 18:53:39 +0800 Subject: [PATCH 68/89] cogvlm2 --- lmdeploy/pytorch/models/cogvlm.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/lmdeploy/pytorch/models/cogvlm.py b/lmdeploy/pytorch/models/cogvlm.py index 9e73c538c..76c237fb1 100644 --- a/lmdeploy/pytorch/models/cogvlm.py +++ b/lmdeploy/pytorch/models/cogvlm.py @@ -35,12 +35,16 @@ def __init__(self, dtype: torch.dtype = None, device: torch.device = None): super().__init__() + is_cogvlm2 = hasattr(config, 'num_multi_query_heads') quantization_config = getattr(config, 'quantization_config', None) num_heads = config.num_attention_heads - num_key_value_heads = num_heads + num_key_value_heads = getattr(config, 'num_multi_query_heads', + num_heads) hidden_size = config.hidden_size head_dim = getattr(config, 'head_dim', hidden_size // num_heads) self.hidden_size = hidden_size + self.num_kv_heads = num_key_value_heads + self.head_dim = head_dim # packed qkv self.vision_expert_query_key_value = build_qkv_proj( @@ -48,7 +52,7 @@ def __init__(self, num_q_heads=num_heads, num_kv_heads=num_key_value_heads, head_size=head_dim, - bias=False, + bias=is_cogvlm2, quant_config=quantization_config, dtype=dtype, device=device, @@ -109,12 +113,14 @@ def forward( """Rewrite of LlamaAttention.forward.""" bsz, seqlen, _ = hidden_states.size() hidden_size = self.hidden_size // self.world_size + kv_size = self.num_kv_heads * self.head_dim // self.world_size # qkv proj if lang_ids is None and vision_ids is None: qkv_states = self.language_expert_query_key_value(hidden_states) else: - qkv_states = hidden_states.new_empty(bsz, seqlen, hidden_size * 3) + qkv_states = hidden_states.new_empty(bsz, seqlen, + hidden_size + kv_size * 2) if lang_ids is not None: qkv_states[:, lang_ids] = self.language_expert_query_key_value( hidden_states[:, lang_ids]) From e5a6c379e10736acae1c88e1b72173efa9155349 Mon Sep 17 00:00:00 2001 From: grimoire Date: Wed, 4 Sep 2024 20:09:47 +0800 Subject: [PATCH 69/89] fix --- lmdeploy/pytorch/backends/default/rotary_embedding.py | 9 +++++---- lmdeploy/pytorch/backends/rotary_embedding.py | 10 ++++++++-- lmdeploy/pytorch/backends/selector.py | 5 +---- lmdeploy/pytorch/models/llama.py | 9 +++++---- lmdeploy/pytorch/nn/rotary_embedding.py | 8 +++----- 5 files changed, 22 insertions(+), 19 deletions(-) diff --git a/lmdeploy/pytorch/backends/default/rotary_embedding.py b/lmdeploy/pytorch/backends/default/rotary_embedding.py index 1003e678d..f0567e077 100644 --- a/lmdeploy/pytorch/backends/default/rotary_embedding.py +++ b/lmdeploy/pytorch/backends/default/rotary_embedding.py @@ -4,7 +4,8 @@ import torch from torch import nn -from ..rotary_embedding import (EmbeddingType, LongRoPEScalingParameters, +from ..rotary_embedding import (EmbeddingType, Llama3Parameters, + LongRoPEScalingParameters, RotaryEmbeddingBuilder, RotaryEmbeddingImpl, YarnParameters) @@ -320,10 +321,9 @@ def build( max_position_embeddings: int = 2048, base: int = 10000, scaling_factor: float = 1.0, - low_freq_factor: float = 1.0, - high_freq_factor: float = 4.0, yarn_params: YarnParameters = None, longrope_params: LongRoPEScalingParameters = None, + llama3_params: Llama3Parameters = None, emb_type: EmbeddingType = EmbeddingType.Default, ): """build.""" @@ -334,7 +334,8 @@ def build( dim, base, scaling_factor, max_position_embeddings) elif emb_type == EmbeddingType.Llama3: return Llama3RotaryEmbeddingImpl(dim, base, scaling_factor, - low_freq_factor, high_freq_factor, + llama3_params.low_freq_factor, + llama3_params.high_freq_factor, max_position_embeddings) elif emb_type == EmbeddingType.Yarn: return YarnRotaryEmbeddingImpl(dim, diff --git a/lmdeploy/pytorch/backends/rotary_embedding.py b/lmdeploy/pytorch/backends/rotary_embedding.py index ce85b1846..a01dd9696 100644 --- a/lmdeploy/pytorch/backends/rotary_embedding.py +++ b/lmdeploy/pytorch/backends/rotary_embedding.py @@ -34,6 +34,13 @@ class LongRoPEScalingParameters: short_mscale: float = None +@dataclass +class Llama3Parameters: + """llama3 rope parameters.""" + low_freq_factor: float = 1.0 + high_freq_factor: float = 4.0 + + class RotaryEmbeddingImpl(ABC): """rotary embedding implementation api.""" @@ -53,10 +60,9 @@ def build( max_position_embeddings: int = 2048, base: int = 10000, scaling_factor: float = 1.0, - low_freq_factor: float = 1.0, - high_freq_factor: float = 4.0, yarn_params: YarnParameters = None, longrope_params: LongRoPEScalingParameters = None, + llama3_params: Llama3Parameters = None, emb_type: EmbeddingType = EmbeddingType.Default, ): """build.""" diff --git a/lmdeploy/pytorch/backends/selector.py b/lmdeploy/pytorch/backends/selector.py index e7680048d..44cc5a289 100644 --- a/lmdeploy/pytorch/backends/selector.py +++ b/lmdeploy/pytorch/backends/selector.py @@ -1,6 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. from lmdeploy.pytorch.devices import get_device_manager -from lmdeploy.utils import get_logger def get_backend(): @@ -17,6 +16,4 @@ def get_backend(): from .ascend import AscendLayersBackend return AscendLayersBackend else: - logger = get_logger('lmdeploy') - logger.warning(f'Unsupported device type: {device_type}') - return None + raise RuntimeError(f'Unsupported device type: {device_type}') diff --git a/lmdeploy/pytorch/models/llama.py b/lmdeploy/pytorch/models/llama.py index fd7d514e6..0261a6a55 100644 --- a/lmdeploy/pytorch/models/llama.py +++ b/lmdeploy/pytorch/models/llama.py @@ -10,6 +10,7 @@ RMSNorm, SiluAndMul, build_rotary_embedding) from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear) +from lmdeploy.pytorch.nn.rotary_embedding import Llama3Parameters from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight @@ -252,20 +253,20 @@ def __init__(self, rope_max_pos_emb = config.max_position_embeddings rope_base = config.rope_theta scaling_factor = 1.0 - low_freq_factor = 1.0 - high_freq_factor = 4.0 + llama3_params = None if rope_scaling is not None: scaling_factor = rope_scaling.get('scaling_factor', scaling_factor) if emb_type == EmbeddingType.Llama3: low_freq_factor = rope_scaling.get('low_freq_factor', 1.0) high_freq_factor = rope_scaling.get('high_freq_factor', 1.0) + llama3_params = Llama3Parameters(low_freq_factor, + high_freq_factor) self.rotary_emb = build_rotary_embedding( rope_dim, rope_max_pos_emb, rope_base, scaling_factor, - low_freq_factor=low_freq_factor, - high_freq_factor=high_freq_factor, + llama3_params=llama3_params, emb_type=emb_type, ) diff --git a/lmdeploy/pytorch/nn/rotary_embedding.py b/lmdeploy/pytorch/nn/rotary_embedding.py index 9d035c961..911e31ff9 100644 --- a/lmdeploy/pytorch/nn/rotary_embedding.py +++ b/lmdeploy/pytorch/nn/rotary_embedding.py @@ -2,7 +2,7 @@ from torch import Tensor, nn from ..backends import LayerType, get_backend -from ..backends.rotary_embedding import (EmbeddingType, +from ..backends.rotary_embedding import (EmbeddingType, Llama3Parameters, LongRoPEScalingParameters, YarnParameters) @@ -12,10 +12,9 @@ def build_rotary_embedding( max_position_embeddings: int = 2048, base: int = 10000, scaling_factor: float = 1.0, - low_freq_factor: float = 1.0, - high_freq_factor: float = 4.0, yarn_params: YarnParameters = None, longrope_params: LongRoPEScalingParameters = None, + llama3_params: Llama3Parameters = None, emb_type: EmbeddingType = EmbeddingType.Default) -> nn.Module: """build rotary embedding op.""" backend = get_backend() @@ -25,10 +24,9 @@ def build_rotary_embedding( max_position_embeddings, base, scaling_factor, - low_freq_factor=low_freq_factor, - high_freq_factor=high_freq_factor, yarn_params=yarn_params, longrope_params=longrope_params, + llama3_params=llama3_params, emb_type=emb_type) From 2236f282ae9c80d5385435d0c3a816a86a23a0be Mon Sep 17 00:00:00 2001 From: grimoire Date: Thu, 5 Sep 2024 11:02:55 +0800 Subject: [PATCH 70/89] fix --- lmdeploy/pytorch/backends/__init__.py | 2 +- lmdeploy/pytorch/backends/ascend/__init__.py | 2 +- .../{layer_backend.py => op_backend.py} | 6 ++--- lmdeploy/pytorch/backends/base.py | 4 +-- lmdeploy/pytorch/backends/cuda/__init__.py | 2 +- .../cuda/{layer_backend.py => op_backend.py} | 22 ++++++++-------- lmdeploy/pytorch/backends/default/__init__.py | 2 +- .../{layer_backend.py => op_backend.py} | 24 ++++++++--------- lmdeploy/pytorch/models/llama.py | 26 +++++++++++-------- lmdeploy/pytorch/models/qwen2.py | 5 ++++ lmdeploy/pytorch/nn/activation.py | 6 ++--- lmdeploy/pytorch/nn/attention.py | 5 ++-- lmdeploy/pytorch/nn/linear.py | 12 ++++----- lmdeploy/pytorch/nn/moe.py | 7 +++-- lmdeploy/pytorch/nn/multinomial_sampling.py | 4 +-- lmdeploy/pytorch/nn/norm.py | 8 +++--- lmdeploy/pytorch/nn/rotary_embedding.py | 6 ++--- 17 files changed, 74 insertions(+), 69 deletions(-) rename lmdeploy/pytorch/backends/ascend/{layer_backend.py => op_backend.py} (96%) rename lmdeploy/pytorch/backends/cuda/{layer_backend.py => op_backend.py} (87%) rename lmdeploy/pytorch/backends/default/{layer_backend.py => op_backend.py} (77%) diff --git a/lmdeploy/pytorch/backends/__init__.py b/lmdeploy/pytorch/backends/__init__.py index dc8538228..51a89d9e3 100644 --- a/lmdeploy/pytorch/backends/__init__.py +++ b/lmdeploy/pytorch/backends/__init__.py @@ -1,3 +1,3 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .base import LayerType # noqa: F401 +from .base import OpType # noqa: F401 from .selector import get_backend # noqa: F401 diff --git a/lmdeploy/pytorch/backends/ascend/__init__.py b/lmdeploy/pytorch/backends/ascend/__init__.py index 73444e906..94bad085f 100644 --- a/lmdeploy/pytorch/backends/ascend/__init__.py +++ b/lmdeploy/pytorch/backends/ascend/__init__.py @@ -1,2 +1,2 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .layer_backend import AscendLayersBackend # noqa: F401 +from .op_backend import AscendLayersBackend # noqa: F401 diff --git a/lmdeploy/pytorch/backends/ascend/layer_backend.py b/lmdeploy/pytorch/backends/ascend/op_backend.py similarity index 96% rename from lmdeploy/pytorch/backends/ascend/layer_backend.py rename to lmdeploy/pytorch/backends/ascend/op_backend.py index 0e2ef4663..c11fa06f6 100644 --- a/lmdeploy/pytorch/backends/ascend/layer_backend.py +++ b/lmdeploy/pytorch/backends/ascend/op_backend.py @@ -3,7 +3,7 @@ import torch -from ..base import LayerType +from ..base import OpType from ..default import DefaultLayersBackend @@ -14,8 +14,8 @@ def get_name() -> str: raise 'ascend' @classmethod - def get_layer_impl_builder(cls, layer_type: LayerType): - if layer_type == LayerType.Attention: + def get_layer_impl_builder(cls, layer_type: OpType): + if layer_type == OpType.Attention: from .attention import AscendAttentionImpl return AscendAttentionImpl else: diff --git a/lmdeploy/pytorch/backends/base.py b/lmdeploy/pytorch/backends/base.py index a690558f8..6b2dd6ff7 100644 --- a/lmdeploy/pytorch/backends/base.py +++ b/lmdeploy/pytorch/backends/base.py @@ -10,7 +10,7 @@ from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig -class LayerType(Enum): +class OpType(Enum): """Layer type enumerate.""" Attention = auto() Linear = auto() @@ -40,7 +40,7 @@ def get_name() -> str: @classmethod @abstractmethod - def get_layer_impl_builder(cls, layer_type: LayerType): + def get_layer_impl_builder(cls, layer_type: OpType): """get builder of given layer type.""" raise NotImplementedError diff --git a/lmdeploy/pytorch/backends/cuda/__init__.py b/lmdeploy/pytorch/backends/cuda/__init__.py index 4f598c098..8abaf35f6 100644 --- a/lmdeploy/pytorch/backends/cuda/__init__.py +++ b/lmdeploy/pytorch/backends/cuda/__init__.py @@ -1,2 +1,2 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .layer_backend import CudaLayersBackend # noqa: F401 +from .op_backend import CudaLayersBackend # noqa: F401 diff --git a/lmdeploy/pytorch/backends/cuda/layer_backend.py b/lmdeploy/pytorch/backends/cuda/op_backend.py similarity index 87% rename from lmdeploy/pytorch/backends/cuda/layer_backend.py rename to lmdeploy/pytorch/backends/cuda/op_backend.py index c14b9ffa1..5833d42d0 100644 --- a/lmdeploy/pytorch/backends/cuda/layer_backend.py +++ b/lmdeploy/pytorch/backends/cuda/op_backend.py @@ -6,7 +6,7 @@ from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig from lmdeploy.utils import get_logger -from ..base import LayerType +from ..base import OpType from ..default import DefaultLayersBackend logger = get_logger('lmdeploy') @@ -21,30 +21,30 @@ def get_name() -> str: raise 'cuda' @classmethod - def get_layer_impl_builder(cls, layer_type: LayerType): + def get_layer_impl_builder(cls, layer_type: OpType): """get cuda layer builder.""" - if layer_type == LayerType.Attention: + if layer_type == OpType.Attention: from .attention import TritonAttentionBuilder return TritonAttentionBuilder - elif layer_type == LayerType.ApplyRotaryEmb: + elif layer_type == OpType.ApplyRotaryEmb: from .apply_rotary_emb import TritonApplyRotaryEmbBuilder return TritonApplyRotaryEmbBuilder - elif layer_type == LayerType.RMSNorm: + elif layer_type == OpType.RMSNorm: from .norm import TritonRMSNormBuilder return TritonRMSNormBuilder - elif layer_type == LayerType.SLoRA: + elif layer_type == OpType.SLoRA: from .slora import TritonSLoRABuilder return TritonSLoRABuilder - elif layer_type == LayerType.LinearW8A8: + elif layer_type == OpType.LinearW8A8: from .qmodules import TritonLinearW8A8Builder return TritonLinearW8A8Builder - elif layer_type == LayerType.RMSNormW8A8: + elif layer_type == OpType.RMSNormW8A8: from .qmodules import TritonRMSNormBuilder return TritonRMSNormBuilder - elif layer_type == LayerType.MultinomialSampling: + elif layer_type == OpType.MultinomialSampling: from .multinomial_sampling import TritonMultinomialSamplingBuilder return TritonMultinomialSamplingBuilder - elif layer_type == LayerType.LinearW4A16: + elif layer_type == OpType.LinearW4A16: from awq.modules.linear.gemm import AWQ_INSTALLED if AWQ_INSTALLED: from .awq_modules import AwqLinearW4A16Builder @@ -53,7 +53,7 @@ def get_layer_impl_builder(cls, layer_type: LayerType): logger.debug( f'Op {layer_type} fallback to default implementation.') return super().get_layer_impl_builder(layer_type) - elif layer_type == LayerType.FusedMoE: + elif layer_type == OpType.FusedMoE: from .moe import TritonFusedMoEBuilder return TritonFusedMoEBuilder else: diff --git a/lmdeploy/pytorch/backends/default/__init__.py b/lmdeploy/pytorch/backends/default/__init__.py index 00aa5e8c3..8ec91a15d 100644 --- a/lmdeploy/pytorch/backends/default/__init__.py +++ b/lmdeploy/pytorch/backends/default/__init__.py @@ -1,2 +1,2 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .layer_backend import DefaultLayersBackend # noqa: F401 +from .op_backend import DefaultLayersBackend # noqa: F401 diff --git a/lmdeploy/pytorch/backends/default/layer_backend.py b/lmdeploy/pytorch/backends/default/op_backend.py similarity index 77% rename from lmdeploy/pytorch/backends/default/layer_backend.py rename to lmdeploy/pytorch/backends/default/op_backend.py index 3bfb377d1..8e01e9aaf 100644 --- a/lmdeploy/pytorch/backends/default/layer_backend.py +++ b/lmdeploy/pytorch/backends/default/op_backend.py @@ -3,7 +3,7 @@ import torch -from ..base import LayersBackend, LayerType +from ..base import LayersBackend, OpType class DefaultLayersBackend(LayersBackend): @@ -13,36 +13,36 @@ def get_name() -> str: raise 'default' @classmethod - def get_layer_impl_builder(cls, layer_type: LayerType): + def get_layer_impl_builder(cls, layer_type: OpType): """get builder of given layer type.""" - if layer_type == LayerType.Linear: + if layer_type == OpType.Linear: from .linear import DefaultLinearBuilder return DefaultLinearBuilder - elif layer_type == LayerType.RotaryEmbedding: + elif layer_type == OpType.RotaryEmbedding: from .rotary_embedding import DefaultRotaryEmbeddingBuilder return DefaultRotaryEmbeddingBuilder - elif layer_type == LayerType.ApplyRotaryEmb: + elif layer_type == OpType.ApplyRotaryEmb: from .apply_rotary_emb import DefaultApplyRotaryEmbBuilder return DefaultApplyRotaryEmbBuilder - elif layer_type == LayerType.SiluAndMul: + elif layer_type == OpType.SiluAndMul: from .activation import DefaultSiluAndMulBuilder return DefaultSiluAndMulBuilder - elif layer_type == LayerType.GeluAndMul: + elif layer_type == OpType.GeluAndMul: from .activation import DefaultGeluAndMulBuilder return DefaultGeluAndMulBuilder - elif layer_type == LayerType.RMSNorm: + elif layer_type == OpType.RMSNorm: from .norm import DefaultRMSNormBuilder return DefaultRMSNormBuilder - elif layer_type == LayerType.LayerNorm: + elif layer_type == OpType.LayerNorm: from .norm import DefaultLayerNormBuilder return DefaultLayerNormBuilder - elif layer_type == LayerType.MultinomialSampling: + elif layer_type == OpType.MultinomialSampling: from .multinomial_sampling import DefaultMultinomialSamplingBuilder return DefaultMultinomialSamplingBuilder - elif layer_type == LayerType.LinearW4A16: + elif layer_type == OpType.LinearW4A16: from .awq_modules import DefaultLinearW4A16Builder return DefaultLinearW4A16Builder - elif layer_type == LayerType.SoftmaxTopK: + elif layer_type == OpType.SoftmaxTopK: from .moe import DefaultSoftmaxTopKBuilder return DefaultSoftmaxTopKBuilder else: diff --git a/lmdeploy/pytorch/models/llama.py b/lmdeploy/pytorch/models/llama.py index 0261a6a55..5c015c8f3 100644 --- a/lmdeploy/pytorch/models/llama.py +++ b/lmdeploy/pytorch/models/llama.py @@ -237,30 +237,34 @@ def __init__(self, device=device) # build rotary embedding in LlamaModel + rope_dim = config.hidden_size // config.num_attention_heads + rope_max_pos_emb = config.max_position_embeddings + rope_base = config.rope_theta + scaling_factor = 1.0 + llama3_params = None rope_scaling = config.rope_scaling if rope_scaling is None: emb_type = EmbeddingType.LinearScaling else: + if 'scaling_factor' in rope_scaling: + scaling_factor = rope_scaling['scaling_factor'] + elif 'factor' in rope_scaling: + scaling_factor = rope_scaling['factor'] + rope_type = rope_scaling['rope_type'] if rope_type == 'dynamic': emb_type = EmbeddingType.DynamicNTKScaling + if rope_type == 'linear': + emb_type = EmbeddingType.LinearScaling elif rope_type == 'llama3': emb_type = EmbeddingType.Llama3 - else: - raise RuntimeError(f'Unsupported rope type: {rope_type}') - - rope_dim = config.hidden_size // config.num_attention_heads - rope_max_pos_emb = config.max_position_embeddings - rope_base = config.rope_theta - scaling_factor = 1.0 - llama3_params = None - if rope_scaling is not None: - scaling_factor = rope_scaling.get('scaling_factor', scaling_factor) - if emb_type == EmbeddingType.Llama3: low_freq_factor = rope_scaling.get('low_freq_factor', 1.0) high_freq_factor = rope_scaling.get('high_freq_factor', 1.0) llama3_params = Llama3Parameters(low_freq_factor, high_freq_factor) + else: + raise RuntimeError(f'Unsupported rope type: {rope_type}') + self.rotary_emb = build_rotary_embedding( rope_dim, rope_max_pos_emb, diff --git a/lmdeploy/pytorch/models/qwen2.py b/lmdeploy/pytorch/models/qwen2.py index 2001c365d..68678e02b 100644 --- a/lmdeploy/pytorch/models/qwen2.py +++ b/lmdeploy/pytorch/models/qwen2.py @@ -349,6 +349,11 @@ def forward( logits = logits.float() return logits + def update_weights(self): + """update weights.""" + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + def get_input_embeddings(self): """get input embeddings.""" return self.model.get_input_embeddings() diff --git a/lmdeploy/pytorch/nn/activation.py b/lmdeploy/pytorch/nn/activation.py index 6c73d9752..bda606f50 100644 --- a/lmdeploy/pytorch/nn/activation.py +++ b/lmdeploy/pytorch/nn/activation.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from torch import Tensor, nn -from ..backends import LayerType, get_backend +from ..backends import OpType, get_backend class SiluAndMul(nn.Module): @@ -10,7 +10,7 @@ class SiluAndMul(nn.Module): def __init__(self, inplace: bool = True): super().__init__() backend = get_backend() - builder = backend.get_layer_impl_builder(LayerType.SiluAndMul) + builder = backend.get_layer_impl_builder(OpType.SiluAndMul) self.impl = builder.build(inplace) def forward(self, x: Tensor): @@ -24,7 +24,7 @@ class GeluAndMul(nn.Module): def __init__(self, approximate: str = 'none'): super().__init__() backend = get_backend() - builder = backend.get_layer_impl_builder(LayerType.GeluAndMul) + builder = backend.get_layer_impl_builder(OpType.GeluAndMul) self.impl = builder.build(approximate) def forward(self, x: Tensor): diff --git a/lmdeploy/pytorch/nn/attention.py b/lmdeploy/pytorch/nn/attention.py index c290bc358..8f18da4dc 100644 --- a/lmdeploy/pytorch/nn/attention.py +++ b/lmdeploy/pytorch/nn/attention.py @@ -2,7 +2,7 @@ import torch from torch import nn -from ..backends import LayerType, get_backend +from ..backends import OpType, get_backend from ..backends.attention import AttentionMetadata from .utils import get_distribute_size, get_world_rank @@ -28,8 +28,7 @@ def __init__( num_heads, num_kv_heads, replicate_kv) layer_backend = get_backend() - impl_builder = layer_backend.get_layer_impl_builder( - LayerType.Attention) + impl_builder = layer_backend.get_layer_impl_builder(OpType.Attention) self.impl = impl_builder.build( num_heads=num_heads, diff --git a/lmdeploy/pytorch/nn/linear.py b/lmdeploy/pytorch/nn/linear.py index b9daf0140..94a1f978c 100644 --- a/lmdeploy/pytorch/nn/linear.py +++ b/lmdeploy/pytorch/nn/linear.py @@ -9,7 +9,7 @@ default_weight_loader from lmdeploy.utils import get_logger -from ..backends import LayerType, get_backend +from ..backends import OpType, get_backend from ..backends.slora import AdapterInfo from .utils import div_up, get_distribute_size, get_world_rank @@ -93,7 +93,7 @@ def __init__(self, base_slice=base_slice, max_rank=max_rank, ) - impl_builder = get_backend().get_layer_impl_builder(LayerType.SLoRA) + impl_builder = get_backend().get_layer_impl_builder(OpType.SLoRA) self.impl = impl_builder.build() self.is_tp = is_tp self.ctx_mgr = ctx_mgr @@ -133,8 +133,7 @@ def __init__( in_features, out_features, w_bit, group_size, colwise) qweight, scales, qzeros, bias = self.create_weights( in_features, out_features, w_bit, group_size, bias, dtype, device) - impl_builder = get_backend().get_layer_impl_builder( - LayerType.LinearW4A16) + impl_builder = get_backend().get_layer_impl_builder(OpType.LinearW4A16) self.impl = impl_builder.build(in_features, out_features, w_bit, @@ -501,8 +500,7 @@ def __init__( if is_tp: in_features, out_features = self._get_io_features( in_features, out_features, colwise) - impl_builder = get_backend().get_layer_impl_builder( - LayerType.LinearW8A8) + impl_builder = get_backend().get_layer_impl_builder(OpType.LinearW8A8) self.impl = impl_builder.build(in_features, out_features, bias is not None, @@ -759,7 +757,7 @@ def __init__( if is_tp: in_features, out_features = self._get_io_features( in_features, out_features, colwise) - impl_builder = get_backend().get_layer_impl_builder(LayerType.Linear) + impl_builder = get_backend().get_layer_impl_builder(OpType.Linear) self.impl = impl_builder.build(in_features, out_features, bias is not None, diff --git a/lmdeploy/pytorch/nn/moe.py b/lmdeploy/pytorch/nn/moe.py index ffd68ddea..378e8561d 100644 --- a/lmdeploy/pytorch/nn/moe.py +++ b/lmdeploy/pytorch/nn/moe.py @@ -5,7 +5,7 @@ import torch.distributed as dist from torch import nn -from ..backends import LayerType, get_backend +from ..backends import OpType, get_backend from .utils import get_world_rank @@ -15,8 +15,7 @@ class SoftmaxTopK(nn.Module): def __init__(self, top_k: int, dim: int = -1): super().__init__() self.top_k = top_k - impl_builder = get_backend().get_layer_impl_builder( - LayerType.SoftmaxTopK) + impl_builder = get_backend().get_layer_impl_builder(OpType.SoftmaxTopK) self.impl = impl_builder.build(top_k, dim) def forward(self, x: torch.Tensor): @@ -43,7 +42,7 @@ def __init__(self, dtype = torch.float16 hidden_dim, ffn_dim = self._update_args(hidden_dim, ffn_dim) - impl_builder = get_backend().get_layer_impl_builder(LayerType.FusedMoE) + impl_builder = get_backend().get_layer_impl_builder(OpType.FusedMoE) self.impl = impl_builder.build(top_k, renormalize) gate_up_weights, down_weights = self.create_weights(hidden_dim, diff --git a/lmdeploy/pytorch/nn/multinomial_sampling.py b/lmdeploy/pytorch/nn/multinomial_sampling.py index 562636bee..f73350cca 100644 --- a/lmdeploy/pytorch/nn/multinomial_sampling.py +++ b/lmdeploy/pytorch/nn/multinomial_sampling.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch -from ..backends import LayerType, get_backend +from ..backends import OpType, get_backend def multinomial_sampling(scores: torch.Tensor, @@ -10,5 +10,5 @@ def multinomial_sampling(scores: torch.Tensor, indices: torch.Tensor = None): """multinomial sampling op.""" impl_builder = get_backend().get_layer_impl_builder( - LayerType.MultinomialSampling) + OpType.MultinomialSampling) return impl_builder.build()(scores, seeds, offsets, indices) diff --git a/lmdeploy/pytorch/nn/norm.py b/lmdeploy/pytorch/nn/norm.py index f7b40b48a..ef244ff73 100644 --- a/lmdeploy/pytorch/nn/norm.py +++ b/lmdeploy/pytorch/nn/norm.py @@ -4,7 +4,7 @@ import torch from torch import nn -from ..backends import LayerType, get_backend +from ..backends import OpType, get_backend def _is_w8a8(quant_config: Any): @@ -31,9 +31,9 @@ def __init__(self, super().__init__() backend = get_backend() if _is_w8a8(quant_config): - builder = backend.get_layer_impl_builder(LayerType.RMSNormW8A8) + builder = backend.get_layer_impl_builder(OpType.RMSNormW8A8) else: - builder = backend.get_layer_impl_builder(LayerType.RMSNorm) + builder = backend.get_layer_impl_builder(OpType.RMSNorm) self.register_parameter('weight', self.create_weight(hidden_size, dtype, device)) self.impl = builder.build(hidden_size, eps) @@ -69,7 +69,7 @@ def __init__(self, device: torch.device = None): super().__init__() backend = get_backend() - builder = backend.get_layer_impl_builder(LayerType.LayerNorm) + builder = backend.get_layer_impl_builder(OpType.LayerNorm) weight, bias = self.create_weight(hidden_size, bias, dtype, device) self.register_parameter('weight', weight) self.register_parameter('bias', bias) diff --git a/lmdeploy/pytorch/nn/rotary_embedding.py b/lmdeploy/pytorch/nn/rotary_embedding.py index 911e31ff9..f7891c600 100644 --- a/lmdeploy/pytorch/nn/rotary_embedding.py +++ b/lmdeploy/pytorch/nn/rotary_embedding.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from torch import Tensor, nn -from ..backends import LayerType, get_backend +from ..backends import OpType, get_backend from ..backends.rotary_embedding import (EmbeddingType, Llama3Parameters, LongRoPEScalingParameters, YarnParameters) @@ -19,7 +19,7 @@ def build_rotary_embedding( """build rotary embedding op.""" backend = get_backend() - builder = backend.get_layer_impl_builder(LayerType.RotaryEmbedding) + builder = backend.get_layer_impl_builder(OpType.RotaryEmbedding) return builder.build(dim, max_position_embeddings, base, @@ -36,7 +36,7 @@ class ApplyRotaryEmb(nn.Module): def __init__(self): super().__init__() backend = get_backend() - builder = backend.get_layer_impl_builder(LayerType.ApplyRotaryEmb) + builder = backend.get_layer_impl_builder(OpType.ApplyRotaryEmb) self.impl = builder.build() def forward(self, From 3327d6dcf5d0505988ae3bc83cc4e650fd4fc61e Mon Sep 17 00:00:00 2001 From: grimoire Date: Thu, 5 Sep 2024 11:46:12 +0800 Subject: [PATCH 71/89] fix internlm2 awq --- lmdeploy/pytorch/models/internlm2.py | 13 +--- lmdeploy/pytorch/nn/linear.py | 88 +++++++++++++++++++++++++--- 2 files changed, 81 insertions(+), 20 deletions(-) diff --git a/lmdeploy/pytorch/models/internlm2.py b/lmdeploy/pytorch/models/internlm2.py index 2526f6709..398463187 100644 --- a/lmdeploy/pytorch/models/internlm2.py +++ b/lmdeploy/pytorch/models/internlm2.py @@ -399,13 +399,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ('.gate_up_proj', '.w3', 1), ] - config = self.config - num_heads = config.num_attention_heads - num_key_value_heads = config.num_key_value_heads - hidden_size = config.hidden_size - head_dim = hidden_size // num_heads - num_key_value_groups = num_heads // num_key_value_heads - params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if 'rotary_emb.inv_freq' in name: @@ -422,12 +415,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): break else: if '.wqkv' in name: - loaded_weight = loaded_weight.unflatten( - 0, (-1, num_key_value_groups + 2, head_dim)) - q = loaded_weight[:, :num_key_value_groups].flatten(0, 2) - k = loaded_weight[:, -2].flatten(0, 1) - v = loaded_weight[:, -1].flatten(0, 1) param = params_dict[name] + q, k, v = param.weight_spliter(loaded_weight, layout='hgd') load_weight(param, q, shard_id='q') load_weight(param, k, shard_id='k') load_weight(param, v, shard_id='v') diff --git a/lmdeploy/pytorch/nn/linear.py b/lmdeploy/pytorch/nn/linear.py index 94a1f978c..407479a87 100644 --- a/lmdeploy/pytorch/nn/linear.py +++ b/lmdeploy/pytorch/nn/linear.py @@ -15,6 +15,14 @@ logger = get_logger('lmdeploy') +QKV_SPLIT_LAYOUTS = ['default', 'hgd'] + + +def _check_qkv_split_layout(layout: str): + if layout not in QKV_SPLIT_LAYOUTS: + raise RuntimeError(f'Expect qkv split layout in {QKV_SPLIT_LAYOUTS}, ' + f'but get: {layout}') + def _chunk_align(weight: torch.Tensor, chunks: int, dim: int, align: int): """chunk aligned.""" @@ -469,13 +477,45 @@ def _update_all_out_features(self, all_out_features: List[int], w_bit: int, """update all out features.""" return all_out_features - def weight_spliter_wz(self, loaded_weight: torch.Tensor): + def weight_spliter_wz(self, + loaded_weight: torch.Tensor, + layout: str = 'default'): """weight spliter.""" - return loaded_weight.split(self.qkv_split_section_wz, dim=1) + _check_qkv_split_layout(layout) + if layout == 'default': + return loaded_weight.split(self.qkv_split_section_wz, dim=1) + elif layout == 'hgd': + assert self.head_size == self.head_size_v + heads = [sec // self.head_size for sec in self.qkv_split_section_s] + kv_heads = heads[-1] + loaded_weight = loaded_weight.unflatten( + 1, (kv_heads, -1, self.head_size // self.elem_per_int)) + q = loaded_weight[:, :, :-2].flatten(1, 3) + k = loaded_weight[:, :, -2].flatten(1, 2) + v = loaded_weight[:, :, -1].flatten(1, 2) + return q, k, v + else: + raise RuntimeError(f'Unsupported layout: {layout}') - def weight_spliter_s(self, loaded_weight: torch.Tensor): + def weight_spliter_s(self, + loaded_weight: torch.Tensor, + layout: str = 'default'): """weight spliter.""" - return loaded_weight.split(self.qkv_split_section_s, dim=-1) + _check_qkv_split_layout(layout) + if layout == 'default': + return loaded_weight.split(self.qkv_split_section_s, dim=-1) + elif layout == 'hgd': + assert self.head_size == self.head_size_v + heads = [sec // self.head_size for sec in self.qkv_split_section_s] + kv_heads = heads[-1] + loaded_weight = loaded_weight.unflatten( + 1, (kv_heads, -1, self.head_size)) + q = loaded_weight[:, :, :-2].flatten(1, 3) + k = loaded_weight[:, :, -2].flatten(1, 2) + v = loaded_weight[:, :, -1].flatten(1, 2) + return q, k, v + else: + raise RuntimeError(f'Unsupported layout: {layout}') class W8A8Linear(nn.Module): @@ -728,9 +768,25 @@ def _update_all_out_features(self, all_out_features: List[int], """update all out features.""" return all_out_features - def weight_spliter(self, loaded_weight: torch.Tensor): + def weight_spliter(self, + loaded_weight: torch.Tensor, + layout: str = 'default'): """weight spliter.""" - return loaded_weight.split(self.qkv_split_section, dim=0) + _check_qkv_split_layout(layout) + if layout == 'default': + return loaded_weight.split(self.qkv_split_section, dim=0) + elif layout == 'hgd': + assert self.head_size == self.head_size_v + heads = [sec // self.head_size for sec in self.qkv_split_section] + kv_heads = heads[-1] + loaded_weight = loaded_weight.unflatten( + 0, (kv_heads, -1, self.head_size)) + q = loaded_weight[:, :-2].flatten(0, 2) + k = loaded_weight[:, -2].flatten(0, 1) + v = loaded_weight[:, -1].flatten(0, 1) + return q, k, v + else: + raise RuntimeError(f'Unsupported layout: {layout}') class BaseLinear(nn.Module): @@ -995,9 +1051,25 @@ def weight_loader(self, param: torch.nn.Parameter, self.head_size_v)[rank] param_w.copy_(loaded_weight) - def weight_spliter(self, loaded_weight: torch.Tensor): + def weight_spliter(self, + loaded_weight: torch.Tensor, + layout: str = 'default'): """weight spliter.""" - return loaded_weight.split(self.qkv_split_section, dim=0) + _check_qkv_split_layout(layout) + if layout == 'default': + return loaded_weight.split(self.qkv_split_section, dim=0) + elif layout == 'hgd': + assert self.head_size == self.head_size_v + heads = [sec // self.head_size for sec in self.qkv_split_section] + kv_heads = heads[-1] + loaded_weight = loaded_weight.unflatten( + 0, (kv_heads, -1, self.head_size)) + q = loaded_weight[:, :-2].flatten(0, 2) + k = loaded_weight[:, -2].flatten(0, 1) + v = loaded_weight[:, -1].flatten(0, 1) + return q, k, v + else: + raise RuntimeError(f'Unsupported layout: {layout}') def build_linear(in_features: int, From fe18df45e6b6eda91225abec5027937a5d091017 Mon Sep 17 00:00:00 2001 From: grimoire Date: Thu, 5 Sep 2024 16:34:50 +0800 Subject: [PATCH 72/89] rename --- lmdeploy/pytorch/backends/ascend/__init__.py | 2 +- .../pytorch/backends/ascend/op_backend.py | 4 ++-- lmdeploy/pytorch/backends/base.py | 2 +- lmdeploy/pytorch/backends/cuda/__init__.py | 2 +- lmdeploy/pytorch/backends/cuda/op_backend.py | 4 ++-- lmdeploy/pytorch/backends/default/__init__.py | 2 +- .../pytorch/backends/default/op_backend.py | 4 ++-- .../backends/default/rotary_embedding.py | 19 ++++++++--------- lmdeploy/pytorch/backends/rotary_embedding.py | 4 ++-- lmdeploy/pytorch/backends/selector.py | 8 +++---- lmdeploy/pytorch/models/baichuan.py | 6 +++--- lmdeploy/pytorch/models/chatglm2.py | 6 +++--- lmdeploy/pytorch/models/cogvlm.py | 6 +++--- lmdeploy/pytorch/models/dbrx.py | 6 +++--- lmdeploy/pytorch/models/deepseek.py | 10 ++++----- lmdeploy/pytorch/models/deepseek_v2.py | 10 ++++----- lmdeploy/pytorch/models/falcon.py | 8 +++---- lmdeploy/pytorch/models/gemma.py | 10 ++++----- lmdeploy/pytorch/models/internlm2.py | 10 ++++----- lmdeploy/pytorch/models/llama.py | 12 +++++------ lmdeploy/pytorch/models/mistral.py | 6 +++--- lmdeploy/pytorch/models/mixtral.py | 6 +++--- lmdeploy/pytorch/models/phi3.py | 8 +++---- lmdeploy/pytorch/models/phi3_moe.py | 7 +++---- lmdeploy/pytorch/models/qwen.py | 6 +++--- lmdeploy/pytorch/models/qwen2.py | 6 +++--- lmdeploy/pytorch/models/qwen2_moe.py | 6 +++--- lmdeploy/pytorch/models/starcoder2.py | 6 +++--- lmdeploy/pytorch/nn/__init__.py | 2 +- lmdeploy/pytorch/nn/rotary_embedding.py | 21 +++++++++---------- 30 files changed, 103 insertions(+), 106 deletions(-) diff --git a/lmdeploy/pytorch/backends/ascend/__init__.py b/lmdeploy/pytorch/backends/ascend/__init__.py index 94bad085f..29c756ae5 100644 --- a/lmdeploy/pytorch/backends/ascend/__init__.py +++ b/lmdeploy/pytorch/backends/ascend/__init__.py @@ -1,2 +1,2 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .op_backend import AscendLayersBackend # noqa: F401 +from .op_backend import AscendOpsBackend # noqa: F401 diff --git a/lmdeploy/pytorch/backends/ascend/op_backend.py b/lmdeploy/pytorch/backends/ascend/op_backend.py index c11fa06f6..d03153187 100644 --- a/lmdeploy/pytorch/backends/ascend/op_backend.py +++ b/lmdeploy/pytorch/backends/ascend/op_backend.py @@ -4,10 +4,10 @@ import torch from ..base import OpType -from ..default import DefaultLayersBackend +from ..default import DefaultOpsBackend -class AscendLayersBackend(DefaultLayersBackend): +class AscendOpsBackend(DefaultOpsBackend): @staticmethod def get_name() -> str: diff --git a/lmdeploy/pytorch/backends/base.py b/lmdeploy/pytorch/backends/base.py index 6b2dd6ff7..b76e3594c 100644 --- a/lmdeploy/pytorch/backends/base.py +++ b/lmdeploy/pytorch/backends/base.py @@ -29,7 +29,7 @@ class OpType(Enum): FusedMoE = auto() -class LayersBackend(ABC): +class OpsBackend(ABC): """Layer backend abstract.""" @staticmethod diff --git a/lmdeploy/pytorch/backends/cuda/__init__.py b/lmdeploy/pytorch/backends/cuda/__init__.py index 8abaf35f6..2703c174b 100644 --- a/lmdeploy/pytorch/backends/cuda/__init__.py +++ b/lmdeploy/pytorch/backends/cuda/__init__.py @@ -1,2 +1,2 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .op_backend import CudaLayersBackend # noqa: F401 +from .op_backend import CudaOpsBackend # noqa: F401 diff --git a/lmdeploy/pytorch/backends/cuda/op_backend.py b/lmdeploy/pytorch/backends/cuda/op_backend.py index 5833d42d0..2d79b4c59 100644 --- a/lmdeploy/pytorch/backends/cuda/op_backend.py +++ b/lmdeploy/pytorch/backends/cuda/op_backend.py @@ -7,12 +7,12 @@ from lmdeploy.utils import get_logger from ..base import OpType -from ..default import DefaultLayersBackend +from ..default import DefaultOpsBackend logger = get_logger('lmdeploy') -class CudaLayersBackend(DefaultLayersBackend): +class CudaOpsBackend(DefaultOpsBackend): """cuda layer backend.""" @staticmethod diff --git a/lmdeploy/pytorch/backends/default/__init__.py b/lmdeploy/pytorch/backends/default/__init__.py index 8ec91a15d..69acf9be9 100644 --- a/lmdeploy/pytorch/backends/default/__init__.py +++ b/lmdeploy/pytorch/backends/default/__init__.py @@ -1,2 +1,2 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .op_backend import DefaultLayersBackend # noqa: F401 +from .op_backend import DefaultOpsBackend # noqa: F401 diff --git a/lmdeploy/pytorch/backends/default/op_backend.py b/lmdeploy/pytorch/backends/default/op_backend.py index 8e01e9aaf..b9d268f8f 100644 --- a/lmdeploy/pytorch/backends/default/op_backend.py +++ b/lmdeploy/pytorch/backends/default/op_backend.py @@ -3,10 +3,10 @@ import torch -from ..base import LayersBackend, OpType +from ..base import OpsBackend, OpType -class DefaultLayersBackend(LayersBackend): +class DefaultOpsBackend(OpsBackend): @staticmethod def get_name() -> str: diff --git a/lmdeploy/pytorch/backends/default/rotary_embedding.py b/lmdeploy/pytorch/backends/default/rotary_embedding.py index f0567e077..1fced2568 100644 --- a/lmdeploy/pytorch/backends/default/rotary_embedding.py +++ b/lmdeploy/pytorch/backends/default/rotary_embedding.py @@ -4,10 +4,9 @@ import torch from torch import nn -from ..rotary_embedding import (EmbeddingType, Llama3Parameters, - LongRoPEScalingParameters, - RotaryEmbeddingBuilder, RotaryEmbeddingImpl, - YarnParameters) +from ..rotary_embedding import (Llama3Parameters, LongRoPEScalingParameters, + RopeType, RotaryEmbeddingBuilder, + RotaryEmbeddingImpl, YarnParameters) def _rotary_embedding_fwd(position_ids: torch.Tensor, @@ -324,26 +323,26 @@ def build( yarn_params: YarnParameters = None, longrope_params: LongRoPEScalingParameters = None, llama3_params: Llama3Parameters = None, - emb_type: EmbeddingType = EmbeddingType.Default, + emb_type: RopeType = RopeType.Default, ): """build.""" - if emb_type in (EmbeddingType.Default, EmbeddingType.LinearScaling): + if emb_type in (RopeType.Default, RopeType.LinearScaling): return RotaryEmbeddingImpl(dim, base, scaling_factor) - elif emb_type == EmbeddingType.DynamicNTKScaling: + elif emb_type == RopeType.DynamicNTKScaling: return LlamaDynamicNTKScalingRotaryEmbedding( dim, base, scaling_factor, max_position_embeddings) - elif emb_type == EmbeddingType.Llama3: + elif emb_type == RopeType.Llama3: return Llama3RotaryEmbeddingImpl(dim, base, scaling_factor, llama3_params.low_freq_factor, llama3_params.high_freq_factor, max_position_embeddings) - elif emb_type == EmbeddingType.Yarn: + elif emb_type == RopeType.Yarn: return YarnRotaryEmbeddingImpl(dim, base, scaling_factor, max_position_embeddings, yarn_params=yarn_params) - elif emb_type == EmbeddingType.LongRoPEScaling: + elif emb_type == RopeType.LongRoPEScaling: return LongRoPEScalingRotaryEmbeddingImpl( dim, base, diff --git a/lmdeploy/pytorch/backends/rotary_embedding.py b/lmdeploy/pytorch/backends/rotary_embedding.py index a01dd9696..6fa6abbdf 100644 --- a/lmdeploy/pytorch/backends/rotary_embedding.py +++ b/lmdeploy/pytorch/backends/rotary_embedding.py @@ -5,7 +5,7 @@ from typing import List -class EmbeddingType(Enum): +class RopeType(Enum): """rotary embedding type.""" Default = auto() LinearScaling = auto() @@ -63,7 +63,7 @@ def build( yarn_params: YarnParameters = None, longrope_params: LongRoPEScalingParameters = None, llama3_params: Llama3Parameters = None, - emb_type: EmbeddingType = EmbeddingType.Default, + emb_type: RopeType = RopeType.Default, ): """build.""" raise NotImplementedError diff --git a/lmdeploy/pytorch/backends/selector.py b/lmdeploy/pytorch/backends/selector.py index 44cc5a289..181dd1535 100644 --- a/lmdeploy/pytorch/backends/selector.py +++ b/lmdeploy/pytorch/backends/selector.py @@ -10,10 +10,10 @@ def get_backend(): device_type = device_ctx.device_type if device_type == 'cuda': - from .cuda import CudaLayersBackend - return CudaLayersBackend + from .cuda import CudaOpsBackend + return CudaOpsBackend if device_type == 'ascend': - from .ascend import AscendLayersBackend - return AscendLayersBackend + from .ascend import AscendOpsBackend + return AscendOpsBackend else: raise RuntimeError(f'Unsupported device type: {device_type}') diff --git a/lmdeploy/pytorch/models/baichuan.py b/lmdeploy/pytorch/models/baichuan.py index 575a82c83..7e8727dcc 100644 --- a/lmdeploy/pytorch/models/baichuan.py +++ b/lmdeploy/pytorch/models/baichuan.py @@ -5,8 +5,8 @@ from torch import nn from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, EmbeddingType, - RMSNorm, SiluAndMul, build_rotary_embedding) +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType, + SiluAndMul, build_rotary_embedding) from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear) from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight @@ -246,7 +246,7 @@ def __init__(self, self.is_13b = _is_baichuan_13b(config) if not self.is_13b: # build rotary embedding in LlamaModel - emb_type = EmbeddingType.LinearScaling + emb_type = RopeType.LinearScaling rope_dim = config.hidden_size // config.num_attention_heads rope_max_pos_emb = config.max_position_embeddings rope_base = 10000 diff --git a/lmdeploy/pytorch/models/chatglm2.py b/lmdeploy/pytorch/models/chatglm2.py index 287a90ffa..5ab967c4b 100644 --- a/lmdeploy/pytorch/models/chatglm2.py +++ b/lmdeploy/pytorch/models/chatglm2.py @@ -6,8 +6,8 @@ from transformers.configuration_utils import PretrainedConfig from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, EmbeddingType, - RMSNorm, SiluAndMul, build_rotary_embedding) +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType, + SiluAndMul, build_rotary_embedding) from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear) from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight @@ -414,7 +414,7 @@ def __init__(self, self.embedding = Embedding(config, dtype=dtype, device=device) # build rotary embedding - emb_type = EmbeddingType.LinearScaling + emb_type = RopeType.LinearScaling rotary_dim = (config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels) rope_max_pos_emb = 1 << 20 diff --git a/lmdeploy/pytorch/models/cogvlm.py b/lmdeploy/pytorch/models/cogvlm.py index 76c237fb1..f4a038863 100644 --- a/lmdeploy/pytorch/models/cogvlm.py +++ b/lmdeploy/pytorch/models/cogvlm.py @@ -7,8 +7,8 @@ from transformers.configuration_utils import PretrainedConfig from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, EmbeddingType, - RMSNorm, SiluAndMul, build_rotary_embedding) +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType, + SiluAndMul, build_rotary_embedding) from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear) from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight @@ -360,7 +360,7 @@ def __init__(self, device=device) # build rotary embedding - emb_type = EmbeddingType.LinearScaling + emb_type = RopeType.LinearScaling rope_dim = config.hidden_size // config.num_attention_heads rope_max_pos_emb = 2048 rope_base = 10000 diff --git a/lmdeploy/pytorch/models/dbrx.py b/lmdeploy/pytorch/models/dbrx.py index 919d3c552..fde661eb7 100644 --- a/lmdeploy/pytorch/models/dbrx.py +++ b/lmdeploy/pytorch/models/dbrx.py @@ -6,8 +6,8 @@ from transformers.configuration_utils import PretrainedConfig from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, EmbeddingType, - LayerNorm, build_rotary_embedding) +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, LayerNorm, + RopeType, build_rotary_embedding) from lmdeploy.pytorch.nn.linear import build_qkv_proj, build_rowwise_linear from lmdeploy.pytorch.nn.moe import FusedMoE, SoftmaxTopK from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight @@ -351,7 +351,7 @@ def __init__(self, device=device) # build rotary embedding - emb_type = EmbeddingType.LinearScaling + emb_type = RopeType.LinearScaling rope_dim = config.d_model // config.n_heads rope_max_pos_emb = config.max_seq_len rope_base = config.attn_config.rope_theta diff --git a/lmdeploy/pytorch/models/deepseek.py b/lmdeploy/pytorch/models/deepseek.py index 187d081e9..c605ee8f8 100644 --- a/lmdeploy/pytorch/models/deepseek.py +++ b/lmdeploy/pytorch/models/deepseek.py @@ -7,8 +7,8 @@ from transformers.configuration_utils import PretrainedConfig from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, EmbeddingType, - RMSNorm, SiluAndMul, build_rotary_embedding) +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType, + SiluAndMul, build_rotary_embedding) from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear) from lmdeploy.pytorch.nn.moe import FusedMoE, SoftmaxTopK @@ -344,14 +344,14 @@ def __init__(self, # build rotary embedding rope_scaling = getattr(config, 'rope_scaling', None) - emb_type = EmbeddingType.LinearScaling + emb_type = RopeType.LinearScaling scaling_factor = 1.0 if rope_scaling is not None: rope_type = rope_scaling['type'] if rope_type == 'linear': - emb_type = EmbeddingType.LinearScaling + emb_type = RopeType.LinearScaling if rope_type == 'dynamic': - emb_type = EmbeddingType.DynamicNTKScaling + emb_type = RopeType.DynamicNTKScaling else: raise RuntimeError(f'Unsupported rope type: {rope_type}') scaling_factor = rope_scaling.get('factor', scaling_factor) diff --git a/lmdeploy/pytorch/models/deepseek_v2.py b/lmdeploy/pytorch/models/deepseek_v2.py index 5a6bfb9f9..cceac2828 100644 --- a/lmdeploy/pytorch/models/deepseek_v2.py +++ b/lmdeploy/pytorch/models/deepseek_v2.py @@ -7,8 +7,8 @@ from torch import nn from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, EmbeddingType, - RMSNorm, SiluAndMul, build_rotary_embedding) +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType, + SiluAndMul, build_rotary_embedding) from lmdeploy.pytorch.nn.linear import (build_colwise_linear, build_merged_colwise_linear, build_rowwise_linear) @@ -529,7 +529,7 @@ def __init__(self, dtype=dtype, device=device) - emb_type = EmbeddingType.LinearScaling + emb_type = RopeType.LinearScaling rope_dim = config.qk_rope_head_dim rope_max_pos_emb = config.max_position_embeddings rope_base = config.rope_theta @@ -539,9 +539,9 @@ def __init__(self, scaling_type = config.rope_scaling['type'] scaling_factor = config.rope_scaling['factor'] if scaling_type == 'dynamic': - emb_type = EmbeddingType.DynamicNTKScaling + emb_type = RopeType.DynamicNTKScaling elif scaling_type == 'yarn': - emb_type = EmbeddingType.Yarn + emb_type = RopeType.Yarn rope_max_pos_emb = config.rope_scaling.get( 'original_max_position_embeddings', 4096) kwargs = { diff --git a/lmdeploy/pytorch/models/falcon.py b/lmdeploy/pytorch/models/falcon.py index d50d97fac..ae8b9f705 100644 --- a/lmdeploy/pytorch/models/falcon.py +++ b/lmdeploy/pytorch/models/falcon.py @@ -6,7 +6,7 @@ from transformers.configuration_utils import PretrainedConfig from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, EmbeddingType, +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RopeType, build_rotary_embedding) from lmdeploy.pytorch.nn.linear import (build_colwise_linear, build_qkv_proj, build_rowwise_linear) @@ -285,14 +285,14 @@ def __init__(self, scaling_factor = 1.0 if not hasattr(config, 'rope_scaling'): - emb_type = EmbeddingType.LinearScaling + emb_type = RopeType.LinearScaling else: rope_scaling = config.rope_scaling rope_type = rope_scaling['rope_type'] if rope_type == 'linear': - emb_type = EmbeddingType.LinearScaling + emb_type = RopeType.LinearScaling elif rope_type == 'dynamic': - emb_type = EmbeddingType.DynamicNTKScaling + emb_type = RopeType.DynamicNTKScaling else: raise RuntimeError(f'Unsupported rope type: {rope_type}') scaling_factor = rope_scaling.get('scaling_factor', scaling_factor) diff --git a/lmdeploy/pytorch/models/gemma.py b/lmdeploy/pytorch/models/gemma.py index a7538387f..e67feca51 100644 --- a/lmdeploy/pytorch/models/gemma.py +++ b/lmdeploy/pytorch/models/gemma.py @@ -7,8 +7,8 @@ from transformers.configuration_utils import PretrainedConfig from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, EmbeddingType, - GeluAndMul, RMSNorm, build_rotary_embedding) +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, GeluAndMul, + RMSNorm, RopeType, build_rotary_embedding) from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear) from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight @@ -279,14 +279,14 @@ def __init__(self, # build rotary embedding rope_scaling = getattr(config, 'rope_scaling', None) - emb_type = EmbeddingType.LinearScaling + emb_type = RopeType.LinearScaling scaling_factor = 1.0 if rope_scaling is not None: rope_type = rope_scaling['rope_type'] if rope_type == 'linear': - emb_type = EmbeddingType.LinearScaling + emb_type = RopeType.LinearScaling if rope_type == 'dynamic': - emb_type = EmbeddingType.DynamicNTKScaling + emb_type = RopeType.DynamicNTKScaling else: raise RuntimeError(f'Unsupported rope type: {rope_type}') scaling_factor = rope_scaling.get('scaling_factor', scaling_factor) diff --git a/lmdeploy/pytorch/models/internlm2.py b/lmdeploy/pytorch/models/internlm2.py index 398463187..e386ef283 100644 --- a/lmdeploy/pytorch/models/internlm2.py +++ b/lmdeploy/pytorch/models/internlm2.py @@ -6,8 +6,8 @@ from transformers.configuration_utils import PretrainedConfig from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, EmbeddingType, - RMSNorm, SiluAndMul, build_rotary_embedding) +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType, + SiluAndMul, build_rotary_embedding) from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear) from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight @@ -240,14 +240,14 @@ def __init__(self, # build rotary embedding in Model rope_scaling = config.rope_scaling scaling_factor = 1.0 - emb_type = EmbeddingType.LinearScaling + emb_type = RopeType.LinearScaling if rope_scaling is not None: scaling_factor = rope_scaling.get('factor', scaling_factor) rope_type = rope_scaling['type'] if rope_type == 'linear': - emb_type = EmbeddingType.LinearScaling + emb_type = RopeType.LinearScaling if rope_type == 'dynamic': - emb_type = EmbeddingType.DynamicNTKScaling + emb_type = RopeType.DynamicNTKScaling else: raise RuntimeError(f'Unsupported rope type: {rope_type}') rope_dim = config.hidden_size // config.num_attention_heads diff --git a/lmdeploy/pytorch/models/llama.py b/lmdeploy/pytorch/models/llama.py index 5c015c8f3..286e89f3e 100644 --- a/lmdeploy/pytorch/models/llama.py +++ b/lmdeploy/pytorch/models/llama.py @@ -6,8 +6,8 @@ from transformers.models.llama import LlamaConfig from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, EmbeddingType, - RMSNorm, SiluAndMul, build_rotary_embedding) +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType, + SiluAndMul, build_rotary_embedding) from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear) from lmdeploy.pytorch.nn.rotary_embedding import Llama3Parameters @@ -244,7 +244,7 @@ def __init__(self, llama3_params = None rope_scaling = config.rope_scaling if rope_scaling is None: - emb_type = EmbeddingType.LinearScaling + emb_type = RopeType.LinearScaling else: if 'scaling_factor' in rope_scaling: scaling_factor = rope_scaling['scaling_factor'] @@ -253,11 +253,11 @@ def __init__(self, rope_type = rope_scaling['rope_type'] if rope_type == 'dynamic': - emb_type = EmbeddingType.DynamicNTKScaling + emb_type = RopeType.DynamicNTKScaling if rope_type == 'linear': - emb_type = EmbeddingType.LinearScaling + emb_type = RopeType.LinearScaling elif rope_type == 'llama3': - emb_type = EmbeddingType.Llama3 + emb_type = RopeType.Llama3 low_freq_factor = rope_scaling.get('low_freq_factor', 1.0) high_freq_factor = rope_scaling.get('high_freq_factor', 1.0) llama3_params = Llama3Parameters(low_freq_factor, diff --git a/lmdeploy/pytorch/models/mistral.py b/lmdeploy/pytorch/models/mistral.py index 03240b7b9..97ca327b9 100644 --- a/lmdeploy/pytorch/models/mistral.py +++ b/lmdeploy/pytorch/models/mistral.py @@ -6,8 +6,8 @@ from transformers.configuration_utils import PretrainedConfig from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, EmbeddingType, - RMSNorm, SiluAndMul, build_rotary_embedding) +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType, + SiluAndMul, build_rotary_embedding) from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear) from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight @@ -239,7 +239,7 @@ def __init__(self, device=device) # build rotary embedding - emb_type = EmbeddingType.LinearScaling + emb_type = RopeType.LinearScaling rope_dim = config.hidden_size // config.num_attention_heads rope_max_pos_emb = config.max_position_embeddings rope_base = config.rope_theta diff --git a/lmdeploy/pytorch/models/mixtral.py b/lmdeploy/pytorch/models/mixtral.py index 365731019..2f98d2a32 100644 --- a/lmdeploy/pytorch/models/mixtral.py +++ b/lmdeploy/pytorch/models/mixtral.py @@ -5,8 +5,8 @@ from torch import nn from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, EmbeddingType, - RMSNorm, build_rotary_embedding) +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType, + build_rotary_embedding) from lmdeploy.pytorch.nn.linear import build_qkv_proj, build_rowwise_linear from lmdeploy.pytorch.nn.moe import FusedMoE, SoftmaxTopK from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight @@ -243,7 +243,7 @@ def __init__(self, dtype=dtype, device=device) - emb_type = EmbeddingType.LinearScaling + emb_type = RopeType.LinearScaling rope_dim = config.hidden_size // config.num_attention_heads rope_max_pos_emb = config.max_position_embeddings rope_base = config.rope_theta diff --git a/lmdeploy/pytorch/models/phi3.py b/lmdeploy/pytorch/models/phi3.py index c04bb235a..0e3b38be8 100644 --- a/lmdeploy/pytorch/models/phi3.py +++ b/lmdeploy/pytorch/models/phi3.py @@ -6,8 +6,8 @@ from transformers.configuration_utils import PretrainedConfig from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, EmbeddingType, - RMSNorm, SiluAndMul) +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType, + SiluAndMul) from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear) from lmdeploy.pytorch.nn.rotary_embedding import (LongRoPEScalingParameters, @@ -242,7 +242,7 @@ def __init__(self, device=device) # build rotary embedding - emb_type = EmbeddingType.LinearScaling + emb_type = RopeType.LinearScaling rope_dim = config.hidden_size // config.num_attention_heads rope_max_pos_emb = config.max_position_embeddings rope_base = config.rope_theta @@ -250,7 +250,7 @@ def __init__(self, if rope_scaling is not None: scaling_type = rope_scaling['type'] assert scaling_type in ['longrope', 'su'] - emb_type = EmbeddingType.LongRoPEScaling + emb_type = RopeType.LongRoPEScaling ori_pos_emb = getattr(config, 'original_max_position_embeddings', rope_max_pos_emb) longrope_params = LongRoPEScalingParameters( diff --git a/lmdeploy/pytorch/models/phi3_moe.py b/lmdeploy/pytorch/models/phi3_moe.py index 51e0cfe5b..4fee7bfc5 100644 --- a/lmdeploy/pytorch/models/phi3_moe.py +++ b/lmdeploy/pytorch/models/phi3_moe.py @@ -5,8 +5,7 @@ from torch import nn from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, EmbeddingType, - LayerNorm) +from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, LayerNorm, RopeType from lmdeploy.pytorch.nn.linear import build_qkv_proj, build_rowwise_linear from lmdeploy.pytorch.nn.moe import FusedMoE from lmdeploy.pytorch.nn.rotary_embedding import (LongRoPEScalingParameters, @@ -298,7 +297,7 @@ def __init__(self, device=device) # build rotary embedding - emb_type = EmbeddingType.LinearScaling + emb_type = RopeType.LinearScaling rope_dim = config.hidden_size // config.num_attention_heads rope_max_pos_emb = config.max_position_embeddings rope_base = config.rope_theta @@ -306,7 +305,7 @@ def __init__(self, if rope_scaling is not None: scaling_type = rope_scaling['type'] assert scaling_type in ['longrope', 'su'] - emb_type = EmbeddingType.LongRoPEScaling + emb_type = RopeType.LongRoPEScaling ori_pos_emb = getattr(config, 'original_max_position_embeddings', rope_max_pos_emb) longrope_params = LongRoPEScalingParameters( diff --git a/lmdeploy/pytorch/models/qwen.py b/lmdeploy/pytorch/models/qwen.py index 0d33e8268..e1268b16c 100644 --- a/lmdeploy/pytorch/models/qwen.py +++ b/lmdeploy/pytorch/models/qwen.py @@ -6,8 +6,8 @@ from transformers.configuration_utils import PretrainedConfig from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, EmbeddingType, - RMSNorm, SiluAndMul, build_rotary_embedding) +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType, + SiluAndMul, build_rotary_embedding) from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear) from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight @@ -238,7 +238,7 @@ def __init__(self, ]) # build rotary embedding - emb_type = EmbeddingType.LinearScaling + emb_type = RopeType.LinearScaling if config.rotary_pct == 1.0: self.rotary_ndims = None else: diff --git a/lmdeploy/pytorch/models/qwen2.py b/lmdeploy/pytorch/models/qwen2.py index 68678e02b..fabc76605 100644 --- a/lmdeploy/pytorch/models/qwen2.py +++ b/lmdeploy/pytorch/models/qwen2.py @@ -6,8 +6,8 @@ from transformers.configuration_utils import PretrainedConfig from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, EmbeddingType, - RMSNorm, SiluAndMul, build_rotary_embedding) +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType, + SiluAndMul, build_rotary_embedding) from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear) from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight @@ -239,7 +239,7 @@ def __init__(self, device=device) # build rotary embedding - emb_type = EmbeddingType.LinearScaling + emb_type = RopeType.LinearScaling rope_dim = config.hidden_size // config.num_attention_heads rope_max_pos_emb = config.max_position_embeddings rope_base = config.rope_theta diff --git a/lmdeploy/pytorch/models/qwen2_moe.py b/lmdeploy/pytorch/models/qwen2_moe.py index c46104f5e..e6dcd8504 100644 --- a/lmdeploy/pytorch/models/qwen2_moe.py +++ b/lmdeploy/pytorch/models/qwen2_moe.py @@ -8,8 +8,8 @@ from transformers.configuration_utils import PretrainedConfig from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, EmbeddingType, - RMSNorm, SiluAndMul, build_rotary_embedding) +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType, + SiluAndMul, build_rotary_embedding) from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear) from lmdeploy.pytorch.nn.moe import FusedMoE, SoftmaxTopK @@ -358,7 +358,7 @@ def __init__(self, device=device) # build rotary embedding - emb_type = EmbeddingType.LinearScaling + emb_type = RopeType.LinearScaling rope_dim = config.hidden_size // config.num_attention_heads rope_max_pos_emb = config.max_position_embeddings rope_base = config.rope_theta diff --git a/lmdeploy/pytorch/models/starcoder2.py b/lmdeploy/pytorch/models/starcoder2.py index 7cb97be2b..766683a9f 100644 --- a/lmdeploy/pytorch/models/starcoder2.py +++ b/lmdeploy/pytorch/models/starcoder2.py @@ -6,8 +6,8 @@ from transformers.configuration_utils import PretrainedConfig from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, EmbeddingType, - LayerNorm, build_rotary_embedding) +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, LayerNorm, + RopeType, build_rotary_embedding) from lmdeploy.pytorch.nn.linear import (build_colwise_linear, build_qkv_proj, build_rowwise_linear) from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight @@ -242,7 +242,7 @@ def __init__(self, device=device) # build rotary embedding - emb_type = EmbeddingType.LinearScaling + emb_type = RopeType.LinearScaling rope_dim = config.hidden_size // config.num_attention_heads rope_max_pos_emb = config.max_position_embeddings rope_base = config.rope_theta diff --git a/lmdeploy/pytorch/nn/__init__.py b/lmdeploy/pytorch/nn/__init__.py index e90899bb8..2b90f4029 100644 --- a/lmdeploy/pytorch/nn/__init__.py +++ b/lmdeploy/pytorch/nn/__init__.py @@ -5,6 +5,6 @@ from .attention import Attention # noqa: F401 from .norm import LayerNorm, RMSNorm # noqa: F401 from .rotary_embedding import ApplyRotaryEmb # noqa: F401 -from .rotary_embedding import EmbeddingType # noqa: F401 +from .rotary_embedding import RopeType # noqa: F401 from .rotary_embedding import YarnParameters # noqa: F401 from .rotary_embedding import build_rotary_embedding # noqa: F401 diff --git a/lmdeploy/pytorch/nn/rotary_embedding.py b/lmdeploy/pytorch/nn/rotary_embedding.py index f7891c600..35a7de714 100644 --- a/lmdeploy/pytorch/nn/rotary_embedding.py +++ b/lmdeploy/pytorch/nn/rotary_embedding.py @@ -2,20 +2,19 @@ from torch import Tensor, nn from ..backends import OpType, get_backend -from ..backends.rotary_embedding import (EmbeddingType, Llama3Parameters, - LongRoPEScalingParameters, +from ..backends.rotary_embedding import (Llama3Parameters, + LongRoPEScalingParameters, RopeType, YarnParameters) -def build_rotary_embedding( - dim: int, - max_position_embeddings: int = 2048, - base: int = 10000, - scaling_factor: float = 1.0, - yarn_params: YarnParameters = None, - longrope_params: LongRoPEScalingParameters = None, - llama3_params: Llama3Parameters = None, - emb_type: EmbeddingType = EmbeddingType.Default) -> nn.Module: +def build_rotary_embedding(dim: int, + max_position_embeddings: int = 2048, + base: int = 10000, + scaling_factor: float = 1.0, + yarn_params: YarnParameters = None, + longrope_params: LongRoPEScalingParameters = None, + llama3_params: Llama3Parameters = None, + emb_type: RopeType = RopeType.Default) -> nn.Module: """build rotary embedding op.""" backend = get_backend() From b74a22dc8a1f9f6ddc296b5aa4e691db5828dd71 Mon Sep 17 00:00:00 2001 From: chenchiyu Date: Thu, 5 Sep 2024 16:59:45 +0800 Subject: [PATCH 73/89] fix a hanging problem when using cli serve mode and device ascend on exit --- lmdeploy/pytorch/engine/model_agent.py | 31 +++++++++++++++++--------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 266a532b7..46d17b5c8 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -676,8 +676,6 @@ def _start_sub_process(self, model_path: str, model_config: ModelConfig, _check_context_alive(self.mp_context) rank = 0 - # Please see [Note Exit By Sending Exit Flag] - atexit.register(_exit_by_sending_exit_flag, rank, self) try: dist.init_process_group('nccl', rank=rank, @@ -690,6 +688,8 @@ def _start_sub_process(self, model_path: str, model_config: ModelConfig, if dist.is_initialized(): dist.destroy_process_group() raise e + # Please see Note [Exit By Sending Exit Flag] + atexit.register(_exit_by_sending_exit_flag, rank, self) @torch.inference_mode() def _build_model( @@ -786,10 +786,26 @@ async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, return output -# [Note] Exit By Sending Exit Flag -# the registration of this function in atexit should be called -# after importing torch.multiprocessing def _exit_by_sending_exit_flag(rank: int, agent: TPModelAgent): + """[Note] Exit By Sending Exit Flag: the registration to `atexit` of this + function should be called after importing torch.multiprocessing and the + initialization of distributed process group.""" + if not hasattr(agent, 'stream'): + # agent is not initialized, just exits normally + if hasattr(agent, 'patched_model'): + del agent.patched_model + return + + import sys + if agent.backend_config.device_type == 'ascend' \ + and 'uvicorn.server' in sys.modules: + # Workaround for CLI serve mode with device_type ascend: + # using uvicorn server causes ascend low-level backend of subprocesses + # corrupted, and using _broadcast_inputs in this case leads to + # main process hanging, just exits normally + del agent.patched_model + return + # send exit_flag to all subprocess relying on all subprocess are alive # and wait at _broadcast_inputs exit_flag = True @@ -798,11 +814,6 @@ def _exit_by_sending_exit_flag(rank: int, agent: TPModelAgent): del agent.patched_model - # Tricky, extra sleep for subprocess releasing resources - import time - time.sleep(1) - return - def build_model_agent(model_path: str, cache_config: CacheConfig, From d81012869ddda5b274504205de9a581e76212dc5 Mon Sep 17 00:00:00 2001 From: grimoire Date: Thu, 5 Sep 2024 18:34:15 +0800 Subject: [PATCH 74/89] raise -> return --- lmdeploy/pytorch/backends/ascend/op_backend.py | 2 +- lmdeploy/pytorch/backends/cuda/op_backend.py | 2 +- lmdeploy/pytorch/backends/default/op_backend.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lmdeploy/pytorch/backends/ascend/op_backend.py b/lmdeploy/pytorch/backends/ascend/op_backend.py index d03153187..90ea3efea 100644 --- a/lmdeploy/pytorch/backends/ascend/op_backend.py +++ b/lmdeploy/pytorch/backends/ascend/op_backend.py @@ -11,7 +11,7 @@ class AscendOpsBackend(DefaultOpsBackend): @staticmethod def get_name() -> str: - raise 'ascend' + return 'ascend' @classmethod def get_layer_impl_builder(cls, layer_type: OpType): diff --git a/lmdeploy/pytorch/backends/cuda/op_backend.py b/lmdeploy/pytorch/backends/cuda/op_backend.py index 2d79b4c59..c7da9c40e 100644 --- a/lmdeploy/pytorch/backends/cuda/op_backend.py +++ b/lmdeploy/pytorch/backends/cuda/op_backend.py @@ -18,7 +18,7 @@ class CudaOpsBackend(DefaultOpsBackend): @staticmethod def get_name() -> str: """backend name.""" - raise 'cuda' + return 'cuda' @classmethod def get_layer_impl_builder(cls, layer_type: OpType): diff --git a/lmdeploy/pytorch/backends/default/op_backend.py b/lmdeploy/pytorch/backends/default/op_backend.py index b9d268f8f..4db018a2d 100644 --- a/lmdeploy/pytorch/backends/default/op_backend.py +++ b/lmdeploy/pytorch/backends/default/op_backend.py @@ -10,7 +10,7 @@ class DefaultOpsBackend(OpsBackend): @staticmethod def get_name() -> str: - raise 'default' + return 'default' @classmethod def get_layer_impl_builder(cls, layer_type: OpType): From 27ec376508761cb70da7d3d5c521113d5d183a28 Mon Sep 17 00:00:00 2001 From: grimoire Date: Thu, 5 Sep 2024 19:27:36 +0800 Subject: [PATCH 75/89] optimize moe --- lmdeploy/pytorch/backends/cuda/moe.py | 9 ++++ lmdeploy/pytorch/kernels/cuda/fused_moe.py | 52 ++-------------------- 2 files changed, 13 insertions(+), 48 deletions(-) diff --git a/lmdeploy/pytorch/backends/cuda/moe.py b/lmdeploy/pytorch/backends/cuda/moe.py index da06e6725..e5ae92d8b 100644 --- a/lmdeploy/pytorch/backends/cuda/moe.py +++ b/lmdeploy/pytorch/backends/cuda/moe.py @@ -14,6 +14,15 @@ def __init__(self, top_k: int, renormalize: bool = False): self.top_k = top_k self.renormalize = renormalize + def update_weights(self, gate_up_weights: torch.Tensor, + down_weights: torch.Tensor): + gate_up_weights = gate_up_weights.transpose(1, + 2).contiguous().transpose( + 1, 2) + down_weights = down_weights.transpose(1, + 2).contiguous().transpose(1, 2) + return gate_up_weights, down_weights + def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.LongTensor, gate_up_weights: torch.Tensor, down_weights: torch.Tensor): diff --git a/lmdeploy/pytorch/kernels/cuda/fused_moe.py b/lmdeploy/pytorch/kernels/cuda/fused_moe.py index 852799d2c..9d1b224ef 100644 --- a/lmdeploy/pytorch/kernels/cuda/fused_moe.py +++ b/lmdeploy/pytorch/kernels/cuda/fused_moe.py @@ -19,30 +19,12 @@ def get_cuda_autotune_config(): }, num_stages=3, num_warps=8), - triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 256, - 'BLOCK_SIZE_K': 32, - 'GROUP_SIZE_M': 8, - }, - num_stages=4, - num_warps=4), triton.Config( { 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, - 'GROUP_SIZE_M': 8, - }, - num_stages=4, - num_warps=4), - triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32, - 'GROUP_SIZE_M': 8, + 'GROUP_SIZE_M': 1, }, num_stages=4, num_warps=4), @@ -50,38 +32,11 @@ def get_cuda_autotune_config(): { 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, - 'BLOCK_SIZE_K': 32, - 'GROUP_SIZE_M': 8, - }, - num_stages=4, - num_warps=4), - triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 32, - 'GROUP_SIZE_M': 8, + 'BLOCK_SIZE_K': 64, + 'GROUP_SIZE_M': 1, }, num_stages=4, num_warps=4), - triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 32, - 'GROUP_SIZE_M': 8, - }, - num_stages=5, - num_warps=2), - triton.Config( - { - 'BLOCK_SIZE_M': 32, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32, - 'GROUP_SIZE_M': 8, - }, - num_stages=5, - num_warps=2), ] @@ -234,6 +189,7 @@ def fused_moe_kernel_launcher( if num_tokens is None: num_tokens = A.size(0) M_NP2 = triton.next_power_of_2(num_tokens) + M_NP2 = max(32, M_NP2) E, N, K = B.shape def _grid_fn(META): From b98153e60b5650765c94bc8a27b1526d6de77c57 Mon Sep 17 00:00:00 2001 From: grimoire Date: Fri, 6 Sep 2024 11:07:17 +0800 Subject: [PATCH 76/89] fix linear awq bias, default awq kernel --- lmdeploy/pytorch/backends/cuda/op_backend.py | 18 +++--- .../pytorch/backends/default/awq_modules.py | 61 ++++++++++++++++++- lmdeploy/pytorch/nn/linear.py | 6 +- 3 files changed, 72 insertions(+), 13 deletions(-) diff --git a/lmdeploy/pytorch/backends/cuda/op_backend.py b/lmdeploy/pytorch/backends/cuda/op_backend.py index c7da9c40e..3f8fc43b3 100644 --- a/lmdeploy/pytorch/backends/cuda/op_backend.py +++ b/lmdeploy/pytorch/backends/cuda/op_backend.py @@ -44,15 +44,15 @@ def get_layer_impl_builder(cls, layer_type: OpType): elif layer_type == OpType.MultinomialSampling: from .multinomial_sampling import TritonMultinomialSamplingBuilder return TritonMultinomialSamplingBuilder - elif layer_type == OpType.LinearW4A16: - from awq.modules.linear.gemm import AWQ_INSTALLED - if AWQ_INSTALLED: - from .awq_modules import AwqLinearW4A16Builder - return AwqLinearW4A16Builder - else: - logger.debug( - f'Op {layer_type} fallback to default implementation.') - return super().get_layer_impl_builder(layer_type) + # elif layer_type == OpType.LinearW4A16: + # from awq.modules.linear.gemm import AWQ_INSTALLED + # if AWQ_INSTALLED: + # from .awq_modules import AwqLinearW4A16Builder + # return AwqLinearW4A16Builder + # else: + # logger.debug( + # f'Op {layer_type} fallback to default implementation.') + # return super().get_layer_impl_builder(layer_type) elif layer_type == OpType.FusedMoE: from .moe import TritonFusedMoEBuilder return TritonFusedMoEBuilder diff --git a/lmdeploy/pytorch/backends/default/awq_modules.py b/lmdeploy/pytorch/backends/default/awq_modules.py index c0a961aad..cb22ded5f 100644 --- a/lmdeploy/pytorch/backends/default/awq_modules.py +++ b/lmdeploy/pytorch/backends/default/awq_modules.py @@ -6,6 +6,66 @@ from ..awq_modules import LinearW4A16Builder, LinearW4A16Impl +AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7] + + +def unpack_awq(qweight: torch.Tensor, qzeros: torch.Tensor, bits: int): + shifts = torch.arange(0, 32, bits, device=qzeros.device) + + # unpacking columnwise + iweights = torch.bitwise_right_shift( + qweight[:, :, None], + shifts[None, None, :]).to(torch.int8 # smallest dtype available + ) + iweights = iweights.view(iweights.shape[0], -1) + + # unpacking columnwise + izeros = torch.bitwise_right_shift( + qzeros[:, :, None], + shifts[None, None, :]).to(torch.int8 # smallest dtype available + ) + izeros = izeros.view(izeros.shape[0], -1) + + return iweights, izeros + + +def reverse_awq_order(iweights: torch.Tensor, izeros: torch.Tensor, bits: int): + reverse_order_tensor = torch.arange( + izeros.shape[-1], + dtype=torch.int32, + device=izeros.device, + ) + # (-1, 8) + reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits) + # (-1, 2, 4) + reverse_order_tensor = reverse_order_tensor.unflatten(-1, (2, 4)) + # (-1, 4, 2) + reverse_order_tensor = reverse_order_tensor.transpose(-1, -2) + reverse_order_tensor = reverse_order_tensor.flatten() + + izeros = izeros[:, reverse_order_tensor] + iweights = iweights[:, reverse_order_tensor] + + return iweights, izeros + + +def dequantize_gemm(qweight, qzeros, scales, bits, group_size): + # Unpack the qweight and qzeros tensors + iweight, izeros = unpack_awq(qweight, qzeros, bits) + # Reverse the order of the iweight and izeros tensors + iweight, izeros = reverse_awq_order(iweight, izeros, bits) + + # overflow checks + iweight = torch.bitwise_and(iweight, (2**bits) - 1) + izeros = torch.bitwise_and(izeros, (2**bits) - 1) + + # fp16 weights + scales = scales.repeat_interleave(group_size, dim=0) + izeros = izeros.repeat_interleave(group_size, dim=0) + iweight = (iweight - izeros) * scales + + return iweight + class DefaultLinearW4A16Impl(LinearW4A16Impl): """w4a16 linear implementation.""" @@ -25,7 +85,6 @@ def forward(self, bias: Optional[torch.Tensor] = None, all_reduce: bool = False): """forward.""" - from awq.utils.packing_utils import dequantize_gemm out_shape = x.shape[:-1] + (self.out_features, ) input_dtype = x.dtype if input_dtype != torch.float16: diff --git a/lmdeploy/pytorch/nn/linear.py b/lmdeploy/pytorch/nn/linear.py index 407479a87..e0c12a4d8 100644 --- a/lmdeploy/pytorch/nn/linear.py +++ b/lmdeploy/pytorch/nn/linear.py @@ -401,10 +401,10 @@ def weight_loader(self, param: torch.nn.Parameter, align)[rank] param_w.copy_(weight) - if param._weight_type == 'scales': + if param._weight_type in ['scales', 'bias']: # scales align = max(self.elem_per_int, self.group_size) - param_w = param.data.split(self.all_out_features, 1)[shard_idx] + param_w = param.data.split(self.all_out_features, -1)[shard_idx] else: # qweight or qzeros align = max(self.elem_per_int, @@ -415,7 +415,7 @@ def weight_loader(self, param: torch.nn.Parameter, param_w = param.data.split(quanted_out_feats, 1)[shard_idx] if not self.replicate[shard_idx]: - weight = _chunk_align(loaded_weight, world_size, 1, align)[rank] + weight = _chunk_align(loaded_weight, world_size, -1, align)[rank] param_w.copy_(weight) def weight_spliter_wz(self, loaded_weight: torch.Tensor): From 77616aaab5b22ffb150f5d0908f20034d04e049d Mon Sep 17 00:00:00 2001 From: grimoire Date: Fri, 6 Sep 2024 11:07:48 +0800 Subject: [PATCH 77/89] fix --- lmdeploy/pytorch/backends/cuda/op_backend.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/lmdeploy/pytorch/backends/cuda/op_backend.py b/lmdeploy/pytorch/backends/cuda/op_backend.py index 3f8fc43b3..c7da9c40e 100644 --- a/lmdeploy/pytorch/backends/cuda/op_backend.py +++ b/lmdeploy/pytorch/backends/cuda/op_backend.py @@ -44,15 +44,15 @@ def get_layer_impl_builder(cls, layer_type: OpType): elif layer_type == OpType.MultinomialSampling: from .multinomial_sampling import TritonMultinomialSamplingBuilder return TritonMultinomialSamplingBuilder - # elif layer_type == OpType.LinearW4A16: - # from awq.modules.linear.gemm import AWQ_INSTALLED - # if AWQ_INSTALLED: - # from .awq_modules import AwqLinearW4A16Builder - # return AwqLinearW4A16Builder - # else: - # logger.debug( - # f'Op {layer_type} fallback to default implementation.') - # return super().get_layer_impl_builder(layer_type) + elif layer_type == OpType.LinearW4A16: + from awq.modules.linear.gemm import AWQ_INSTALLED + if AWQ_INSTALLED: + from .awq_modules import AwqLinearW4A16Builder + return AwqLinearW4A16Builder + else: + logger.debug( + f'Op {layer_type} fallback to default implementation.') + return super().get_layer_impl_builder(layer_type) elif layer_type == OpType.FusedMoE: from .moe import TritonFusedMoEBuilder return TritonFusedMoEBuilder From 3bfcae2f0447d5ad3e508ef10f46b62e6438c159 Mon Sep 17 00:00:00 2001 From: grimoire Date: Fri, 6 Sep 2024 11:58:23 +0800 Subject: [PATCH 78/89] optimize default awq --- .../pytorch/backends/default/awq_modules.py | 57 +++++++------------ 1 file changed, 19 insertions(+), 38 deletions(-) diff --git a/lmdeploy/pytorch/backends/default/awq_modules.py b/lmdeploy/pytorch/backends/default/awq_modules.py index cb22ded5f..98e3b2f9f 100644 --- a/lmdeploy/pytorch/backends/default/awq_modules.py +++ b/lmdeploy/pytorch/backends/default/awq_modules.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from functools import lru_cache from typing import Optional import torch @@ -6,45 +7,31 @@ from ..awq_modules import LinearW4A16Builder, LinearW4A16Impl -AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7] + +@lru_cache +def get_shifts(bits: int, device: torch.device): + """get awq shifts.""" + shifts = torch.arange(0, 32, bits, device=device) + shifts = shifts.view(2, 4).t().flatten() + return shifts def unpack_awq(qweight: torch.Tensor, qzeros: torch.Tensor, bits: int): - shifts = torch.arange(0, 32, bits, device=qzeros.device) + shifts = get_shifts(bits, qzeros.device) # unpacking columnwise - iweights = torch.bitwise_right_shift( - qweight[:, :, None], - shifts[None, None, :]).to(torch.int8 # smallest dtype available - ) + iweights = torch.bitwise_right_shift(qweight[:, :, None], + shifts[None, None, :]).to(torch.int8) iweights = iweights.view(iweights.shape[0], -1) # unpacking columnwise - izeros = torch.bitwise_right_shift( - qzeros[:, :, None], - shifts[None, None, :]).to(torch.int8 # smallest dtype available - ) + izeros = torch.bitwise_right_shift(qzeros[:, :, None], + shifts[None, None, :]).to(torch.int8) izeros = izeros.view(izeros.shape[0], -1) - return iweights, izeros - - -def reverse_awq_order(iweights: torch.Tensor, izeros: torch.Tensor, bits: int): - reverse_order_tensor = torch.arange( - izeros.shape[-1], - dtype=torch.int32, - device=izeros.device, - ) - # (-1, 8) - reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits) - # (-1, 2, 4) - reverse_order_tensor = reverse_order_tensor.unflatten(-1, (2, 4)) - # (-1, 4, 2) - reverse_order_tensor = reverse_order_tensor.transpose(-1, -2) - reverse_order_tensor = reverse_order_tensor.flatten() - - izeros = izeros[:, reverse_order_tensor] - iweights = iweights[:, reverse_order_tensor] + # overflow checks + iweights = torch.bitwise_and(iweights, (2**bits) - 1) + izeros = torch.bitwise_and(izeros, (2**bits) - 1) return iweights, izeros @@ -52,17 +39,11 @@ def reverse_awq_order(iweights: torch.Tensor, izeros: torch.Tensor, bits: int): def dequantize_gemm(qweight, qzeros, scales, bits, group_size): # Unpack the qweight and qzeros tensors iweight, izeros = unpack_awq(qweight, qzeros, bits) - # Reverse the order of the iweight and izeros tensors - iweight, izeros = reverse_awq_order(iweight, izeros, bits) - - # overflow checks - iweight = torch.bitwise_and(iweight, (2**bits) - 1) - izeros = torch.bitwise_and(izeros, (2**bits) - 1) # fp16 weights - scales = scales.repeat_interleave(group_size, dim=0) - izeros = izeros.repeat_interleave(group_size, dim=0) - iweight = (iweight - izeros) * scales + iweight = iweight.unflatten(0, (-1, group_size)) + iweight = (iweight - izeros[:, None]) * scales[:, None] + iweight = iweight.flatten(0, 1) return iweight From 50f5b3ce13123f507447d60ae9ff8c7cdeb7da41 Mon Sep 17 00:00:00 2001 From: grimoire Date: Fri, 6 Sep 2024 13:40:55 +0800 Subject: [PATCH 79/89] fix llama rope, add internlm --- lmdeploy/pytorch/models/internlm.py | 426 ++++++++++++++++++++++++++ lmdeploy/pytorch/models/llama.py | 2 +- lmdeploy/pytorch/models/module_map.py | 2 +- 3 files changed, 428 insertions(+), 2 deletions(-) create mode 100644 lmdeploy/pytorch/models/internlm.py diff --git a/lmdeploy/pytorch/models/internlm.py b/lmdeploy/pytorch/models/internlm.py new file mode 100644 index 000000000..8b57aedc8 --- /dev/null +++ b/lmdeploy/pytorch/models/internlm.py @@ -0,0 +1,426 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Any, Iterable, List, Optional, Tuple + +import torch +from torch import nn +from transformers.configuration_utils import PretrainedConfig + +from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType, + SiluAndMul, build_rotary_embedding) +from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear, + build_qkv_proj, build_rowwise_linear) +from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight + + +class InternLMAttention(nn.Module): + """Rewrite module of LlamaAttention.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + num_heads = config.num_attention_heads + num_key_value_heads = config.num_key_value_heads + hidden_size = config.hidden_size + head_dim = getattr(config, 'head_dim', hidden_size // num_heads) + + # packed qkv + self.qkv_proj = build_qkv_proj( + hidden_size, + num_q_heads=num_heads, + num_kv_heads=num_key_value_heads, + head_size=head_dim, + bias=config.bias, + quant_config=quantization_config, + dtype=dtype, + device=device, + ) + + # rotary embedding + self.apply_rotary_pos_emb = ApplyRotaryEmb() + + # attention + self.attn_fwd = Attention( + num_heads, + head_dim, + num_kv_heads=num_key_value_heads, + v_head_size=head_dim, + ) + + # o_proj + self.o_proj = build_rowwise_linear(num_heads * head_dim, + hidden_size, + bias=config.bias, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward( + self, + hidden_states: torch.Tensor, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attn_metadata: Any = None, + ): + """Rewrite of LlamaAttention.forward.""" + # qkv proj + qkv_states = self.qkv_proj(hidden_states) + # (-1, heads, head_dim) + qkv_states = qkv_states.flatten(0, -2) + query_states, key_states, value_states = self.qkv_proj.split_qkv( + qkv_states) + + # apply rotary embedding + cos, sin = rotary_pos_emb + query_states, key_states = self.apply_rotary_pos_emb( + query_states, + key_states, + cos, + sin, + inplace=True, + ) + + # attention + attn_output = self.attn_fwd( + query_states, + key_states, + value_states, + past_key_value[0], + past_key_value[1], + attn_metadata, + inplace=True, + ) + attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1) + + # o proj + attn_output = self.o_proj(attn_output) + return attn_output + + +class InternLMMLP(nn.Module): + """mlp.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + # gate up + self.gate_up_proj = build_merged_colwise_linear( + config.hidden_size, + [config.intermediate_size, config.intermediate_size], + bias=config.bias, + dtype=dtype, + device=device, + quant_config=quantization_config, + is_tp=True, + ) + + # silu and mul + self.act_fn = SiluAndMul(inplace=True) + + # down + self.down_proj = build_rowwise_linear(config.intermediate_size, + config.hidden_size, + bias=config.bias, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward(self, x): + """forward.""" + gate_up = self.gate_up_proj(x) + act = self.act_fn(gate_up) + return self.down_proj(act) + + +class InternLMDecoderLayer(nn.Module): + """decoder layer.""" + + def __init__(self, + config: PretrainedConfig, + layer_idx: int, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.layer_idx = layer_idx + quantization_config = getattr(config, 'quantization_config', None) + + # build attention layer + self.self_attn = InternLMAttention(config, dtype=dtype, device=device) + + # builf MLP + self.mlp = InternLMMLP(config, dtype=dtype, device=device) + + # build input layer norm + self.input_layernorm = RMSNorm(config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) + + # build attention layer norm + self.post_attention_layernorm = RMSNorm( + config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) + + def forward( + self, + hidden_states: torch.Tensor, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], + past_key_value: Optional[List[torch.FloatTensor]], + residual: Optional[torch.Tensor] = None, + attn_metadata: Any = None, + ): + + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + # Self Attention + hidden_states = self.self_attn( + hidden_states=hidden_states, + rotary_pos_emb=rotary_pos_emb, + past_key_value=past_key_value, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + + outputs = (hidden_states, residual) + return outputs + + +class InternLMModel(nn.Module): + """model.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=dtype, + device=device) + + # build all decode layers + self.layers = nn.ModuleList([ + InternLMDecoderLayer(config, layer_idx, dtype=dtype, device=device) + for layer_idx in range(config.num_hidden_layers) + ]) + + # build norm + self.norm = RMSNorm(config.hidden_size, + config.rms_norm_eps, + dtype=dtype, + device=device) + + # build rotary embedding in LlamaModel + rope_dim = config.hidden_size // config.num_attention_heads + rope_max_pos_emb = config.max_position_embeddings + scaling_factor = 1.0 + rope_scaling = config.rotary + rope_base = rope_scaling["base"] + rope_type = rope_scaling['type'] + if rope_type == 'dynamic': + emb_type = RopeType.DynamicNTKScaling + scaling_factor=rope_scaling.get("scaling_factor", 1.0) + elif rope_type == 'origin': + emb_type = RopeType.LinearScaling + else: + raise RuntimeError(f'Unsupported rope type: {rope_type}') + + self.rotary_emb = build_rotary_embedding( + rope_dim, + rope_max_pos_emb, + rope_base, + scaling_factor, + emb_type=emb_type, + ) + + def forward( + self, + input_ids: torch.LongTensor = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + attn_metadata: Any = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ): + """Rewrite of LlamaModel.forward.""" + + # token embedding + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + hidden_states = inputs_embeds + + # rotary embedding + cos, sin = self.rotary_emb(hidden_states, position_ids) + cos, sin = cos[0], sin[0] + rotary_pos_emb = (cos, sin) + + # decoding + residual = None + for idx, decoder_layer in enumerate(self.layers): + past_key_value = past_key_values[idx] + hidden_states, residual = decoder_layer( + hidden_states, + rotary_pos_emb=rotary_pos_emb, + past_key_value=past_key_value, + residual=residual, + attn_metadata=attn_metadata, + ) + + # norm + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + def get_input_embeddings(self): + """get input embeddings.""" + return self.embed_tokens + + +class InternLMForCausalLM(nn.Module): + """rewrote model of LlamaForCausalLM.""" + + support_cuda_graph = True + + packed_modules_mapping = { + 'qkv_proj': [ + 'q_proj', + 'k_proj', + 'v_proj', + ], + 'gate_up_proj': [ + 'gate_proj', + 'up_proj', + ], + } + + def __init__(self, + config: PretrainedConfig, + ctx_mgr: StepContextManager, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.ctx_mgr = ctx_mgr + # build LLamaModel + self.model = InternLMModel(config, dtype=dtype, device=device) + # build lm_head + self.lm_head = build_rowwise_linear(config.hidden_size, + config.vocab_size, + bias=False, + dtype=dtype, + device=device) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: List[List[torch.Tensor]], + attn_metadata: Any = None, + inputs_embeds: torch.Tensor = None, + **kwargs, + ): + """model forward, return logits.""" + hidden_states = self.model( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + ) + + logits = self.lm_head(hidden_states) + logits = logits.float() + return logits + + def get_input_embeddings(self): + """get input embeddings.""" + return self.model.get_input_embeddings() + + def prepare_inputs_for_generation( + self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + context: StepContext = None, + ): + """prepare input.""" + # get input_ids, position_ids and attention metadatas + input_ids = context.input_ids + position_ids = context.position_ids + attn_metadata = context.attn_metadata + + # process vision embeddings + vision_embeddings = context.input_embeddings + vision_embedding_indexing = context.input_embedding_indexing + if vision_embeddings is not None and len(vision_embeddings) > 0: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + inputs_embeds[:, + vision_embedding_indexing, :] = vision_embeddings.to( + inputs_embeds) + + # inputs of forward + return dict( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """load weights.""" + # modify from vllm + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ('.qkv_proj', '.q_proj', 'q'), + ('.qkv_proj', '.k_proj', 'k'), + ('.qkv_proj', '.v_proj', 'v'), + ('.gate_up_proj', '.gate_proj', 0), + ('.gate_up_proj', '.up_proj', 1), + ] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if 'rotary_emb.inv_freq' in name: + continue + if ('rotary_emb.cos_cached' in name + or 'rotary_emb.sin_cached' in name): + continue + if self.config.tie_word_embeddings and 'lm_head.weight' in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + load_weight(param, loaded_weight, shard_id=shard_id) + break + else: + param = params_dict[name] + load_weight(param, loaded_weight) + diff --git a/lmdeploy/pytorch/models/llama.py b/lmdeploy/pytorch/models/llama.py index 286e89f3e..6a63e57fe 100644 --- a/lmdeploy/pytorch/models/llama.py +++ b/lmdeploy/pytorch/models/llama.py @@ -254,7 +254,7 @@ def __init__(self, rope_type = rope_scaling['rope_type'] if rope_type == 'dynamic': emb_type = RopeType.DynamicNTKScaling - if rope_type == 'linear': + elif rope_type == 'linear': emb_type = RopeType.LinearScaling elif rope_type == 'llama3': emb_type = RopeType.Llama3 diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index 9840367b6..75ba2f514 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -35,7 +35,7 @@ # internlm MODULE_MAP.update({ 'InternLMForCausalLM': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaForCausalLM', + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.internlm.InternLMForCausalLM', }) # internlm2 From e61ddcf72da4d8db5cd017a9c4c9f76bf6626289 Mon Sep 17 00:00:00 2001 From: grimoire Date: Fri, 6 Sep 2024 15:09:02 +0800 Subject: [PATCH 80/89] optimize decoding --- lmdeploy/pytorch/kernels/cuda/pagedattention.py | 5 ++++- lmdeploy/pytorch/models/internlm.py | 5 ++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/lmdeploy/pytorch/kernels/cuda/pagedattention.py b/lmdeploy/pytorch/kernels/cuda/pagedattention.py index c5d4d4db6..27891a9be 100644 --- a/lmdeploy/pytorch/kernels/cuda/pagedattention.py +++ b/lmdeploy/pytorch/kernels/cuda/pagedattention.py @@ -240,7 +240,10 @@ def _fwd_grouped_split_kernel( # initialize pointers to output off_acc = (cur_batch * stride_obs + split_k_id * stride_ok + cur_head[:, None] * stride_oh + offs_dv[None, :] * stride_od) - tl.store(Acc_out + off_acc, acc, mask=mask_h[:, None] & mask_dv[None, :]) + if loop_end > loop_start: + tl.store(Acc_out + off_acc, + acc, + mask=mask_h[:, None] & mask_dv[None, :]) off_meta = (cur_batch * stride_obs + split_k_id * stride_ok + cur_head * stride_oh + head_size_v) diff --git a/lmdeploy/pytorch/models/internlm.py b/lmdeploy/pytorch/models/internlm.py index 8b57aedc8..90eb89991 100644 --- a/lmdeploy/pytorch/models/internlm.py +++ b/lmdeploy/pytorch/models/internlm.py @@ -240,11 +240,11 @@ def __init__(self, rope_max_pos_emb = config.max_position_embeddings scaling_factor = 1.0 rope_scaling = config.rotary - rope_base = rope_scaling["base"] + rope_base = rope_scaling['base'] rope_type = rope_scaling['type'] if rope_type == 'dynamic': emb_type = RopeType.DynamicNTKScaling - scaling_factor=rope_scaling.get("scaling_factor", 1.0) + scaling_factor = rope_scaling.get('scaling_factor', 1.0) elif rope_type == 'origin': emb_type = RopeType.LinearScaling else: @@ -423,4 +423,3 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): else: param = params_dict[name] load_weight(param, loaded_weight) - From a27bf512bdc7ec2288dc8a2f34d308e6e15769ab Mon Sep 17 00:00:00 2001 From: grimoire Date: Fri, 6 Sep 2024 17:12:31 +0800 Subject: [PATCH 81/89] recovery attention --- lmdeploy/pytorch/kernels/cuda/pagedattention.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/lmdeploy/pytorch/kernels/cuda/pagedattention.py b/lmdeploy/pytorch/kernels/cuda/pagedattention.py index 27891a9be..c5d4d4db6 100644 --- a/lmdeploy/pytorch/kernels/cuda/pagedattention.py +++ b/lmdeploy/pytorch/kernels/cuda/pagedattention.py @@ -240,10 +240,7 @@ def _fwd_grouped_split_kernel( # initialize pointers to output off_acc = (cur_batch * stride_obs + split_k_id * stride_ok + cur_head[:, None] * stride_oh + offs_dv[None, :] * stride_od) - if loop_end > loop_start: - tl.store(Acc_out + off_acc, - acc, - mask=mask_h[:, None] & mask_dv[None, :]) + tl.store(Acc_out + off_acc, acc, mask=mask_h[:, None] & mask_dv[None, :]) off_meta = (cur_batch * stride_obs + split_k_id * stride_ok + cur_head * stride_oh + head_size_v) From 038773093c000ab9664e699d4e836a13891e25d3 Mon Sep 17 00:00:00 2001 From: grimoire Date: Fri, 6 Sep 2024 18:42:00 +0800 Subject: [PATCH 82/89] fix fill kv cache --- lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py b/lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py index 05ae279c4..69c18f523 100644 --- a/lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py +++ b/lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py @@ -55,6 +55,7 @@ def _fill_kv_cache_kernel( BlockOffsets, num_heads: tl.constexpr, head_dim: tl.constexpr, + head_dim_v: tl.constexpr, stride_kss, stride_ksh, stride_ksd, @@ -122,7 +123,8 @@ def _fill_kv_cache_kernel( if BLOCK_DV > 0: dv_off = tl.arange(0, BLOCK_DV) - maskv = (h_off[:, None] < num_heads) & (dv_off[None, :] < head_dim) + maskv = (h_off[:, None] < num_heads) & (dv_off[None, :] < + head_dim_v) v = tl.load(vs_ptr + sidx * stride_vss + h_off[:, None] * stride_vsh + dv_off[None, :] * stride_vsd, @@ -162,6 +164,7 @@ def fill_kv_cache(k_states: Tensor, v_states: Tensor, k_caches: Tensor, block_offsets, num_heads=num_heads, head_dim=head_dim, + head_dim_v=head_dim_v, stride_kss=k_states.stride(-3), stride_ksh=k_states.stride(-2), stride_ksd=k_states.stride(-1), From 195ed837906f03ecd7dee50cb6340fd9d2a1dc9c Mon Sep 17 00:00:00 2001 From: grimoire Date: Mon, 9 Sep 2024 15:01:39 +0800 Subject: [PATCH 83/89] fix internlm oom --- lmdeploy/pytorch/backends/cuda/graph_runner.py | 4 ++-- lmdeploy/pytorch/engine/logits_process.py | 7 ++++--- lmdeploy/pytorch/engine/model_agent.py | 2 +- lmdeploy/pytorch/models/internlm2.py | 2 +- 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/lmdeploy/pytorch/backends/cuda/graph_runner.py b/lmdeploy/pytorch/backends/cuda/graph_runner.py index 87738bd2e..d0242e933 100644 --- a/lmdeploy/pytorch/backends/cuda/graph_runner.py +++ b/lmdeploy/pytorch/backends/cuda/graph_runner.py @@ -172,7 +172,7 @@ def capture(self, **kwargs): current_stream = torch.cuda.current_stream() # warmup - output = self.model(**padded_kwargs) + self.model(**padded_kwargs) self._graph = torch.cuda.CUDAGraph() with torch.cuda.graph(self._graph, @@ -190,7 +190,7 @@ def forward(self, **kwargs): self._fill_inputs(**kwargs) self._graph.replay() - output = self.output_buffers['logits'][:, :num_tokens].clone() + output = self.output_buffers['logits'][:, :num_tokens] return output def __del__(self): diff --git a/lmdeploy/pytorch/engine/logits_process.py b/lmdeploy/pytorch/engine/logits_process.py index dd8eb49d7..ec0117800 100644 --- a/lmdeploy/pytorch/engine/logits_process.py +++ b/lmdeploy/pytorch/engine/logits_process.py @@ -323,9 +323,10 @@ def __random_sampling(scores: torch.Tensor, indices: torch.LongTensor): if max_topk <= 0: scores, indices = logits.sort(1, descending=True) else: - scores = torch.zeros_like(logits) - indices = torch.zeros_like(logits, dtype=torch.int64) + # scores = torch.zeros_like(logits) + # indices = torch.zeros_like(logits, dtype=torch.int64) topk_scores, topk_indices = logits.topk(max_topk, dim=1) + scores = logits.fill_(0) # (seq, vocalb_size) scores[..., :max_topk] = topk_scores - indices[..., :max_topk] = topk_indices + indices = topk_indices # (seq, max_topk) return __random_sampling(scores, indices) diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 46d17b5c8..5fec0f83c 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -46,7 +46,7 @@ def __get_runtime_size(num_free_gpu_mem: int, cache_block_size: int, runtime_cache_size = 0 while max_prefill_token_num > 0: # lm_head output(2) + to float(4) + estimated misc(1) = 7 - runtime_cache_size = int(max_prefill_token_num * vocal_size * 7) + runtime_cache_size = int(max_prefill_token_num * vocal_size * 10) num_available = (num_free_gpu_mem - runtime_cache_size) * cache_max_entry_count if int(num_available) // cache_block_size >= 16: diff --git a/lmdeploy/pytorch/models/internlm2.py b/lmdeploy/pytorch/models/internlm2.py index e386ef283..00b4a8643 100644 --- a/lmdeploy/pytorch/models/internlm2.py +++ b/lmdeploy/pytorch/models/internlm2.py @@ -352,7 +352,7 @@ def forward( ) logits = self.output(hidden_states) - logits = logits.float() + # logits = logits.float() return logits def get_input_embeddings(self): From 9b5bc4359034d3648a7be929749a505dd05ede09 Mon Sep 17 00:00:00 2001 From: grimoire Date: Mon, 9 Sep 2024 17:05:04 +0800 Subject: [PATCH 84/89] fix llama3 memory usage --- lmdeploy/pytorch/engine/logits_process.py | 2 -- lmdeploy/pytorch/engine/model_agent.py | 2 +- lmdeploy/pytorch/models/internlm2.py | 1 - lmdeploy/pytorch/models/llama.py | 17 ++++++++++++++--- lmdeploy/pytorch/models/mixtral.py | 1 - 5 files changed, 15 insertions(+), 8 deletions(-) diff --git a/lmdeploy/pytorch/engine/logits_process.py b/lmdeploy/pytorch/engine/logits_process.py index ec0117800..fce5be9b5 100644 --- a/lmdeploy/pytorch/engine/logits_process.py +++ b/lmdeploy/pytorch/engine/logits_process.py @@ -323,8 +323,6 @@ def __random_sampling(scores: torch.Tensor, indices: torch.LongTensor): if max_topk <= 0: scores, indices = logits.sort(1, descending=True) else: - # scores = torch.zeros_like(logits) - # indices = torch.zeros_like(logits, dtype=torch.int64) topk_scores, topk_indices = logits.topk(max_topk, dim=1) scores = logits.fill_(0) # (seq, vocalb_size) scores[..., :max_topk] = topk_scores diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 5fec0f83c..46d17b5c8 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -46,7 +46,7 @@ def __get_runtime_size(num_free_gpu_mem: int, cache_block_size: int, runtime_cache_size = 0 while max_prefill_token_num > 0: # lm_head output(2) + to float(4) + estimated misc(1) = 7 - runtime_cache_size = int(max_prefill_token_num * vocal_size * 10) + runtime_cache_size = int(max_prefill_token_num * vocal_size * 7) num_available = (num_free_gpu_mem - runtime_cache_size) * cache_max_entry_count if int(num_available) // cache_block_size >= 16: diff --git a/lmdeploy/pytorch/models/internlm2.py b/lmdeploy/pytorch/models/internlm2.py index 00b4a8643..08c0f85d5 100644 --- a/lmdeploy/pytorch/models/internlm2.py +++ b/lmdeploy/pytorch/models/internlm2.py @@ -352,7 +352,6 @@ def forward( ) logits = self.output(hidden_states) - # logits = logits.float() return logits def get_input_embeddings(self): diff --git a/lmdeploy/pytorch/models/llama.py b/lmdeploy/pytorch/models/llama.py index 6a63e57fe..7952c8899 100644 --- a/lmdeploy/pytorch/models/llama.py +++ b/lmdeploy/pytorch/models/llama.py @@ -320,8 +320,6 @@ def get_input_embeddings(self): class LlamaForCausalLM(nn.Module): """rewrote model of LlamaForCausalLM.""" - support_cuda_graph = True - packed_modules_mapping = { 'qkv_proj': [ 'q_proj', @@ -370,9 +368,22 @@ def forward( ) logits = self.lm_head(hidden_states) - logits = logits.float() return logits + def support_cuda_graph( + self, + input_ids: torch.Tensor, + **kwargs, + ): + """support cudagraph.""" + seq_lens = input_ids.size(1) + if seq_lens <= 512: + return True + + # prevent oom on llama-3 70b + if self.config.num_hidden_layers >= 40: + return False + def get_input_embeddings(self): """get input embeddings.""" return self.model.get_input_embeddings() diff --git a/lmdeploy/pytorch/models/mixtral.py b/lmdeploy/pytorch/models/mixtral.py index 2f98d2a32..18467f2d9 100644 --- a/lmdeploy/pytorch/models/mixtral.py +++ b/lmdeploy/pytorch/models/mixtral.py @@ -332,7 +332,6 @@ def forward( ) logits = self.lm_head(hidden_states) - logits = logits.float() return logits def get_input_embeddings(self): From 3020ada176c3ad0362c44035edd054ef4cd6cc51 Mon Sep 17 00:00:00 2001 From: grimoire Date: Mon, 9 Sep 2024 17:28:19 +0800 Subject: [PATCH 85/89] remove float deepseekv2 --- lmdeploy/pytorch/models/deepseek_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lmdeploy/pytorch/models/deepseek_v2.py b/lmdeploy/pytorch/models/deepseek_v2.py index cceac2828..6a6dd8bf8 100644 --- a/lmdeploy/pytorch/models/deepseek_v2.py +++ b/lmdeploy/pytorch/models/deepseek_v2.py @@ -638,7 +638,7 @@ def forward( ) logits = self.lm_head(hidden_states) - logits = logits.float() + # logits = logits.float() return logits def get_input_embeddings(self): From 331e2c0d567473099bfba699acd4be0c71f69577 Mon Sep 17 00:00:00 2001 From: grimoire Date: Mon, 9 Sep 2024 17:37:25 +0800 Subject: [PATCH 86/89] fix llama3 --- lmdeploy/pytorch/models/llama.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lmdeploy/pytorch/models/llama.py b/lmdeploy/pytorch/models/llama.py index 7952c8899..17af8da54 100644 --- a/lmdeploy/pytorch/models/llama.py +++ b/lmdeploy/pytorch/models/llama.py @@ -384,6 +384,8 @@ def support_cuda_graph( if self.config.num_hidden_layers >= 40: return False + return False + def get_input_embeddings(self): """get input embeddings.""" return self.model.get_input_embeddings() From aa9c722d887415b3b5cb23bc4d8550e4da4f7579 Mon Sep 17 00:00:00 2001 From: grimoire Date: Mon, 9 Sep 2024 17:50:34 +0800 Subject: [PATCH 87/89] update smooth quant flag --- lmdeploy/lite/apis/smooth_quant.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lmdeploy/lite/apis/smooth_quant.py b/lmdeploy/lite/apis/smooth_quant.py index f805fc6ea..100aa5672 100644 --- a/lmdeploy/lite/apis/smooth_quant.py +++ b/lmdeploy/lite/apis/smooth_quant.py @@ -158,6 +158,8 @@ def smooth_quant(model: str, model.save_pretrained(work_dir, max_shard_size='2GB', safe_serialization=False) + model.config.update( + dict(quantization_config=dict(quant_method='smooth_quant'))) tokenizer.save_pretrained(work_dir) shutil.copy(MODEL_PATH_MAP[type(model).__name__], work_dir) From adbc53154b30ff4708ea7fad4d1678393aff3c28 Mon Sep 17 00:00:00 2001 From: grimoire Date: Tue, 10 Sep 2024 10:46:32 +0800 Subject: [PATCH 88/89] fix w8a8 --- lmdeploy/lite/apis/smooth_quant.py | 4 ++-- lmdeploy/pytorch/models/internlm2.py | 13 +++++++++++-- lmdeploy/pytorch/models/mixtral.py | 13 +++++++++++-- 3 files changed, 24 insertions(+), 6 deletions(-) diff --git a/lmdeploy/lite/apis/smooth_quant.py b/lmdeploy/lite/apis/smooth_quant.py index 100aa5672..45684602b 100644 --- a/lmdeploy/lite/apis/smooth_quant.py +++ b/lmdeploy/lite/apis/smooth_quant.py @@ -155,11 +155,11 @@ def smooth_quant(model: str, from .auto_awq import save_vl_model save_vl_model(vl_model, model_path, work_dir) else: + model.config.update( + dict(quantization_config=dict(quant_method='smooth_quant'))) model.save_pretrained(work_dir, max_shard_size='2GB', safe_serialization=False) - model.config.update( - dict(quantization_config=dict(quant_method='smooth_quant'))) tokenizer.save_pretrained(work_dir) shutil.copy(MODEL_PATH_MAP[type(model).__name__], work_dir) diff --git a/lmdeploy/pytorch/models/internlm2.py b/lmdeploy/pytorch/models/internlm2.py index 08c0f85d5..574195c5c 100644 --- a/lmdeploy/pytorch/models/internlm2.py +++ b/lmdeploy/pytorch/models/internlm2.py @@ -307,8 +307,6 @@ def get_input_embeddings(self): class InternLM2ForCausalLM(nn.Module): """rewrote model of InternLM2ForCausalLM.""" - support_cuda_graph = True - packed_modules_mapping = { 'gate_up_proj': [ 'w1', @@ -354,6 +352,17 @@ def forward( logits = self.output(hidden_states) return logits + def support_cuda_graph( + self, + input_ids: torch.Tensor, + **kwargs, + ): + """support cudagraph.""" + seq_lens = input_ids.size(1) + if seq_lens <= 512: + return True + return False + def get_input_embeddings(self): """get input embeddings.""" return self.model.get_input_embeddings() diff --git a/lmdeploy/pytorch/models/mixtral.py b/lmdeploy/pytorch/models/mixtral.py index 18467f2d9..75bad285e 100644 --- a/lmdeploy/pytorch/models/mixtral.py +++ b/lmdeploy/pytorch/models/mixtral.py @@ -296,8 +296,6 @@ def get_input_embeddings(self): class MixtralForCausalLM(nn.Module): """mixture model for causalLM.""" - support_cuda_graph = True - def __init__(self, config: Any, ctx_mgr: StepContextManager, @@ -334,6 +332,17 @@ def forward( logits = self.lm_head(hidden_states) return logits + def support_cuda_graph( + self, + input_ids: torch.Tensor, + **kwargs, + ): + """support cudagraph.""" + seq_lens = input_ids.size(1) + if seq_lens <= 512: + return True + return False + def get_input_embeddings(self): """get input embeddings.""" return self.model.get_input_embeddings() From 1fae365d1baa8842f6c7cdf9191b43ee103690fd Mon Sep 17 00:00:00 2001 From: grimoire Date: Tue, 10 Sep 2024 12:12:39 +0800 Subject: [PATCH 89/89] fix w8a8 tp --- lmdeploy/pytorch/nn/linear.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/lmdeploy/pytorch/nn/linear.py b/lmdeploy/pytorch/nn/linear.py index e0c12a4d8..c1927df62 100644 --- a/lmdeploy/pytorch/nn/linear.py +++ b/lmdeploy/pytorch/nn/linear.py @@ -589,9 +589,12 @@ def _weight_loader_tp_rowwise(self, param: torch.nn.Parameter, if loaded_weight.dim() == 2 and param.dtype == torch.int8: weight = loaded_weight.chunk(world_size, 1)[rank] return default_weight_loader(param, weight) + elif loaded_weight.dim() == 2 and loaded_weight.size(1) == 1: + # scaling + return default_weight_loader(param, loaded_weight) else: # bias - if rank == 0: + if rank != 0: loaded_weight = torch.zeros_like(loaded_weight) return default_weight_loader(param, loaded_weight)