diff --git a/lightllm/common/kv_cache_mem_manager/mem_manager.py b/lightllm/common/kv_cache_mem_manager/mem_manager.py index d8fd93009..b599bedfc 100755 --- a/lightllm/common/kv_cache_mem_manager/mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/mem_manager.py @@ -28,7 +28,7 @@ class MemoryManager: def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): self.size = size self.head_num = head_num - self.head_dim = head_dim + self.head_dim = head_dim * 2 # neo kv 是[k, k_h, k_w]拼在一起的 self.layer_num = layer_num self.always_copy = always_copy self.dtype = dtype @@ -60,7 +60,7 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False self.size, dtype, head_num, - head_dim, + self.head_dim, layer_num, ) self.HOLD_TOKEN_MEMINDEX = self.size diff --git a/lightllm/models/__init__.py b/lightllm/models/__init__.py index 4ee02f003..9a51d9512 100644 --- a/lightllm/models/__init__.py +++ b/lightllm/models/__init__.py @@ -38,4 +38,6 @@ Tarsier2LlamaTpPartModel, ) from lightllm.models.gpt_oss.model import GptOssTpPartModel +from lightllm.models.neo_chat_moe.model import NeoTpMOEPartModel +from lightllm.models.neo_chat.model import NeoTpPartModel from .registry import get_model, get_model_class diff --git a/lightllm/models/llama/model.py b/lightllm/models/llama/model.py index a228e0025..36b5d79b5 100644 --- a/lightllm/models/llama/model.py +++ b/lightllm/models/llama/model.py @@ -110,6 +110,8 @@ def _init_custom(self): rope_scaling = self.config.get("rope_scaling", None) if rope_scaling is None: self._init_to_get_rotary() + if "rope_theta_hw" in self.config: + self._init_to_get_hw_rotary() return if "rope_type" in rope_scaling: @@ -132,6 +134,8 @@ def _init_custom(self): self._init_to_get_rotary() else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + if "rope_theta_hw" in self.config: + self._init_to_get_hw_rotary() return def _init_weights(self): @@ -178,7 +182,7 @@ def _init_to_get_rotary(self, default_base=10000): rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0) base = self.config.get("rope_theta", float(default_base)) - + print(f"base is {base}") if "max_sequence_length" in self.config: max_seq_len = self.config["max_sequence_length"] else: @@ -211,6 +215,47 @@ def _init_to_get_rotary(self, default_base=10000): self._sin_cached = torch.sin(freqs).to(self.data_type).cuda() return + def _init_to_get_hw_rotary(self, default_base=10000): + partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_ // 2) + if self.config.get("rope_scaling", {}) is None: + rope_scaling_factor = 1.0 + else: + rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0) + + base = self.config.get("rope_theta_hw", float(default_base)) + print(f"hw_base is {base}") + if "max_sequence_length" in self.config: + max_seq_len = self.config["max_sequence_length"] + else: + max_position_embeddings = self.config.get( + "max_position_embeddings_hw", 2048 if base <= 10000.0 + 1e-5 else 16384 + ) + max_seq_len = max_position_embeddings * rope_scaling_factor + + # NTK + try: + ntk_alpha = float(os.environ.get("LIGHTLLM_NTK_ALPHA", 1)) + assert ntk_alpha >= 1 + if ntk_alpha > 1: + logger.info(f"Note: NTK enabled, alpha set to {ntk_alpha}") + max_seq_len *= ntk_alpha + base = base * (ntk_alpha ** (partial_head_dim / (partial_head_dim - 2))) # Base change formula + except: + pass + + inv_freq = 1.0 / ( + base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim) + ) + t = ( + torch.arange(max(max_seq_len + 1024 * 128, self.max_seq_length), device="cpu", dtype=torch.float32) + / rope_scaling_factor + ) + freqs = torch.outer(t, inv_freq) + + self._hw_cos_cached = torch.cos(freqs).to(self.data_type).cuda() + self._hw_sin_cached = torch.sin(freqs).to(self.data_type).cuda() + return + def _init_to_get_dynamic_ntk_rotary(self): partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_) max_position_embeddings = self.config.get("max_position_embeddings", 2048) diff --git a/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py b/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py index eb5af6fec..02bd277ad 100644 --- a/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py +++ b/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py @@ -74,7 +74,8 @@ def token_att_fwd(q, k, att_out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen Lq, Lk = q.shape[-1], k.shape[-1] assert Lq == Lk assert Lk in {16, 32, 64, 128, 256} - sm_scale = 1.0 / (Lk ** 0.5) + Lk_scale = Lk // 2 + sm_scale = 1.0 / (Lk_scale ** 0.5) batch, head_num = B_req_idx.shape[0], q.shape[1] diff --git a/lightllm/models/neo_chat/__init__.py b/lightllm/models/neo_chat/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/neo_chat/layer_infer/__init__.py b/lightllm/models/neo_chat/layer_infer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py new file mode 100644 index 000000000..1cf13c413 --- /dev/null +++ b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py @@ -0,0 +1,159 @@ +import torch +from functools import partial +from typing import Tuple +from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward +from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd +from lightllm.models.neo_chat_moe.infer_struct import NeoChatInferStateInfo +from lightllm.models.neo_chat_moe.triton_kernel.context_attention_fwd_neo import context_attention_fwd_neo +from lightllm.models.llama.triton_kernel.context_flashattention_nopad import context_attention_fwd +from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd +from lightllm.models.qwen3.layer_infer.transformer_layer_infer import Qwen3TransformerLayerInfer +from lightllm.models.neo_chat.layer_weights.transformer_layer_weight import NeoChatTransformerLayerWeight +from lightllm.distributed import all_reduce +import torch.distributed as dist +from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer +from lightllm.models.qwen3.triton_kernel.qk_norm import qk_rmsnorm_forward + + +class NeoChatTransformerLayerInfer(Qwen3TransformerLayerInfer): + def __init__(self, data_type, network_config, mode): + super().__init__(data_type, network_config, mode) + return + + def _bind_attention(self): + self._context_attention_kernel = self._context_attention_kernel + self._token_attention_kernel = self._token_decode_attention_normal + self._copy_kv_to_mem_cache = self._copy_kv_to_mem_cache_normal + return + + def _get_qkv(self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoChatTransformerLayerWeight): + input = input.view(-1, self.embed_dim_) + q = layer_weight.q_proj.mm(input) # [T, Hq*D] + + q_hw = layer_weight.q_hw_proj.mm(input) + q_hw = q_hw.view(-1, self.tp_q_head_num_, self.head_dim_) + q_h, q_w = q_hw.chunk(2, dim=-1) + + k_hw = layer_weight.k_hw_proj.mm(input) + k_hw = k_hw.view(-1, self.tp_k_head_num_, self.head_dim_) + k_h, k_w = k_hw.chunk(2, dim=-1) + + cache_kv = layer_weight.kv_proj.mm(input) # [T, (Hk+Hv)*D] + + qk_rmsnorm_forward(q, weight=layer_weight.q_norm_weight_.weight, eps=self.eps_) + + q_h_2d = q_h.reshape(q.shape[0], -1) + q_w_2d = q_w.reshape(q.shape[0], -1) + qk_rmsnorm_forward(q_h_2d, weight=layer_weight.q_norm_h_weight_.weight, eps=self.eps_) + qk_rmsnorm_forward(q_w_2d, weight=layer_weight.q_norm_w_weight_.weight, eps=self.eps_) + q_h = q_h_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2) + q_w = q_w_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2) + + qk_rmsnorm_forward( + cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], + weight=layer_weight.k_norm_weight_.weight, + eps=self.eps_, + ) + + k_h_2d = k_h.reshape(q.shape[0], -1) # [T, Hk*(D/2)] + k_w_2d = k_w.reshape(q.shape[0], -1) + qk_rmsnorm_forward(k_h_2d, weight=layer_weight.k_norm_h_weight_.weight, eps=self.eps_) + qk_rmsnorm_forward(k_w_2d, weight=layer_weight.k_norm_w_weight_.weight, eps=self.eps_) + k_h = k_h_2d.view(q.shape[0], self.tp_k_head_num_, self.head_dim_ // 2) + k_w = k_w_2d.view(q.shape[0], self.tp_k_head_num_, self.head_dim_ // 2) + + cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) + + rotary_emb_fwd( + q.view(-1, self.tp_q_head_num_, self.head_dim_), + cache_kv[:, : self.tp_k_head_num_, :], + infer_state.position_cos, + infer_state.position_sin, + ) + rotary_emb_fwd( + q_h, + k_h, + infer_state.position_cos_h, + infer_state.position_sin_h, + ) + rotary_emb_fwd( + q_w, + k_w, + infer_state.position_cos_w, + infer_state.position_sin_w, + ) + + q3 = q.view(-1, self.tp_q_head_num_, self.head_dim_) + q3 = torch.cat([q3, q_h, q_w], dim=-1) + q = q3.reshape(q3.shape[0], -1) + + k = cache_kv[:, : self.tp_k_head_num_, :] + k = torch.cat([k, k_h, k_w], dim=-1) + + v = cache_kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :] + v_pad = torch.zeros((v.shape[0], v.shape[1], self.head_dim_), device=v.device, dtype=v.dtype) + v = torch.cat([v, v_pad], dim=-1) + + cache_kv = torch.cat([k, v], dim=1) + return q, cache_kv + + def _context_attention_kernel( + self, q, kv, infer_state: NeoChatInferStateInfo, layer_weight, out=None + ) -> torch.Tensor: + o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out + kv = infer_state.mem_manager.kv_buffer[self.layer_num_] + context_attention_fwd_neo( + q.view(-1, self.tp_q_head_num_, self.head_dim_ * 2), + kv[:, 0 : self.tp_k_head_num_, :], + kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :], + o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2), + infer_state.position_ids[0], # [0,0,1,2,3,3,3,4] + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.b_ready_cache_len, + infer_state.max_len_in_batch, + infer_state.req_manager.req_to_token_indexs, + ) + o3 = o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2) + o3 = o3[:, :, : self.head_dim_].contiguous() + return o3.view(o3.shape[0], -1) + + def _token_decode_attention_normal(self, q, infer_state: NeoChatInferStateInfo, layer_weight, out=None): + total_token_num = infer_state.total_token_num + batch_size = infer_state.batch_size + + q_3d = q.view(batch_size, self.tp_q_head_num_, self.head_dim_ * 2) + + att_m_tensor = self.alloc_tensor((self.tp_q_head_num_, total_token_num), torch.float32) + + k_3d = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] + token_att_fwd( + q_3d, + k_3d, + att_m_tensor, + infer_state.req_manager.req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.max_len_in_batch, + ) + + from lightllm.models.llama.triton_kernel.token_attention_softmax_and_reducev import token_softmax_reducev_fwd + + v_3d = infer_state.mem_manager.kv_buffer[self.layer_num_][ + :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : self.head_dim_ + ] + + o_3d = self.alloc_tensor((batch_size, self.tp_q_head_num_, self.head_dim_), q.dtype) if out is None else out + + token_softmax_reducev_fwd( + att_m_tensor, + v_3d, + o_3d, + infer_state.req_manager.req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + ) + return o_3d.view(batch_size, -1) diff --git a/lightllm/models/neo_chat/layer_weights/__init__.py b/lightllm/models/neo_chat/layer_weights/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/neo_chat/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/neo_chat/layer_weights/pre_and_post_layer_weight.py new file mode 100644 index 000000000..c1f0638ac --- /dev/null +++ b/lightllm/models/neo_chat/layer_weights/pre_and_post_layer_weight.py @@ -0,0 +1,23 @@ +import torch +import numpy as np +from lightllm.models.qwen2.layer_weights.pre_and_post_layer_weight import Qwen2PreAndPostLayerWeight + +# add key: language_model.xxx -> xxx +# only change keys at PreAndPostLayerWeight load, TransformLayerWeight is correct now +def rename_weight_keys(weights): + prefix = "language_model." + keys = list(weights.keys()) + for k in keys: + if prefix in k: + weights[k.replace(prefix, "")] = weights.pop(k) + + +class NeoChatPreAndPostLayerWeight(Qwen2PreAndPostLayerWeight): + def __init__(self, data_type, network_config, mode): + super().__init__(data_type, network_config, mode) + return + + def load_hf_weights(self, weights): + rename_weight_keys(weights) + super().load_hf_weights(weights) + return diff --git a/lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py b/lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py new file mode 100644 index 000000000..e5e769a76 --- /dev/null +++ b/lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py @@ -0,0 +1,51 @@ +from lightllm.models.qwen3.layer_weights.transformer_layer_weight import Qwen3TransformerLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ( + NormWeight, + ROWMMWeight, +) + + +class NeoChatTransformerLayerWeight(Qwen3TransformerLayerWeight): + def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): + super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + return + + def _init_weight_names(self): + super()._init_weight_names() + self._q_weight_hw_name = f"model.layers.{self.layer_num_}.self_attn.q_proj_hw.weight" + self._q_bias_hw_name = None + self._k_weight_hw_name = f"model.layers.{self.layer_num_}.self_attn.k_proj_hw.weight" + self._k_bias_hw_name = None + + self._q_norm_h_name = f"model.layers.{self.layer_num_}.self_attn.q_norm_h.weight" + self._q_norm_w_name = f"model.layers.{self.layer_num_}.self_attn.q_norm_w.weight" + + self._k_norm_h_name = f"model.layers.{self.layer_num_}.self_attn.k_norm_h.weight" + self._k_norm_w_name = f"model.layers.{self.layer_num_}.self_attn.k_norm_w.weight" + + def _init_qkv(self): + super()._init_qkv() + self.q_hw_proj = ROWMMWeight( + weight_names=self._q_weight_hw_name, + data_type=self.data_type_, + bias_names=self._q_bias_hw_name, + quant_cfg=self.quant_cfg, + layer_num=self.layer_num_, + name="q_hw_proj", + ) + self.k_hw_proj = ROWMMWeight( + weight_names=self._k_weight_hw_name, + data_type=self.data_type_, + bias_names=self._k_bias_hw_name, + quant_cfg=self.quant_cfg, + layer_num=self.layer_num_, + name="k_hw_proj", + ) + + def _init_norm(self): + super()._init_norm() + + self.q_norm_h_weight_ = NormWeight(weight_name=self._q_norm_h_name, data_type=self.data_type_) + self.q_norm_w_weight_ = NormWeight(weight_name=self._q_norm_w_name, data_type=self.data_type_) + self.k_norm_h_weight_ = NormWeight(weight_name=self._k_norm_h_name, data_type=self.data_type_) + self.k_norm_w_weight_ = NormWeight(weight_name=self._k_norm_w_name, data_type=self.data_type_) diff --git a/lightllm/models/neo_chat/model.py b/lightllm/models/neo_chat/model.py new file mode 100644 index 000000000..14d9f96dc --- /dev/null +++ b/lightllm/models/neo_chat/model.py @@ -0,0 +1,53 @@ +import os +import json +from lightllm.common.build_utils import repair_config +from lightllm.models.registry import ModelRegistry, llm_model_type_is +from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo +from lightllm.models.qwen3_vl.layer_infer.pre_layer_infer import Qwen3VLMultimodalPreLayerInfer +from lightllm.models.qwen3_vl.layer_infer.transformer_layer_infer import Qwen3VLTransformerLayerInfer +from lightllm.models.qwen3_vl.layer_weights.pre_and_post_layer_weight import Qwen3VLPreAndPostLayerWeight +from lightllm.models.qwen2_vl.model import QWen2VLTokenizer +from lightllm.models.qwen3.model import Qwen3TpPartModel +from lightllm.server.core.objs import SamplingParams +from lightllm.models.qwen3_moe.model import Qwen3MOEModel +from lightllm.server.multimodal_params import AudioItem, MultimodalParams, ImageItem +from lightllm.models.neo_chat_moe.vision_process import smart_resize +from lightllm.models.internvl.model import InternvlTokenizer +from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer +from lightllm.models.neo_chat.layer_infer.transformer_layer_infer import NeoChatTransformerLayerInfer +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.models.neo_chat.layer_weights.transformer_layer_weight import NeoChatTransformerLayerWeight +from lightllm.models.neo_chat.layer_weights.pre_and_post_layer_weight import NeoChatPreAndPostLayerWeight +from lightllm.common.basemodel.multimodal_tokenizer import BaseMultiModalTokenizer +from lightllm.models.neo_chat_moe.infer_struct import NeoChatInferStateInfo + + +@ModelRegistry(["neo_chat"], is_multimodal=True, condition=llm_model_type_is("qwen3")) +class NeoTpPartModel(Qwen3TpPartModel): + + pre_layer_infer_class = LlamaMultimodalPreLayerInfer + transformer_layer_infer_class = NeoChatTransformerLayerInfer + + pre_and_post_weight_class = NeoChatPreAndPostLayerWeight + transformer_weight_class = NeoChatTransformerLayerWeight + + infer_state_class = NeoChatInferStateInfo + + def __init__(self, kvargs): + super().__init__(kvargs) + return + + def _init_inferstate_cls(self): + pass + + def _init_config(self): + with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: + all_config = json.load(json_file) + self.config = all_config["llm_config"] + # rename keys + repair_config(self.config, same_names=["num_attention_heads", "n_head"]) + repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"]) + repair_config(self.config, same_names=["num_hidden_layers", "n_layer"]) + if self.finetune_config: + self.config["vocab_size"] = self.finetune_config.vocab_size + return diff --git a/lightllm/models/neo_chat_moe/__init__.py b/lightllm/models/neo_chat_moe/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/neo_chat_moe/infer_struct.py b/lightllm/models/neo_chat_moe/infer_struct.py new file mode 100644 index 000000000..0c7d9372e --- /dev/null +++ b/lightllm/models/neo_chat_moe/infer_struct.py @@ -0,0 +1,99 @@ +from typing import Optional, List +import torch +import numpy as np +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.common.req_manager import ReqManager +from lightllm.models.neo_chat_moe.triton_kernel.get_neo_position import get_neo_position_triton +from lightllm.models.llama.model import LlamaTpPartModel + + +class NeoChatInferStateInfo(LlamaInferStateInfo): + def __init__(self): + super().__init__() + self.position_cos = None + self.position_sin = None + self.position_cos_h = None + self.position_sin_h = None + self.position_cos_w = None + self.position_sin_w = None + + def init_some_extra_state(self, model: LlamaTpPartModel, input_ids: torch.Tensor): + LlamaInferStateInfo.init_some_extra_state(self, model, input_ids) + if self.is_prefill: + self.position_ids = self.get_neo_position(self.multimodal_params) + else: + b_position_delta = [0 for _ in range(self.b_seq_len.shape[0])] + for batch_idx, p in enumerate(self.multimodal_params): + position_delta = 0 + for image in p["images"]: + position_delta += image["grid_thwd"][3] + b_position_delta[batch_idx] = position_delta + position_ids = self.position_ids + torch.tensor(b_position_delta, device=self.position_ids.device) + self.position_ids = position_ids.unsqueeze(0).expand(3, -1).clone() + self.position_ids[1:].zero_() + + self.position_ids = self.position_ids.contiguous() + self.position_cos = model._cos_cached[self.position_ids[0]] + self.position_sin = model._sin_cached[self.position_ids[0]] + self.position_cos_h = model._hw_cos_cached[self.position_ids[1]] + self.position_sin_h = model._hw_sin_cached[self.position_ids[1]] + self.position_cos_w = model._hw_cos_cached[self.position_ids[2]] + self.position_sin_w = model._hw_sin_cached[self.position_ids[2]] + return + + def get_neo_position(self, multimodal_params: List[dict]) -> torch.Tensor: + if len(multimodal_params) == 0: + position_ids = self.position_ids.new_zeros((3, self.position_ids.size(0))) + position_ids[0].copy_(self.position_ids) + return position_ids + b_image_start_idx = [] + b_image_nums = [] + b_image_start_num = [] + b_image_len = [] + image_start_num = 0 + b_image_thwd = [] + + # pad multimodal_params to batch size. + batch_size = self.b_q_seq_len.shape[0] + multimodal_params = multimodal_params + [ + {"images": [], "audios": []} for _ in range(batch_size - len(multimodal_params)) + ] + + for _, p in enumerate(multimodal_params): + images = p.get("images", []) + for img in images: + b_image_start_idx.append(img["start_idx"]) + a = img["start_idx"] + print(f"img start_idx: {a}") + b_image_len.append(img["token_num"]) + b_image_thwd.append(img["grid_thwd"]) + b_image_nums.append(len(images)) + b_image_start_num.append(image_start_num) + image_start_num += len(images) + + # 没有任何图片 + if image_start_num == 0: + position_ids = self.position_ids.new_zeros((3, self.position_ids.size(0))) + position_ids[0].copy_(self.position_ids) + return position_ids.contiguous() + b_image_start_idx = torch.tensor(b_image_start_idx, device="cpu").cuda(non_blocking=True) + b_image_thwd = torch.tensor(b_image_thwd, device="cpu").cuda(non_blocking=True) # image_num x 4 + b_image_nums = torch.tensor(b_image_nums, device="cpu").cuda(non_blocking=True) + b_image_start_num = torch.tensor(b_image_start_num, device="cpu").cuda(non_blocking=True) + b_image_len = torch.tensor(b_image_len, device="cpu").cuda(non_blocking=True) + + position_ids = self.position_ids.new_zeros((3, self.position_ids.size(0))) + position_ids[0].copy_(self.position_ids) + + get_neo_position_triton( + b_image_start_idx=b_image_start_idx, + b_image_thwd=b_image_thwd, + b_image_nums=b_image_nums, + b_image_start_num=b_image_start_num, + b_image_len=b_image_len, + position_ids=position_ids, + b_ready_cache_len=self.b_ready_cache_len, + b_q_seq_len=self.b_q_seq_len, + b_start_loc=self.b_start_loc, + ) + return position_ids diff --git a/lightllm/models/neo_chat_moe/layer_infer/__init__.py b/lightllm/models/neo_chat_moe/layer_infer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py new file mode 100644 index 000000000..ed48a9c6f --- /dev/null +++ b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py @@ -0,0 +1,159 @@ +import torch +from functools import partial +from typing import Tuple +from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward +from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd +from lightllm.models.neo_chat_moe.infer_struct import NeoChatInferStateInfo +from lightllm.models.neo_chat_moe.triton_kernel.context_attention_fwd_neo import context_attention_fwd_neo +from lightllm.models.llama.triton_kernel.context_flashattention_nopad import context_attention_fwd +from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd +from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import Qwen3MOETransformerLayerInfer +from lightllm.models.neo_chat_moe.layer_weights.transformer_layer_weight import NeoChatMOETransformerLayerWeight +from lightllm.distributed import all_reduce +import torch.distributed as dist +from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer +from lightllm.models.qwen3.triton_kernel.qk_norm import qk_rmsnorm_forward + + +class NeoChatMOETransformerLayerInfer(Qwen3MOETransformerLayerInfer): + def __init__(self, data_type, network_config, mode): + super().__init__(data_type, network_config, mode) + return + + def _bind_attention(self): + self._context_attention_kernel = self._context_attention_kernel + self._token_attention_kernel = self._token_decode_attention_normal + self._copy_kv_to_mem_cache = self._copy_kv_to_mem_cache_normal + return + + def _get_qkv(self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoChatMOETransformerLayerWeight): + input = input.view(-1, self.embed_dim_) + q = layer_weight.q_proj.mm(input) # [T, Hq*D] + + q_hw = layer_weight.q_hw_proj.mm(input) + q_hw = q_hw.view(-1, self.tp_q_head_num_, self.head_dim_) + q_h, q_w = q_hw.chunk(2, dim=-1) + + k_hw = layer_weight.k_hw_proj.mm(input) + k_hw = k_hw.view(-1, self.tp_k_head_num_, self.head_dim_) + k_h, k_w = k_hw.chunk(2, dim=-1) + + cache_kv = layer_weight.kv_proj.mm(input) # [T, (Hk+Hv)*D] + + qk_rmsnorm_forward(q, weight=layer_weight.q_norm_weight_.weight, eps=self.eps_) + + q_h_2d = q_h.reshape(q.shape[0], -1) + q_w_2d = q_w.reshape(q.shape[0], -1) + qk_rmsnorm_forward(q_h_2d, weight=layer_weight.q_norm_h_weight_.weight, eps=self.eps_) + qk_rmsnorm_forward(q_w_2d, weight=layer_weight.q_norm_w_weight_.weight, eps=self.eps_) + q_h = q_h_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2) + q_w = q_w_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2) + + qk_rmsnorm_forward( + cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], + weight=layer_weight.k_norm_weight_.weight, + eps=self.eps_, + ) + + k_h_2d = k_h.reshape(q.shape[0], -1) # [T, Hk*(D/2)] + k_w_2d = k_w.reshape(q.shape[0], -1) + qk_rmsnorm_forward(k_h_2d, weight=layer_weight.k_norm_h_weight_.weight, eps=self.eps_) + qk_rmsnorm_forward(k_w_2d, weight=layer_weight.k_norm_w_weight_.weight, eps=self.eps_) + k_h = k_h_2d.view(q.shape[0], self.tp_k_head_num_, self.head_dim_ // 2) + k_w = k_w_2d.view(q.shape[0], self.tp_k_head_num_, self.head_dim_ // 2) + + cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) + + rotary_emb_fwd( + q.view(-1, self.tp_q_head_num_, self.head_dim_), + cache_kv[:, : self.tp_k_head_num_, :], + infer_state.position_cos, + infer_state.position_sin, + ) + rotary_emb_fwd( + q_h, + k_h, + infer_state.position_cos_h, + infer_state.position_sin_h, + ) + rotary_emb_fwd( + q_w, + k_w, + infer_state.position_cos_w, + infer_state.position_sin_w, + ) + + q3 = q.view(-1, self.tp_q_head_num_, self.head_dim_) + q3 = torch.cat([q3, q_h, q_w], dim=-1) + q = q3.reshape(q3.shape[0], -1) + + k = cache_kv[:, : self.tp_k_head_num_, :] + k = torch.cat([k, k_h, k_w], dim=-1) + + v = cache_kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :] + v_pad = torch.zeros((v.shape[0], v.shape[1], self.head_dim_), device=v.device, dtype=v.dtype) + v = torch.cat([v, v_pad], dim=-1) + + cache_kv = torch.cat([k, v], dim=1) + return q, cache_kv + + def _context_attention_kernel( + self, q, kv, infer_state: NeoChatInferStateInfo, layer_weight, out=None + ) -> torch.Tensor: + o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out + kv = infer_state.mem_manager.kv_buffer[self.layer_num_] + context_attention_fwd_neo( + q.view(-1, self.tp_q_head_num_, self.head_dim_ * 2), + kv[:, 0 : self.tp_k_head_num_, :], + kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :], + o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2), + infer_state.position_ids[0], # [0,0,1,2,3,3,3,4] + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.b_ready_cache_len, + infer_state.max_len_in_batch, + infer_state.req_manager.req_to_token_indexs, + ) + o3 = o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2) + o3 = o3[:, :, : self.head_dim_].contiguous() + return o3.view(o3.shape[0], -1) + + def _token_decode_attention_normal(self, q, infer_state: NeoChatInferStateInfo, layer_weight, out=None): + total_token_num = infer_state.total_token_num + batch_size = infer_state.batch_size + + q_3d = q.view(batch_size, self.tp_q_head_num_, self.head_dim_ * 2) + + att_m_tensor = self.alloc_tensor((self.tp_q_head_num_, total_token_num), torch.float32) + + k_3d = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] + token_att_fwd( + q_3d, + k_3d, + att_m_tensor, + infer_state.req_manager.req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.max_len_in_batch, + ) + + from lightllm.models.llama.triton_kernel.token_attention_softmax_and_reducev import token_softmax_reducev_fwd + + v_3d = infer_state.mem_manager.kv_buffer[self.layer_num_][ + :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : self.head_dim_ + ] + + o_3d = self.alloc_tensor((batch_size, self.tp_q_head_num_, self.head_dim_), q.dtype) if out is None else out + + token_softmax_reducev_fwd( + att_m_tensor, + v_3d, + o_3d, + infer_state.req_manager.req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + ) + return o_3d.view(batch_size, -1) diff --git a/lightllm/models/neo_chat_moe/layer_weights/__init__.py b/lightllm/models/neo_chat_moe/layer_weights/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/neo_chat_moe/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/neo_chat_moe/layer_weights/pre_and_post_layer_weight.py new file mode 100644 index 000000000..7766a5d29 --- /dev/null +++ b/lightllm/models/neo_chat_moe/layer_weights/pre_and_post_layer_weight.py @@ -0,0 +1,23 @@ +import torch +import numpy as np +from lightllm.models.qwen2.layer_weights.pre_and_post_layer_weight import Qwen2PreAndPostLayerWeight + +# add key: language_model.xxx -> xxx +# only change keys at PreAndPostLayerWeight load, TransformLayerWeight is correct now +def rename_weight_keys(weights): + prefix = "language_model." + keys = list(weights.keys()) + for k in keys: + if prefix in k: + weights[k.replace(prefix, "")] = weights.pop(k) + + +class NeoChatMOEPreAndPostLayerWeight(Qwen2PreAndPostLayerWeight): + def __init__(self, data_type, network_config, mode): + super().__init__(data_type, network_config, mode) + return + + def load_hf_weights(self, weights): + rename_weight_keys(weights) + super().load_hf_weights(weights) + return diff --git a/lightllm/models/neo_chat_moe/layer_weights/transformer_layer_weight.py b/lightllm/models/neo_chat_moe/layer_weights/transformer_layer_weight.py new file mode 100644 index 000000000..bc38f1adc --- /dev/null +++ b/lightllm/models/neo_chat_moe/layer_weights/transformer_layer_weight.py @@ -0,0 +1,51 @@ +from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ( + NormWeight, + ROWMMWeight, +) + + +class NeoChatMOETransformerLayerWeight(Qwen3MOETransformerLayerWeight): + def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): + super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + return + + def _init_weight_names(self): + super()._init_weight_names() + self._q_weight_hw_name = f"model.layers.{self.layer_num_}.self_attn.q_proj_hw.weight" + self._q_bias_hw_name = None + self._k_weight_hw_name = f"model.layers.{self.layer_num_}.self_attn.k_proj_hw.weight" + self._k_bias_hw_name = None + + self._q_norm_h_name = f"model.layers.{self.layer_num_}.self_attn.q_norm_h.weight" + self._q_norm_w_name = f"model.layers.{self.layer_num_}.self_attn.q_norm_w.weight" + + self._k_norm_h_name = f"model.layers.{self.layer_num_}.self_attn.k_norm_h.weight" + self._k_norm_w_name = f"model.layers.{self.layer_num_}.self_attn.k_norm_w.weight" + + def _init_qkv(self): + super()._init_qkv() + self.q_hw_proj = ROWMMWeight( + weight_names=self._q_weight_hw_name, + data_type=self.data_type_, + bias_names=self._q_bias_hw_name, + quant_cfg=self.quant_cfg, + layer_num=self.layer_num_, + name="q_hw_proj", + ) + self.k_hw_proj = ROWMMWeight( + weight_names=self._k_weight_hw_name, + data_type=self.data_type_, + bias_names=self._k_bias_hw_name, + quant_cfg=self.quant_cfg, + layer_num=self.layer_num_, + name="k_hw_proj", + ) + + def _init_norm(self): + super()._init_norm() + + self.q_norm_h_weight_ = NormWeight(weight_name=self._q_norm_h_name, data_type=self.data_type_) + self.q_norm_w_weight_ = NormWeight(weight_name=self._q_norm_w_name, data_type=self.data_type_) + self.k_norm_h_weight_ = NormWeight(weight_name=self._k_norm_h_name, data_type=self.data_type_) + self.k_norm_w_weight_ = NormWeight(weight_name=self._k_norm_w_name, data_type=self.data_type_) diff --git a/lightllm/models/neo_chat_moe/model.py b/lightllm/models/neo_chat_moe/model.py new file mode 100644 index 000000000..e4123d109 --- /dev/null +++ b/lightllm/models/neo_chat_moe/model.py @@ -0,0 +1,139 @@ +import os +import json +from lightllm.common.build_utils import repair_config +from lightllm.models.registry import ModelRegistry, llm_model_type_is +from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo +from lightllm.models.qwen3_vl.layer_infer.pre_layer_infer import Qwen3VLMultimodalPreLayerInfer +from lightllm.models.qwen3_vl.layer_infer.transformer_layer_infer import Qwen3VLTransformerLayerInfer +from lightllm.models.qwen3_vl.layer_weights.pre_and_post_layer_weight import Qwen3VLPreAndPostLayerWeight +from lightllm.models.qwen2_vl.model import QWen2VLTokenizer +from lightllm.models.qwen3.model import Qwen3TpPartModel +from lightllm.server.core.objs import SamplingParams +from lightllm.models.qwen3_moe.model import Qwen3MOEModel +from lightllm.server.multimodal_params import AudioItem, MultimodalParams, ImageItem +from lightllm.models.neo_chat_moe.vision_process import smart_resize +from lightllm.models.internvl.model import InternvlTokenizer +from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer +from lightllm.models.neo_chat_moe.layer_infer.transformer_layer_infer import NeoChatMOETransformerLayerInfer +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.models.neo_chat_moe.layer_weights.transformer_layer_weight import NeoChatMOETransformerLayerWeight +from lightllm.models.neo_chat_moe.layer_weights.pre_and_post_layer_weight import NeoChatMOEPreAndPostLayerWeight +from lightllm.common.basemodel.multimodal_tokenizer import BaseMultiModalTokenizer +from lightllm.models.neo_chat_moe.infer_struct import NeoChatInferStateInfo + +IMG_START_TOKEN = "" +IMG_END_TOKEN = "" +IMG_TOKEN = "" +AUDIO_START_TOKEN = "" + + +class NeoChatTokenizer(BaseMultiModalTokenizer): + def __init__(self, tokenizer, model_cfg, **kwargs): + super().__init__(tokenizer) + self.tokenizer = tokenizer + self.min_pixel = model_cfg.get("vision_config").get("min_pixels") + self.max_pixel = model_cfg.get("vision_config").get("max_pixels") + self.patch_size = model_cfg.get("vision_config").get("patch_size") + self.downsample_ratio = model_cfg.get("vision_config").get("downsample_ratio") + + self.image_token_id = model_cfg.get("image_token_id") + self.image_start_tag = IMG_START_TOKEN + self.image_start_id = tokenizer.convert_tokens_to_ids(self.image_start_tag) + self.image_end_tag = IMG_END_TOKEN + self.image_end_id = tokenizer.convert_tokens_to_ids(self.image_end_tag) + + def init_imageitem_extral_params( + self, img: ImageItem, multi_params: MultimodalParams, sampling_params: SamplingParams + ): + return + + def init_audioitem_extral_params( + self, audio: AudioItem, multi_params: MultimodalParams, sampling_params: SamplingParams + ): + raise NotImplementedError + + def get_audio_token_length(self, audio: AudioItem): + raise NotImplementedError + + def get_image_token_length(self, img: ImageItem): + width, height = img.image_w, img.image_h + resized_height, resized_width = smart_resize( + height=height, + width=width, + factor=int(self.patch_size // self.downsample_ratio), + min_pixels=self.min_pixel, + max_pixels=self.max_pixel, + ) + grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size + token_num = int((grid_h * grid_w) * (self.downsample_ratio ** 2)) + # 这里的grid_h和grid_w需要* self.downsample_ratio么?再仔细看下代码 + img.grid_thwd = (1, int(grid_h * self.downsample_ratio), int(grid_w * self.downsample_ratio), 1 - token_num) + return token_num + + # only change the impl of the encode func: + def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs): + # TEXTTEXTTEXT --> TEXTTEXTTEXT + image_tokens = IMG_START_TOKEN + IMG_END_TOKEN + if multimodal_params is None: + add_special_tokens = kwargs.get("add_special_tokens", True) + return self.tokenizer.encode(prompt, add_special_tokens=add_special_tokens) + image_count = len(multimodal_params.images) + prompt = prompt.replace(IMG_TOKEN, image_tokens, image_count) + + origin_ids = self.tokenizer.encode(prompt, add_special_tokens=kwargs["add_special_tokens"]) + # --> id,id+1...id+num + input_ids = [] + image_id = 0 + start_idx = 0 + while True: + try: + start_idx = origin_ids.index(self.image_start_id) + if start_idx + 1 >= len(origin_ids): + break + if origin_ids[start_idx + 1] == self.image_end_id: + input_ids.extend(origin_ids[: start_idx + 1]) + token_id = multimodal_params.images[image_id].token_id + token_num = multimodal_params.images[image_id].token_num + multimodal_params.images[image_id].start_idx = len(input_ids) + input_ids.extend(range(token_id, token_id + token_num)) + input_ids.append(self.image_end_id) + origin_ids = origin_ids[start_idx + 2 :] + image_id += 1 + else: + raise ValueError("image token error") + except ValueError: + break + input_ids.extend(origin_ids) + return input_ids + + +@ModelRegistry(["neo_chat"], is_multimodal=True, condition=llm_model_type_is("qwen3_moe")) +class NeoTpMOEPartModel(Qwen3MOEModel): + + pre_layer_infer_class = LlamaMultimodalPreLayerInfer + transformer_layer_infer_class = NeoChatMOETransformerLayerInfer + + pre_and_post_weight_class = NeoChatMOEPreAndPostLayerWeight + transformer_weight_class = NeoChatMOETransformerLayerWeight + + infer_state_class = NeoChatInferStateInfo + + def __init__(self, kvargs): + super().__init__(kvargs) + return + + def _init_inferstate_cls(self): + pass + + def _init_config(self): + with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: + all_config = json.load(json_file) + self.config = all_config["llm_config"] + # rename keys + repair_config(self.config, same_names=["num_attention_heads", "n_head"]) + repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"]) + repair_config(self.config, same_names=["num_hidden_layers", "n_layer"]) + if self.finetune_config: + self.config["vocab_size"] = self.finetune_config.vocab_size + return diff --git a/lightllm/models/neo_chat_moe/neo_visual.py b/lightllm/models/neo_chat_moe/neo_visual.py new file mode 100644 index 000000000..852ddc095 --- /dev/null +++ b/lightllm/models/neo_chat_moe/neo_visual.py @@ -0,0 +1,279 @@ +import os +import torch +import torch.nn.functional as F +from PIL import Image +from typing import List +from io import BytesIO +import torch.nn as nn +from transformers.activations import ACT2FN +from safetensors import safe_open +from lightllm.server.multimodal_params import ImageItem +from transformers.modeling_outputs import BaseModelOutputWithPooling +from transformers.modeling_utils import PreTrainedModel +from lightllm.models.neo_chat_moe.vision_process import load_image_native +from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data + + +def apply_rotary_emb_1d( + x: torch.Tensor, + cos_cached: torch.Tensor, + sin_cached: torch.Tensor, + positions: torch.Tensor, +): + """对输入张量的一部分应用1D RoPE。""" + # x: (..., seq_len, dim_part) + # positions: (..., seq_len) + # cos_cached: (max_pos, dim_part / 2) + cos_cached = cos_cached.to(device=positions.device) + sin_cached = sin_cached.to(device=positions.device) + + cos = cos_cached[positions] # Shape: (positions.shape, dim_part / 2) + sin = sin_cached[positions] # Shape: (positions.shape, dim_part / 2) + + x1 = x[..., 0::2] + x2 = x[..., 1::2] + + rotated_x1 = x1 * cos - x2 * sin + rotated_x2 = x1 * sin + x2 * cos + + x_rotated = torch.empty_like(x) + x_rotated[..., 0::2] = rotated_x1 + x_rotated[..., 1::2] = rotated_x2 + return x_rotated + + +def apply_2d_rotary_pos_emb( + x: torch.Tensor, + cos_cached_x: torch.Tensor, + sin_cached_x: torch.Tensor, + cos_cached_y: torch.Tensor, + sin_cached_y: torch.Tensor, + abs_positions_x: torch.Tensor, + abs_positions_y: torch.Tensor, +): + """应用2D RoPE到输入张量x。""" + dim = x.shape[-1] + dim_half = dim // 2 + + # 假设我们将embedding的前半部分用于一个方向的RoPE,后半部分用于另一个方向 + # 例如,前一半给X坐标,后一半给Y坐标 (或者反过来,但要保持一致) + x_part_1 = x[..., :dim_half] + x_part_2 = x[..., dim_half:] + + # 将与 abs_positions_x 相关的旋转应用于 x_part_1 + rotated_part_1 = apply_rotary_emb_1d(x_part_1, cos_cached_x, sin_cached_x, abs_positions_x) + # 将与 abs_positions_y 相关的旋转应用于 x_part_2 + rotated_part_2 = apply_rotary_emb_1d(x_part_2, cos_cached_y, sin_cached_y, abs_positions_y) + + # 将它们重新拼接起来。确保顺序与你分割时一致。 + return torch.cat((rotated_part_1, rotated_part_2), dim=-1) + + +def build_abs_positions_from_grid_hw(grid_hw: torch.Tensor, device=None): + """ + Compute patch coordinates (x, y) + + Args: + grid_hw: (B, 2) tensor representing (H, W) per image + """ + device = grid_hw.device + B = grid_hw.shape[0] + + # Get the number of patches per image + H = grid_hw[:, 0] + W = grid_hw[:, 1] + N = H * W + N_total = N.sum() + + # Create the batch index for each patch (B x patch count) + patch_to_sample = torch.repeat_interleave(torch.arange(B, device=device), N) # (N_total,) + + # Generate intra-image patch index (row-major order) + patch_id_within_image = torch.arange(N_total, device=device) + patch_id_within_image = ( + patch_id_within_image + - torch.cumsum(torch.cat([torch.tensor([0], device=device), N[:-1]]), dim=0)[patch_to_sample] + ) + + # Get H/W for each patch according to its image + W_per_patch = W[patch_to_sample] + abs_x = patch_id_within_image % W_per_patch + abs_y = patch_id_within_image // W_per_patch + + return abs_x, abs_y + + +class NeoVisionTransformerPretrainedModel(nn.Module): + def __init__( + self, + kvargs, + hidden_size: int = 1024, + llm_hidden_size: int = 2048, + downsample_ratio: float = 0.5, + patch_size: int = 16, + num_channels: int = 3, + max_position_embeddings_vision: int = 10000, + rope_theta_vision: float = 10000.0, + min_pixels: int = 65536, + max_pixels: int = 2408448, + **kwargs, + ): + super().__init__() + self.weight_dir = kvargs["weight_dir"] + self.data_type = kvargs.get("data_type", "bfloat16") + self.embed_dim = hidden_size + self.llm_hidden_size = llm_hidden_size + self.patch_size = patch_size + self.num_channels = num_channels + self.downsample_ratio = downsample_ratio + self.downsample_factor = int(1 / downsample_ratio) + self.max_position_embeddings_vision = max_position_embeddings_vision + self.rope_theta_vision = rope_theta_vision + self.rope_dim_part = self.embed_dim // 2 + self.min_pixels = min_pixels + self.max_pixels = max_pixels + + self.patch_embedding = nn.Conv2d( + in_channels=num_channels, out_channels=self.embed_dim, kernel_size=patch_size, stride=patch_size + ) + + self.dense_embedding = nn.Conv2d( + in_channels=self.embed_dim, + out_channels=self.llm_hidden_size, + kernel_size=self.downsample_factor, + stride=self.downsample_factor, + ) + self.gelu = nn.GELU() + + self.repe_dim_part = self.embed_dim // 2 + self.cos_x, self.sin_x = self.precompute_rope_freqs_sincos() + self.cos_y, self.sin_y = self.precompute_rope_freqs_sincos() + self._init_datatype() + + def _init_datatype(self): + if isinstance(self.data_type, torch.dtype): + return + if self.data_type in ["fp16", "float16"]: + self.data_type = torch.float16 + elif self.data_type in ["bf16", "bfloat16"]: + self.data_type = torch.bfloat16 + elif self.data_type in ["fp32", "float32"]: + self.data_type = torch.float32 + else: + raise ValueError(f"Unsupport datatype {self.data_type}!") + return + + def load_model(self, weight_dir): + bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")] + if bin_weight_files: + weight_dict = {} + for file_ in bin_weight_files: + f = torch.load(os.path.join(weight_dir, file_), "cpu") + for k, v in f.items(): + if "vision_model" in k: + weight_dict[k[len("vision_model.embeddings.") :]] = v + else: + hf_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".safetensors")] + weight_dict = {} + for file_ in hf_weight_files: + f = safe_open(os.path.join(weight_dir, file_), "pt", "cpu") + for k in f.keys(): + if "vision_model" in k: + weight_dict[k[len("vision_model.embeddings.") :]] = f.get_tensor(k) + self.load_state_dict(weight_dict) + + def precompute_rope_freqs_sincos(self): + inv_freq = 1.0 / ( + self.rope_theta_vision ** (torch.arange(0, self.rope_dim_part, 2).float() / self.rope_dim_part) + ) + t = torch.arange(self.max_position_embeddings_vision).type_as(inv_freq) + freqs = torch.outer(t, inv_freq) + return torch.cos(freqs), torch.sin(freqs) + + def _apply_2d_rotary_pos_emb(self, patch_embeds, grid_hw): + """ + Apply 2D Rotary Position Embedding to the patch embeddings. + """ + abs_pos_x, abs_pos_y = build_abs_positions_from_grid_hw(grid_hw, device=patch_embeds.device) + embeddings = apply_2d_rotary_pos_emb( + patch_embeds.to(torch.float32), # RoPE calculations are often more stable in float32 + self.cos_x, + self.sin_x, + self.cos_y, + self.sin_y, + abs_pos_x, + abs_pos_y, + ).to(self.patch_embedding.weight.dtype) + return embeddings + + def forward(self, pixel_values: torch.Tensor, grid_hw: torch.Tensor) -> torch.Tensor: + pixel_values = pixel_values.view( + -1, + 3, + self.patch_size, + self.patch_size, + ) + patch_embeds = self.gelu(self.patch_embedding(pixel_values)).view(-1, self.embed_dim) + patch_embeds = self._apply_2d_rotary_pos_emb(patch_embeds, grid_hw) + assert (grid_hw[:, 0] * grid_hw[:, 1]).sum() == patch_embeds.shape[ + 0 + ], "Grid size and patch embeds size mismatch." + + patches_list = [] + cur_position = 0 + for i in range(grid_hw.shape[0]): + h, w = grid_hw[i] + patches_per_img = patch_embeds[cur_position : cur_position + h * w].view(h, w, -1).unsqueeze(0) + patches_per_img = self.dense_embedding(patches_per_img.permute(0, 3, 1, 2)) + patches_per_img = patches_per_img.permute(0, 2, 3, 1) + patches_list.append(patches_per_img.view(-1, patches_per_img.shape[-1])) + cur_position += h * w + + embeddings = torch.cat(patches_list, dim=0) # (N_total // downsample_factor**2, C) + assert cur_position == patch_embeds.shape[0] + assert embeddings.shape[0] == int(patch_embeds.shape[0] / self.downsample_factor ** 2) + + return embeddings + + def encode(self, images: List[ImageItem]): + img_tensors = [] + valid_ids = [] + valid_id = 0 + img_grids = [] + uuids = [] + + for i, img in enumerate(images): + if isinstance(img, ImageItem): + uuids.append(img.uuid) + image_data = read_shm(get_shm_name_data(img.uuid)) + image_data = Image.open(BytesIO(image_data)) + pixel_values, image_grid_hw = load_image_native( + image_data, + patch_size=self.patch_size, + downsample_ratio=self.downsample_ratio, + min_pixels=self.min_pixels, + max_pixels=self.max_pixels, + ) + img_tensors.append(pixel_values) + img_grids.append(image_grid_hw) + else: + raise Exception("Unsupport input types: {} for {}".format(type(img), img)) + + # must devide merge_length + cur_num = int(img_tensors[-1].shape[0] * (self.downsample_ratio ** 2)) + print(f"cur_num is {cur_num}") + valid_ids.append([valid_id, valid_id + cur_num]) + valid_id += cur_num + + if len(img_tensors) <= 0: + return None + + imgs = torch.cat(img_tensors, dim=0) + grid_hw = torch.cat(img_grids, dim=0) + + pixel_values = imgs.to("cuda", dtype=self.data_type, non_blocking=True) + image_grid_hw = grid_hw.to("cuda", non_blocking=True) + + all_img_embeds = self.forward(pixel_values, grid_hw=image_grid_hw) + + return all_img_embeds, uuids, valid_ids diff --git a/lightllm/models/neo_chat_moe/triton_kernel/__init__.py b/lightllm/models/neo_chat_moe/triton_kernel/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py b/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py new file mode 100644 index 000000000..f5dae493c --- /dev/null +++ b/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py @@ -0,0 +1,452 @@ +import math +import torch +import triton +import triton.language as tl + +from lightllm.utils.device_utils import is_tesla + + +@triton.jit +def _fwd_kernel( + Q, + K, + V, + sm_scale, + Out, + position_ids, # 1D: packed like Q (only NEW tokens), length == Q.shape[0] + B_Start_Loc, + B_Seqlen, + Req_to_tokens, + B_req_idx, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_req_to_tokens_b, + stride_req_to_tokens_s, + kv_group_num, + b_prompt_cache_len, + H: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + start_m = tl.program_id(0) + cur_bh = tl.program_id(1) + cur_batch = cur_bh // H + cur_head = cur_bh % H + + cur_kv_head = cur_head // kv_group_num + + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + prompt_cache_len = tl.load(b_prompt_cache_len + cur_batch) + total_len = tl.load(B_Seqlen + cur_batch) + cur_batch_seq_len = total_len - prompt_cache_len # NEW len + cur_batch_req_idx = tl.load(B_req_idx + cur_batch) + + block_start_loc = BLOCK_M * start_m + + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = block_start_loc + tl.arange(0, BLOCK_M) + + # Q pointers + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + + offs_d[None, :] * stride_qd + ) + + q_valid = offs_m < cur_batch_seq_len + q = tl.load(Q + off_q, mask=q_valid[:, None], other=0.0) + + # online softmax state + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) + block_end_loc = total_len + + # absolute q positions in the request + q_pos = prompt_cache_len + offs_m # [M] + + # q_gid from packed position_ids (aligned with Q rows) + q_gid = tl.load( + position_ids + cur_batch_in_all_start_index + offs_m, + mask=q_valid, + other=-2147483648, + ).to(tl.int32) + + BIG = tl.full([BLOCK_N], 1000000000, tl.int32) # ensure != any normal gid + + for start_n in range(0, block_mask * block_end_loc, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + + k_pos = start_n + offs_n # [N] + k_valid = k_pos < block_end_loc + + # map logical pos -> mem_index (for K/V) + kv_loc = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * k_pos, + mask=k_valid, + other=0, + ).to(tl.int64) + + # load K + off_k = kv_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd + k = tl.load(K + off_k, mask=k_valid[None, :], other=0.0) + + qk = tl.dot(q, k) + + # k_gid: + # - for cached keys (k_pos < prompt_cache_len): set to BIG + k_pos so equality is always false + # - for new keys (k_pos >= prompt_cache_len): read from packed position_ids by (k_pos - prompt_cache_len) + k_in_new = k_pos >= prompt_cache_len + k_new_idx = (k_pos - prompt_cache_len).to(tl.int32) # [N] valid only when k_in_new + k_gid_new = tl.load( + position_ids + cur_batch_in_all_start_index + k_new_idx, + mask=k_valid & k_in_new, + other=-2147483647, + ).to(tl.int32) + + k_gid = tl.where( + k_in_new, + k_gid_new, + (k_pos.to(tl.int32) + BIG), + ) + + # mask: causal OR same gid (only possible inside NEW part) + mask = (q_pos[:, None] >= k_pos[None, :]) | (q_gid[:, None] == k_gid[None, :]) + mask = mask & q_valid[:, None] & k_valid[None, :] + + qk = tl.where(mask, qk * sm_scale, -1.0e8) + + # online softmax + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + p = tl.math.exp2(qk) + l_ij = tl.sum(p, 1) + + alpha = tl.math.exp2(m_i - m_ij) + l_i = l_i * alpha + l_ij + acc = acc * alpha[:, None] + + # load V + off_v = kv_loc[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd + v = tl.load(V + off_v, mask=k_valid[:, None], other=0.0) + + p = p.to(v.dtype) + acc = tl.dot(p, v, acc) + + m_i = m_ij + + acc = acc / l_i[:, None] + + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + + offs_d[None, :] * stride_od + ) + tl.store(Out + off_o, acc, mask=q_valid[:, None]) + + +@torch.no_grad() +def context_attention_fwd_neo( + q, + k, + v, + o, + position_ids, # 1D packed like q (only NEW tokens) + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + max_input_len, + req_to_token_indexs, +): + # minimal safety: position_ids must cover packed q rows + assert position_ids.numel() >= q.shape[0], (position_ids.numel(), q.shape[0]) + + BLOCK_M = 128 if not is_tesla() else 64 + + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128, 256} + base_head_dim = Lq // 2 + sm_scale = 1.0 / (base_head_dim ** 0.5) * 1.4426950408889634 + + batch, head = b_seq_len.shape[0], q.shape[1] + kv_group_num = q.shape[1] // k.shape[1] + + grid = lambda meta: (triton.cdiv(max_input_len, meta["BLOCK_M"]), batch * head, 1) + + BLOCK_N = BLOCK_M + num_warps = 4 if Lk <= 64 else 8 + num_stages = 1 + + _fwd_kernel[grid]( + q, + k, + v, + sm_scale, + o, + position_ids, + b_start_loc, + b_seq_len, + req_to_token_indexs, + b_req_idx, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + req_to_token_indexs.stride(0), + req_to_token_indexs.stride(1), + kv_group_num=kv_group_num, + b_prompt_cache_len=b_prompt_cache_len, + H=head, + BLOCK_DMODEL=Lk, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + num_warps=num_warps, + num_stages=num_stages, + ) + + +def reference_attention( + q, + k, + v, + position_ids_q, # 1D packed like q (only NEW tokens) + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + req_to_token_indexs, +): + device = q.device + dtype = q.dtype + sum_q, Hq, D = q.shape + Hk = k.shape[1] + kv_group_num = Hq // Hk + + batch = b_seq_len.shape[0] + out = torch.empty_like(q) + scale = 1.0 / math.sqrt(D) + + for b in range(batch): + req = int(b_req_idx[b].item()) + total_len = int(b_seq_len[b].item()) + prompt_len = int(b_prompt_cache_len[b].item()) + new_len = total_len - prompt_len + + q_start = int(b_start_loc[b].item()) + q_blk = q[q_start : q_start + new_len] # [M, Hq, D] + gid_new = position_ids_q[q_start : q_start + new_len].to(torch.int64) # [M] + + # gather K/V for full request by logical pos -> mem_index + token_locs = req_to_token_indexs[req, :total_len].to(torch.int64) # [L] + k_blk = k[token_locs] # [L, Hk, D] + v_blk = v[token_locs] # [L, Hk, D] + + # expand kv heads to q heads (GQA) + k_hq = k_blk.repeat_interleave(kv_group_num, dim=1) # [L, Hq, D] + v_hq = v_blk.repeat_interleave(kv_group_num, dim=1) # [L, Hq, D] + + # positions + q_pos = torch.arange(prompt_len, total_len, device=device, dtype=torch.int64) # [M] + k_pos = torch.arange(0, total_len, device=device, dtype=torch.int64) # [L] + + # build allow mask: + # causal always + allow = k_pos[None, :] <= q_pos[:, None] + + # full-attn only inside NEW part by gid + # compare only when k_pos in NEW + k_in_new = k_pos >= prompt_len + k_rel = (k_pos - prompt_len).clamp_min(0) # [L] + # map k_rel to gid_new, but only valid where k_in_new + k_gid = torch.empty((total_len,), device=device, dtype=torch.int64) + k_gid[:] = 10 ** 12 + k_pos # never equal to gid_new + k_gid[k_in_new] = gid_new[k_rel[k_in_new]] + + allow = allow | (gid_new[q_pos - prompt_len][:, None] == k_gid[None, :]) + + # scores: [Hq, M, L] + q_t = q_blk.permute(1, 0, 2).to(torch.float32) # [Hq, M, D] + k_t = k_hq.permute(1, 2, 0).to(torch.float32) # [Hq, D, L] + scores = torch.matmul(q_t, k_t) * scale # [Hq, M, L] + + neg = torch.tensor(-1.0e9, device=device, dtype=torch.float32) + scores = torch.where(allow[None, :, :], scores, neg) + + p = torch.softmax(scores, dim=-1).to(torch.float32) # [Hq, M, L] + v_t = v_hq.permute(1, 0, 2).to(torch.float32) # [Hq, L, D] + out_hq = torch.matmul(p, v_t) # [Hq, M, D] + out_blk = out_hq.permute(1, 0, 2).to(dtype) # [M, Hq, D] + + out[q_start : q_start + new_len] = out_blk + + return out + + +def make_test_case( + device="cuda", + dtype=torch.float16, + batch=3, + Hq=8, + Hk=4, + D=64, + seed=0, + base_index=50000, +): + torch.manual_seed(seed) + + # prompt (cached) len and new len + prompt_lens = torch.randint(low=2, high=8, size=(batch,), device=device) + new_lens = torch.randint(low=1, high=8, size=(batch,), device=device) + total_lens = (prompt_lens + new_lens).to(torch.int32) + + max_total_len = int(total_lens.max().item()) + max_new_len = int(new_lens.max().item()) + + # packed q start + b_start_loc = torch.zeros((batch,), device=device, dtype=torch.int32) + cur = 0 + for b in range(batch): + b_start_loc[b] = cur + cur += int(new_lens[b].item()) + sum_q = cur + + b_seq_len = total_lens + b_prompt_cache_len = prompt_lens.to(torch.int32) + + # one req per batch + num_req = batch + b_req_idx = torch.arange(batch, device=device, dtype=torch.int32) + + # global KV space large, indices not small + sum_kv = int(total_lens.sum().item()) + kv_size = base_index + sum_kv + 1024 + pool = torch.randperm(kv_size - base_index, device=device, dtype=torch.int64)[:sum_kv] + base_index + + # Req_to_tokens [num_req, max_total_len] + req_to_token_indexs = torch.zeros((num_req, max_total_len), device=device, dtype=torch.int32) + p = 0 + for r in range(num_req): + L = int(total_lens[r].item()) + req_to_token_indexs[r, :L] = pool[p : p + L].to(torch.int32) + p += L + + # position_ids_q: only NEW tokens, packed like q + position_ids_q = torch.empty((sum_q,), device=device, dtype=torch.int32) + for b in range(batch): + M = int(new_lens[b].item()) + start = int(b_start_loc[b].item()) + + gid = torch.arange(M, device=device, dtype=torch.int32) + + # make one repeated block inside NEW part to simulate image tokens + if M >= 4 and torch.rand((), device=device).item() > 0.3: + s = int(torch.randint(0, M - 2, (1,), device=device).item()) + e = min(M, s + 3) + gid[s:e] = gid[s] + + position_ids_q[start : start + M] = gid + + q = torch.randn((sum_q, Hq, D), device=device, dtype=dtype) + k = torch.randn((kv_size, Hk, D), device=device, dtype=dtype) + v = torch.randn((kv_size, Hk, D), device=device, dtype=dtype) + o = torch.empty((sum_q, Hq, D), device=device, dtype=dtype) + + return ( + q, + k, + v, + o, + position_ids_q, + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + max_new_len, + req_to_token_indexs, + ) + + +def check_once(device="cuda", dtype=torch.float16, seed=0): + ( + q, + k, + v, + o, + position_ids_q, + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + max_new_len, + req_to_token_indexs, + ) = make_test_case(device=device, dtype=dtype, seed=seed) + + context_attention_fwd_neo( + q, + k, + v, + o, + position_ids_q, + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + max_new_len, + req_to_token_indexs, + ) + + ref = reference_attention( + q, + k, + v, + position_ids_q, + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + req_to_token_indexs, + ) + + diff = (o - ref).abs() + max_abs = diff.max().item() + denom = ref.abs().max().item() + 1e-6 + max_rel = max_abs / denom + + print(f"seed={seed}, dtype={dtype}") + print(f"max_abs_error = {max_abs:.6e}") + print(f"max_rel_error = {max_rel:.6e}") + print("allclose(fp16 tol)?", torch.allclose(o, ref, atol=5e-2, rtol=5e-2)) + + +if __name__ == "__main__": + if not torch.cuda.is_available(): + print("No CUDA, skip.") + else: + torch.cuda.synchronize() + check_once(dtype=torch.bfloat16, seed=0) + check_once(dtype=torch.bfloat16, seed=1) + check_once(dtype=torch.bfloat16, seed=2) diff --git a/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py b/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py new file mode 100644 index 000000000..955f48bd8 --- /dev/null +++ b/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py @@ -0,0 +1,174 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _get_neo_position_triton( + b_image_start_idx: torch.Tensor, + b_image_thwd: torch.Tensor, + b_image_thwd_stride0: torch.Tensor, + b_image_nums: torch.Tensor, + b_image_start_num: torch.Tensor, + b_image_len: torch.Tensor, + position_ids: torch.Tensor, + position_ids_stride0: torch.Tensor, + b_ready_cache_len: torch.Tensor, + b_q_seq_len: torch.Tensor, + b_start_loc: torch.Tensor, + BLOCK_SIZE: tl.constexpr, +) -> torch.Tensor: + cur_batch = tl.program_id(0) + cache_len = tl.load(b_ready_cache_len + cur_batch) + q_seq_len = tl.load(b_q_seq_len + cur_batch) + image_num = tl.load(b_image_nums + cur_batch) + image_start_num = tl.load(b_image_start_num + cur_batch) + start_loc = tl.load(b_start_loc + cur_batch) + for i in range(image_num): + local_image_start_idx = tl.load(b_image_start_idx + image_start_num + i) + image_start_idx = start_loc + local_image_start_idx - cache_len + image_len = tl.load(b_image_len + image_start_num + i) + # image_h = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 1) + image_w = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 2) + for j in range(0, image_len, BLOCK_SIZE): + off = j + tl.arange(0, BLOCK_SIZE) + # 目前没考虑视频,所以t 恒为 0 + t_pos = local_image_start_idx + off * 0 + h_pos = off // image_w + w_pos = off % image_w + tl.store( + position_ids + off + image_start_idx, + t_pos, + mask=(off < image_len) + & (off + local_image_start_idx - cache_len < q_seq_len) + & (local_image_start_idx - cache_len + off >= 0), + ) + tl.store( + position_ids + position_ids_stride0 + off + image_start_idx, + h_pos, + mask=(off < image_len) + & (off + local_image_start_idx - cache_len < q_seq_len) + & (local_image_start_idx - cache_len + off >= 0), + ) + tl.store( + position_ids + position_ids_stride0 * 2 + off + image_start_idx, + w_pos, + mask=(off < image_len) + & (off + local_image_start_idx - cache_len < q_seq_len) + & (local_image_start_idx - cache_len + off >= 0), + ) + + for i in range(image_num): + local_image_start_idx = tl.load(b_image_start_idx + image_start_num + i) + image_len = tl.load(b_image_len + image_start_num + i) + image_delta = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 3) + image_end = local_image_start_idx + image_len - cache_len + text_start = tl.maximum(0, image_end) + for j in range(text_start, q_seq_len, BLOCK_SIZE): + off = j + tl.arange(0, BLOCK_SIZE) + t_pos = tl.load(position_ids + off + start_loc, mask=(off < q_seq_len), other=0.0) + image_delta + h_pos = tl.load(position_ids + position_ids_stride0 + off + start_loc, mask=(off < q_seq_len), other=0.0) + w_pos = tl.load( + position_ids + position_ids_stride0 * 2 + off + start_loc, mask=(off < q_seq_len), other=0.0 + ) + tl.store(position_ids + off + start_loc, t_pos, mask=(off < q_seq_len)) + tl.store(position_ids + position_ids_stride0 + off + start_loc, h_pos, mask=(off < q_seq_len)) + tl.store(position_ids + position_ids_stride0 * 2 + off + start_loc, w_pos, mask=(off < q_seq_len)) + return + + +def get_neo_position_triton( + b_image_start_idx: torch.Tensor, + b_image_thwd: torch.Tensor, + b_image_nums: torch.Tensor, + b_image_start_num: torch.Tensor, + b_image_len: torch.Tensor, + position_ids: torch.Tensor, + b_ready_cache_len: torch.Tensor, + b_q_seq_len: torch.Tensor, + b_start_loc: torch.Tensor, +) -> torch.Tensor: + + batch_size = b_q_seq_len.shape[0] + assert batch_size == b_image_nums.shape[0] + grid = (batch_size,) + BLOCK_SIZE = 64 + _get_neo_position_triton[grid]( + b_image_start_idx=b_image_start_idx, + b_image_thwd=b_image_thwd, + b_image_thwd_stride0=b_image_thwd.stride(0), + b_image_nums=b_image_nums, + b_image_start_num=b_image_start_num, + b_image_len=b_image_len, + position_ids=position_ids, + position_ids_stride0=position_ids.stride(0), + b_ready_cache_len=b_ready_cache_len, + b_q_seq_len=b_q_seq_len, + b_start_loc=b_start_loc, + BLOCK_SIZE=BLOCK_SIZE, + ) + + +def test(): + b_image_start_idx = torch.tensor([0, 0, 4], dtype=torch.int32, device="cuda") + b_image_thwd = torch.tensor([[1, 2, 2, -3], [1, 2, 2, -3], [1, 2, 4, -7]], dtype=torch.int32, device="cuda") + b_image_nums = torch.tensor([1, 2], dtype=torch.int32, device="cuda") + b_image_start_num = torch.tensor([0, 1], dtype=torch.int32, device="cuda") + b_image_len = torch.tensor([4, 4, 8], dtype=torch.int32, device="cuda") + position_ids = ( + torch.tensor([0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=torch.int32, device="cuda") + .unsqueeze(0) + .expand(3, -1) + .contiguous() + ) + position_ids[1:].zero_() + b_ready_cache_len = torch.tensor([0, 0], dtype=torch.int32, device="cuda") + b_q_seq_len = torch.tensor([7, 13], dtype=torch.int32, device="cuda") + b_start_loc = torch.tensor([0, 7], dtype=torch.int32, device="cuda") + get_neo_position_triton( + b_image_start_idx, + b_image_thwd, + b_image_nums, + b_image_start_num, + b_image_len, + position_ids, + b_ready_cache_len, + b_q_seq_len, + b_start_loc, + ) + + print(position_ids) + # old_value = torch.cat([position_ids[:, 2:7], position_ids[:, 7 + 2 :]], dim=1) + + # position_ids = ( + # torch.tensor([2, 3, 4, 5, 6, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=torch.int32, device="cuda") + # .unsqueeze(0) + # .expand(3, -1) + # .contiguous() + # ) + # b_ready_cache_len = torch.tensor([2, 2], dtype=torch.int32, device="cuda") + # b_q_seq_len = torch.tensor([5, 11], dtype=torch.int32, device="cuda") + # b_start_loc = torch.tensor([0, 5], dtype=torch.int32, device="cuda") + + # get_neo_position_triton( + # b_image_start_idx, + # b_image_thwd, + # b_image_nums, + # b_image_start_num, + # b_image_len, + # position_ids, + # b_ready_cache_len, + # b_q_seq_len, + # b_start_loc, + # ) + + # print(f"old_value:\n{old_value}") + # print(f"position_ids:\n{position_ids}") + # assert torch.equal(old_value, position_ids) + + """ + tensor([[0, 0, 0, 0, 2, 3, 4, 0, 0, 0, 0, 2, 2, 2, 2, 4, 5, 6, 7, 8], + [0, 0, 1, 1, 2, 3, 4, 0, 0, 1, 1, 2, 2, 3, 3, 4, 5, 6, 7, 8], + [0, 1, 0, 1, 2, 3, 4, 0, 1, 0, 1, 2, 3, 2, 3, 4, 5, 6, 7, 8]], + device='cuda:0', dtype=torch.int32) + """ diff --git a/lightllm/models/neo_chat_moe/vision_process.py b/lightllm/models/neo_chat_moe/vision_process.py new file mode 100644 index 000000000..aa008e18f --- /dev/null +++ b/lightllm/models/neo_chat_moe/vision_process.py @@ -0,0 +1,141 @@ +import re +import math +import torch +import string +import numpy as np +import pandas as pd +from PIL import Image +import torch.distributed as dist +import torchvision.transforms as T + +IMAGENET_MEAN = (0.485, 0.456, 0.406) +IMAGENET_STD = (0.229, 0.224, 0.225) + + +def round_by_factor(number: int, factor: int) -> int: + """Returns the closest integer to 'number' that is divisible by 'factor'.""" + return round(number / factor) * factor + + +def ceil_by_factor(number: int, factor: int) -> int: + """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" + return math.ceil(number / factor) * factor + + +def floor_by_factor(number: int, factor: int) -> int: + """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" + return math.floor(number / factor) * factor + + +# copy from https://github.com/QwenLM/Qwen2.5-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py#L60 +def smart_resize( + height: int, width: int, factor: int = 32, min_pixels: int = 65536, max_pixels: int = 4194304 +) -> tuple[int, int]: + """ + Rescales the image so that the following conditions are met: + + 1. Both dimensions (height and width) are divisible by 'factor'. + + 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. + + 3. The aspect ratio of the image is maintained as closely as possible. + """ + if max(height, width) / min(height, width) > 200: + raise ValueError( + f"absolute aspect ratio must be smaller than {200}, got {max(height, width) / min(height, width)}" + ) + h_bar = max(factor, round_by_factor(height, factor)) + w_bar = max(factor, round_by_factor(width, factor)) + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = max(factor, floor_by_factor(height / beta, factor)) + w_bar = max(factor, floor_by_factor(width / beta, factor)) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = ceil_by_factor(height * beta, factor) + w_bar = ceil_by_factor(width * beta, factor) + return h_bar, w_bar + + +def dynamic_preprocess_native_resolution(image, size_factor=32, min_pixels=65536, max_pixels=4194304, **kwargs): + width, height = image.size + resized_height, resized_width = smart_resize( + height, + width, + factor=size_factor, + min_pixels=min_pixels, + max_pixels=max_pixels, + ) + image = image.resize((resized_width, resized_height)) + + return image + + +def preprocess_pixel_values(pixel_values, patch_size=16): + c, h, w = pixel_values.shape + grid_h = h // patch_size + grid_w = w // patch_size + + flatten_pixel_values = ( + pixel_values.view(c, grid_h, patch_size, grid_w, patch_size) + .permute(1, 3, 0, 2, 4) # [grid_h, grid_w, c, patch_size, patch_size] + .reshape(grid_h * grid_w, c * patch_size ** 2) + ) + + grid_hw = torch.tensor([[grid_h, grid_w]]).to(device=pixel_values.device) + + return flatten_pixel_values, grid_hw + + +def get_contrasting_background(image): + """ + Calculate the color (white or black) that is different from the average foreground color + to use as the background color + """ + image_np = np.array(image) + if (image_np[:, :, 3] == 0).any(): + non_transparent_pixels = image_np[:, :, :3][image_np[:, :, 3] > 0] + if non_transparent_pixels.size == 0: + return None + pixel_mean = non_transparent_pixels.mean() + contrasting_color = (0, 0, 0) if pixel_mean > 382.5 else (255, 255, 255) + return contrasting_color + else: + return None + + +def load_image_native(image, patch_size=16, downsample_ratio=0.5, min_pixels=65536, max_pixels=4194304, upscale=False): + """ + Load and preprocess an image file, converting it to RGB mode, + resizing, normalizing, and optionally adding a thumbnail version. + """ + if image.mode == "RGBA": + bg_color = get_contrasting_background(image) + if bg_color: + background = Image.new("RGB", image.size, bg_color) + background.paste(image, mask=image.split()[3]) + image = background.convert("RGB") + else: + image = image.convert("RGB") + else: + image = image.convert("RGB") + + if upscale: + image = image.resize((image.width * 2, image.height * 2), Image.BILINEAR) + + transform = T.Compose( + [ + T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), + T.ToTensor(), + T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), + ] + ) + + new_image = dynamic_preprocess_native_resolution( + image, size_factor=int(patch_size // downsample_ratio), min_pixels=min_pixels, max_pixels=max_pixels + ) + pixel_values, grid_hw = preprocess_pixel_values(transform(new_image).to(torch.float32), patch_size=patch_size) + + print(f"Transfer image_size from ({image.height, image.width}) to ({new_image.height, new_image.width})") + + return pixel_values, grid_hw diff --git a/lightllm/server/tokenizer.py b/lightllm/server/tokenizer.py index e0b2bd425..3563739f7 100644 --- a/lightllm/server/tokenizer.py +++ b/lightllm/server/tokenizer.py @@ -30,6 +30,7 @@ from ..models.qwen2_vl.model import QWen2VLTokenizer from ..models.qwen3_vl.model import QWen3VLTokenizer from ..models.internvl.model import InternvlTokenizer +from ..models.neo_chat_moe.model import NeoChatTokenizer from ..models.gemma3.model import Gemma3Tokenizer # A fast LLaMA tokenizer with the pre-processed `tokenizer.json` file. @@ -104,5 +105,7 @@ def get_tokenizer( tokenizer = InternvlTokenizer(tokenizer, model_cfg, weight_dir=tokenizer_name) elif model_type == "gemma3": tokenizer = Gemma3Tokenizer(tokenizer, model_cfg) + elif model_type == "neo_chat": + tokenizer = NeoChatTokenizer(tokenizer, model_cfg, weight_dir=tokenizer_name) return tokenizer diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index d3d1610f3..df5d66bcb 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -19,6 +19,7 @@ from lightllm.models.qwen2_5_vl.qwen2_5_visual import Qwen2_5_VisionTransformerPretrainedModel from lightllm.models.qwen3_vl.qwen3_visual import Qwen3VisionTransformerPretrainedModel from lightllm.models.tarsier2.tarsier2_visual import TarsierVisionTransformerPretrainedModel +from lightllm.models.neo_chat_moe.neo_visual import NeoVisionTransformerPretrainedModel from lightllm.utils.infer_utils import set_random_seed from lightllm.utils.dist_utils import init_vision_distributed_env from lightllm.utils.graceful_utils import graceful_registry @@ -78,6 +79,8 @@ def exposed_init_model(self, kvargs): # self.model = InternVLVisionModel() elif self.model_type == "gemma3": self.model = Gemma3VisionModel() + elif self.model_type == "neo_chat": + self.model = NeoVisionTransformerPretrainedModel(kvargs, **model_cfg["vision_config"]).eval().bfloat16() else: raise Exception(f"can not support {self.model_type} now")