diff --git a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py index bf3b210e8..b3d103534 100755 --- a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py +++ b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py @@ -7,6 +7,8 @@ from lightllm.common.basemodel.triton_kernel.destindex_copy_kv import destindex_copy_kv from typing import Tuple +import os + class TransformerLayerInferTpl(TransformerLayerInfer): """ """ @@ -21,6 +23,7 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode): self.tp_o_head_num_ = -1 self.head_dim_ = -1 self.embed_dim_ = -1 + self.tp_split_ = not os.environ.get("EDP_MODE_ENABLED") == "true" return def _att_norm(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor: @@ -79,7 +82,7 @@ def _context_attention(self, input_embding, infer_state: InferStateInfo, layer_w o = self._context_attention_kernel(q, cache_kv, infer_state, layer_weight) q = None o = self._get_o(o, infer_state, layer_weight) - if self.world_size_ > 1: + if self.world_size_ > 1 and self.tp_split_: dist.all_reduce(o, op=dist.ReduceOp.SUM, async_op=False) input_embding.add_(o.view(-1, self.embed_dim_)) return @@ -88,7 +91,7 @@ def _context_ffn(self, input_embdings, infer_state: InferStateInfo, layer_weight input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) ffn_out = self._ffn(input1, infer_state, layer_weight) input1 = None - if self.world_size_ > 1: + if self.world_size_ > 1 and self.tp_split_: dist.all_reduce(ffn_out, op=dist.ReduceOp.SUM, async_op=False) input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) return @@ -102,7 +105,7 @@ def _token_attention(self, input_embding, infer_state: InferStateInfo, layer_wei o = self._token_attention_kernel(q, infer_state, layer_weight) q = None o = self._get_o(o, infer_state, layer_weight) - if self.world_size_ > 1: + if self.world_size_ > 1 and self.tp_split_: dist.all_reduce(o, op=dist.ReduceOp.SUM, async_op=False) input_embding.add_(o.view(-1, self.embed_dim_)) return @@ -111,7 +114,7 @@ def _token_ffn(self, input_embdings, infer_state: InferStateInfo, layer_weight): input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) ffn_out = self._ffn(input1, infer_state, layer_weight) input1 = None - if self.world_size_ > 1: + if self.world_size_ > 1 and self.tp_split_: dist.all_reduce(ffn_out, op=dist.ReduceOp.SUM, async_op=False) input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) return @@ -125,7 +128,7 @@ def _splitfuse_attention(self, input_embding, infer_state: SplitFuseInferStateIn o = self._splitfuse_attention_kernel(q, infer_state, layer_weight) q = None o = self._get_o(o, infer_state, layer_weight) - if self.world_size_ > 1: + if self.world_size_ > 1 and self.tp_split_: dist.all_reduce(o, op=dist.ReduceOp.SUM, async_op=False) input_embding.add_(o.view(-1, self.embed_dim_)) return @@ -134,7 +137,7 @@ def _splitfuse_ffn(self, input_embdings, infer_state: SplitFuseInferStateInfo, l input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) ffn_out = self._ffn(input1, infer_state, layer_weight) input1 = None - if self.world_size_ > 1: + if self.world_size_ > 1 and self.tp_split_: dist.all_reduce(ffn_out, op=dist.ReduceOp.SUM, async_op=False) input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) return diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py index c6b1ab500..973de48da 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py @@ -11,6 +11,10 @@ MultiCOLMMWeight, ROWBMMWeight, COLBMMWeight, + MultiCOLMMWeightNoTp, + ROWBMMWeightNoTp, + COLBMMWeightNoTp, + COLMMWeightNoTp, ) from .norm_weight import NormWeight, GEMMANormWeight, TpNormWeight from .fused_moe_weight import FusedMoeWeight diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight.py index ba6d7d028..5526eb383 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight.py @@ -153,7 +153,7 @@ def _load_hf_weights_etp(self, weights): self.expert_down_proj_etp[i_experts_ep, :] = self.experts_up_projs[i_experts_ep] def load_hf_weights(self, weights): - if os.environ.get("ETP_MODE_ENABLED") == "true": + if os.environ.get("ETP_MODE_ENABLED") == "true" or os.environ.get("EDP_MODE_ENABLED") == "true": self._load_hf_weights_etp(weights) else: for i_experts in range(self.n_routed_experts): @@ -184,7 +184,7 @@ def _cuda(self, cpu_tensor): return cpu_tensor.contiguous().to(self.data_type_).cuda(self.tp_rank_) def verify_load(self): - if os.environ.get("ETP_MODE_ENABLED") == "true": + if os.environ.get("ETP_MODE_ENABLED") == "true" or os.environ.get("EDP_MODE_ENABLED") == "true": return True else: return self.w1 is not None and self.w2 is not None diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py index c91dd8acb..dd812304d 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py @@ -319,3 +319,73 @@ def __init__( def _post_load_weights(self): self.weight = self.weight.transpose(0, 1).cuda(self.tp_rank_) + + +class COLMMWeightNoTp(MMWeight): + def __init__(self, weight_name, data_type, split_n_embed, bias_name=None): + super().__init__(weight_name, data_type, split_n_embed, bias_name) + self.start = 0 + self.end = split_n_embed + + def load_hf_weights(self, weights): + weight = None + if self.weight_name in weights: + weight = weights[self.weight_name].to(self.data_type_) + self.weight = weight[:, self.start : self.end] + if self.bias_name in weights: + bias = weights[self.bias_name] + self.bias = bias.to(self.data_type_).cuda(self.tp_rank_) + if weight is None: + return + self._post_load_weights() + return + + +class MultiCOLMMWeightNoTp(MultiROWMMWeightNoTP): + def __init__(self, weight_names, data_type, split_n_embed, bias_names=[]): + super().__init__(weight_names, data_type, split_n_embed, bias_names) + + def load_hf_weights(self, weights): + weight = None + for i in range(len(self.weight_names)): + if self.weight_names[i] in weights: + weight = weights[self.weight_names[i]].to(self.data_type_) + self.weights[i] = weight[:, self.starts[i] : self.ends[i]] + if self.has_bias and self.bias_names[i] in weights: + bias = weights[self.bias_names[i]].to(self.data_type_) + self.biases[i] = bias[:, self.starts[i] : self.ends[i]] + self._fuse() + return + + +class ROWBMMWeightNoTp(BMMWeight): + load_hf_weights = ROWMMWeight.load_hf_weights + + def __init__( + self, + weight_name, + data_type, + split_n_embed, + bias_name=None, + ): + super().__init__(weight_name, data_type, split_n_embed, bias_name) + self.start = 0 + self.end = split_n_embed + + +class COLBMMWeightNoTp(BMMWeight): + load_hf_weights = COLMMWeightNoTp.load_hf_weights + + def __init__( + self, + weight_name, + data_type, + split_n_embed, + bias_name=None, + ): + super().__init__(weight_name, data_type, split_n_embed, bias_name) + self.start = 0 + self.end = split_n_embed + + def _post_load_weights(self): + self.weight = self.weight.transpose(0, 1).cuda(self.tp_rank_) diff --git a/lightllm/common/deepseek2_mem_manager.py b/lightllm/common/deepseek2_mem_manager.py index 8df5a61e2..3dab29851 100644 --- a/lightllm/common/deepseek2_mem_manager.py +++ b/lightllm/common/deepseek2_mem_manager.py @@ -14,9 +14,23 @@ def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): self.kv_buffer = torch.empty((layer_num, size, head_num, head_dim), dtype=dtype, device="cuda") # todo, etp or edp use the same work buffer here # also it can be used for any kernels for work buffer witout save info only - if os.environ.get("ETP_MODE_ENABLED") == "true": + if os.environ.get("ETP_MODE_ENABLED") == "true" or os.environ.get("EDP_MODE_ENABLED") == "true": self.work_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.bfloat16, device="cuda") self.work_buffer.share_memory_() + import lightllm_moe_etp_kernel + import torch.distributed as dist + + rank_id = dist.get_rank() + world_size = dist.get_world_size() + + # lightllm_moe_etp_kernel.enableP2P(world_size, rank_id) + + handle = lightllm_moe_etp_kernel.get_handle(self.work_buffer.contiguous(), rank_id) + handles = [None] * world_size + dist.all_gather_object(handles, handle) + self.handles_work_buffer = handles + + lightllm_moe_etp_kernel.init_system(world_size, rank_id, self.work_buffer.contiguous(), handles) def alloc_kv_move_buffer(self, max_req_total_len): self.kv_move_buffer = torch.empty( diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index 4b686c602..67da6576b 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -69,6 +69,8 @@ def __init__( self.enable_opt_decoding_mha = os.getenv("ENABLE_OPT_DECODE_MHA", "False").upper() in ["ON", "TRUE", "1"] self.mla_type = "ACCM" + self.tp_split_ = not os.environ.get("EDP_MODE_ENABLED") == "true" + return def _bind_attention(self): @@ -78,8 +80,8 @@ def _bind_attention(self): ) self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_normal, self) if self.is_moe: - if os.environ.get("ETP_MODE_ENABLED") == "true": - self._ffn = partial(Deepseek2TransformerLayerInfer._moe_ffn_etp, self) + if os.environ.get("ETP_MODE_ENABLED") == "true" or os.environ.get("EDP_MODE_ENABLED") == "true": + self._ffn = partial(Deepseek2TransformerLayerInfer._moe_ffn_etp_edp, self) else: self._ffn = partial(Deepseek2TransformerLayerInfer._moe_ffn, self) else: @@ -155,7 +157,7 @@ def _CC_method( ): num_local_heads = self.num_heads num_local_kv_heads = self.num_kv_heads - if self.world_size_ > 1: + if self.world_size_ > 1 and self.tp_split_: num_local_heads //= self.world_size_ num_local_kv_heads //= self.world_size_ if infer_state.use_dynamic_prompt_cache: @@ -187,7 +189,7 @@ def _ACC_method( q_nope, q_rope = q num_local_heads = self.num_heads num_local_kv_heads = self.num_kv_heads - if self.world_size_ > 1: + if self.world_size_ > 1 and self.tp_split_: num_local_heads //= self.world_size_ num_local_kv_heads //= self.world_size_ # ACC @@ -275,6 +277,10 @@ def _context_attention_kernel_origin( self, q: Tuple[torch.Tensor, torch.Tensor], kv, infer_state: Deepseek2InferStateInfo, layer_weight, out=None ) -> torch.Tensor: q_nope, q_rope = q + + # not support edp yet + # assert self.tp_split_ == True + o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) if out is None else out if infer_state.use_dynamic_prompt_cache: @@ -440,7 +446,7 @@ def _splitfuse_attention_kernel_with_CC( torch.cuda.default_stream().wait_event(infer_state.end_event) return o_tensor - def _moe_ffn_etp( + def _moe_ffn_etp_edp( self, input, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight ) -> torch.Tensor: world_size_ = self.world_size_ @@ -460,17 +466,25 @@ def _moe_ffn_etp( final_hidden_states = torch.empty( num_tokens, hidden_dim, device=hidden_states.device, dtype=hidden_states.dtype ) - - # router_logits_len = hidden_states.shape[0]*layer_weight.moe_gate.shape[1] - router_logits = layer_weight.moe_gate.mm(hidden_states) + # now some parameter is not supported yet # assert gating_normalize_prob is False # assert num_expert_groups<=1 - - import lightllm_moe_etp_kernel - - lightllm_moe_etp_kernel.moe_fused_all( + is_etp = True + if os.environ.get("ETP_MODE_ENABLED") == "true": + router_logits = layer_weight.moe_gate.mm(hidden_states) + elif os.environ.get("EDP_MODE_ENABLED") == "true": + router_logits = infer_state.mem_manager.work_buffer[ -(num_tokens*num_experts_per_token+hidden_states.nelement()):-hidden_states.nelement()].view( num_tokens ,num_experts_per_token) + router_logits = layer_weight.moe_gate.mm(hidden_states,out=router_logits) + is_etp = False + + #print(" hid state addr ", infer_state.mem_manager.work_buffer.data_ptr(), + # hidden_states.data_ptr(), + # hidden_states.shape() + # ) + + moe_fused_all( router_logits.contiguous(), hidden_states.contiguous(), layer_weight.gate_up_proj.weight.contiguous(), # transpose @@ -490,8 +504,10 @@ def _moe_ffn_etp( layer_weight.gate_up_proj.weight.size(1) // 2, layer_weight.experts.expert_gate_up_proj_etp.size(1) // 2, self.n_shared_experts is not None, + is_etp ) - router_logits = None + if os.environ.get("ETP_MODE_ENABLED") == "true": + router_logits = None return final_hidden_states.view(num_tokens, hidden_dim) diff --git a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py index 995ad1f11..b7024ec1b 100644 --- a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py @@ -6,15 +6,22 @@ ROWMMWeight, ROWMMWeightNoTP, MultiROWMMWeight, + MultiROWMMWeightNoTP, COLMMWeight, + COLMMWeightNoTp, MultiCOLMMWeight, + MultiCOLMMWeightNoTp, NormWeight, FusedMoeWeight, ROWBMMWeight, + ROWBMMWeightNoTp, COLBMMWeight, + COLBMMWeightNoTp, ) from functools import partial +import os + def fuse_q_kb(self, layer_weight): if not (self.weight is None and all(w is not None for w in self.weights)): @@ -74,6 +81,8 @@ def __init__( ): self.disable_qk_absorb = disable_qk_absorb self.disable_vo_absorb = disable_vo_absorb + self.tp_split_ = not os.environ.get("EDP_MODE_ENABLED") == "true" + super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode, quant_cfg) # mla_type = "ACCM", "MIX" # MIX是prefilled CC,decoding ACC @@ -89,7 +98,9 @@ def _parse_config(self): and self.layer_num_ >= self.network_config_["first_k_dense_replace"] and self.layer_num_ % self.network_config_["moe_layer_freq"] == 0 ) - self.tp_q_head_num_ = self.network_config_["num_attention_heads"] // self.world_size_ + self.tp_q_head_num_ = self.network_config_["num_attention_heads"] + if self.tp_split_: + self.tp_q_head_num_ //= self.world_size_ self.n_routed_experts = self.network_config_["n_routed_experts"] self.q_lora_rank = self.network_config_["q_lora_rank"] self.qk_nope_head_dim = self.network_config_["qk_nope_head_dim"] @@ -104,7 +115,10 @@ def _init_weight_names(self): self.rope_weight_name = f"model.layers.{self.layer_num_}.self_attn.q_b_proj.weight" def _init_weight(self): - self._init_qkvo() + if self.tp_split_: + self._init_qkvo() + else: + self._init_qkvo_dp() if self.is_moe: self._init_moe() else: @@ -112,12 +126,13 @@ def _init_weight(self): self._init_norm() def _load_q_rope(self, q_weight_): - q_split_n_embed_with_rope = ( - (self.qk_nope_head_dim + self.qk_rope_head_dim) * self.num_attention_heads // self.world_size_ - ) - q_weight_ = q_weight_[ - q_split_n_embed_with_rope * self.tp_rank_ : q_split_n_embed_with_rope * (self.tp_rank_ + 1), : - ] + if self.tp_split_: + q_split_n_embed_with_rope = ( + (self.qk_nope_head_dim + self.qk_rope_head_dim) * self.num_attention_heads // self.world_size_ + ) + q_weight_ = q_weight_[ + q_split_n_embed_with_rope * self.tp_rank_ : q_split_n_embed_with_rope * (self.tp_rank_ + 1), : + ] q_weight_ = q_weight_.transpose(0, 1).contiguous() q_nope_proj_, q_rope_proj_ = torch.split( q_weight_.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim), @@ -239,11 +254,100 @@ def _init_qkvo(self): q_split_n_embed, ) - def _load_mlp(self, mlp_prefix, split_inter_size): - self.gate_up_proj = MultiROWMMWeight( - [f"{mlp_prefix}.gate_proj.weight", f"{mlp_prefix}.up_proj.weight"], self.data_type_, split_inter_size + def _init_qkvo_dp(self): + q_split_n_embed = self.qk_nope_head_dim * self.tp_q_head_num_ + q_split_n_embed_with_rope = (self.qk_nope_head_dim + self.qk_rope_head_dim) * self.num_attention_heads + if self.q_lora_rank is None: + if not self.disable_qk_absorb: # acc + self.fuse_qk_weight_ = MultiROWMMWeightNoTP( + [ + f"model.layers.{self.layer_num_}.self_attn.q_proj.weight", + f"model.layers.{self.layer_num_}.self_attn.k_b_proj.weight", + ], + self.data_type_, + [q_split_n_embed_with_rope, self.tp_q_head_num_], + ) + self.fuse_qk_weight_._fuse = partial(fuse_q_kb, self.fuse_qk_weight_, self) + else: # cc + self.q_weight_ = ROWMMWeightNoTP( + f"model.layers.{self.layer_num_}.self_attn.q_proj.weight", + self.data_type_, + q_split_n_embed_with_rope, + ) + else: + self.q_a_proj_ = ROWMMWeightNoTP( + f"model.layers.{self.layer_num_}.self_attn.q_a_proj.weight", + self.data_type_, + self.q_lora_rank, + ) + if not self.disable_qk_absorb: + self.fuse_qk_weight_ = MultiROWMMWeightNoTP( + [ + f"model.layers.{self.layer_num_}.self_attn.q_b_proj.weight", + f"model.layers.{self.layer_num_}.self_attn.k_b_proj.weight", + ], + self.data_type_, + [q_split_n_embed_with_rope, self.tp_q_head_num_], + ) + self.fuse_qk_weight_._fuse = partial(fuse_q_kb, self.fuse_qk_weight_, self) + else: + self.q_b_proj_ = ROWMMWeightNoTP( + f"model.layers.{self.layer_num_}.self_attn.q_b_proj.weight", + self.data_type_, + q_split_n_embed_with_rope, + ) + + self.q_rope_proj_ = ROWMMWeightNoTP( + f"model.layers.{self.layer_num_}.self_attn.q_rope_proj.weight", + self.data_type_, + self.qk_rope_head_dim * self.tp_q_head_num_, ) - self.down_proj = COLMMWeight(f"{mlp_prefix}.down_proj.weight", self.data_type_, split_inter_size) + + self.kv_a_proj_with_mqa_ = ROWMMWeightNoTP( + f"model.layers.{self.layer_num_}.self_attn.kv_a_proj_with_mqa.weight", + self.data_type_, + self.kv_lora_rank + self.qk_rope_head_dim, + ) + if self.disable_qk_absorb: + self.k_b_proj_ = ROWBMMWeightNoTp( + f"model.layers.{self.layer_num_}.self_attn.k_b_proj.weight", + self.data_type_, + split_n_embed=self.tp_q_head_num_, + ) + if not self.disable_vo_absorb: + self.fuse_vo_weight_ = MultiCOLMMWeightNoTp( + [ + f"model.layers.{self.layer_num_}.self_attn.v_b_proj.weight", + f"model.layers.{self.layer_num_}.self_attn.o_proj.weight", + ], + self.data_type_, + [self.tp_q_head_num_, q_split_n_embed], + ) + self.fuse_vo_weight_._fuse = partial(fuse_vb_o, self.fuse_vo_weight_, self) + else: + self.v_b_proj_ = COLBMMWeightNoTp( + f"model.layers.{self.layer_num_}.self_attn.v_b_proj.weight", + self.data_type_, + split_n_embed=self.tp_q_head_num_, + ) + if self.disable_vo_absorb: + self.o_weight_ = COLMMWeightNoTp( + f"model.layers.{self.layer_num_}.self_attn.o_proj.weight", + self.data_type_, + q_split_n_embed, + ) + + def _load_mlp(self, mlp_prefix, split_inter_size): + if self.tp_split_: + self.gate_up_proj = MultiROWMMWeight( + [f"{mlp_prefix}.gate_proj.weight", f"{mlp_prefix}.up_proj.weight"], self.data_type_, split_inter_size + ) + self.down_proj = COLMMWeight(f"{mlp_prefix}.down_proj.weight", self.data_type_, split_inter_size) + else: + self.gate_up_proj = MultiROWMMWeightNoTP( + [f"{mlp_prefix}.gate_proj.weight", f"{mlp_prefix}.up_proj.weight"], self.data_type_, split_inter_size + ) + self.down_proj = COLMMWeightNoTp(f"{mlp_prefix}.down_proj.weight", self.data_type_, split_inter_size) def _init_moe(self): moe_intermediate_size = self.network_config_["moe_intermediate_size"] @@ -251,8 +355,9 @@ def _init_moe(self): f"model.layers.{self.layer_num_}.mlp.gate.weight", self.data_type_, moe_intermediate_size ) shared_intermediate_size = moe_intermediate_size * self.network_config_["n_shared_experts"] - shared_split_inter_size = shared_intermediate_size // self.world_size_ - self._load_mlp(f"model.layers.{self.layer_num_}.mlp.shared_experts", shared_split_inter_size) + + num_shards = self.world_size_ if self.tp_split_ else 1 + self._load_mlp(f"model.layers.{self.layer_num_}.mlp.shared_experts", shared_intermediate_size // num_shards) self.experts = FusedMoeWeight( gate_proj_name="gate_proj", @@ -266,8 +371,8 @@ def _init_moe(self): def _init_ffn(self): inter_size = self.network_config_["intermediate_size"] - split_inter_size = inter_size // self.world_size_ - self._load_mlp(f"model.layers.{self.layer_num_}.mlp", split_inter_size) + num_shards = self.world_size_ if self.tp_split_ else 1 + self._load_mlp(f"model.layers.{self.layer_num_}.mlp", inter_size // num_shards) def _init_norm(self): self.att_norm_weight_ = NormWeight(f"model.layers.{self.layer_num_}.input_layernorm.weight", self.data_type_) diff --git a/lightllm/models/llama/layer_infer/post_layer_infer.py b/lightllm/models/llama/layer_infer/post_layer_infer.py index a642a0fe0..f31761788 100644 --- a/lightllm/models/llama/layer_infer/post_layer_infer.py +++ b/lightllm/models/llama/layer_infer/post_layer_infer.py @@ -12,6 +12,8 @@ from lightllm.common.basemodel import PostLayerInferTpl from lightllm.utils.infer_utils import mark_cost_time +import os + class LlamaPostLayerInfer(PostLayerInferTpl): """ """ @@ -21,6 +23,8 @@ def __init__(self, tp_rank, world_size, network_config, mode): self.eps_ = network_config["rms_norm_eps"] self.vocab_size_ = network_config["vocab_size"] self.embed_dim_ = network_config["n_embed"] + self.tp_split_ = not os.environ.get("EDP_MODE_ENABLED") == "true" + return def _norm(self, input, infer_state, layer_weight: LlamaPreAndPostLayerWeight) -> torch.Tensor: @@ -89,7 +93,7 @@ def token_forward(self, input_embdings, infer_state: LlamaInferStateInfo, layer_ torch.mm(layer_weight.lm_head_weight_, last_input, out=logic_batch) last_input = None - if self.world_size_ == 1: + if self.world_size_ == 1 or not self.tp_split_: gather_data = logic_batch else: gather_data = self.alloc_tensor((self.vocab_size_, token_num), dtype=input_embdings_dtype) diff --git a/lightllm/models/llama/layer_infer/pre_layer_infer.py b/lightllm/models/llama/layer_infer/pre_layer_infer.py index f60fa6127..cee3c9b1a 100644 --- a/lightllm/models/llama/layer_infer/pre_layer_infer.py +++ b/lightllm/models/llama/layer_infer/pre_layer_infer.py @@ -9,14 +9,22 @@ from lightllm.utils.infer_utils import mark_cost_time from lightllm.models.llama.triton_kernel.embedding import embedding +import os + class LlamaPreLayerInfer(PreLayerInferTpl): """ """ def __init__(self, tp_rank, world_size, network_config, mode): super().__init__(tp_rank, world_size, network_config, mode) - tp_vob_ids = np.linspace(0, network_config["vocab_size"], self.world_size_ + 1, dtype=np.int64) - self.vob_start_id_, self.vob_end_id_ = int(tp_vob_ids[self.tp_rank_]), int(tp_vob_ids[self.tp_rank_ + 1]) + self.tp_split_ = not os.environ.get("EDP_MODE_ENABLED") == "true" + + if self.tp_split_: + tp_vob_ids = np.linspace(0, network_config["vocab_size"], self.world_size_ + 1, dtype=np.int64) + self.vob_start_id_, self.vob_end_id_ = int(tp_vob_ids[self.tp_rank_]), int(tp_vob_ids[self.tp_rank_ + 1]) + else: + self.vob_start_id_, self.vob_end_id_ = 0, network_config["vocab_size"] + return def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): @@ -24,7 +32,7 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei (input_ids.shape[0], layer_weight.wte_weight_.shape[1]), dtype=layer_weight.data_type_ ) embedding(input_ids, layer_weight.wte_weight_, self.vob_start_id_, self.vob_end_id_, input_embdings) - if self.world_size_ > 1: + if self.world_size_ > 1 and self.tp_split_: dist.all_reduce(input_embdings, op=dist.ReduceOp.SUM, async_op=False) return input_embdings @@ -33,7 +41,7 @@ def token_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weigh (input_ids.shape[0], layer_weight.wte_weight_.shape[1]), dtype=layer_weight.data_type_ ) embedding(input_ids, layer_weight.wte_weight_, self.vob_start_id_, self.vob_end_id_, input_embdings) - if self.world_size_ > 1: + if self.world_size_ > 1 and self.tp_split_: dist.all_reduce(input_embdings, op=dist.ReduceOp.SUM, async_op=False) return input_embdings diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index bc8ab44fb..4ec2fd00f 100755 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -25,6 +25,8 @@ from lightllm.common.basemodel import TransformerLayerInferTpl from lightllm.models.llama.triton_kernel.ppl_quant_copy_kv import destindex_copy_dequantize_kv +import os + class LlamaTransformerLayerInfer(TransformerLayerInferTpl): """ """ @@ -32,9 +34,15 @@ class LlamaTransformerLayerInfer(TransformerLayerInferTpl): def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]): super().__init__(layer_num, tp_rank, world_size, network_config, mode) self.eps_ = network_config["rms_norm_eps"] - self.tp_q_head_num_ = network_config["num_attention_heads"] // self.world_size_ - self.tp_k_head_num_ = network_config["num_key_value_heads"] // self.world_size_ - self.tp_v_head_num_ = network_config["num_key_value_heads"] // self.world_size_ + + self.tp_q_head_num_ = network_config["num_attention_heads"] + self.tp_k_head_num_ = network_config["num_key_value_heads"] + self.tp_v_head_num_ = network_config["num_key_value_heads"] + if not os.environ.get("EDP_MODE_ENABLED") == "true": + self.tp_q_head_num_ //= world_size + self.tp_k_head_num_ //= world_size + self.tp_v_head_num_ //= world_size + self.tp_o_head_num_ = self.tp_q_head_num_ self.head_dim_ = network_config["hidden_size"] // network_config["num_attention_heads"] self.embed_dim_ = network_config["hidden_size"] @@ -119,7 +127,11 @@ def _att_norm( def _ffn_norm( self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight ) -> torch.Tensor: - out = self.alloc_tensor(input.shape, input.dtype) + if not os.environ.get("EDP_MODE_ENABLED") == "true": + out = self.alloc_tensor(input.shape, input.dtype) + else: + num_ele = input.nelement() + out = self.infer_state.mem_manager.work_buffer[ -num_ele: ].view(input.shape) rmsnorm_forward(input, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_, out=out) return out diff --git a/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py index 25e9bd10c..06b2277a1 100644 --- a/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py @@ -2,17 +2,26 @@ import numpy as np from lightllm.common.basemodel import PreAndPostLayerWeight +import os + class LlamaPreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, tp_rank, world_size, data_type, network_config, mode): super().__init__(tp_rank, world_size, data_type, network_config, mode) + self.tp_split_ = not os.environ.get("EDP_MODE_ENABLED") == "true" return def load_hf_weights(self, weights): vob_size = self.network_config_["vocab_size"] - split_indexes = np.linspace(0, vob_size, self.world_size_ + 1, dtype=np.int64) - split_start = split_indexes[self.tp_rank_] - split_end = split_indexes[self.tp_rank_ + 1] + + if self.tp_split_: + split_indexes = np.linspace(0, vob_size, self.world_size_ + 1, dtype=np.int64) + split_start = split_indexes[self.tp_rank_] + split_end = split_indexes[self.tp_rank_ + 1] + else: + split_start = 0 + split_end = vob_size + if "model.embed_tokens.weight" in weights: self.wte_weight_ = self._cuda(weights["model.embed_tokens.weight"][split_start:split_end, :]) tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False)