diff --git a/lightllm/common/kv_cache_mem_manager/mem_manager.py b/lightllm/common/kv_cache_mem_manager/mem_manager.py
index d8fd93009..b599bedfc 100755
--- a/lightllm/common/kv_cache_mem_manager/mem_manager.py
+++ b/lightllm/common/kv_cache_mem_manager/mem_manager.py
@@ -28,7 +28,7 @@ class MemoryManager:
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9):
self.size = size
self.head_num = head_num
- self.head_dim = head_dim
+ self.head_dim = head_dim * 2 # neo kv 是[k, k_h, k_w]拼在一起的
self.layer_num = layer_num
self.always_copy = always_copy
self.dtype = dtype
@@ -60,7 +60,7 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False
self.size,
dtype,
head_num,
- head_dim,
+ self.head_dim,
layer_num,
)
self.HOLD_TOKEN_MEMINDEX = self.size
diff --git a/lightllm/models/__init__.py b/lightllm/models/__init__.py
index 4ee02f003..9a51d9512 100644
--- a/lightllm/models/__init__.py
+++ b/lightllm/models/__init__.py
@@ -38,4 +38,6 @@
Tarsier2LlamaTpPartModel,
)
from lightllm.models.gpt_oss.model import GptOssTpPartModel
+from lightllm.models.neo_chat_moe.model import NeoTpMOEPartModel
+from lightllm.models.neo_chat.model import NeoTpPartModel
from .registry import get_model, get_model_class
diff --git a/lightllm/models/llama/model.py b/lightllm/models/llama/model.py
index a228e0025..36b5d79b5 100644
--- a/lightllm/models/llama/model.py
+++ b/lightllm/models/llama/model.py
@@ -110,6 +110,8 @@ def _init_custom(self):
rope_scaling = self.config.get("rope_scaling", None)
if rope_scaling is None:
self._init_to_get_rotary()
+ if "rope_theta_hw" in self.config:
+ self._init_to_get_hw_rotary()
return
if "rope_type" in rope_scaling:
@@ -132,6 +134,8 @@ def _init_custom(self):
self._init_to_get_rotary()
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
+ if "rope_theta_hw" in self.config:
+ self._init_to_get_hw_rotary()
return
def _init_weights(self):
@@ -178,7 +182,7 @@ def _init_to_get_rotary(self, default_base=10000):
rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0)
base = self.config.get("rope_theta", float(default_base))
-
+ print(f"base is {base}")
if "max_sequence_length" in self.config:
max_seq_len = self.config["max_sequence_length"]
else:
@@ -211,6 +215,47 @@ def _init_to_get_rotary(self, default_base=10000):
self._sin_cached = torch.sin(freqs).to(self.data_type).cuda()
return
+ def _init_to_get_hw_rotary(self, default_base=10000):
+ partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_ // 2)
+ if self.config.get("rope_scaling", {}) is None:
+ rope_scaling_factor = 1.0
+ else:
+ rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0)
+
+ base = self.config.get("rope_theta_hw", float(default_base))
+ print(f"hw_base is {base}")
+ if "max_sequence_length" in self.config:
+ max_seq_len = self.config["max_sequence_length"]
+ else:
+ max_position_embeddings = self.config.get(
+ "max_position_embeddings_hw", 2048 if base <= 10000.0 + 1e-5 else 16384
+ )
+ max_seq_len = max_position_embeddings * rope_scaling_factor
+
+ # NTK
+ try:
+ ntk_alpha = float(os.environ.get("LIGHTLLM_NTK_ALPHA", 1))
+ assert ntk_alpha >= 1
+ if ntk_alpha > 1:
+ logger.info(f"Note: NTK enabled, alpha set to {ntk_alpha}")
+ max_seq_len *= ntk_alpha
+ base = base * (ntk_alpha ** (partial_head_dim / (partial_head_dim - 2))) # Base change formula
+ except:
+ pass
+
+ inv_freq = 1.0 / (
+ base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim)
+ )
+ t = (
+ torch.arange(max(max_seq_len + 1024 * 128, self.max_seq_length), device="cpu", dtype=torch.float32)
+ / rope_scaling_factor
+ )
+ freqs = torch.outer(t, inv_freq)
+
+ self._hw_cos_cached = torch.cos(freqs).to(self.data_type).cuda()
+ self._hw_sin_cached = torch.sin(freqs).to(self.data_type).cuda()
+ return
+
def _init_to_get_dynamic_ntk_rotary(self):
partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_)
max_position_embeddings = self.config.get("max_position_embeddings", 2048)
diff --git a/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py b/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
index eb5af6fec..02bd277ad 100644
--- a/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
+++ b/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
@@ -74,7 +74,8 @@ def token_att_fwd(q, k, att_out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen
Lq, Lk = q.shape[-1], k.shape[-1]
assert Lq == Lk
assert Lk in {16, 32, 64, 128, 256}
- sm_scale = 1.0 / (Lk ** 0.5)
+ Lk_scale = Lk // 2
+ sm_scale = 1.0 / (Lk_scale ** 0.5)
batch, head_num = B_req_idx.shape[0], q.shape[1]
diff --git a/lightllm/models/neo_chat/__init__.py b/lightllm/models/neo_chat/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/lightllm/models/neo_chat/layer_infer/__init__.py b/lightllm/models/neo_chat/layer_infer/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py
new file mode 100644
index 000000000..1cf13c413
--- /dev/null
+++ b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py
@@ -0,0 +1,159 @@
+import torch
+from functools import partial
+from typing import Tuple
+from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward
+from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd
+from lightllm.models.neo_chat_moe.infer_struct import NeoChatInferStateInfo
+from lightllm.models.neo_chat_moe.triton_kernel.context_attention_fwd_neo import context_attention_fwd_neo
+from lightllm.models.llama.triton_kernel.context_flashattention_nopad import context_attention_fwd
+from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd
+from lightllm.models.qwen3.layer_infer.transformer_layer_infer import Qwen3TransformerLayerInfer
+from lightllm.models.neo_chat.layer_weights.transformer_layer_weight import NeoChatTransformerLayerWeight
+from lightllm.distributed import all_reduce
+import torch.distributed as dist
+from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
+from lightllm.models.qwen3.triton_kernel.qk_norm import qk_rmsnorm_forward
+
+
+class NeoChatTransformerLayerInfer(Qwen3TransformerLayerInfer):
+ def __init__(self, data_type, network_config, mode):
+ super().__init__(data_type, network_config, mode)
+ return
+
+ def _bind_attention(self):
+ self._context_attention_kernel = self._context_attention_kernel
+ self._token_attention_kernel = self._token_decode_attention_normal
+ self._copy_kv_to_mem_cache = self._copy_kv_to_mem_cache_normal
+ return
+
+ def _get_qkv(self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoChatTransformerLayerWeight):
+ input = input.view(-1, self.embed_dim_)
+ q = layer_weight.q_proj.mm(input) # [T, Hq*D]
+
+ q_hw = layer_weight.q_hw_proj.mm(input)
+ q_hw = q_hw.view(-1, self.tp_q_head_num_, self.head_dim_)
+ q_h, q_w = q_hw.chunk(2, dim=-1)
+
+ k_hw = layer_weight.k_hw_proj.mm(input)
+ k_hw = k_hw.view(-1, self.tp_k_head_num_, self.head_dim_)
+ k_h, k_w = k_hw.chunk(2, dim=-1)
+
+ cache_kv = layer_weight.kv_proj.mm(input) # [T, (Hk+Hv)*D]
+
+ qk_rmsnorm_forward(q, weight=layer_weight.q_norm_weight_.weight, eps=self.eps_)
+
+ q_h_2d = q_h.reshape(q.shape[0], -1)
+ q_w_2d = q_w.reshape(q.shape[0], -1)
+ qk_rmsnorm_forward(q_h_2d, weight=layer_weight.q_norm_h_weight_.weight, eps=self.eps_)
+ qk_rmsnorm_forward(q_w_2d, weight=layer_weight.q_norm_w_weight_.weight, eps=self.eps_)
+ q_h = q_h_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2)
+ q_w = q_w_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2)
+
+ qk_rmsnorm_forward(
+ cache_kv[:, : self.tp_k_head_num_ * self.head_dim_],
+ weight=layer_weight.k_norm_weight_.weight,
+ eps=self.eps_,
+ )
+
+ k_h_2d = k_h.reshape(q.shape[0], -1) # [T, Hk*(D/2)]
+ k_w_2d = k_w.reshape(q.shape[0], -1)
+ qk_rmsnorm_forward(k_h_2d, weight=layer_weight.k_norm_h_weight_.weight, eps=self.eps_)
+ qk_rmsnorm_forward(k_w_2d, weight=layer_weight.k_norm_w_weight_.weight, eps=self.eps_)
+ k_h = k_h_2d.view(q.shape[0], self.tp_k_head_num_, self.head_dim_ // 2)
+ k_w = k_w_2d.view(q.shape[0], self.tp_k_head_num_, self.head_dim_ // 2)
+
+ cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
+
+ rotary_emb_fwd(
+ q.view(-1, self.tp_q_head_num_, self.head_dim_),
+ cache_kv[:, : self.tp_k_head_num_, :],
+ infer_state.position_cos,
+ infer_state.position_sin,
+ )
+ rotary_emb_fwd(
+ q_h,
+ k_h,
+ infer_state.position_cos_h,
+ infer_state.position_sin_h,
+ )
+ rotary_emb_fwd(
+ q_w,
+ k_w,
+ infer_state.position_cos_w,
+ infer_state.position_sin_w,
+ )
+
+ q3 = q.view(-1, self.tp_q_head_num_, self.head_dim_)
+ q3 = torch.cat([q3, q_h, q_w], dim=-1)
+ q = q3.reshape(q3.shape[0], -1)
+
+ k = cache_kv[:, : self.tp_k_head_num_, :]
+ k = torch.cat([k, k_h, k_w], dim=-1)
+
+ v = cache_kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :]
+ v_pad = torch.zeros((v.shape[0], v.shape[1], self.head_dim_), device=v.device, dtype=v.dtype)
+ v = torch.cat([v, v_pad], dim=-1)
+
+ cache_kv = torch.cat([k, v], dim=1)
+ return q, cache_kv
+
+ def _context_attention_kernel(
+ self, q, kv, infer_state: NeoChatInferStateInfo, layer_weight, out=None
+ ) -> torch.Tensor:
+ o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out
+ kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
+ context_attention_fwd_neo(
+ q.view(-1, self.tp_q_head_num_, self.head_dim_ * 2),
+ kv[:, 0 : self.tp_k_head_num_, :],
+ kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :],
+ o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2),
+ infer_state.position_ids[0], # [0,0,1,2,3,3,3,4]
+ infer_state.b_req_idx,
+ infer_state.b_start_loc,
+ infer_state.b_seq_len,
+ infer_state.b_ready_cache_len,
+ infer_state.max_len_in_batch,
+ infer_state.req_manager.req_to_token_indexs,
+ )
+ o3 = o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2)
+ o3 = o3[:, :, : self.head_dim_].contiguous()
+ return o3.view(o3.shape[0], -1)
+
+ def _token_decode_attention_normal(self, q, infer_state: NeoChatInferStateInfo, layer_weight, out=None):
+ total_token_num = infer_state.total_token_num
+ batch_size = infer_state.batch_size
+
+ q_3d = q.view(batch_size, self.tp_q_head_num_, self.head_dim_ * 2)
+
+ att_m_tensor = self.alloc_tensor((self.tp_q_head_num_, total_token_num), torch.float32)
+
+ k_3d = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :]
+ token_att_fwd(
+ q_3d,
+ k_3d,
+ att_m_tensor,
+ infer_state.req_manager.req_to_token_indexs,
+ infer_state.b_req_idx,
+ infer_state.b_start_loc,
+ infer_state.b_seq_len,
+ infer_state.max_len_in_batch,
+ )
+
+ from lightllm.models.llama.triton_kernel.token_attention_softmax_and_reducev import token_softmax_reducev_fwd
+
+ v_3d = infer_state.mem_manager.kv_buffer[self.layer_num_][
+ :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : self.head_dim_
+ ]
+
+ o_3d = self.alloc_tensor((batch_size, self.tp_q_head_num_, self.head_dim_), q.dtype) if out is None else out
+
+ token_softmax_reducev_fwd(
+ att_m_tensor,
+ v_3d,
+ o_3d,
+ infer_state.req_manager.req_to_token_indexs,
+ infer_state.b_req_idx,
+ infer_state.b_start_loc,
+ infer_state.b_seq_len,
+ )
+ return o_3d.view(batch_size, -1)
diff --git a/lightllm/models/neo_chat/layer_weights/__init__.py b/lightllm/models/neo_chat/layer_weights/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/lightllm/models/neo_chat/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/neo_chat/layer_weights/pre_and_post_layer_weight.py
new file mode 100644
index 000000000..c1f0638ac
--- /dev/null
+++ b/lightllm/models/neo_chat/layer_weights/pre_and_post_layer_weight.py
@@ -0,0 +1,23 @@
+import torch
+import numpy as np
+from lightllm.models.qwen2.layer_weights.pre_and_post_layer_weight import Qwen2PreAndPostLayerWeight
+
+# add key: language_model.xxx -> xxx
+# only change keys at PreAndPostLayerWeight load, TransformLayerWeight is correct now
+def rename_weight_keys(weights):
+ prefix = "language_model."
+ keys = list(weights.keys())
+ for k in keys:
+ if prefix in k:
+ weights[k.replace(prefix, "")] = weights.pop(k)
+
+
+class NeoChatPreAndPostLayerWeight(Qwen2PreAndPostLayerWeight):
+ def __init__(self, data_type, network_config, mode):
+ super().__init__(data_type, network_config, mode)
+ return
+
+ def load_hf_weights(self, weights):
+ rename_weight_keys(weights)
+ super().load_hf_weights(weights)
+ return
diff --git a/lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py b/lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py
new file mode 100644
index 000000000..e5e769a76
--- /dev/null
+++ b/lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py
@@ -0,0 +1,51 @@
+from lightllm.models.qwen3.layer_weights.transformer_layer_weight import Qwen3TransformerLayerWeight
+from lightllm.common.basemodel.layer_weights.meta_weights import (
+ NormWeight,
+ ROWMMWeight,
+)
+
+
+class NeoChatTransformerLayerWeight(Qwen3TransformerLayerWeight):
+ def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None):
+ super().__init__(layer_num, data_type, network_config, mode, quant_cfg)
+ return
+
+ def _init_weight_names(self):
+ super()._init_weight_names()
+ self._q_weight_hw_name = f"model.layers.{self.layer_num_}.self_attn.q_proj_hw.weight"
+ self._q_bias_hw_name = None
+ self._k_weight_hw_name = f"model.layers.{self.layer_num_}.self_attn.k_proj_hw.weight"
+ self._k_bias_hw_name = None
+
+ self._q_norm_h_name = f"model.layers.{self.layer_num_}.self_attn.q_norm_h.weight"
+ self._q_norm_w_name = f"model.layers.{self.layer_num_}.self_attn.q_norm_w.weight"
+
+ self._k_norm_h_name = f"model.layers.{self.layer_num_}.self_attn.k_norm_h.weight"
+ self._k_norm_w_name = f"model.layers.{self.layer_num_}.self_attn.k_norm_w.weight"
+
+ def _init_qkv(self):
+ super()._init_qkv()
+ self.q_hw_proj = ROWMMWeight(
+ weight_names=self._q_weight_hw_name,
+ data_type=self.data_type_,
+ bias_names=self._q_bias_hw_name,
+ quant_cfg=self.quant_cfg,
+ layer_num=self.layer_num_,
+ name="q_hw_proj",
+ )
+ self.k_hw_proj = ROWMMWeight(
+ weight_names=self._k_weight_hw_name,
+ data_type=self.data_type_,
+ bias_names=self._k_bias_hw_name,
+ quant_cfg=self.quant_cfg,
+ layer_num=self.layer_num_,
+ name="k_hw_proj",
+ )
+
+ def _init_norm(self):
+ super()._init_norm()
+
+ self.q_norm_h_weight_ = NormWeight(weight_name=self._q_norm_h_name, data_type=self.data_type_)
+ self.q_norm_w_weight_ = NormWeight(weight_name=self._q_norm_w_name, data_type=self.data_type_)
+ self.k_norm_h_weight_ = NormWeight(weight_name=self._k_norm_h_name, data_type=self.data_type_)
+ self.k_norm_w_weight_ = NormWeight(weight_name=self._k_norm_w_name, data_type=self.data_type_)
diff --git a/lightllm/models/neo_chat/model.py b/lightllm/models/neo_chat/model.py
new file mode 100644
index 000000000..14d9f96dc
--- /dev/null
+++ b/lightllm/models/neo_chat/model.py
@@ -0,0 +1,53 @@
+import os
+import json
+from lightllm.common.build_utils import repair_config
+from lightllm.models.registry import ModelRegistry, llm_model_type_is
+from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo
+from lightllm.models.qwen3_vl.layer_infer.pre_layer_infer import Qwen3VLMultimodalPreLayerInfer
+from lightllm.models.qwen3_vl.layer_infer.transformer_layer_infer import Qwen3VLTransformerLayerInfer
+from lightllm.models.qwen3_vl.layer_weights.pre_and_post_layer_weight import Qwen3VLPreAndPostLayerWeight
+from lightllm.models.qwen2_vl.model import QWen2VLTokenizer
+from lightllm.models.qwen3.model import Qwen3TpPartModel
+from lightllm.server.core.objs import SamplingParams
+from lightllm.models.qwen3_moe.model import Qwen3MOEModel
+from lightllm.server.multimodal_params import AudioItem, MultimodalParams, ImageItem
+from lightllm.models.neo_chat_moe.vision_process import smart_resize
+from lightllm.models.internvl.model import InternvlTokenizer
+from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer
+from lightllm.models.neo_chat.layer_infer.transformer_layer_infer import NeoChatTransformerLayerInfer
+from lightllm.models.llama.infer_struct import LlamaInferStateInfo
+from lightllm.models.neo_chat.layer_weights.transformer_layer_weight import NeoChatTransformerLayerWeight
+from lightllm.models.neo_chat.layer_weights.pre_and_post_layer_weight import NeoChatPreAndPostLayerWeight
+from lightllm.common.basemodel.multimodal_tokenizer import BaseMultiModalTokenizer
+from lightllm.models.neo_chat_moe.infer_struct import NeoChatInferStateInfo
+
+
+@ModelRegistry(["neo_chat"], is_multimodal=True, condition=llm_model_type_is("qwen3"))
+class NeoTpPartModel(Qwen3TpPartModel):
+
+ pre_layer_infer_class = LlamaMultimodalPreLayerInfer
+ transformer_layer_infer_class = NeoChatTransformerLayerInfer
+
+ pre_and_post_weight_class = NeoChatPreAndPostLayerWeight
+ transformer_weight_class = NeoChatTransformerLayerWeight
+
+ infer_state_class = NeoChatInferStateInfo
+
+ def __init__(self, kvargs):
+ super().__init__(kvargs)
+ return
+
+ def _init_inferstate_cls(self):
+ pass
+
+ def _init_config(self):
+ with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file:
+ all_config = json.load(json_file)
+ self.config = all_config["llm_config"]
+ # rename keys
+ repair_config(self.config, same_names=["num_attention_heads", "n_head"])
+ repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"])
+ repair_config(self.config, same_names=["num_hidden_layers", "n_layer"])
+ if self.finetune_config:
+ self.config["vocab_size"] = self.finetune_config.vocab_size
+ return
diff --git a/lightllm/models/neo_chat_moe/__init__.py b/lightllm/models/neo_chat_moe/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/lightllm/models/neo_chat_moe/infer_struct.py b/lightllm/models/neo_chat_moe/infer_struct.py
new file mode 100644
index 000000000..0c7d9372e
--- /dev/null
+++ b/lightllm/models/neo_chat_moe/infer_struct.py
@@ -0,0 +1,99 @@
+from typing import Optional, List
+import torch
+import numpy as np
+from lightllm.models.llama.infer_struct import LlamaInferStateInfo
+from lightllm.common.req_manager import ReqManager
+from lightllm.models.neo_chat_moe.triton_kernel.get_neo_position import get_neo_position_triton
+from lightllm.models.llama.model import LlamaTpPartModel
+
+
+class NeoChatInferStateInfo(LlamaInferStateInfo):
+ def __init__(self):
+ super().__init__()
+ self.position_cos = None
+ self.position_sin = None
+ self.position_cos_h = None
+ self.position_sin_h = None
+ self.position_cos_w = None
+ self.position_sin_w = None
+
+ def init_some_extra_state(self, model: LlamaTpPartModel, input_ids: torch.Tensor):
+ LlamaInferStateInfo.init_some_extra_state(self, model, input_ids)
+ if self.is_prefill:
+ self.position_ids = self.get_neo_position(self.multimodal_params)
+ else:
+ b_position_delta = [0 for _ in range(self.b_seq_len.shape[0])]
+ for batch_idx, p in enumerate(self.multimodal_params):
+ position_delta = 0
+ for image in p["images"]:
+ position_delta += image["grid_thwd"][3]
+ b_position_delta[batch_idx] = position_delta
+ position_ids = self.position_ids + torch.tensor(b_position_delta, device=self.position_ids.device)
+ self.position_ids = position_ids.unsqueeze(0).expand(3, -1).clone()
+ self.position_ids[1:].zero_()
+
+ self.position_ids = self.position_ids.contiguous()
+ self.position_cos = model._cos_cached[self.position_ids[0]]
+ self.position_sin = model._sin_cached[self.position_ids[0]]
+ self.position_cos_h = model._hw_cos_cached[self.position_ids[1]]
+ self.position_sin_h = model._hw_sin_cached[self.position_ids[1]]
+ self.position_cos_w = model._hw_cos_cached[self.position_ids[2]]
+ self.position_sin_w = model._hw_sin_cached[self.position_ids[2]]
+ return
+
+ def get_neo_position(self, multimodal_params: List[dict]) -> torch.Tensor:
+ if len(multimodal_params) == 0:
+ position_ids = self.position_ids.new_zeros((3, self.position_ids.size(0)))
+ position_ids[0].copy_(self.position_ids)
+ return position_ids
+ b_image_start_idx = []
+ b_image_nums = []
+ b_image_start_num = []
+ b_image_len = []
+ image_start_num = 0
+ b_image_thwd = []
+
+ # pad multimodal_params to batch size.
+ batch_size = self.b_q_seq_len.shape[0]
+ multimodal_params = multimodal_params + [
+ {"images": [], "audios": []} for _ in range(batch_size - len(multimodal_params))
+ ]
+
+ for _, p in enumerate(multimodal_params):
+ images = p.get("images", [])
+ for img in images:
+ b_image_start_idx.append(img["start_idx"])
+ a = img["start_idx"]
+ print(f"img start_idx: {a}")
+ b_image_len.append(img["token_num"])
+ b_image_thwd.append(img["grid_thwd"])
+ b_image_nums.append(len(images))
+ b_image_start_num.append(image_start_num)
+ image_start_num += len(images)
+
+ # 没有任何图片
+ if image_start_num == 0:
+ position_ids = self.position_ids.new_zeros((3, self.position_ids.size(0)))
+ position_ids[0].copy_(self.position_ids)
+ return position_ids.contiguous()
+ b_image_start_idx = torch.tensor(b_image_start_idx, device="cpu").cuda(non_blocking=True)
+ b_image_thwd = torch.tensor(b_image_thwd, device="cpu").cuda(non_blocking=True) # image_num x 4
+ b_image_nums = torch.tensor(b_image_nums, device="cpu").cuda(non_blocking=True)
+ b_image_start_num = torch.tensor(b_image_start_num, device="cpu").cuda(non_blocking=True)
+ b_image_len = torch.tensor(b_image_len, device="cpu").cuda(non_blocking=True)
+
+ position_ids = self.position_ids.new_zeros((3, self.position_ids.size(0)))
+ position_ids[0].copy_(self.position_ids)
+
+ get_neo_position_triton(
+ b_image_start_idx=b_image_start_idx,
+ b_image_thwd=b_image_thwd,
+ b_image_nums=b_image_nums,
+ b_image_start_num=b_image_start_num,
+ b_image_len=b_image_len,
+ position_ids=position_ids,
+ b_ready_cache_len=self.b_ready_cache_len,
+ b_q_seq_len=self.b_q_seq_len,
+ b_start_loc=self.b_start_loc,
+ )
+ return position_ids
diff --git a/lightllm/models/neo_chat_moe/layer_infer/__init__.py b/lightllm/models/neo_chat_moe/layer_infer/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py
new file mode 100644
index 000000000..ed48a9c6f
--- /dev/null
+++ b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py
@@ -0,0 +1,159 @@
+import torch
+from functools import partial
+from typing import Tuple
+from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward
+from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd
+from lightllm.models.neo_chat_moe.infer_struct import NeoChatInferStateInfo
+from lightllm.models.neo_chat_moe.triton_kernel.context_attention_fwd_neo import context_attention_fwd_neo
+from lightllm.models.llama.triton_kernel.context_flashattention_nopad import context_attention_fwd
+from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd
+from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import Qwen3MOETransformerLayerInfer
+from lightllm.models.neo_chat_moe.layer_weights.transformer_layer_weight import NeoChatMOETransformerLayerWeight
+from lightllm.distributed import all_reduce
+import torch.distributed as dist
+from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
+from lightllm.models.qwen3.triton_kernel.qk_norm import qk_rmsnorm_forward
+
+
+class NeoChatMOETransformerLayerInfer(Qwen3MOETransformerLayerInfer):
+ def __init__(self, data_type, network_config, mode):
+ super().__init__(data_type, network_config, mode)
+ return
+
+ def _bind_attention(self):
+ self._context_attention_kernel = self._context_attention_kernel
+ self._token_attention_kernel = self._token_decode_attention_normal
+ self._copy_kv_to_mem_cache = self._copy_kv_to_mem_cache_normal
+ return
+
+ def _get_qkv(self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoChatMOETransformerLayerWeight):
+ input = input.view(-1, self.embed_dim_)
+ q = layer_weight.q_proj.mm(input) # [T, Hq*D]
+
+ q_hw = layer_weight.q_hw_proj.mm(input)
+ q_hw = q_hw.view(-1, self.tp_q_head_num_, self.head_dim_)
+ q_h, q_w = q_hw.chunk(2, dim=-1)
+
+ k_hw = layer_weight.k_hw_proj.mm(input)
+ k_hw = k_hw.view(-1, self.tp_k_head_num_, self.head_dim_)
+ k_h, k_w = k_hw.chunk(2, dim=-1)
+
+ cache_kv = layer_weight.kv_proj.mm(input) # [T, (Hk+Hv)*D]
+
+ qk_rmsnorm_forward(q, weight=layer_weight.q_norm_weight_.weight, eps=self.eps_)
+
+ q_h_2d = q_h.reshape(q.shape[0], -1)
+ q_w_2d = q_w.reshape(q.shape[0], -1)
+ qk_rmsnorm_forward(q_h_2d, weight=layer_weight.q_norm_h_weight_.weight, eps=self.eps_)
+ qk_rmsnorm_forward(q_w_2d, weight=layer_weight.q_norm_w_weight_.weight, eps=self.eps_)
+ q_h = q_h_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2)
+ q_w = q_w_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2)
+
+ qk_rmsnorm_forward(
+ cache_kv[:, : self.tp_k_head_num_ * self.head_dim_],
+ weight=layer_weight.k_norm_weight_.weight,
+ eps=self.eps_,
+ )
+
+ k_h_2d = k_h.reshape(q.shape[0], -1) # [T, Hk*(D/2)]
+ k_w_2d = k_w.reshape(q.shape[0], -1)
+ qk_rmsnorm_forward(k_h_2d, weight=layer_weight.k_norm_h_weight_.weight, eps=self.eps_)
+ qk_rmsnorm_forward(k_w_2d, weight=layer_weight.k_norm_w_weight_.weight, eps=self.eps_)
+ k_h = k_h_2d.view(q.shape[0], self.tp_k_head_num_, self.head_dim_ // 2)
+ k_w = k_w_2d.view(q.shape[0], self.tp_k_head_num_, self.head_dim_ // 2)
+
+ cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
+
+ rotary_emb_fwd(
+ q.view(-1, self.tp_q_head_num_, self.head_dim_),
+ cache_kv[:, : self.tp_k_head_num_, :],
+ infer_state.position_cos,
+ infer_state.position_sin,
+ )
+ rotary_emb_fwd(
+ q_h,
+ k_h,
+ infer_state.position_cos_h,
+ infer_state.position_sin_h,
+ )
+ rotary_emb_fwd(
+ q_w,
+ k_w,
+ infer_state.position_cos_w,
+ infer_state.position_sin_w,
+ )
+
+ q3 = q.view(-1, self.tp_q_head_num_, self.head_dim_)
+ q3 = torch.cat([q3, q_h, q_w], dim=-1)
+ q = q3.reshape(q3.shape[0], -1)
+
+ k = cache_kv[:, : self.tp_k_head_num_, :]
+ k = torch.cat([k, k_h, k_w], dim=-1)
+
+ v = cache_kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :]
+ v_pad = torch.zeros((v.shape[0], v.shape[1], self.head_dim_), device=v.device, dtype=v.dtype)
+ v = torch.cat([v, v_pad], dim=-1)
+
+ cache_kv = torch.cat([k, v], dim=1)
+ return q, cache_kv
+
+ def _context_attention_kernel(
+ self, q, kv, infer_state: NeoChatInferStateInfo, layer_weight, out=None
+ ) -> torch.Tensor:
+ o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out
+ kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
+ context_attention_fwd_neo(
+ q.view(-1, self.tp_q_head_num_, self.head_dim_ * 2),
+ kv[:, 0 : self.tp_k_head_num_, :],
+ kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :],
+ o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2),
+ infer_state.position_ids[0], # [0,0,1,2,3,3,3,4]
+ infer_state.b_req_idx,
+ infer_state.b_start_loc,
+ infer_state.b_seq_len,
+ infer_state.b_ready_cache_len,
+ infer_state.max_len_in_batch,
+ infer_state.req_manager.req_to_token_indexs,
+ )
+ o3 = o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2)
+ o3 = o3[:, :, : self.head_dim_].contiguous()
+ return o3.view(o3.shape[0], -1)
+
+ def _token_decode_attention_normal(self, q, infer_state: NeoChatInferStateInfo, layer_weight, out=None):
+ total_token_num = infer_state.total_token_num
+ batch_size = infer_state.batch_size
+
+ q_3d = q.view(batch_size, self.tp_q_head_num_, self.head_dim_ * 2)
+
+ att_m_tensor = self.alloc_tensor((self.tp_q_head_num_, total_token_num), torch.float32)
+
+ k_3d = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :]
+ token_att_fwd(
+ q_3d,
+ k_3d,
+ att_m_tensor,
+ infer_state.req_manager.req_to_token_indexs,
+ infer_state.b_req_idx,
+ infer_state.b_start_loc,
+ infer_state.b_seq_len,
+ infer_state.max_len_in_batch,
+ )
+
+ from lightllm.models.llama.triton_kernel.token_attention_softmax_and_reducev import token_softmax_reducev_fwd
+
+ v_3d = infer_state.mem_manager.kv_buffer[self.layer_num_][
+ :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : self.head_dim_
+ ]
+
+ o_3d = self.alloc_tensor((batch_size, self.tp_q_head_num_, self.head_dim_), q.dtype) if out is None else out
+
+ token_softmax_reducev_fwd(
+ att_m_tensor,
+ v_3d,
+ o_3d,
+ infer_state.req_manager.req_to_token_indexs,
+ infer_state.b_req_idx,
+ infer_state.b_start_loc,
+ infer_state.b_seq_len,
+ )
+ return o_3d.view(batch_size, -1)
diff --git a/lightllm/models/neo_chat_moe/layer_weights/__init__.py b/lightllm/models/neo_chat_moe/layer_weights/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/lightllm/models/neo_chat_moe/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/neo_chat_moe/layer_weights/pre_and_post_layer_weight.py
new file mode 100644
index 000000000..7766a5d29
--- /dev/null
+++ b/lightllm/models/neo_chat_moe/layer_weights/pre_and_post_layer_weight.py
@@ -0,0 +1,23 @@
+import torch
+import numpy as np
+from lightllm.models.qwen2.layer_weights.pre_and_post_layer_weight import Qwen2PreAndPostLayerWeight
+
+# add key: language_model.xxx -> xxx
+# only change keys at PreAndPostLayerWeight load, TransformLayerWeight is correct now
+def rename_weight_keys(weights):
+ prefix = "language_model."
+ keys = list(weights.keys())
+ for k in keys:
+ if prefix in k:
+ weights[k.replace(prefix, "")] = weights.pop(k)
+
+
+class NeoChatMOEPreAndPostLayerWeight(Qwen2PreAndPostLayerWeight):
+ def __init__(self, data_type, network_config, mode):
+ super().__init__(data_type, network_config, mode)
+ return
+
+ def load_hf_weights(self, weights):
+ rename_weight_keys(weights)
+ super().load_hf_weights(weights)
+ return
diff --git a/lightllm/models/neo_chat_moe/layer_weights/transformer_layer_weight.py b/lightllm/models/neo_chat_moe/layer_weights/transformer_layer_weight.py
new file mode 100644
index 000000000..bc38f1adc
--- /dev/null
+++ b/lightllm/models/neo_chat_moe/layer_weights/transformer_layer_weight.py
@@ -0,0 +1,51 @@
+from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight
+from lightllm.common.basemodel.layer_weights.meta_weights import (
+ NormWeight,
+ ROWMMWeight,
+)
+
+
+class NeoChatMOETransformerLayerWeight(Qwen3MOETransformerLayerWeight):
+ def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None):
+ super().__init__(layer_num, data_type, network_config, mode, quant_cfg)
+ return
+
+ def _init_weight_names(self):
+ super()._init_weight_names()
+ self._q_weight_hw_name = f"model.layers.{self.layer_num_}.self_attn.q_proj_hw.weight"
+ self._q_bias_hw_name = None
+ self._k_weight_hw_name = f"model.layers.{self.layer_num_}.self_attn.k_proj_hw.weight"
+ self._k_bias_hw_name = None
+
+ self._q_norm_h_name = f"model.layers.{self.layer_num_}.self_attn.q_norm_h.weight"
+ self._q_norm_w_name = f"model.layers.{self.layer_num_}.self_attn.q_norm_w.weight"
+
+ self._k_norm_h_name = f"model.layers.{self.layer_num_}.self_attn.k_norm_h.weight"
+ self._k_norm_w_name = f"model.layers.{self.layer_num_}.self_attn.k_norm_w.weight"
+
+ def _init_qkv(self):
+ super()._init_qkv()
+ self.q_hw_proj = ROWMMWeight(
+ weight_names=self._q_weight_hw_name,
+ data_type=self.data_type_,
+ bias_names=self._q_bias_hw_name,
+ quant_cfg=self.quant_cfg,
+ layer_num=self.layer_num_,
+ name="q_hw_proj",
+ )
+ self.k_hw_proj = ROWMMWeight(
+ weight_names=self._k_weight_hw_name,
+ data_type=self.data_type_,
+ bias_names=self._k_bias_hw_name,
+ quant_cfg=self.quant_cfg,
+ layer_num=self.layer_num_,
+ name="k_hw_proj",
+ )
+
+ def _init_norm(self):
+ super()._init_norm()
+
+ self.q_norm_h_weight_ = NormWeight(weight_name=self._q_norm_h_name, data_type=self.data_type_)
+ self.q_norm_w_weight_ = NormWeight(weight_name=self._q_norm_w_name, data_type=self.data_type_)
+ self.k_norm_h_weight_ = NormWeight(weight_name=self._k_norm_h_name, data_type=self.data_type_)
+ self.k_norm_w_weight_ = NormWeight(weight_name=self._k_norm_w_name, data_type=self.data_type_)
diff --git a/lightllm/models/neo_chat_moe/model.py b/lightllm/models/neo_chat_moe/model.py
new file mode 100644
index 000000000..e4123d109
--- /dev/null
+++ b/lightllm/models/neo_chat_moe/model.py
@@ -0,0 +1,139 @@
+import os
+import json
+from lightllm.common.build_utils import repair_config
+from lightllm.models.registry import ModelRegistry, llm_model_type_is
+from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo
+from lightllm.models.qwen3_vl.layer_infer.pre_layer_infer import Qwen3VLMultimodalPreLayerInfer
+from lightllm.models.qwen3_vl.layer_infer.transformer_layer_infer import Qwen3VLTransformerLayerInfer
+from lightllm.models.qwen3_vl.layer_weights.pre_and_post_layer_weight import Qwen3VLPreAndPostLayerWeight
+from lightllm.models.qwen2_vl.model import QWen2VLTokenizer
+from lightllm.models.qwen3.model import Qwen3TpPartModel
+from lightllm.server.core.objs import SamplingParams
+from lightllm.models.qwen3_moe.model import Qwen3MOEModel
+from lightllm.server.multimodal_params import AudioItem, MultimodalParams, ImageItem
+from lightllm.models.neo_chat_moe.vision_process import smart_resize
+from lightllm.models.internvl.model import InternvlTokenizer
+from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer
+from lightllm.models.neo_chat_moe.layer_infer.transformer_layer_infer import NeoChatMOETransformerLayerInfer
+from lightllm.models.llama.infer_struct import LlamaInferStateInfo
+from lightllm.models.neo_chat_moe.layer_weights.transformer_layer_weight import NeoChatMOETransformerLayerWeight
+from lightllm.models.neo_chat_moe.layer_weights.pre_and_post_layer_weight import NeoChatMOEPreAndPostLayerWeight
+from lightllm.common.basemodel.multimodal_tokenizer import BaseMultiModalTokenizer
+from lightllm.models.neo_chat_moe.infer_struct import NeoChatInferStateInfo
+
+IMG_START_TOKEN = "
"
+IMG_END_TOKEN = ""
+IMG_TOKEN = ""
+AUDIO_START_TOKEN = ""
+
+
+class NeoChatTokenizer(BaseMultiModalTokenizer):
+ def __init__(self, tokenizer, model_cfg, **kwargs):
+ super().__init__(tokenizer)
+ self.tokenizer = tokenizer
+ self.min_pixel = model_cfg.get("vision_config").get("min_pixels")
+ self.max_pixel = model_cfg.get("vision_config").get("max_pixels")
+ self.patch_size = model_cfg.get("vision_config").get("patch_size")
+ self.downsample_ratio = model_cfg.get("vision_config").get("downsample_ratio")
+
+ self.image_token_id = model_cfg.get("image_token_id")
+ self.image_start_tag = IMG_START_TOKEN
+ self.image_start_id = tokenizer.convert_tokens_to_ids(self.image_start_tag)
+ self.image_end_tag = IMG_END_TOKEN
+ self.image_end_id = tokenizer.convert_tokens_to_ids(self.image_end_tag)
+
+ def init_imageitem_extral_params(
+ self, img: ImageItem, multi_params: MultimodalParams, sampling_params: SamplingParams
+ ):
+ return
+
+ def init_audioitem_extral_params(
+ self, audio: AudioItem, multi_params: MultimodalParams, sampling_params: SamplingParams
+ ):
+ raise NotImplementedError
+
+ def get_audio_token_length(self, audio: AudioItem):
+ raise NotImplementedError
+
+ def get_image_token_length(self, img: ImageItem):
+ width, height = img.image_w, img.image_h
+ resized_height, resized_width = smart_resize(
+ height=height,
+ width=width,
+ factor=int(self.patch_size // self.downsample_ratio),
+ min_pixels=self.min_pixel,
+ max_pixels=self.max_pixel,
+ )
+ grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
+ token_num = int((grid_h * grid_w) * (self.downsample_ratio ** 2))
+ # 这里的grid_h和grid_w需要* self.downsample_ratio么?再仔细看下代码
+ img.grid_thwd = (1, int(grid_h * self.downsample_ratio), int(grid_w * self.downsample_ratio), 1 - token_num)
+ return token_num
+
+ # only change the impl of the encode func:
+ def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs):
+ # TEXTTEXTTEXT --> TEXT
TEXT
TEXT
+ image_tokens = IMG_START_TOKEN + IMG_END_TOKEN
+ if multimodal_params is None:
+ add_special_tokens = kwargs.get("add_special_tokens", True)
+ return self.tokenizer.encode(prompt, add_special_tokens=add_special_tokens)
+ image_count = len(multimodal_params.images)
+ prompt = prompt.replace(IMG_TOKEN, image_tokens, image_count)
+
+ origin_ids = self.tokenizer.encode(prompt, add_special_tokens=kwargs["add_special_tokens"])
+ #
-->
id,id+1...id+num
+ input_ids = []
+ image_id = 0
+ start_idx = 0
+ while True:
+ try:
+ start_idx = origin_ids.index(self.image_start_id)
+ if start_idx + 1 >= len(origin_ids):
+ break
+ if origin_ids[start_idx + 1] == self.image_end_id:
+ input_ids.extend(origin_ids[: start_idx + 1])
+ token_id = multimodal_params.images[image_id].token_id
+ token_num = multimodal_params.images[image_id].token_num
+ multimodal_params.images[image_id].start_idx = len(input_ids)
+ input_ids.extend(range(token_id, token_id + token_num))
+ input_ids.append(self.image_end_id)
+ origin_ids = origin_ids[start_idx + 2 :]
+ image_id += 1
+ else:
+ raise ValueError("image token error")
+ except ValueError:
+ break
+ input_ids.extend(origin_ids)
+ return input_ids
+
+
+@ModelRegistry(["neo_chat"], is_multimodal=True, condition=llm_model_type_is("qwen3_moe"))
+class NeoTpMOEPartModel(Qwen3MOEModel):
+
+ pre_layer_infer_class = LlamaMultimodalPreLayerInfer
+ transformer_layer_infer_class = NeoChatMOETransformerLayerInfer
+
+ pre_and_post_weight_class = NeoChatMOEPreAndPostLayerWeight
+ transformer_weight_class = NeoChatMOETransformerLayerWeight
+
+ infer_state_class = NeoChatInferStateInfo
+
+ def __init__(self, kvargs):
+ super().__init__(kvargs)
+ return
+
+ def _init_inferstate_cls(self):
+ pass
+
+ def _init_config(self):
+ with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file:
+ all_config = json.load(json_file)
+ self.config = all_config["llm_config"]
+ # rename keys
+ repair_config(self.config, same_names=["num_attention_heads", "n_head"])
+ repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"])
+ repair_config(self.config, same_names=["num_hidden_layers", "n_layer"])
+ if self.finetune_config:
+ self.config["vocab_size"] = self.finetune_config.vocab_size
+ return
diff --git a/lightllm/models/neo_chat_moe/neo_visual.py b/lightllm/models/neo_chat_moe/neo_visual.py
new file mode 100644
index 000000000..852ddc095
--- /dev/null
+++ b/lightllm/models/neo_chat_moe/neo_visual.py
@@ -0,0 +1,279 @@
+import os
+import torch
+import torch.nn.functional as F
+from PIL import Image
+from typing import List
+from io import BytesIO
+import torch.nn as nn
+from transformers.activations import ACT2FN
+from safetensors import safe_open
+from lightllm.server.multimodal_params import ImageItem
+from transformers.modeling_outputs import BaseModelOutputWithPooling
+from transformers.modeling_utils import PreTrainedModel
+from lightllm.models.neo_chat_moe.vision_process import load_image_native
+from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data
+
+
+def apply_rotary_emb_1d(
+ x: torch.Tensor,
+ cos_cached: torch.Tensor,
+ sin_cached: torch.Tensor,
+ positions: torch.Tensor,
+):
+ """对输入张量的一部分应用1D RoPE。"""
+ # x: (..., seq_len, dim_part)
+ # positions: (..., seq_len)
+ # cos_cached: (max_pos, dim_part / 2)
+ cos_cached = cos_cached.to(device=positions.device)
+ sin_cached = sin_cached.to(device=positions.device)
+
+ cos = cos_cached[positions] # Shape: (positions.shape, dim_part / 2)
+ sin = sin_cached[positions] # Shape: (positions.shape, dim_part / 2)
+
+ x1 = x[..., 0::2]
+ x2 = x[..., 1::2]
+
+ rotated_x1 = x1 * cos - x2 * sin
+ rotated_x2 = x1 * sin + x2 * cos
+
+ x_rotated = torch.empty_like(x)
+ x_rotated[..., 0::2] = rotated_x1
+ x_rotated[..., 1::2] = rotated_x2
+ return x_rotated
+
+
+def apply_2d_rotary_pos_emb(
+ x: torch.Tensor,
+ cos_cached_x: torch.Tensor,
+ sin_cached_x: torch.Tensor,
+ cos_cached_y: torch.Tensor,
+ sin_cached_y: torch.Tensor,
+ abs_positions_x: torch.Tensor,
+ abs_positions_y: torch.Tensor,
+):
+ """应用2D RoPE到输入张量x。"""
+ dim = x.shape[-1]
+ dim_half = dim // 2
+
+ # 假设我们将embedding的前半部分用于一个方向的RoPE,后半部分用于另一个方向
+ # 例如,前一半给X坐标,后一半给Y坐标 (或者反过来,但要保持一致)
+ x_part_1 = x[..., :dim_half]
+ x_part_2 = x[..., dim_half:]
+
+ # 将与 abs_positions_x 相关的旋转应用于 x_part_1
+ rotated_part_1 = apply_rotary_emb_1d(x_part_1, cos_cached_x, sin_cached_x, abs_positions_x)
+ # 将与 abs_positions_y 相关的旋转应用于 x_part_2
+ rotated_part_2 = apply_rotary_emb_1d(x_part_2, cos_cached_y, sin_cached_y, abs_positions_y)
+
+ # 将它们重新拼接起来。确保顺序与你分割时一致。
+ return torch.cat((rotated_part_1, rotated_part_2), dim=-1)
+
+
+def build_abs_positions_from_grid_hw(grid_hw: torch.Tensor, device=None):
+ """
+ Compute patch coordinates (x, y)
+
+ Args:
+ grid_hw: (B, 2) tensor representing (H, W) per image
+ """
+ device = grid_hw.device
+ B = grid_hw.shape[0]
+
+ # Get the number of patches per image
+ H = grid_hw[:, 0]
+ W = grid_hw[:, 1]
+ N = H * W
+ N_total = N.sum()
+
+ # Create the batch index for each patch (B x patch count)
+ patch_to_sample = torch.repeat_interleave(torch.arange(B, device=device), N) # (N_total,)
+
+ # Generate intra-image patch index (row-major order)
+ patch_id_within_image = torch.arange(N_total, device=device)
+ patch_id_within_image = (
+ patch_id_within_image
+ - torch.cumsum(torch.cat([torch.tensor([0], device=device), N[:-1]]), dim=0)[patch_to_sample]
+ )
+
+ # Get H/W for each patch according to its image
+ W_per_patch = W[patch_to_sample]
+ abs_x = patch_id_within_image % W_per_patch
+ abs_y = patch_id_within_image // W_per_patch
+
+ return abs_x, abs_y
+
+
+class NeoVisionTransformerPretrainedModel(nn.Module):
+ def __init__(
+ self,
+ kvargs,
+ hidden_size: int = 1024,
+ llm_hidden_size: int = 2048,
+ downsample_ratio: float = 0.5,
+ patch_size: int = 16,
+ num_channels: int = 3,
+ max_position_embeddings_vision: int = 10000,
+ rope_theta_vision: float = 10000.0,
+ min_pixels: int = 65536,
+ max_pixels: int = 2408448,
+ **kwargs,
+ ):
+ super().__init__()
+ self.weight_dir = kvargs["weight_dir"]
+ self.data_type = kvargs.get("data_type", "bfloat16")
+ self.embed_dim = hidden_size
+ self.llm_hidden_size = llm_hidden_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.downsample_ratio = downsample_ratio
+ self.downsample_factor = int(1 / downsample_ratio)
+ self.max_position_embeddings_vision = max_position_embeddings_vision
+ self.rope_theta_vision = rope_theta_vision
+ self.rope_dim_part = self.embed_dim // 2
+ self.min_pixels = min_pixels
+ self.max_pixels = max_pixels
+
+ self.patch_embedding = nn.Conv2d(
+ in_channels=num_channels, out_channels=self.embed_dim, kernel_size=patch_size, stride=patch_size
+ )
+
+ self.dense_embedding = nn.Conv2d(
+ in_channels=self.embed_dim,
+ out_channels=self.llm_hidden_size,
+ kernel_size=self.downsample_factor,
+ stride=self.downsample_factor,
+ )
+ self.gelu = nn.GELU()
+
+ self.repe_dim_part = self.embed_dim // 2
+ self.cos_x, self.sin_x = self.precompute_rope_freqs_sincos()
+ self.cos_y, self.sin_y = self.precompute_rope_freqs_sincos()
+ self._init_datatype()
+
+ def _init_datatype(self):
+ if isinstance(self.data_type, torch.dtype):
+ return
+ if self.data_type in ["fp16", "float16"]:
+ self.data_type = torch.float16
+ elif self.data_type in ["bf16", "bfloat16"]:
+ self.data_type = torch.bfloat16
+ elif self.data_type in ["fp32", "float32"]:
+ self.data_type = torch.float32
+ else:
+ raise ValueError(f"Unsupport datatype {self.data_type}!")
+ return
+
+ def load_model(self, weight_dir):
+ bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")]
+ if bin_weight_files:
+ weight_dict = {}
+ for file_ in bin_weight_files:
+ f = torch.load(os.path.join(weight_dir, file_), "cpu")
+ for k, v in f.items():
+ if "vision_model" in k:
+ weight_dict[k[len("vision_model.embeddings.") :]] = v
+ else:
+ hf_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".safetensors")]
+ weight_dict = {}
+ for file_ in hf_weight_files:
+ f = safe_open(os.path.join(weight_dir, file_), "pt", "cpu")
+ for k in f.keys():
+ if "vision_model" in k:
+ weight_dict[k[len("vision_model.embeddings.") :]] = f.get_tensor(k)
+ self.load_state_dict(weight_dict)
+
+ def precompute_rope_freqs_sincos(self):
+ inv_freq = 1.0 / (
+ self.rope_theta_vision ** (torch.arange(0, self.rope_dim_part, 2).float() / self.rope_dim_part)
+ )
+ t = torch.arange(self.max_position_embeddings_vision).type_as(inv_freq)
+ freqs = torch.outer(t, inv_freq)
+ return torch.cos(freqs), torch.sin(freqs)
+
+ def _apply_2d_rotary_pos_emb(self, patch_embeds, grid_hw):
+ """
+ Apply 2D Rotary Position Embedding to the patch embeddings.
+ """
+ abs_pos_x, abs_pos_y = build_abs_positions_from_grid_hw(grid_hw, device=patch_embeds.device)
+ embeddings = apply_2d_rotary_pos_emb(
+ patch_embeds.to(torch.float32), # RoPE calculations are often more stable in float32
+ self.cos_x,
+ self.sin_x,
+ self.cos_y,
+ self.sin_y,
+ abs_pos_x,
+ abs_pos_y,
+ ).to(self.patch_embedding.weight.dtype)
+ return embeddings
+
+ def forward(self, pixel_values: torch.Tensor, grid_hw: torch.Tensor) -> torch.Tensor:
+ pixel_values = pixel_values.view(
+ -1,
+ 3,
+ self.patch_size,
+ self.patch_size,
+ )
+ patch_embeds = self.gelu(self.patch_embedding(pixel_values)).view(-1, self.embed_dim)
+ patch_embeds = self._apply_2d_rotary_pos_emb(patch_embeds, grid_hw)
+ assert (grid_hw[:, 0] * grid_hw[:, 1]).sum() == patch_embeds.shape[
+ 0
+ ], "Grid size and patch embeds size mismatch."
+
+ patches_list = []
+ cur_position = 0
+ for i in range(grid_hw.shape[0]):
+ h, w = grid_hw[i]
+ patches_per_img = patch_embeds[cur_position : cur_position + h * w].view(h, w, -1).unsqueeze(0)
+ patches_per_img = self.dense_embedding(patches_per_img.permute(0, 3, 1, 2))
+ patches_per_img = patches_per_img.permute(0, 2, 3, 1)
+ patches_list.append(patches_per_img.view(-1, patches_per_img.shape[-1]))
+ cur_position += h * w
+
+ embeddings = torch.cat(patches_list, dim=0) # (N_total // downsample_factor**2, C)
+ assert cur_position == patch_embeds.shape[0]
+ assert embeddings.shape[0] == int(patch_embeds.shape[0] / self.downsample_factor ** 2)
+
+ return embeddings
+
+ def encode(self, images: List[ImageItem]):
+ img_tensors = []
+ valid_ids = []
+ valid_id = 0
+ img_grids = []
+ uuids = []
+
+ for i, img in enumerate(images):
+ if isinstance(img, ImageItem):
+ uuids.append(img.uuid)
+ image_data = read_shm(get_shm_name_data(img.uuid))
+ image_data = Image.open(BytesIO(image_data))
+ pixel_values, image_grid_hw = load_image_native(
+ image_data,
+ patch_size=self.patch_size,
+ downsample_ratio=self.downsample_ratio,
+ min_pixels=self.min_pixels,
+ max_pixels=self.max_pixels,
+ )
+ img_tensors.append(pixel_values)
+ img_grids.append(image_grid_hw)
+ else:
+ raise Exception("Unsupport input types: {} for {}".format(type(img), img))
+
+ # must devide merge_length
+ cur_num = int(img_tensors[-1].shape[0] * (self.downsample_ratio ** 2))
+ print(f"cur_num is {cur_num}")
+ valid_ids.append([valid_id, valid_id + cur_num])
+ valid_id += cur_num
+
+ if len(img_tensors) <= 0:
+ return None
+
+ imgs = torch.cat(img_tensors, dim=0)
+ grid_hw = torch.cat(img_grids, dim=0)
+
+ pixel_values = imgs.to("cuda", dtype=self.data_type, non_blocking=True)
+ image_grid_hw = grid_hw.to("cuda", non_blocking=True)
+
+ all_img_embeds = self.forward(pixel_values, grid_hw=image_grid_hw)
+
+ return all_img_embeds, uuids, valid_ids
diff --git a/lightllm/models/neo_chat_moe/triton_kernel/__init__.py b/lightllm/models/neo_chat_moe/triton_kernel/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py b/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py
new file mode 100644
index 000000000..f5dae493c
--- /dev/null
+++ b/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py
@@ -0,0 +1,452 @@
+import math
+import torch
+import triton
+import triton.language as tl
+
+from lightllm.utils.device_utils import is_tesla
+
+
+@triton.jit
+def _fwd_kernel(
+ Q,
+ K,
+ V,
+ sm_scale,
+ Out,
+ position_ids, # 1D: packed like Q (only NEW tokens), length == Q.shape[0]
+ B_Start_Loc,
+ B_Seqlen,
+ Req_to_tokens,
+ B_req_idx,
+ stride_qbs,
+ stride_qh,
+ stride_qd,
+ stride_kbs,
+ stride_kh,
+ stride_kd,
+ stride_vbs,
+ stride_vh,
+ stride_vd,
+ stride_obs,
+ stride_oh,
+ stride_od,
+ stride_req_to_tokens_b,
+ stride_req_to_tokens_s,
+ kv_group_num,
+ b_prompt_cache_len,
+ H: tl.constexpr,
+ BLOCK_DMODEL: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+):
+ start_m = tl.program_id(0)
+ cur_bh = tl.program_id(1)
+ cur_batch = cur_bh // H
+ cur_head = cur_bh % H
+
+ cur_kv_head = cur_head // kv_group_num
+
+ cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
+ prompt_cache_len = tl.load(b_prompt_cache_len + cur_batch)
+ total_len = tl.load(B_Seqlen + cur_batch)
+ cur_batch_seq_len = total_len - prompt_cache_len # NEW len
+ cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
+
+ block_start_loc = BLOCK_M * start_m
+
+ offs_n = tl.arange(0, BLOCK_N)
+ offs_d = tl.arange(0, BLOCK_DMODEL)
+ offs_m = block_start_loc + tl.arange(0, BLOCK_M)
+
+ # Q pointers
+ off_q = (
+ (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs
+ + cur_head * stride_qh
+ + offs_d[None, :] * stride_qd
+ )
+
+ q_valid = offs_m < cur_batch_seq_len
+ q = tl.load(Q + off_q, mask=q_valid[:, None], other=0.0)
+
+ # online softmax state
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
+ acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
+
+ block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
+ block_end_loc = total_len
+
+ # absolute q positions in the request
+ q_pos = prompt_cache_len + offs_m # [M]
+
+ # q_gid from packed position_ids (aligned with Q rows)
+ q_gid = tl.load(
+ position_ids + cur_batch_in_all_start_index + offs_m,
+ mask=q_valid,
+ other=-2147483648,
+ ).to(tl.int32)
+
+ BIG = tl.full([BLOCK_N], 1000000000, tl.int32) # ensure != any normal gid
+
+ for start_n in range(0, block_mask * block_end_loc, BLOCK_N):
+ start_n = tl.multiple_of(start_n, BLOCK_N)
+
+ k_pos = start_n + offs_n # [N]
+ k_valid = k_pos < block_end_loc
+
+ # map logical pos -> mem_index (for K/V)
+ kv_loc = tl.load(
+ Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * k_pos,
+ mask=k_valid,
+ other=0,
+ ).to(tl.int64)
+
+ # load K
+ off_k = kv_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd
+ k = tl.load(K + off_k, mask=k_valid[None, :], other=0.0)
+
+ qk = tl.dot(q, k)
+
+ # k_gid:
+ # - for cached keys (k_pos < prompt_cache_len): set to BIG + k_pos so equality is always false
+ # - for new keys (k_pos >= prompt_cache_len): read from packed position_ids by (k_pos - prompt_cache_len)
+ k_in_new = k_pos >= prompt_cache_len
+ k_new_idx = (k_pos - prompt_cache_len).to(tl.int32) # [N] valid only when k_in_new
+ k_gid_new = tl.load(
+ position_ids + cur_batch_in_all_start_index + k_new_idx,
+ mask=k_valid & k_in_new,
+ other=-2147483647,
+ ).to(tl.int32)
+
+ k_gid = tl.where(
+ k_in_new,
+ k_gid_new,
+ (k_pos.to(tl.int32) + BIG),
+ )
+
+ # mask: causal OR same gid (only possible inside NEW part)
+ mask = (q_pos[:, None] >= k_pos[None, :]) | (q_gid[:, None] == k_gid[None, :])
+ mask = mask & q_valid[:, None] & k_valid[None, :]
+
+ qk = tl.where(mask, qk * sm_scale, -1.0e8)
+
+ # online softmax
+ m_ij = tl.maximum(m_i, tl.max(qk, 1))
+ qk -= m_ij[:, None]
+ p = tl.math.exp2(qk)
+ l_ij = tl.sum(p, 1)
+
+ alpha = tl.math.exp2(m_i - m_ij)
+ l_i = l_i * alpha + l_ij
+ acc = acc * alpha[:, None]
+
+ # load V
+ off_v = kv_loc[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd
+ v = tl.load(V + off_v, mask=k_valid[:, None], other=0.0)
+
+ p = p.to(v.dtype)
+ acc = tl.dot(p, v, acc)
+
+ m_i = m_ij
+
+ acc = acc / l_i[:, None]
+
+ off_o = (
+ (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs
+ + cur_head * stride_oh
+ + offs_d[None, :] * stride_od
+ )
+ tl.store(Out + off_o, acc, mask=q_valid[:, None])
+
+
+@torch.no_grad()
+def context_attention_fwd_neo(
+ q,
+ k,
+ v,
+ o,
+ position_ids, # 1D packed like q (only NEW tokens)
+ b_req_idx,
+ b_start_loc,
+ b_seq_len,
+ b_prompt_cache_len,
+ max_input_len,
+ req_to_token_indexs,
+):
+ # minimal safety: position_ids must cover packed q rows
+ assert position_ids.numel() >= q.shape[0], (position_ids.numel(), q.shape[0])
+
+ BLOCK_M = 128 if not is_tesla() else 64
+
+ Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
+ assert Lq == Lk and Lk == Lv
+ assert Lk in {16, 32, 64, 128, 256}
+ base_head_dim = Lq // 2
+ sm_scale = 1.0 / (base_head_dim ** 0.5) * 1.4426950408889634
+
+ batch, head = b_seq_len.shape[0], q.shape[1]
+ kv_group_num = q.shape[1] // k.shape[1]
+
+ grid = lambda meta: (triton.cdiv(max_input_len, meta["BLOCK_M"]), batch * head, 1)
+
+ BLOCK_N = BLOCK_M
+ num_warps = 4 if Lk <= 64 else 8
+ num_stages = 1
+
+ _fwd_kernel[grid](
+ q,
+ k,
+ v,
+ sm_scale,
+ o,
+ position_ids,
+ b_start_loc,
+ b_seq_len,
+ req_to_token_indexs,
+ b_req_idx,
+ q.stride(0),
+ q.stride(1),
+ q.stride(2),
+ k.stride(0),
+ k.stride(1),
+ k.stride(2),
+ v.stride(0),
+ v.stride(1),
+ v.stride(2),
+ o.stride(0),
+ o.stride(1),
+ o.stride(2),
+ req_to_token_indexs.stride(0),
+ req_to_token_indexs.stride(1),
+ kv_group_num=kv_group_num,
+ b_prompt_cache_len=b_prompt_cache_len,
+ H=head,
+ BLOCK_DMODEL=Lk,
+ BLOCK_M=BLOCK_M,
+ BLOCK_N=BLOCK_N,
+ num_warps=num_warps,
+ num_stages=num_stages,
+ )
+
+
+def reference_attention(
+ q,
+ k,
+ v,
+ position_ids_q, # 1D packed like q (only NEW tokens)
+ b_req_idx,
+ b_start_loc,
+ b_seq_len,
+ b_prompt_cache_len,
+ req_to_token_indexs,
+):
+ device = q.device
+ dtype = q.dtype
+ sum_q, Hq, D = q.shape
+ Hk = k.shape[1]
+ kv_group_num = Hq // Hk
+
+ batch = b_seq_len.shape[0]
+ out = torch.empty_like(q)
+ scale = 1.0 / math.sqrt(D)
+
+ for b in range(batch):
+ req = int(b_req_idx[b].item())
+ total_len = int(b_seq_len[b].item())
+ prompt_len = int(b_prompt_cache_len[b].item())
+ new_len = total_len - prompt_len
+
+ q_start = int(b_start_loc[b].item())
+ q_blk = q[q_start : q_start + new_len] # [M, Hq, D]
+ gid_new = position_ids_q[q_start : q_start + new_len].to(torch.int64) # [M]
+
+ # gather K/V for full request by logical pos -> mem_index
+ token_locs = req_to_token_indexs[req, :total_len].to(torch.int64) # [L]
+ k_blk = k[token_locs] # [L, Hk, D]
+ v_blk = v[token_locs] # [L, Hk, D]
+
+ # expand kv heads to q heads (GQA)
+ k_hq = k_blk.repeat_interleave(kv_group_num, dim=1) # [L, Hq, D]
+ v_hq = v_blk.repeat_interleave(kv_group_num, dim=1) # [L, Hq, D]
+
+ # positions
+ q_pos = torch.arange(prompt_len, total_len, device=device, dtype=torch.int64) # [M]
+ k_pos = torch.arange(0, total_len, device=device, dtype=torch.int64) # [L]
+
+ # build allow mask:
+ # causal always
+ allow = k_pos[None, :] <= q_pos[:, None]
+
+ # full-attn only inside NEW part by gid
+ # compare only when k_pos in NEW
+ k_in_new = k_pos >= prompt_len
+ k_rel = (k_pos - prompt_len).clamp_min(0) # [L]
+ # map k_rel to gid_new, but only valid where k_in_new
+ k_gid = torch.empty((total_len,), device=device, dtype=torch.int64)
+ k_gid[:] = 10 ** 12 + k_pos # never equal to gid_new
+ k_gid[k_in_new] = gid_new[k_rel[k_in_new]]
+
+ allow = allow | (gid_new[q_pos - prompt_len][:, None] == k_gid[None, :])
+
+ # scores: [Hq, M, L]
+ q_t = q_blk.permute(1, 0, 2).to(torch.float32) # [Hq, M, D]
+ k_t = k_hq.permute(1, 2, 0).to(torch.float32) # [Hq, D, L]
+ scores = torch.matmul(q_t, k_t) * scale # [Hq, M, L]
+
+ neg = torch.tensor(-1.0e9, device=device, dtype=torch.float32)
+ scores = torch.where(allow[None, :, :], scores, neg)
+
+ p = torch.softmax(scores, dim=-1).to(torch.float32) # [Hq, M, L]
+ v_t = v_hq.permute(1, 0, 2).to(torch.float32) # [Hq, L, D]
+ out_hq = torch.matmul(p, v_t) # [Hq, M, D]
+ out_blk = out_hq.permute(1, 0, 2).to(dtype) # [M, Hq, D]
+
+ out[q_start : q_start + new_len] = out_blk
+
+ return out
+
+
+def make_test_case(
+ device="cuda",
+ dtype=torch.float16,
+ batch=3,
+ Hq=8,
+ Hk=4,
+ D=64,
+ seed=0,
+ base_index=50000,
+):
+ torch.manual_seed(seed)
+
+ # prompt (cached) len and new len
+ prompt_lens = torch.randint(low=2, high=8, size=(batch,), device=device)
+ new_lens = torch.randint(low=1, high=8, size=(batch,), device=device)
+ total_lens = (prompt_lens + new_lens).to(torch.int32)
+
+ max_total_len = int(total_lens.max().item())
+ max_new_len = int(new_lens.max().item())
+
+ # packed q start
+ b_start_loc = torch.zeros((batch,), device=device, dtype=torch.int32)
+ cur = 0
+ for b in range(batch):
+ b_start_loc[b] = cur
+ cur += int(new_lens[b].item())
+ sum_q = cur
+
+ b_seq_len = total_lens
+ b_prompt_cache_len = prompt_lens.to(torch.int32)
+
+ # one req per batch
+ num_req = batch
+ b_req_idx = torch.arange(batch, device=device, dtype=torch.int32)
+
+ # global KV space large, indices not small
+ sum_kv = int(total_lens.sum().item())
+ kv_size = base_index + sum_kv + 1024
+ pool = torch.randperm(kv_size - base_index, device=device, dtype=torch.int64)[:sum_kv] + base_index
+
+ # Req_to_tokens [num_req, max_total_len]
+ req_to_token_indexs = torch.zeros((num_req, max_total_len), device=device, dtype=torch.int32)
+ p = 0
+ for r in range(num_req):
+ L = int(total_lens[r].item())
+ req_to_token_indexs[r, :L] = pool[p : p + L].to(torch.int32)
+ p += L
+
+ # position_ids_q: only NEW tokens, packed like q
+ position_ids_q = torch.empty((sum_q,), device=device, dtype=torch.int32)
+ for b in range(batch):
+ M = int(new_lens[b].item())
+ start = int(b_start_loc[b].item())
+
+ gid = torch.arange(M, device=device, dtype=torch.int32)
+
+ # make one repeated block inside NEW part to simulate image tokens
+ if M >= 4 and torch.rand((), device=device).item() > 0.3:
+ s = int(torch.randint(0, M - 2, (1,), device=device).item())
+ e = min(M, s + 3)
+ gid[s:e] = gid[s]
+
+ position_ids_q[start : start + M] = gid
+
+ q = torch.randn((sum_q, Hq, D), device=device, dtype=dtype)
+ k = torch.randn((kv_size, Hk, D), device=device, dtype=dtype)
+ v = torch.randn((kv_size, Hk, D), device=device, dtype=dtype)
+ o = torch.empty((sum_q, Hq, D), device=device, dtype=dtype)
+
+ return (
+ q,
+ k,
+ v,
+ o,
+ position_ids_q,
+ b_req_idx,
+ b_start_loc,
+ b_seq_len,
+ b_prompt_cache_len,
+ max_new_len,
+ req_to_token_indexs,
+ )
+
+
+def check_once(device="cuda", dtype=torch.float16, seed=0):
+ (
+ q,
+ k,
+ v,
+ o,
+ position_ids_q,
+ b_req_idx,
+ b_start_loc,
+ b_seq_len,
+ b_prompt_cache_len,
+ max_new_len,
+ req_to_token_indexs,
+ ) = make_test_case(device=device, dtype=dtype, seed=seed)
+
+ context_attention_fwd_neo(
+ q,
+ k,
+ v,
+ o,
+ position_ids_q,
+ b_req_idx,
+ b_start_loc,
+ b_seq_len,
+ b_prompt_cache_len,
+ max_new_len,
+ req_to_token_indexs,
+ )
+
+ ref = reference_attention(
+ q,
+ k,
+ v,
+ position_ids_q,
+ b_req_idx,
+ b_start_loc,
+ b_seq_len,
+ b_prompt_cache_len,
+ req_to_token_indexs,
+ )
+
+ diff = (o - ref).abs()
+ max_abs = diff.max().item()
+ denom = ref.abs().max().item() + 1e-6
+ max_rel = max_abs / denom
+
+ print(f"seed={seed}, dtype={dtype}")
+ print(f"max_abs_error = {max_abs:.6e}")
+ print(f"max_rel_error = {max_rel:.6e}")
+ print("allclose(fp16 tol)?", torch.allclose(o, ref, atol=5e-2, rtol=5e-2))
+
+
+if __name__ == "__main__":
+ if not torch.cuda.is_available():
+ print("No CUDA, skip.")
+ else:
+ torch.cuda.synchronize()
+ check_once(dtype=torch.bfloat16, seed=0)
+ check_once(dtype=torch.bfloat16, seed=1)
+ check_once(dtype=torch.bfloat16, seed=2)
diff --git a/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py b/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py
new file mode 100644
index 000000000..955f48bd8
--- /dev/null
+++ b/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py
@@ -0,0 +1,174 @@
+import torch
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _get_neo_position_triton(
+ b_image_start_idx: torch.Tensor,
+ b_image_thwd: torch.Tensor,
+ b_image_thwd_stride0: torch.Tensor,
+ b_image_nums: torch.Tensor,
+ b_image_start_num: torch.Tensor,
+ b_image_len: torch.Tensor,
+ position_ids: torch.Tensor,
+ position_ids_stride0: torch.Tensor,
+ b_ready_cache_len: torch.Tensor,
+ b_q_seq_len: torch.Tensor,
+ b_start_loc: torch.Tensor,
+ BLOCK_SIZE: tl.constexpr,
+) -> torch.Tensor:
+ cur_batch = tl.program_id(0)
+ cache_len = tl.load(b_ready_cache_len + cur_batch)
+ q_seq_len = tl.load(b_q_seq_len + cur_batch)
+ image_num = tl.load(b_image_nums + cur_batch)
+ image_start_num = tl.load(b_image_start_num + cur_batch)
+ start_loc = tl.load(b_start_loc + cur_batch)
+ for i in range(image_num):
+ local_image_start_idx = tl.load(b_image_start_idx + image_start_num + i)
+ image_start_idx = start_loc + local_image_start_idx - cache_len
+ image_len = tl.load(b_image_len + image_start_num + i)
+ # image_h = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 1)
+ image_w = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 2)
+ for j in range(0, image_len, BLOCK_SIZE):
+ off = j + tl.arange(0, BLOCK_SIZE)
+ # 目前没考虑视频,所以t 恒为 0
+ t_pos = local_image_start_idx + off * 0
+ h_pos = off // image_w
+ w_pos = off % image_w
+ tl.store(
+ position_ids + off + image_start_idx,
+ t_pos,
+ mask=(off < image_len)
+ & (off + local_image_start_idx - cache_len < q_seq_len)
+ & (local_image_start_idx - cache_len + off >= 0),
+ )
+ tl.store(
+ position_ids + position_ids_stride0 + off + image_start_idx,
+ h_pos,
+ mask=(off < image_len)
+ & (off + local_image_start_idx - cache_len < q_seq_len)
+ & (local_image_start_idx - cache_len + off >= 0),
+ )
+ tl.store(
+ position_ids + position_ids_stride0 * 2 + off + image_start_idx,
+ w_pos,
+ mask=(off < image_len)
+ & (off + local_image_start_idx - cache_len < q_seq_len)
+ & (local_image_start_idx - cache_len + off >= 0),
+ )
+
+ for i in range(image_num):
+ local_image_start_idx = tl.load(b_image_start_idx + image_start_num + i)
+ image_len = tl.load(b_image_len + image_start_num + i)
+ image_delta = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 3)
+ image_end = local_image_start_idx + image_len - cache_len
+ text_start = tl.maximum(0, image_end)
+ for j in range(text_start, q_seq_len, BLOCK_SIZE):
+ off = j + tl.arange(0, BLOCK_SIZE)
+ t_pos = tl.load(position_ids + off + start_loc, mask=(off < q_seq_len), other=0.0) + image_delta
+ h_pos = tl.load(position_ids + position_ids_stride0 + off + start_loc, mask=(off < q_seq_len), other=0.0)
+ w_pos = tl.load(
+ position_ids + position_ids_stride0 * 2 + off + start_loc, mask=(off < q_seq_len), other=0.0
+ )
+ tl.store(position_ids + off + start_loc, t_pos, mask=(off < q_seq_len))
+ tl.store(position_ids + position_ids_stride0 + off + start_loc, h_pos, mask=(off < q_seq_len))
+ tl.store(position_ids + position_ids_stride0 * 2 + off + start_loc, w_pos, mask=(off < q_seq_len))
+ return
+
+
+def get_neo_position_triton(
+ b_image_start_idx: torch.Tensor,
+ b_image_thwd: torch.Tensor,
+ b_image_nums: torch.Tensor,
+ b_image_start_num: torch.Tensor,
+ b_image_len: torch.Tensor,
+ position_ids: torch.Tensor,
+ b_ready_cache_len: torch.Tensor,
+ b_q_seq_len: torch.Tensor,
+ b_start_loc: torch.Tensor,
+) -> torch.Tensor:
+
+ batch_size = b_q_seq_len.shape[0]
+ assert batch_size == b_image_nums.shape[0]
+ grid = (batch_size,)
+ BLOCK_SIZE = 64
+ _get_neo_position_triton[grid](
+ b_image_start_idx=b_image_start_idx,
+ b_image_thwd=b_image_thwd,
+ b_image_thwd_stride0=b_image_thwd.stride(0),
+ b_image_nums=b_image_nums,
+ b_image_start_num=b_image_start_num,
+ b_image_len=b_image_len,
+ position_ids=position_ids,
+ position_ids_stride0=position_ids.stride(0),
+ b_ready_cache_len=b_ready_cache_len,
+ b_q_seq_len=b_q_seq_len,
+ b_start_loc=b_start_loc,
+ BLOCK_SIZE=BLOCK_SIZE,
+ )
+
+
+def test():
+ b_image_start_idx = torch.tensor([0, 0, 4], dtype=torch.int32, device="cuda")
+ b_image_thwd = torch.tensor([[1, 2, 2, -3], [1, 2, 2, -3], [1, 2, 4, -7]], dtype=torch.int32, device="cuda")
+ b_image_nums = torch.tensor([1, 2], dtype=torch.int32, device="cuda")
+ b_image_start_num = torch.tensor([0, 1], dtype=torch.int32, device="cuda")
+ b_image_len = torch.tensor([4, 4, 8], dtype=torch.int32, device="cuda")
+ position_ids = (
+ torch.tensor([0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=torch.int32, device="cuda")
+ .unsqueeze(0)
+ .expand(3, -1)
+ .contiguous()
+ )
+ position_ids[1:].zero_()
+ b_ready_cache_len = torch.tensor([0, 0], dtype=torch.int32, device="cuda")
+ b_q_seq_len = torch.tensor([7, 13], dtype=torch.int32, device="cuda")
+ b_start_loc = torch.tensor([0, 7], dtype=torch.int32, device="cuda")
+ get_neo_position_triton(
+ b_image_start_idx,
+ b_image_thwd,
+ b_image_nums,
+ b_image_start_num,
+ b_image_len,
+ position_ids,
+ b_ready_cache_len,
+ b_q_seq_len,
+ b_start_loc,
+ )
+
+ print(position_ids)
+ # old_value = torch.cat([position_ids[:, 2:7], position_ids[:, 7 + 2 :]], dim=1)
+
+ # position_ids = (
+ # torch.tensor([2, 3, 4, 5, 6, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=torch.int32, device="cuda")
+ # .unsqueeze(0)
+ # .expand(3, -1)
+ # .contiguous()
+ # )
+ # b_ready_cache_len = torch.tensor([2, 2], dtype=torch.int32, device="cuda")
+ # b_q_seq_len = torch.tensor([5, 11], dtype=torch.int32, device="cuda")
+ # b_start_loc = torch.tensor([0, 5], dtype=torch.int32, device="cuda")
+
+ # get_neo_position_triton(
+ # b_image_start_idx,
+ # b_image_thwd,
+ # b_image_nums,
+ # b_image_start_num,
+ # b_image_len,
+ # position_ids,
+ # b_ready_cache_len,
+ # b_q_seq_len,
+ # b_start_loc,
+ # )
+
+ # print(f"old_value:\n{old_value}")
+ # print(f"position_ids:\n{position_ids}")
+ # assert torch.equal(old_value, position_ids)
+
+ """
+ tensor([[0, 0, 0, 0, 2, 3, 4, 0, 0, 0, 0, 2, 2, 2, 2, 4, 5, 6, 7, 8],
+ [0, 0, 1, 1, 2, 3, 4, 0, 0, 1, 1, 2, 2, 3, 3, 4, 5, 6, 7, 8],
+ [0, 1, 0, 1, 2, 3, 4, 0, 1, 0, 1, 2, 3, 2, 3, 4, 5, 6, 7, 8]],
+ device='cuda:0', dtype=torch.int32)
+ """
diff --git a/lightllm/models/neo_chat_moe/vision_process.py b/lightllm/models/neo_chat_moe/vision_process.py
new file mode 100644
index 000000000..aa008e18f
--- /dev/null
+++ b/lightllm/models/neo_chat_moe/vision_process.py
@@ -0,0 +1,141 @@
+import re
+import math
+import torch
+import string
+import numpy as np
+import pandas as pd
+from PIL import Image
+import torch.distributed as dist
+import torchvision.transforms as T
+
+IMAGENET_MEAN = (0.485, 0.456, 0.406)
+IMAGENET_STD = (0.229, 0.224, 0.225)
+
+
+def round_by_factor(number: int, factor: int) -> int:
+ """Returns the closest integer to 'number' that is divisible by 'factor'."""
+ return round(number / factor) * factor
+
+
+def ceil_by_factor(number: int, factor: int) -> int:
+ """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
+ return math.ceil(number / factor) * factor
+
+
+def floor_by_factor(number: int, factor: int) -> int:
+ """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
+ return math.floor(number / factor) * factor
+
+
+# copy from https://github.com/QwenLM/Qwen2.5-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py#L60
+def smart_resize(
+ height: int, width: int, factor: int = 32, min_pixels: int = 65536, max_pixels: int = 4194304
+) -> tuple[int, int]:
+ """
+ Rescales the image so that the following conditions are met:
+
+ 1. Both dimensions (height and width) are divisible by 'factor'.
+
+ 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
+
+ 3. The aspect ratio of the image is maintained as closely as possible.
+ """
+ if max(height, width) / min(height, width) > 200:
+ raise ValueError(
+ f"absolute aspect ratio must be smaller than {200}, got {max(height, width) / min(height, width)}"
+ )
+ h_bar = max(factor, round_by_factor(height, factor))
+ w_bar = max(factor, round_by_factor(width, factor))
+ if h_bar * w_bar > max_pixels:
+ beta = math.sqrt((height * width) / max_pixels)
+ h_bar = max(factor, floor_by_factor(height / beta, factor))
+ w_bar = max(factor, floor_by_factor(width / beta, factor))
+ elif h_bar * w_bar < min_pixels:
+ beta = math.sqrt(min_pixels / (height * width))
+ h_bar = ceil_by_factor(height * beta, factor)
+ w_bar = ceil_by_factor(width * beta, factor)
+ return h_bar, w_bar
+
+
+def dynamic_preprocess_native_resolution(image, size_factor=32, min_pixels=65536, max_pixels=4194304, **kwargs):
+ width, height = image.size
+ resized_height, resized_width = smart_resize(
+ height,
+ width,
+ factor=size_factor,
+ min_pixels=min_pixels,
+ max_pixels=max_pixels,
+ )
+ image = image.resize((resized_width, resized_height))
+
+ return image
+
+
+def preprocess_pixel_values(pixel_values, patch_size=16):
+ c, h, w = pixel_values.shape
+ grid_h = h // patch_size
+ grid_w = w // patch_size
+
+ flatten_pixel_values = (
+ pixel_values.view(c, grid_h, patch_size, grid_w, patch_size)
+ .permute(1, 3, 0, 2, 4) # [grid_h, grid_w, c, patch_size, patch_size]
+ .reshape(grid_h * grid_w, c * patch_size ** 2)
+ )
+
+ grid_hw = torch.tensor([[grid_h, grid_w]]).to(device=pixel_values.device)
+
+ return flatten_pixel_values, grid_hw
+
+
+def get_contrasting_background(image):
+ """
+ Calculate the color (white or black) that is different from the average foreground color
+ to use as the background color
+ """
+ image_np = np.array(image)
+ if (image_np[:, :, 3] == 0).any():
+ non_transparent_pixels = image_np[:, :, :3][image_np[:, :, 3] > 0]
+ if non_transparent_pixels.size == 0:
+ return None
+ pixel_mean = non_transparent_pixels.mean()
+ contrasting_color = (0, 0, 0) if pixel_mean > 382.5 else (255, 255, 255)
+ return contrasting_color
+ else:
+ return None
+
+
+def load_image_native(image, patch_size=16, downsample_ratio=0.5, min_pixels=65536, max_pixels=4194304, upscale=False):
+ """
+ Load and preprocess an image file, converting it to RGB mode,
+ resizing, normalizing, and optionally adding a thumbnail version.
+ """
+ if image.mode == "RGBA":
+ bg_color = get_contrasting_background(image)
+ if bg_color:
+ background = Image.new("RGB", image.size, bg_color)
+ background.paste(image, mask=image.split()[3])
+ image = background.convert("RGB")
+ else:
+ image = image.convert("RGB")
+ else:
+ image = image.convert("RGB")
+
+ if upscale:
+ image = image.resize((image.width * 2, image.height * 2), Image.BILINEAR)
+
+ transform = T.Compose(
+ [
+ T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
+ T.ToTensor(),
+ T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
+ ]
+ )
+
+ new_image = dynamic_preprocess_native_resolution(
+ image, size_factor=int(patch_size // downsample_ratio), min_pixels=min_pixels, max_pixels=max_pixels
+ )
+ pixel_values, grid_hw = preprocess_pixel_values(transform(new_image).to(torch.float32), patch_size=patch_size)
+
+ print(f"Transfer image_size from ({image.height, image.width}) to ({new_image.height, new_image.width})")
+
+ return pixel_values, grid_hw
diff --git a/lightllm/server/tokenizer.py b/lightllm/server/tokenizer.py
index e0b2bd425..3563739f7 100644
--- a/lightllm/server/tokenizer.py
+++ b/lightllm/server/tokenizer.py
@@ -30,6 +30,7 @@
from ..models.qwen2_vl.model import QWen2VLTokenizer
from ..models.qwen3_vl.model import QWen3VLTokenizer
from ..models.internvl.model import InternvlTokenizer
+from ..models.neo_chat_moe.model import NeoChatTokenizer
from ..models.gemma3.model import Gemma3Tokenizer
# A fast LLaMA tokenizer with the pre-processed `tokenizer.json` file.
@@ -104,5 +105,7 @@ def get_tokenizer(
tokenizer = InternvlTokenizer(tokenizer, model_cfg, weight_dir=tokenizer_name)
elif model_type == "gemma3":
tokenizer = Gemma3Tokenizer(tokenizer, model_cfg)
+ elif model_type == "neo_chat":
+ tokenizer = NeoChatTokenizer(tokenizer, model_cfg, weight_dir=tokenizer_name)
return tokenizer
diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py
index d3d1610f3..df5d66bcb 100644
--- a/lightllm/server/visualserver/model_infer/model_rpc.py
+++ b/lightllm/server/visualserver/model_infer/model_rpc.py
@@ -19,6 +19,7 @@
from lightllm.models.qwen2_5_vl.qwen2_5_visual import Qwen2_5_VisionTransformerPretrainedModel
from lightllm.models.qwen3_vl.qwen3_visual import Qwen3VisionTransformerPretrainedModel
from lightllm.models.tarsier2.tarsier2_visual import TarsierVisionTransformerPretrainedModel
+from lightllm.models.neo_chat_moe.neo_visual import NeoVisionTransformerPretrainedModel
from lightllm.utils.infer_utils import set_random_seed
from lightllm.utils.dist_utils import init_vision_distributed_env
from lightllm.utils.graceful_utils import graceful_registry
@@ -78,6 +79,8 @@ def exposed_init_model(self, kvargs):
# self.model = InternVLVisionModel()
elif self.model_type == "gemma3":
self.model = Gemma3VisionModel()
+ elif self.model_type == "neo_chat":
+ self.model = NeoVisionTransformerPretrainedModel(kvargs, **model_cfg["vision_config"]).eval().bfloat16()
else:
raise Exception(f"can not support {self.model_type} now")