From a6b233663a0bf7c49885ed89584906cebbe69545 Mon Sep 17 00:00:00 2001 From: tangzhiyi Date: Sat, 18 Apr 2026 11:36:29 +0000 Subject: [PATCH 01/12] [Ascend] MTP speculative decoding support (qwen3_5_mtp_final_2) - ascend_cudagraph.py: multi-token decode graph mode support (4-tuple graph key with query_len, actual_seq_lengths_q buffers) - device/__init__.py: add patch_attention_is_tp (draft model TP), patch_ray_init (NPU Ray resource), MTP multi-token paths in GatedDelta conv1d and sigmoid_gating update kernels Co-Authored-By: Claude Opus 4.6 --- .../cudagraph/ascend_cudagraph.py | 114 ++++++++--- .../framework/lmdeploy_ext/device/__init__.py | 189 +++++++++++++++++- 2 files changed, 274 insertions(+), 29 deletions(-) diff --git a/dlinfer/framework/lmdeploy_ext/cudagraph/ascend_cudagraph.py b/dlinfer/framework/lmdeploy_ext/cudagraph/ascend_cudagraph.py index 314e4b0c..f84686c7 100644 --- a/dlinfer/framework/lmdeploy_ext/cudagraph/ascend_cudagraph.py +++ b/dlinfer/framework/lmdeploy_ext/cudagraph/ascend_cudagraph.py @@ -54,28 +54,38 @@ def AscendCudaGraphMixin_make_buffers_cudagraph( (1, max_tokens), dtype=torch.int32, device=device ) + # multi-token decode expands block_offsets per-token, so size by + # max_tokens; for single-token decode max_tokens == max_batches so + # this is backward-compatible. input_buffers["block_offsets"] = torch.zeros( - (max_batches, num_blocks), dtype=torch.int32, device=device + (max_tokens, num_blocks), dtype=torch.int32, device=device ) input_buffers["q_seqlens"] = torch.ones( max_batches, dtype=torch.int32, device=device ) - input_buffers["kv_seqlens"] = torch.ones(max_batches, dtype=torch.int32) + # kv_seqlens and kv_start_indices are per-token for paged-prefill + # (multi-token decode); use max_tokens to accommodate both cases. + input_buffers["kv_seqlens"] = torch.ones(max_tokens, dtype=torch.int32) input_buffers["q_start_loc"] = torch.arange( max_batches + 1, dtype=torch.int32, device=device ) input_buffers["kv_start_indices"] = -torch.ones( - (max_batches), dtype=torch.int32, device=device + (max_tokens), dtype=torch.int32, device=device ) input_buffers["x_active_mask"] = torch.zeros( (max_batches), dtype=torch.bool, device=device ) + # actual_seq_lengths_q for multi-token decode (CPU tensor, cumulative) + input_buffers["actual_seq_lengths_q"] = torch.zeros( + max_batches, dtype=torch.int32 + ) + # ssm if graph_meta.is_ssm: input_buffers["state_ids"] = torch.full( @@ -111,8 +121,11 @@ def AscendCudaGraphMixin_fill_buffers_cudagraph( input_buffers: BuffType = graph_meta.input_buffers - batch_size, num_blocks = block_offsets.size() + expanded_batch_size, num_blocks = block_offsets.size() num_tokens = input_ids.size(-1) + # q_seqlens is per-sequence (not expanded), so its size gives the + # true number of sequences even for multi-token decode. + num_seqs = kv_seqlens.size(0) # fill buffer max_num_tokens = input_buffers["input_ids"].size(-1) @@ -125,19 +138,26 @@ def AscendCudaGraphMixin_fill_buffers_cudagraph( input_buffers["position_ids"].zero_() input_buffers["position_ids"][:, :num_tokens] = position_ids input_buffers["block_offsets"].zero_() - input_buffers["block_offsets"][:batch_size, :num_blocks] = block_offsets + input_buffers["block_offsets"][:expanded_batch_size, :num_blocks] = block_offsets input_buffers["kv_seqlens"].fill_(0) - input_buffers["kv_seqlens"][:batch_size] = kv_seqlens + input_buffers["kv_seqlens"][:num_seqs] = kv_seqlens input_buffers["kv_start_indices"].fill_(-1) - input_buffers["kv_start_indices"][:batch_size] = kv_start_indices + input_buffers["kv_start_indices"][:kv_start_indices.size(0)] = kv_start_indices if x_active_mask is not None: input_buffers["x_active_mask"].fill_(0) - input_buffers["x_active_mask"][:batch_size] = x_active_mask + input_buffers["x_active_mask"][:x_active_mask.size(0)] = x_active_mask + + # multi-token decode: fill actual_seq_lengths_q + actual_seq_lengths_q = getattr(attn_metadata, 'actual_seq_lengths_q', None) + if actual_seq_lengths_q is not None: + input_buffers["actual_seq_lengths_q"].zero_() + input_buffers["actual_seq_lengths_q"][:actual_seq_lengths_q.size(0)] = actual_seq_lengths_q + attn_metadata.actual_seq_lengths_q = input_buffers["actual_seq_lengths_q"] # ssm if graph_meta.is_ssm: - input_buffers["q_start_loc"][: batch_size + 1] = q_start_loc - input_buffers["q_start_loc"][batch_size + 1 :] = q_start_loc[-1] + input_buffers["q_start_loc"][: num_seqs + 1] = q_start_loc + input_buffers["q_start_loc"][num_seqs + 1 :] = q_start_loc[-1] state_ids = kwargs["state_ids"] input_buffers["state_ids"].fill_(-1) @@ -153,7 +173,9 @@ def AscendCudaGraphMixin_fill_buffers_cudagraph( input_buffers["inputs_embeds"][:, :num_tokens] = inputs_embeds # create inputs # Use compatible size but cap at graph's max_batchs to avoid buffer overflow - new_batch_size = min(get_ascend_compatible_size(batch_size), graph_meta.max_batchs) + # For multi-token decode, expanded_batch_size is per-token, so we + # compute padded size from the true sequence count. + new_batch_size = min(get_ascend_compatible_size(num_seqs), graph_meta.max_batchs) attn_metadata.block_offsets = input_buffers["block_offsets"] attn_metadata.kv_seqlens = input_buffers["kv_seqlens"] @@ -352,10 +374,15 @@ def forward(self, **kwargs): self.model.update_context_cudagraph(self.meta, context) if aclgraph_use_torch_npu_update(): self._graph.replay() + update_dict = { + "actual_seq_lengths_kv": self.meta.input_buffers["kv_seqlens"], + } + # multi-token decode also needs actual_seq_lengths updated + actual_seq_lengths_q = self.meta.input_buffers.get("actual_seq_lengths_q") + if actual_seq_lengths_q is not None and actual_seq_lengths_q.any(): + update_dict["actual_seq_lengths"] = actual_seq_lengths_q self._graph.update( - cpu_update_input=[ - {"actual_seq_lengths_kv": self.meta.input_buffers["kv_seqlens"]} - ] + cpu_update_input=[update_dict] ) else: update_attn_params(self.update_stream, self.meta, self.max_tokens) @@ -427,34 +454,74 @@ def _get_capture_tokens(self, batch_size: int): def get_graph_key( self, input_ids: torch.Tensor, + attn_metadata: Any, **kwargs, ): """Get graph key.""" context = self.ctx_mgr.current_context() is_decoding = context.is_decoding - num_tokens = input_ids.numel() meta = self.get_meta() enable_microbatch = get_step_ctx_manager().current_context().enable_microbatch + + if is_decoding: + batch_size = None + q_seqlens = None + if attn_metadata is not None: + q_seqlens = getattr(attn_metadata, "q_seqlens", None) + if q_seqlens is None: + q_seqlens = getattr(context, "q_seqlens", None) + if q_seqlens is not None: + batch_size = q_seqlens.size(0) + elif kwargs.get("state_ids", None) is not None: + batch_size = kwargs["state_ids"].size(0) + + if batch_size is not None and batch_size > 0 and input_ids.size(-1) % batch_size == 0: + query_len = input_ids.size(-1) // batch_size + if meta.padding_batch_size is None: + new_batch_size = self._get_capture_tokens(batch_size) + else: + padding_num_tokens = meta.padding_batch_size + padding_batch_size = (padding_num_tokens + query_len - 1) // query_len + new_batch_size = self._get_capture_tokens(padding_batch_size) + return (new_batch_size, is_decoding, enable_microbatch, query_len) + + num_tokens = input_ids.numel() if meta.padding_batch_size is None: new_num_tokens = self._get_capture_tokens(num_tokens) else: new_num_tokens = self._get_capture_tokens(meta.padding_batch_size) - return (new_num_tokens, is_decoding, enable_microbatch) + return (new_num_tokens, is_decoding, enable_microbatch, 1) def __call__(self, **kwargs): """call.""" + import os as _os + _debug_graph = _os.environ.get('LMDEPLOY_DEBUG_GRAPH', '0') == '1' enable_graph = self.enable_graph(**kwargs) if not enable_graph: + if _debug_graph: + print(f'[GRAPH_DEBUG] eager path (enable_graph=False)', flush=True) with record_function("forward_eager"): ret = self.model(**kwargs) return self.model.make_output_buffers(ret) + if _debug_graph: + print(f'[GRAPH_DEBUG] graph path', flush=True) graph_key = self.get_graph_key(**kwargs) - max_tokens = graph_key[0] + max_batches = graph_key[0] is_decoding = graph_key[1] + decode_query_len = graph_key[3] + if is_decoding: + max_tokens = max_batches * decode_query_len + else: + max_tokens = max_batches + max_batches = self.max_batches if graph_key not in self._runner_map: - max_batches = max_tokens if is_decoding else self.max_batches + if _debug_graph: + print(f'[GRAPH_DEBUG] capturing new graph: key={graph_key} ' + f'max_batches={max_batches} max_tokens={max_tokens} ' + f'is_decoding={is_decoding} decode_query_len={decode_query_len}', + flush=True) runner = AscendSingleGraphRunner( self.model, max_batches=max_batches, @@ -468,6 +535,9 @@ def __call__(self, **kwargs): ) runner.capture(**kwargs) self._runner_map[graph_key] = runner + if _debug_graph: + print(f'[GRAPH_DEBUG] graph captured OK, total graphs={len(self._runner_map)}', + flush=True) else: runner = self._runner_map[graph_key] output = runner.forward(**kwargs) @@ -534,12 +604,10 @@ class GraphParams: _graph_params: Optional[GraphParams] = None -_graph_capture_sizes: set[int] = None def set_graph_params(aclgraph_capture_sizes: set[int]): global _graph_params - global _graph_capture_sizes if _graph_params is not None: raise ValueError("Graph parameters have already been set!") _graph_params = GraphParams( @@ -549,7 +617,6 @@ def set_graph_params(aclgraph_capture_sizes: set[int]): attn_params={size: [] for size in aclgraph_capture_sizes}, is_mla=False, ) - _graph_capture_sizes = aclgraph_capture_sizes def get_graph_params(): @@ -559,7 +626,6 @@ def get_graph_params(): def clear_graph_params(): """Clear global graph params and release references to KV cache tensors.""" global _graph_params - global _graph_capture_sizes if _graph_params is None: return @@ -575,10 +641,6 @@ def clear_graph_params(): _graph_params.workspaces.clear() finally: _graph_params = None - _graph_capture_sizes = None - # 清除 lru_cache,使下次推理时 _get_capture_batch_size_impl - # 重新执行并调用 set_graph_params 干净重建 - _get_capture_batch_size_impl.cache_clear() def update_attn_params(update_stream, forward_meta, runtime_size): diff --git a/dlinfer/framework/lmdeploy_ext/device/__init__.py b/dlinfer/framework/lmdeploy_ext/device/__init__.py index 75a45b36..9964bd5d 100644 --- a/dlinfer/framework/lmdeploy_ext/device/__init__.py +++ b/dlinfer/framework/lmdeploy_ext/device/__init__.py @@ -263,6 +263,7 @@ def __init__( ): self.is_decoding = attn_metadata.is_decoding self.cu_seqlens = attn_metadata.q_start_loc + self.is_multi_token_decoding = getattr(attn_metadata, 'is_multi_token_decoding', False) # state_ids, fill invalid state with 0 self.state_ids = state_ids.clamp(0) @@ -341,8 +342,34 @@ def __call__( weight_reshaped = weight.squeeze(1) x = x.squeeze(0) - if gated_delta_meta.is_decoding: + is_multi_token_decode = ( + not gated_delta_meta.is_decoding + and getattr(gated_delta_meta, 'is_multi_token_decoding', False) + ) + + if gated_delta_meta.is_decoding or is_multi_token_decode: conv_state_indices = gated_delta_meta.conv_state_indices + # causal_conv1d_update_npu supports multi-token via + # seqlen > 1 in the Triton kernel's for-loop. + # For multi-token, x shape is (batch * seqlen, dim); + # reshape to (batch, seqlen, dim) for the update kernel. + if is_multi_token_decode: + cu = gated_delta_meta.cu_seqlens + num_seqs = cu.size(0) - 1 + seqlen = x.size(0) // num_seqs + if seqlen > 1: + x = x.view(num_seqs, seqlen, -1).contiguous() + out = self.causal_conv1d_update( + x, + conv_state, + weight_reshaped.t().contiguous(), + bias, + self.activation, + conv_state_indices=conv_state_indices, + validate_data=False, + ) + out = out.reshape(-1, out.size(-1)).unsqueeze(0) + return out, conv_state return self.conv1d_update( x, weight_reshaped, bias, conv_state, conv_state_indices ) @@ -374,8 +401,12 @@ def __call__( """call.""" is_decoding = gated_delta_meta.is_decoding + is_multi_token_decode = ( + not is_decoding + and getattr(gated_delta_meta, 'is_multi_token_decoding', False) + ) - if is_decoding: + if is_decoding or is_multi_token_decode: indices = gated_delta_meta.state_ids cu_seqlens = gated_delta_meta.cu_seqlens core_attn_out = self.fused_sigmoid_gating_delta_rule_update( @@ -431,6 +462,36 @@ def import_vendor_module(vendor_name_str): importlib.import_module(f".{vendor_name_str}", __package__) +def patch_attention_is_tp(): + """Monkey-patch Qwen3_5Attention to skip TP head division for draft model. + + The MTP draft model uses is_tp=False to keep full head counts on each + rank. Qwen3_5Attention already passes is_tp to build_qkv_proj and + build_o_proj, but Attention.__init__ always calls _update_num_heads. + We temporarily replace _update_num_heads with an identity function + during Qwen3_5Attention.__init__ when is_tp=False. + """ + from lmdeploy.pytorch.nn import attention as _attn_mod + from lmdeploy.pytorch.models import qwen3_5 + + _orig_update = _attn_mod._update_num_heads + _identity_update = lambda nh, nkv: (nh, nkv) + _orig_init = qwen3_5.Qwen3_5Attention.__init__ + + def _patched_init(self, config, layer_idx, dtype=None, device=None, + prefix='', is_tp=True): + if not is_tp: + _attn_mod._update_num_heads = _identity_update + try: + _orig_init(self, config, layer_idx, dtype=dtype, device=device, + prefix=prefix, is_tp=is_tp) + finally: + if not is_tp: + _attn_mod._update_num_heads = _orig_update + + qwen3_5.Qwen3_5Attention.__init__ = _patched_init + + def patch_qwen3_5(): import torch from typing import List @@ -451,6 +512,8 @@ def patch_qwen3_5(): @classmethod def custom_build(cls, hf_config, model_path: str = None, tp: int = 1, **kwargs): """build.""" + is_draft_model = kwargs.get("is_draft_model", False) + spec_method = kwargs.get("spec_method", None) text_config = hf_config.text_config # propagate quantization_config from top-level hf_config into text_config quantization_config = getattr(hf_config, "quantization_config", None) @@ -499,6 +562,23 @@ def custom_build(cls, hf_config, model_path: str = None, tp: int = 1, **kwargs): cfg.check_env_func = _check_env_qwen3_next cfg.use_mrope = True + + # Speculative decoding support + if spec_method is not None: + assert spec_method == "qwen3_5_mtp" + cfg.model_paradigm = "ar_spec" + + if is_draft_model: + hf_config.architectures = ["Qwen3_5MTPModel"] + if getattr(hf_config, "auto_map", None): + hf_config.auto_map = {} + cfg.model_paradigm = "ar_spec" + cfg.num_layers = text_config.mtp_num_hidden_layers + cfg.states_shapes = [] + # Draft model uses is_tp=False — each rank runs the full model + # independently, so keep the replicated KV head count for correct + # cache allocation. + return cfg def custom_prepare_inputs_for_generation( @@ -671,6 +751,107 @@ def custom_forward( ) +def patch_ray_init(): + """Monkey-patch lmdeploy's init_ray_cluster to register custom NPU resources. + + Ray does not auto-detect Ascend NPUs; without registering custom resources + at ray.init() time, placement groups requesting ``{'NPU': 1}`` never schedule + on a fresh local cluster. + """ + import os + import logging + import lmdeploy.pytorch.ray as _ray_mod + + logger = logging.getLogger('dlinfer.ray') + _orig_init_ray_cluster = _ray_mod.init_ray_cluster + + def _infer_local_ray_custom_resources(device_type, world_size): + if device_type == 'ascend': + n = None + try: + npu_mod = getattr(torch, 'npu', None) + if npu_mod is not None and callable(getattr(npu_mod, 'device_count', None)): + n = int(npu_mod.device_count()) + if n <= 0: + n = None + except Exception: + n = None + if n is None: + vis = os.environ.get('ASCEND_RT_VISIBLE_DEVICES', '').strip() + if vis: + n = len([x for x in vis.split(',') if x.strip() != '']) + if n is None or n <= 0: + n = int(world_size) + logger.warning( + 'Could not detect NPU count; registering Ray resource NPU=%d ' + 'from world_size.', n) + return {'NPU': float(n)} + if device_type == 'camb': + n = None + try: + mlu = getattr(torch, 'mlu', None) + if mlu is not None and callable(getattr(mlu, 'device_count', None)): + n = int(mlu.device_count()) + if n <= 0: + n = None + except Exception: + n = None + if n is None or n <= 0: + n = int(world_size) + logger.warning('Could not detect MLU count; registering MLU=%d.', n) + return {'MLU': float(n)} + return None + + def _patched_init_ray_cluster(world_size, ray_address=None, dp=1, device_type='cuda'): + """Same as original but registers custom resources at ray.init() for local clusters.""" + import ray + if not ray.is_initialized(): + num_cpus = world_size + object_store_memory = _ray_mod._get_obj_store_memory(dp=dp) + init_kwargs = dict( + ignore_reinit_error=True, + num_cpus=num_cpus, + object_store_memory=object_store_memory, + ) + if ray_address is not None: + init_kwargs['address'] = ray_address + if ray_address is None: + custom_res = _infer_local_ray_custom_resources(device_type, world_size) + if custom_res: + init_kwargs['resources'] = custom_res + try: + ray.init(**init_kwargs) + except ValueError as e: + if e.args is not None and len(e.args) >= 1 and e.args[ + 0] == 'When connecting to an existing cluster, num_cpus and num_gpus must not be provided.': + ray.init(address=ray_address, ignore_reinit_error=True) + else: + raise + + # Remaining logic unchanged from original init_ray_cluster + device_str = _ray_mod.get_device_str(device_type) + current_placement_group = ray.util.get_current_placement_group() + owned_pg = False + if not current_placement_group: + num_devices_in_cluster = ray.cluster_resources().get(device_str, 0) + if world_size > num_devices_in_cluster: + _ray_mod.logger.warning( + 'The number of required %ss exceeds the total ' + 'number of available %ss in the placement group.', device_str, device_str) + placement_group_specs = [{device_str: 1.0} for _ in range(world_size)] + current_ip = ray.util.get_node_ip_address() + placement_group_specs[0][f'node:{current_ip}'] = 0.001 + current_placement_group = ray.util.placement_group(placement_group_specs, strategy='PACK') + _ray_mod._wait_until_pg_ready(current_placement_group) + owned_pg = True + + assert current_placement_group is not None + placement_group = current_placement_group + return placement_group, owned_pg + + _ray_mod.init_ray_cluster = _patched_init_ray_cluster + + def vendor_device_init(): import_vendor_module(vendor_name) patch_compiled_func() @@ -679,8 +860,10 @@ def vendor_device_init(): patch_contiguous_cache_engine() if vendor_name == "ascend": patch_state_cache_engine() - patch_gated_delta_net() + patch_gated_delta_net() # MUST be before patch_attention_is_tp + patch_attention_is_tp() patch_qwen3_5() + patch_ray_init() vendor_device_init() From ff1cc7d481ee3529b725984ffc46921fff988944 Mon Sep 17 00:00:00 2001 From: tangzhiyi Date: Sun, 19 Apr 2026 09:03:09 +0000 Subject: [PATCH 02/12] [Ascend] Patch MTP graph and sampling runtime Move the Ascend-specific graph alignment, state replay, and sampling fallback into dlinfer so multi-token speculative decode stays stable without expanding lmdeploy core runtime changes. Made-with: Cursor --- .../cudagraph/ascend_cudagraph.py | 121 +++-- .../framework/lmdeploy_ext/device/__init__.py | 456 ++++++++++++++++-- dlinfer/vendor/ascend/torch_npu_ops.py | 64 ++- dlinfer/vendor/ascend/triton_ops/__init__.py | 6 +- .../vendor/ascend/triton_ops/fla/__init__.py | 6 +- .../ascend/triton_ops/fla/sigmoid_gating.py | 119 +++++ 6 files changed, 682 insertions(+), 90 deletions(-) diff --git a/dlinfer/framework/lmdeploy_ext/cudagraph/ascend_cudagraph.py b/dlinfer/framework/lmdeploy_ext/cudagraph/ascend_cudagraph.py index f84686c7..51042569 100644 --- a/dlinfer/framework/lmdeploy_ext/cudagraph/ascend_cudagraph.py +++ b/dlinfer/framework/lmdeploy_ext/cudagraph/ascend_cudagraph.py @@ -37,6 +37,26 @@ def aclgraph_use_torch_npu_update(): # AscendCudaGraphMixin methods for cudagraph buffer management. +def AscendCudaGraphMixin_support_cuda_graph( + self, + input_ids: Tensor, + position_ids: Tensor, + past_key_values: List[List[Tensor]], + attn_metadata: Any = None, + inputs_embeds: Tensor = None, + **kwargs, +): + """Allow multi-token decode graph only when runtime length updates exist.""" + if attn_metadata is None: + return False + + is_decoding = getattr(attn_metadata, "is_decoding", False) + is_multi_token = getattr(attn_metadata, "is_multi_token_decoding", False) + if is_multi_token and not aclgraph_use_torch_npu_update(): + return False + return is_decoding or is_multi_token + + def AscendCudaGraphMixin_make_buffers_cudagraph( self, graph_meta: CudaGraphMeta, *args, **kwargs ) -> BuffType: @@ -54,21 +74,21 @@ def AscendCudaGraphMixin_make_buffers_cudagraph( (1, max_tokens), dtype=torch.int32, device=device ) - # multi-token decode expands block_offsets per-token, so size by - # max_tokens; for single-token decode max_tokens == max_batches so - # this is backward-compatible. + # TND paged attention consumes block tables per sequence. Keep the graph + # buffer batch-shaped even when one decode step contains multiple query + # tokens per sequence. input_buffers["block_offsets"] = torch.zeros( - (max_tokens, num_blocks), dtype=torch.int32, device=device + (max_batches, num_blocks), dtype=torch.int32, device=device ) - input_buffers["q_seqlens"] = torch.ones( + input_buffers["q_seqlens"] = torch.ones(max_batches, dtype=torch.int32) + + # actual_seq_lengths_kv is also tracked per sequence in the TND path. + input_buffers["kv_seqlens"] = torch.ones(max_batches, dtype=torch.int32) + input_buffers["kv_seqlens_device"] = torch.ones( max_batches, dtype=torch.int32, device=device ) - # kv_seqlens and kv_start_indices are per-token for paged-prefill - # (multi-token decode); use max_tokens to accommodate both cases. - input_buffers["kv_seqlens"] = torch.ones(max_tokens, dtype=torch.int32) - input_buffers["q_start_loc"] = torch.arange( max_batches + 1, dtype=torch.int32, device=device ) @@ -77,8 +97,9 @@ def AscendCudaGraphMixin_make_buffers_cudagraph( (max_tokens), dtype=torch.int32, device=device ) + # MoE routing still reasons in token space for multi-token verify. input_buffers["x_active_mask"] = torch.zeros( - (max_batches), dtype=torch.bool, device=device + (max_tokens), dtype=torch.bool, device=device ) # actual_seq_lengths_q for multi-token decode (CPU tensor, cumulative) @@ -91,6 +112,9 @@ def AscendCudaGraphMixin_make_buffers_cudagraph( input_buffers["state_ids"] = torch.full( (max_batches,), -1, dtype=torch.int64, device=device ) + input_buffers["num_accepted_tokens"] = torch.ones( + max_batches, dtype=torch.int32, device=device + ) # mrope if graph_meta.use_mrope: @@ -121,11 +145,9 @@ def AscendCudaGraphMixin_fill_buffers_cudagraph( input_buffers: BuffType = graph_meta.input_buffers - expanded_batch_size, num_blocks = block_offsets.size() + num_seqs, num_blocks = block_offsets.size() num_tokens = input_ids.size(-1) - # q_seqlens is per-sequence (not expanded), so its size gives the - # true number of sequences even for multi-token decode. - num_seqs = kv_seqlens.size(0) + q_seqlens: Tensor = attn_metadata.q_seqlens # fill buffer max_num_tokens = input_buffers["input_ids"].size(-1) @@ -137,10 +159,16 @@ def AscendCudaGraphMixin_fill_buffers_cudagraph( input_buffers["input_ids"][:, :num_tokens] = input_ids input_buffers["position_ids"].zero_() input_buffers["position_ids"][:, :num_tokens] = position_ids + input_buffers["q_seqlens"].fill_(1) + input_buffers["q_seqlens"][: q_seqlens.size(0)] = q_seqlens input_buffers["block_offsets"].zero_() - input_buffers["block_offsets"][:expanded_batch_size, :num_blocks] = block_offsets + input_buffers["block_offsets"][:num_seqs, :num_blocks] = block_offsets input_buffers["kv_seqlens"].fill_(0) input_buffers["kv_seqlens"][:num_seqs] = kv_seqlens + input_buffers["kv_seqlens_device"].fill_(0) + input_buffers["kv_seqlens_device"][:num_seqs].copy_( + kv_seqlens.to(device=input_buffers["kv_seqlens_device"].device) + ) input_buffers["kv_start_indices"].fill_(-1) input_buffers["kv_start_indices"][:kv_start_indices.size(0)] = kv_start_indices if x_active_mask is not None: @@ -152,17 +180,52 @@ def AscendCudaGraphMixin_fill_buffers_cudagraph( if actual_seq_lengths_q is not None: input_buffers["actual_seq_lengths_q"].zero_() input_buffers["actual_seq_lengths_q"][:actual_seq_lengths_q.size(0)] = actual_seq_lengths_q + # TND graph replay uses fixed-size query buffers. Pad the cumulative + # query lengths to the compatible graph batch so the final element + # still matches the captured query token count. + if actual_seq_lengths_q.numel() > 0: + pad_query_len = torch.diff( + actual_seq_lengths_q, + prepend=actual_seq_lengths_q.new_zeros(1), + )[-1] + last_q = input_buffers["actual_seq_lengths_q"][actual_seq_lengths_q.size(0) - 1] + for idx in range(actual_seq_lengths_q.size(0), input_buffers["actual_seq_lengths_q"].size(0)): + last_q = last_q + pad_query_len + input_buffers["actual_seq_lengths_q"][idx] = last_q + input_buffers["q_seqlens"].copy_(input_buffers["actual_seq_lengths_q"]) attn_metadata.actual_seq_lengths_q = input_buffers["actual_seq_lengths_q"] # ssm if graph_meta.is_ssm: - input_buffers["q_start_loc"][: num_seqs + 1] = q_start_loc - input_buffers["q_start_loc"][num_seqs + 1 :] = q_start_loc[-1] + input_buffers["q_start_loc"].fill_(0) + input_buffers["q_start_loc"][: q_start_loc.size(0)] = q_start_loc + if actual_seq_lengths_q is not None and q_start_loc.numel() > 1: + pad_query_len = q_start_loc[-1] - q_start_loc[-2] + last_q = q_start_loc[-1] + for idx in range(q_start_loc.size(0), input_buffers["q_start_loc"].size(0)): + last_q = last_q + pad_query_len + input_buffers["q_start_loc"][idx] = last_q + else: + input_buffers["q_start_loc"][q_start_loc.size(0):] = q_start_loc[-1] state_ids = kwargs["state_ids"] input_buffers["state_ids"].fill_(-1) input_buffers["state_ids"][: state_ids.size(0)].copy_(state_ids) + num_accepted_tokens = getattr(attn_metadata, "num_accepted_tokens", None) + input_buffers["num_accepted_tokens"].fill_(1) + if num_accepted_tokens is not None: + input_buffers["num_accepted_tokens"][: num_accepted_tokens.size(0)].copy_( + num_accepted_tokens.to( + device=input_buffers["num_accepted_tokens"].device, + dtype=input_buffers["num_accepted_tokens"].dtype, + ) + ) + attn_metadata.num_accepted_tokens = input_buffers["num_accepted_tokens"] + # Keep linear-attention state math on the fixed graph buffer so its + # per-sequence cache lengths stay aligned with padded q_start_loc. + attn_metadata.kv_seqlens_device = input_buffers["kv_seqlens_device"] + if inputs_embeds is not None: emb_size = inputs_embeds.size(-1) if "inputs_embeds" not in input_buffers: @@ -171,12 +234,7 @@ def AscendCudaGraphMixin_fill_buffers_cudagraph( 1, max_num_tokens, emb_size ) input_buffers["inputs_embeds"][:, :num_tokens] = inputs_embeds - # create inputs - # Use compatible size but cap at graph's max_batchs to avoid buffer overflow - # For multi-token decode, expanded_batch_size is per-token, so we - # compute padded size from the true sequence count. - new_batch_size = min(get_ascend_compatible_size(num_seqs), graph_meta.max_batchs) - + attn_metadata.q_seqlens = input_buffers["q_seqlens"] attn_metadata.block_offsets = input_buffers["block_offsets"] attn_metadata.kv_seqlens = input_buffers["kv_seqlens"] attn_metadata.kv_start_indices = input_buffers["kv_start_indices"] @@ -231,6 +289,7 @@ def AscendCudaGraphMixin_update_context_cudagraph(self, graph_meta, context): context.mrope_position_ids = input_buffers["mrope_position_ids"] +CudaGraphMixin.support_cuda_graph = AscendCudaGraphMixin_support_cuda_graph CudaGraphMixin.make_buffers_cudagraph = AscendCudaGraphMixin_make_buffers_cudagraph CudaGraphMixin.fill_buffers_cudagraph = AscendCudaGraphMixin_fill_buffers_cudagraph CudaGraphMixin.update_context_cudagraph = AscendCudaGraphMixin_update_context_cudagraph @@ -385,7 +444,7 @@ def forward(self, **kwargs): cpu_update_input=[update_dict] ) else: - update_attn_params(self.update_stream, self.meta, self.max_tokens) + update_attn_params(self.update_stream, self.meta, self.max_batches) self._graph.replay() output_buffers = self.meta.output_buffers output = self.model.get_outputs_cudagraph(output_buffers, **kwargs) @@ -494,19 +553,13 @@ def get_graph_key( def __call__(self, **kwargs): """call.""" - import os as _os - _debug_graph = _os.environ.get('LMDEPLOY_DEBUG_GRAPH', '0') == '1' enable_graph = self.enable_graph(**kwargs) if not enable_graph: - if _debug_graph: - print(f'[GRAPH_DEBUG] eager path (enable_graph=False)', flush=True) with record_function("forward_eager"): ret = self.model(**kwargs) return self.model.make_output_buffers(ret) - if _debug_graph: - print(f'[GRAPH_DEBUG] graph path', flush=True) graph_key = self.get_graph_key(**kwargs) max_batches = graph_key[0] is_decoding = graph_key[1] @@ -517,11 +570,6 @@ def __call__(self, **kwargs): max_tokens = max_batches max_batches = self.max_batches if graph_key not in self._runner_map: - if _debug_graph: - print(f'[GRAPH_DEBUG] capturing new graph: key={graph_key} ' - f'max_batches={max_batches} max_tokens={max_tokens} ' - f'is_decoding={is_decoding} decode_query_len={decode_query_len}', - flush=True) runner = AscendSingleGraphRunner( self.model, max_batches=max_batches, @@ -535,9 +583,6 @@ def __call__(self, **kwargs): ) runner.capture(**kwargs) self._runner_map[graph_key] = runner - if _debug_graph: - print(f'[GRAPH_DEBUG] graph captured OK, total graphs={len(self._runner_map)}', - flush=True) else: runner = self._runner_map[graph_key] output = runner.forward(**kwargs) diff --git a/dlinfer/framework/lmdeploy_ext/device/__init__.py b/dlinfer/framework/lmdeploy_ext/device/__init__.py index 9964bd5d..c878b14f 100644 --- a/dlinfer/framework/lmdeploy_ext/device/__init__.py +++ b/dlinfer/framework/lmdeploy_ext/device/__init__.py @@ -100,6 +100,277 @@ async def async_sampling_logits( BaseModelAgent.async_sampling_logits = async_sampling_logits +def patch_rejection_sampler(): + from lmdeploy.pytorch.spec_decode import reject_sampler as _reject_sampler_mod + _orig_rejection_sample = _reject_sampler_mod.rejection_sample + + def _patched_rejection_sample( + target_logits, + draft_token_ids, + bonus_token_ids, + sampling_inputs, + draft_probs=None, + ): + if sampling_inputs.max_top_k == 1: + return _orig_rejection_sample( + target_logits, + draft_token_ids, + bonus_token_ids, + sampling_inputs, + draft_probs=draft_probs, + ) + + assert draft_probs is None or draft_probs.is_contiguous() + if not draft_token_ids.is_contiguous(): + draft_token_ids = draft_token_ids.contiguous() + + if not target_logits.is_contiguous(): + target_logits = target_logits.contiguous() + + batch_size, num_spec_tokens = draft_token_ids.shape + device = target_logits.device + + output_token_ids = torch.full( + (batch_size, num_spec_tokens + 1), + _reject_sampler_mod.PLACEHOLDER_TOKEN_ID, + dtype=torch.long, + device=device, + ) + + target_probs = target_logits.softmax(dim=-1, dtype=torch.float32) + if sampling_inputs.top_k is not None: + is_greedy = (sampling_inputs.top_k == 1) + if not torch.is_tensor(is_greedy): + is_greedy = torch.full( + (batch_size,), bool(is_greedy), dtype=torch.bool, device=device + ) + else: + is_greedy = is_greedy.to(device=device, dtype=torch.bool) + else: + is_greedy = torch.zeros(batch_size, dtype=torch.bool, device=device) + + target_argmax = target_probs.argmax(dim=-1) + uniform_probs = torch.rand( + (batch_size, num_spec_tokens), dtype=torch.float64, device=device + ) + inv_q = torch.empty( + (batch_size, target_probs.shape[-1]), dtype=torch.float32, device=device + ) + inv_q.exponential_() + inv_q = inv_q.reciprocal() + + recovered_token_ids = torch.empty( + (batch_size, num_spec_tokens), dtype=torch.long, device=device + ) + zero = target_probs.new_tensor(0.0) + for batch_idx in range(batch_size): + if bool(is_greedy[batch_idx].item()): + continue + batch_inv_q = inv_q[batch_idx] + for pos in range(num_spec_tokens): + draft_token_id = draft_token_ids[batch_idx, pos] + if draft_probs is None: + prob = target_probs[batch_idx, pos].clone() + prob[draft_token_id] = 0.0 + else: + prob = torch.maximum( + target_probs[batch_idx, pos] - draft_probs[batch_idx, pos], + zero, + ) + recovered_token_ids[batch_idx, pos] = torch.argmax(prob * batch_inv_q) + + for batch_idx in range(batch_size): + rejected = False + if bool(is_greedy[batch_idx].item()): + for pos in range(num_spec_tokens): + token_id = target_argmax[batch_idx, pos] + output_token_ids[batch_idx, pos] = token_id + if draft_token_ids[batch_idx, pos] != token_id: + rejected = True + break + else: + for pos in range(num_spec_tokens): + draft_token_id = draft_token_ids[batch_idx, pos] + if draft_probs is None: + draft_prob = 1.0 + else: + draft_prob = float( + draft_probs[batch_idx, pos, draft_token_id].item() + ) + target_prob = float( + target_probs[batch_idx, pos, draft_token_id].item() + ) + uniform_prob = float(uniform_probs[batch_idx, pos].item()) + if draft_prob > 0 and target_prob / draft_prob >= uniform_prob: + token_id = draft_token_id + else: + token_id = recovered_token_ids[batch_idx, pos] + rejected = True + output_token_ids[batch_idx, pos] = token_id + if rejected: + break + + if not rejected: + output_token_ids[batch_idx, num_spec_tokens] = bonus_token_ids[batch_idx] + + return _reject_sampler_mod._extract_outputs(output_token_ids, num_spec_tokens) + + _reject_sampler_mod.rejection_sample = _patched_rejection_sample + + +def patch_spec_decode_runtime(): + from torch.profiler import record_function + from lmdeploy.pytorch.engine.model_agent import BaseModelAgent + from lmdeploy.pytorch.spec_decode.spec_agent import SpecModelAgent + + _orig_build_cache_engine = BaseModelAgent.build_cache_engine + _orig_forward_impl = BaseModelAgent._forward_impl + + def _is_ascend_agent(agent): + return getattr(getattr(agent, "backend_config", None), "device_type", None) == "ascend" + + def _set_main_runtime(self, model, cache_engine, state_cache_engine, stream): + self.main_model = model + self.main_cache_engine = cache_engine + self.main_state_cache_engine = state_cache_engine + self.main_stream = stream + + def _maybe_snapshot_main_states(self, model_inputs): + self._main_state_snapshot = None + self._main_replay_inputs = None + state_cache_engine = getattr(self, "main_state_cache_engine", None) + main_model = getattr(self, "main_model", None) + if state_cache_engine is None or main_model is None: + return + if not model_inputs.is_decoding or model_inputs.max_q_seqlen <= 1: + return + mem_pool = getattr(state_cache_engine, "mem_pool", None) + if mem_pool is None or mem_pool.numel() == 0: + return + self._main_state_snapshot = mem_pool.clone() + self._main_replay_inputs = model_inputs.clone() + + def _build_replay_inputs(self, model_inputs, output_token_ids): + valid_mask = output_token_ids.ge(0) + seq_length = valid_mask.sum(dim=-1).to(model_inputs.seq_length.dtype) + if torch.any(seq_length <= 0): + return None + + replay_ids = [row[mask] for row, mask in zip(output_token_ids, valid_mask)] + input_ids = torch.cat(replay_ids, dim=0).unsqueeze(0) + + mrope_pos_ids = model_inputs.mrope_pos_ids + if mrope_pos_ids is not None: + mrope_chunks = [] + reshaped = mrope_pos_ids.unflatten(1, (-1, model_inputs.max_q_seqlen)) + for batch_idx, replay_len in enumerate(seq_length.tolist()): + mrope_chunks.append(reshaped[:, batch_idx, :replay_len]) + mrope_pos_ids = torch.cat(mrope_chunks, dim=1) + + max_q_seqlen = int(seq_length.max().item()) + max_kv_seqlen = int((model_inputs.history_lengths + seq_length).max().item()) + sum_kv_seqlen = int((model_inputs.history_lengths + seq_length).sum().item()) + return model_inputs.clone( + input_ids=input_ids, + seq_length=seq_length, + max_q_seqlen=max_q_seqlen, + max_kv_seqlen=max_kv_seqlen, + sum_kv_seqlen=sum_kv_seqlen, + target_hidden_states=None, + target_position_ids=None, + target_inputs_embeds=None, + mrope_pos_ids=mrope_pos_ids, + ) + + def _maybe_replay_main_states(self, extra_inputs): + state_snapshot = getattr(self, "_main_state_snapshot", None) + replay_template = getattr(self, "_main_replay_inputs", None) + if state_snapshot is None or replay_template is None: + return + try: + if extra_inputs.num_rejected_tokens is None or not torch.any(extra_inputs.num_rejected_tokens > 0): + return + replay_inputs = self._build_replay_inputs(replay_template, extra_inputs.output_token_ids) + if replay_inputs is None: + return + + main_model = getattr(self, "main_model", None) + main_cache_engine = getattr(self, "main_cache_engine", None) + state_cache_engine = getattr(self, "main_state_cache_engine", None) + if main_model is None or main_cache_engine is None or state_cache_engine is None: + return + + state_cache_engine.mem_pool.copy_(state_snapshot) + from lmdeploy.pytorch.engine.model_agent.agent import model_forward as _main_model_forward + + _main_model_forward( + main_model, + replay_inputs, + main_cache_engine, + state_cache_engine, + stream=getattr(self, "main_stream", None), + ) + finally: + self._main_state_snapshot = None + self._main_replay_inputs = None + + def _patched_build_cache_engine(self): + if _is_ascend_agent(self): + state_shapes = getattr(self.model_config, "states_shapes", []) + self.cache_config.states_shapes = state_shapes + if self.cache_config.num_state_caches is None and len(state_shapes) > 0: + self.cache_config.num_state_caches = int(self.cache_config.max_batches + 1) + + _orig_build_cache_engine(self) + + if ( + _is_ascend_agent(self) + and self.spec_agent is not None + and self.spec_agent.is_enabled() + ): + self.spec_agent.set_main_runtime( + self.patched_model, + self.cache_engine, + self.state_cache_engine, + self.stream, + ) + + def _patched_forward_impl(self, inputs): + if ( + _is_ascend_agent(self) + and self.spec_agent is not None + and self.spec_agent.is_enabled() + ): + self.spec_agent.maybe_snapshot_main_states(inputs) + return _orig_forward_impl(self, inputs) + + async def _patched_async_model_forward( + self, + model_inputs, + extra_inputs, + sampling_inputs, + ): + with record_function("spec_rejection_sampling"): + draft_extra_inputs = await self._rejection_sampling( + model_inputs, extra_inputs, sampling_inputs + ) + self._maybe_replay_main_states(draft_extra_inputs) + draft_model_inputs, draft_extra_inputs = self._prepare_inputs_from_main( + model_inputs, draft_extra_inputs + ) + return await self._async_model_forward( + draft_model_inputs, draft_extra_inputs, sampling_inputs + ) + + BaseModelAgent.build_cache_engine = _patched_build_cache_engine + BaseModelAgent._forward_impl = _patched_forward_impl + SpecModelAgent.set_main_runtime = _set_main_runtime + SpecModelAgent.maybe_snapshot_main_states = _maybe_snapshot_main_states + SpecModelAgent._build_replay_inputs = _build_replay_inputs + SpecModelAgent._maybe_replay_main_states = _maybe_replay_main_states + SpecModelAgent.async_model_forward = _patched_async_model_forward + + ##### patch cache engine ##### def patch_contiguous_cache_engine(): from lmdeploy.pytorch.config import CacheConfig, ModelConfig @@ -241,6 +512,7 @@ def patch_gated_delta_net(): from lmdeploy.pytorch.nn import gated_delta from lmdeploy.pytorch.nn.gated_delta import GatedDeltaMeta + from lmdeploy.pytorch.model_inputs import get_step_ctx_manager from dlinfer.vendor.ascend.triton_ops import RMSNormGated from dlinfer.vendor.ascend.triton_ops import ( @@ -249,6 +521,7 @@ def patch_gated_delta_net(): ) from dlinfer.vendor.ascend.triton_ops import ( chunk_gated_delta_rule, + fused_recurrent_gated_delta_rule, fused_sigmoid_gating_delta_rule_update, ) @@ -261,14 +534,44 @@ def __init__( state_ids: torch.Tensor, attn_metadata: Any, ): - self.is_decoding = attn_metadata.is_decoding - self.cu_seqlens = attn_metadata.q_start_loc self.is_multi_token_decoding = getattr(attn_metadata, 'is_multi_token_decoding', False) + # Keep decode semantics for linear-attention state updates even when + # full attention uses a prefill-style TND verify path. + self.is_decoding = attn_metadata.is_decoding or self.is_multi_token_decoding + self.cu_seqlens = attn_metadata.q_start_loc + self.num_spec_tokens = get_step_ctx_manager().build_ctx.num_spec_tokens + + query_lens = None + num_seqs = 1 + if self.cu_seqlens is not None: + query_lens = torch.diff(self.cu_seqlens).to(torch.int32) + num_seqs = max(int(self.cu_seqlens.numel()) - 1, 1) + self.max_query_len = max(num_tokens // num_seqs, 1) + self.cache_seqlens = None + self.spec_state_offsets = None + kv_seqlens_device = getattr(attn_metadata, 'kv_seqlens_device', None) + if query_lens is not None and kv_seqlens_device is not None: + kv_seqlens = kv_seqlens_device.to(dtype=torch.int32) + self.cache_seqlens = (kv_seqlens - query_lens).contiguous() + if self.num_spec_tokens > 0 and not self.is_decoding: + state_slots = 1 + self.num_spec_tokens + self.spec_state_offsets = ( + torch.remainder(self.cache_seqlens, state_slots), + torch.remainder(kv_seqlens, state_slots), + ) + self.num_accepted_tokens = getattr(attn_metadata, 'num_accepted_tokens', None) + if self.num_accepted_tokens is None and self.is_multi_token_decoding and query_lens is not None: + self.num_accepted_tokens = torch.ones(query_lens.size(0), dtype=torch.int32, device=self.cu_seqlens.device) + elif self.num_accepted_tokens is not None: + self.num_accepted_tokens = self.num_accepted_tokens.to( + device=self.cu_seqlens.device if self.cu_seqlens is not None else state_ids.device, + dtype=torch.int32, + ).contiguous() # state_ids, fill invalid state with 0 self.state_ids = state_ids.clamp(0) self.has_initial_state = attn_metadata.has_initial_state - self.conv_state_indices = self.state_ids + self.conv_state_indices = self.state_ids.to(torch.int32) def build_rmsnorm_gated(hidden_size: int, eps=1e-6, **kwargs): device = kwargs["device"] @@ -318,7 +621,17 @@ def conv1d_update( bias: torch.Tensor, conv_state: torch.Tensor, conv_state_indices: torch.Tensor, + gated_delta_meta: GatedDeltaMeta, ): + update_kwargs = {} + validate_data = True + if getattr(gated_delta_meta, 'is_multi_token_decoding', False): + update_kwargs.update( + num_accepted_tokens=gated_delta_meta.num_accepted_tokens, + query_start_loc=gated_delta_meta.cu_seqlens, + max_query_len=gated_delta_meta.max_query_len, + ) + validate_data = False out = self.causal_conv1d_update( x, conv_state, @@ -326,7 +639,8 @@ def conv1d_update( bias, self.activation, conv_state_indices=conv_state_indices, - validate_data=True, + validate_data=validate_data, + **update_kwargs, ) return out.unsqueeze(0), conv_state @@ -342,36 +656,10 @@ def __call__( weight_reshaped = weight.squeeze(1) x = x.squeeze(0) - is_multi_token_decode = ( - not gated_delta_meta.is_decoding - and getattr(gated_delta_meta, 'is_multi_token_decoding', False) - ) - - if gated_delta_meta.is_decoding or is_multi_token_decode: + if gated_delta_meta.is_decoding: conv_state_indices = gated_delta_meta.conv_state_indices - # causal_conv1d_update_npu supports multi-token via - # seqlen > 1 in the Triton kernel's for-loop. - # For multi-token, x shape is (batch * seqlen, dim); - # reshape to (batch, seqlen, dim) for the update kernel. - if is_multi_token_decode: - cu = gated_delta_meta.cu_seqlens - num_seqs = cu.size(0) - 1 - seqlen = x.size(0) // num_seqs - if seqlen > 1: - x = x.view(num_seqs, seqlen, -1).contiguous() - out = self.causal_conv1d_update( - x, - conv_state, - weight_reshaped.t().contiguous(), - bias, - self.activation, - conv_state_indices=conv_state_indices, - validate_data=False, - ) - out = out.reshape(-1, out.size(-1)).unsqueeze(0) - return out, conv_state return self.conv1d_update( - x, weight_reshaped, bias, conv_state, conv_state_indices + x, weight_reshaped, bias, conv_state, conv_state_indices, gated_delta_meta ) return self.conv1d_func( x, weight_reshaped, bias, conv_state, gated_delta_meta=gated_delta_meta @@ -380,12 +668,32 @@ def __call__( class AscendGatedDelta: def __init__(self, use_qk_l2norm_in_kernel: bool = True): + self.fused_recurrent_gated_delta_rule = fused_recurrent_gated_delta_rule self.fused_sigmoid_gating_delta_rule_update = ( fused_sigmoid_gating_delta_rule_update ) self.chunk_gated_delta_rule = chunk_gated_delta_rule self.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel + @staticmethod + def _get_decode_state_indices( + state_ids: torch.Tensor, + cache_seqlens: torch.Tensor, + num_slots: int, + query_len: int, + ) -> torch.Tensor: + """Map LMDeploy's per-sequence slot layout to per-token state ids.""" + token_offsets = torch.arange( + query_len, + device=cache_seqlens.device, + dtype=torch.int64, + ) + slot_offsets = torch.remainder( + cache_seqlens.to(torch.int64)[:, None] + token_offsets[None], + num_slots, + ) + return (state_ids.to(torch.int64)[:, None] * num_slots + slot_offsets).contiguous() + def __call__( self, query: torch.Tensor, @@ -401,14 +709,53 @@ def __call__( """call.""" is_decoding = gated_delta_meta.is_decoding - is_multi_token_decode = ( - not is_decoding - and getattr(gated_delta_meta, 'is_multi_token_decoding', False) - ) + is_multi_token_decode = getattr(gated_delta_meta, 'is_multi_token_decoding', False) + beta = b.sigmoid() + # If the model is loaded in fp16, without the .float() here, A might be -inf + g = (-A_log.float().exp()) * F.softplus(a.float() + dt_bias) - if is_decoding or is_multi_token_decode: + if is_decoding: indices = gated_delta_meta.state_ids cu_seqlens = gated_delta_meta.cu_seqlens + if is_multi_token_decode: + query_len = gated_delta_meta.max_query_len + state_slots = recurrent_state.size(1) + flat_recurrent_state = recurrent_state.view(-1, *recurrent_state.shape[2:]) + state_indices = self._get_decode_state_indices( + indices, + gated_delta_meta.cache_seqlens, + state_slots, + query_len, + ) + core_attn_out, _ = self.fused_recurrent_gated_delta_rule( + q=query.contiguous(), + k=key.contiguous(), + v=value.contiguous(), + g=g.contiguous(), + beta=beta.contiguous(), + initial_state=flat_recurrent_state, + inplace_final_state=True, + cu_seqlens=cu_seqlens, + ssm_state_indices=state_indices, + num_accepted_tokens=gated_delta_meta.num_accepted_tokens, + use_qk_l2norm_in_kernel=self.use_qk_l2norm_in_kernel, + ) + return core_attn_out, None + + # Single-token decode: use the optimized update kernel + initial_state_source = recurrent_state + initial_state_indices = indices + if recurrent_state.dim() == 5: + state_slots = recurrent_state.size(1) + flat_recurrent_state = recurrent_state.view(-1, *recurrent_state.shape[2:]) + slot_offsets = torch.remainder( + gated_delta_meta.cache_seqlens.to(torch.int64), + state_slots, + ) + initial_state_source = flat_recurrent_state + initial_state_indices = ( + indices.to(torch.int64) * state_slots + slot_offsets + ).contiguous() core_attn_out = self.fused_sigmoid_gating_delta_rule_update( A_log=A_log, dt_bias=dt_bias, @@ -417,8 +764,8 @@ def __call__( v=value.contiguous(), a=a.contiguous(), b=b.contiguous(), - initial_state_source=recurrent_state, - initial_state_indices=indices, + initial_state_source=initial_state_source, + initial_state_indices=initial_state_indices, cu_seqlens=cu_seqlens, use_qk_l2norm_in_kernel=True, softplus_beta=1.0, @@ -426,11 +773,11 @@ def __call__( ) last_recurrent_state = None else: - beta = b.sigmoid() - # If the model is loaded in fp16, without the .float() here, A might be -inf - g = (-A_log.float().exp()) * F.softplus(a.float() + dt_bias) - - initial_state = recurrent_state[gated_delta_meta.state_ids] + if gated_delta_meta.spec_state_offsets is not None: + read_offsets, write_offsets = gated_delta_meta.spec_state_offsets + initial_state = recurrent_state[gated_delta_meta.state_ids, read_offsets] + else: + initial_state = recurrent_state[gated_delta_meta.state_ids] initial_state[~gated_delta_meta.has_initial_state, ...] = 0 core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule( q=query, @@ -444,9 +791,14 @@ def __call__( head_first=False, use_qk_l2norm_in_kernel=self.use_qk_l2norm_in_kernel, ) - recurrent_state[gated_delta_meta.state_ids] = last_recurrent_state.to( - recurrent_state.dtype - ) + if gated_delta_meta.spec_state_offsets is not None: + recurrent_state[gated_delta_meta.state_ids, write_offsets] = last_recurrent_state.to( + recurrent_state.dtype + ) + else: + recurrent_state[gated_delta_meta.state_ids] = last_recurrent_state.to( + recurrent_state.dtype + ) return core_attn_out, last_recurrent_state @@ -514,6 +866,7 @@ def custom_build(cls, hf_config, model_path: str = None, tp: int = 1, **kwargs): """build.""" is_draft_model = kwargs.get("is_draft_model", False) spec_method = kwargs.get("spec_method", None) + num_spec_tokens = kwargs.get("num_spec_tokens", 0) text_config = hf_config.text_config # propagate quantization_config from top-level hf_config into text_config quantization_config = getattr(hf_config, "quantization_config", None) @@ -538,11 +891,14 @@ def custom_build(cls, hf_config, model_path: str = None, tp: int = 1, **kwargs): key_dim = head_k_dim * num_k_heads value_dim = head_v_dim * num_v_heads conv_dim = key_dim * 2 + value_dim - conv_kernel_size = text_config.linear_conv_kernel_dim + conv_kernel_size = text_config.linear_conv_kernel_dim + num_spec_tokens # Ascend Patch conv_state_shape = (conv_kernel_size, conv_dim) - recurrent_state_shape = (num_v_heads, head_k_dim, head_v_dim) + if num_spec_tokens > 0: + recurrent_state_shape = (1 + num_spec_tokens, num_v_heads, head_k_dim, head_v_dim) + else: + recurrent_state_shape = (num_v_heads, head_k_dim, head_v_dim) device_type = kwargs.get("device_type", "cuda") if is_bf16_supported(device_type): @@ -859,6 +1215,8 @@ def vendor_device_init(): if vendor_name in ["camb", "ascend"]: patch_contiguous_cache_engine() if vendor_name == "ascend": + patch_rejection_sampler() + patch_spec_decode_runtime() patch_state_cache_engine() patch_gated_delta_net() # MUST be before patch_attention_is_tp patch_attention_is_tp() diff --git a/dlinfer/vendor/ascend/torch_npu_ops.py b/dlinfer/vendor/ascend/torch_npu_ops.py index 2a12aabe..48dff7c4 100644 --- a/dlinfer/vendor/ascend/torch_npu_ops.py +++ b/dlinfer/vendor/ascend/torch_npu_ops.py @@ -1,5 +1,4 @@ # Copyright (c) 2024, DeepLink. All rights reserved. -import os import math import torch import torch.distributed as dist @@ -457,6 +456,69 @@ def paged_prefill_attention( scale_value = softmax_scale if softmax_scale else 1.0 / math.sqrt(query.shape[-1]) query = query.contiguous() + if q_seq_len.dim() != 1 or kv_seq_len.dim() != 1: + raise ValueError("TND paged prefill expects 1D actual_seq_lengths tensors.") + if block_table.size(0) != q_seq_len.numel() or kv_seq_len.numel() != q_seq_len.numel(): + raise ValueError("TND paged prefill expects per-sequence block_table and kv_seq_len.") + q_seq_len_cpu = get_cpu_seq_len(q_seq_len) + kv_seq_len_cpu = get_cpu_seq_len(kv_seq_len) + if ( + q_seq_len_cpu.numel() > 0 + and int(q_seq_len_cpu.max().item()) > 1 + and torch.any(kv_seq_len_cpu > q_seq_len_cpu) + ): + # Ascend TND fused infer attention is still unstable for speculative + # multi-token verify. Fall back to token-wise paged decode semantics so + # each speculative token only attends to history plus accepted prefix. + q_seq_len_per_seq = torch.diff( + q_seq_len_cpu, + prepend=q_seq_len_cpu.new_zeros(1), + ) + history_lens = kv_seq_len_cpu - q_seq_len_per_seq + expanded_kv_seq_len = torch.cat([ + torch.arange( + int(history_len.item()) + 1, + int(final_len.item()) + 1, + dtype=kv_seq_len_cpu.dtype, + ) + for history_len, final_len in zip(history_lens, kv_seq_len_cpu) + ]) + expanded_q_seq_len = torch.arange( + 1, + expanded_kv_seq_len.numel() + 1, + dtype=q_seq_len_cpu.dtype, + ) + expanded_block_table = block_table.repeat_interleave( + q_seq_len_per_seq.to(device=block_table.device, dtype=torch.int64), + dim=0, + ) + key_headsize, value_headsize = key_cache.shape[-1], value_cache.shape[-1] + if key_headsize == value_headsize: + return decode_attention( + query=query, + key_cache=key_cache, + value_cache=value_cache, + num_q_heads=num_q_heads, + num_kv_heads=num_kv_heads, + scale_value=scale_value, + block_table=expanded_block_table, + block_size=block_size, + q_seq_len=expanded_q_seq_len, + kv_seq_len=expanded_kv_seq_len, + softmax_scale=softmax_scale, + attn_output=attn_output, + ) + return decode_attention_mla( + query=query, + key_cache=key_cache, + num_kv_heads=num_kv_heads, + num_q_heads=num_q_heads, + scale_value=scale_value, + block_table=expanded_block_table, + kv_seq_len=expanded_kv_seq_len, + mla_vheadsize=value_cache.shape[-1], + attn_output=attn_output, + ) block_num = key_cache.size(0) key_cache = key_cache.view(block_num, block_size, -1) value_cache = value_cache.view(block_num, block_size, -1) diff --git a/dlinfer/vendor/ascend/triton_ops/__init__.py b/dlinfer/vendor/ascend/triton_ops/__init__.py index 49a22221..8ac1a579 100644 --- a/dlinfer/vendor/ascend/triton_ops/__init__.py +++ b/dlinfer/vendor/ascend/triton_ops/__init__.py @@ -6,11 +6,15 @@ "causal_conv1d_fn", "causal_conv1d_update_npu", "chunk_gated_delta_rule", + "fused_recurrent_gated_delta_rule", "fused_sigmoid_gating_delta_rule_update", "RMSNormGated", ] from .fla.chunk import chunk_gated_delta_rule -from .fla.sigmoid_gating import fused_sigmoid_gating_delta_rule_update +from .fla.sigmoid_gating import ( + fused_recurrent_gated_delta_rule, + fused_sigmoid_gating_delta_rule_update, +) from .rms_norm_gated import RMSNormGated from .causal_conv1d import causal_conv1d_fn, causal_conv1d_update_npu diff --git a/dlinfer/vendor/ascend/triton_ops/fla/__init__.py b/dlinfer/vendor/ascend/triton_ops/fla/__init__.py index e2eea080..94ed346c 100644 --- a/dlinfer/vendor/ascend/triton_ops/fla/__init__.py +++ b/dlinfer/vendor/ascend/triton_ops/fla/__init__.py @@ -1,9 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 from .chunk import chunk_gated_delta_rule -from .sigmoid_gating import fused_sigmoid_gating_delta_rule_update +from .sigmoid_gating import ( + fused_recurrent_gated_delta_rule, + fused_sigmoid_gating_delta_rule_update, +) __all__ = [ "chunk_gated_delta_rule", + "fused_recurrent_gated_delta_rule", "fused_sigmoid_gating_delta_rule_update", ] diff --git a/dlinfer/vendor/ascend/triton_ops/fla/sigmoid_gating.py b/dlinfer/vendor/ascend/triton_ops/fla/sigmoid_gating.py index c7fafa3b..46e547f8 100644 --- a/dlinfer/vendor/ascend/triton_ops/fla/sigmoid_gating.py +++ b/dlinfer/vendor/ascend/triton_ops/fla/sigmoid_gating.py @@ -418,3 +418,122 @@ def fused_sigmoid_gating_delta_rule_update( ) o = o.squeeze(0) return o + + +def fused_recurrent_gated_delta_rule_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + inplace_final_state: bool = True, + cu_seqlens: torch.Tensor | None = None, + ssm_state_indices: torch.Tensor | None = None, + num_accepted_tokens: torch.Tensor | None = None, + use_qk_l2norm_in_kernel: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + """Fused recurrent gated-delta forward with optional spec-decode offsets.""" + B, T, H, K, V = *k.shape, v.shape[-1] + HV = v.shape[2] + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 32) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1, "NK > 1 is not supported yet" + num_stages = 3 + num_warps = 1 + + o = q.new_empty(NK, *v.shape) + if inplace_final_state: + final_state = initial_state + else: + final_state = q.new_empty(T, HV, V, K, dtype=initial_state.dtype) + + stride_init_state_token = initial_state.stride(0) + stride_final_state_token = final_state.stride(0) + + if ssm_state_indices is None: + stride_indices_seq, stride_indices_tok = 1, 1 + elif ssm_state_indices.ndim == 1: + stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1 + else: + stride_indices_seq, stride_indices_tok = ssm_state_indices.stride() + + grid = (NK, NV, N * HV) + fused_recurrent_gated_delta_rule_fwd_kernel[grid]( + q=q, + k=k, + v=v, + g=g, + beta=beta, + o=o, + h0=initial_state, + ht=final_state, + cu_seqlens=cu_seqlens, + ssm_state_indices=ssm_state_indices, + num_accepted_tokens=num_accepted_tokens, + scale=scale, + N=N, + T=T, + B=B, + H=H, + HV=HV, + K=K, + V=V, + BK=BK, + BV=BV, + stride_init_state_token=stride_init_state_token, + stride_final_state_token=stride_final_state_token, + stride_indices_seq=stride_indices_seq, + stride_indices_tok=stride_indices_tok, + IS_BETA_HEADWISE=beta.ndim == v.ndim, + USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, + INPLACE_FINAL_STATE=inplace_final_state, + IS_KDA=False, + num_warps=num_warps, + num_stages=num_stages, + ) + o = o.squeeze(0) + return o, final_state + + +def fused_recurrent_gated_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor | None = None, + scale: float | None = None, + initial_state: torch.Tensor | None = None, + inplace_final_state: bool = True, + cu_seqlens: torch.LongTensor | None = None, + ssm_state_indices: torch.Tensor | None = None, + num_accepted_tokens: torch.Tensor | None = None, + use_qk_l2norm_in_kernel: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + """Wrapper for recurrent gated-delta forward.""" + if cu_seqlens is not None and q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + ) + if scale is None: + scale = k.shape[-1] ** -0.5 + else: + assert scale > 0, "scale must be positive" + if beta is None: + beta = torch.ones_like(q[..., 0]) + return fused_recurrent_gated_delta_rule_fwd( + q=q.contiguous(), + k=k.contiguous(), + v=v.contiguous(), + g=g.contiguous(), + beta=beta.contiguous(), + scale=scale, + initial_state=initial_state, + inplace_final_state=inplace_final_state, + cu_seqlens=cu_seqlens, + ssm_state_indices=ssm_state_indices, + num_accepted_tokens=num_accepted_tokens, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + ) From d96947ff788da3e4662a9d6bddfe3e3f57df4a61 Mon Sep 17 00:00:00 2001 From: tangzhiyi Date: Sun, 19 Apr 2026 09:51:28 +0000 Subject: [PATCH 03/12] [Ascend] Reduce MTP replay state snapshot peak memory Snapshot only the active state-cache rows during speculative replay so Ascend no longer clones the full state pool for rejection recovery. Made-with: Cursor --- .../framework/lmdeploy_ext/device/__init__.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/dlinfer/framework/lmdeploy_ext/device/__init__.py b/dlinfer/framework/lmdeploy_ext/device/__init__.py index c878b14f..5981a6a6 100644 --- a/dlinfer/framework/lmdeploy_ext/device/__init__.py +++ b/dlinfer/framework/lmdeploy_ext/device/__init__.py @@ -237,6 +237,7 @@ def _set_main_runtime(self, model, cache_engine, state_cache_engine, stream): def _maybe_snapshot_main_states(self, model_inputs): self._main_state_snapshot = None + self._main_state_ids = None self._main_replay_inputs = None state_cache_engine = getattr(self, "main_state_cache_engine", None) main_model = getattr(self, "main_model", None) @@ -247,7 +248,16 @@ def _maybe_snapshot_main_states(self, model_inputs): mem_pool = getattr(state_cache_engine, "mem_pool", None) if mem_pool is None or mem_pool.numel() == 0: return - self._main_state_snapshot = mem_pool.clone() + state_offsets = getattr(model_inputs, "state_offsets", None) + if state_offsets is None: + return + active_state_ids = state_offsets[state_offsets >= 0] + if active_state_ids.numel() == 0: + return + active_state_ids = active_state_ids.to(device=mem_pool.device, dtype=torch.long) + # Only snapshot rows touched by the current batch. + self._main_state_ids = active_state_ids + self._main_state_snapshot = mem_pool.index_select(0, active_state_ids).clone() self._main_replay_inputs = model_inputs.clone() def _build_replay_inputs(self, model_inputs, output_token_ids): @@ -284,8 +294,9 @@ def _build_replay_inputs(self, model_inputs, output_token_ids): def _maybe_replay_main_states(self, extra_inputs): state_snapshot = getattr(self, "_main_state_snapshot", None) + state_ids = getattr(self, "_main_state_ids", None) replay_template = getattr(self, "_main_replay_inputs", None) - if state_snapshot is None or replay_template is None: + if state_snapshot is None or state_ids is None or replay_template is None: return try: if extra_inputs.num_rejected_tokens is None or not torch.any(extra_inputs.num_rejected_tokens > 0): @@ -300,7 +311,7 @@ def _maybe_replay_main_states(self, extra_inputs): if main_model is None or main_cache_engine is None or state_cache_engine is None: return - state_cache_engine.mem_pool.copy_(state_snapshot) + state_cache_engine.mem_pool.index_copy_(0, state_ids, state_snapshot) from lmdeploy.pytorch.engine.model_agent.agent import model_forward as _main_model_forward _main_model_forward( @@ -312,6 +323,7 @@ def _maybe_replay_main_states(self, extra_inputs): ) finally: self._main_state_snapshot = None + self._main_state_ids = None self._main_replay_inputs = None def _patched_build_cache_engine(self): From 38e8d1c2d73fdc812785cae49f065bc8a436b624 Mon Sep 17 00:00:00 2001 From: WangQing <2917021186@qq.com> Date: Tue, 28 Apr 2026 08:06:30 +0000 Subject: [PATCH 04/12] [Ascend] fix gdn kernel in speculative decoding --- .../framework/lmdeploy_ext/device/__init__.py | 10 +- dlinfer/vendor/ascend/triton_ops/__init__.py | 6 +- .../vendor/ascend/triton_ops/fla/__init__.py | 6 +- .../ascend/triton_ops/fla/fused_recurrent.py | 401 ++++++++++++++++++ .../ascend/triton_ops/fla/sigmoid_gating.py | 119 ------ 5 files changed, 410 insertions(+), 132 deletions(-) create mode 100644 dlinfer/vendor/ascend/triton_ops/fla/fused_recurrent.py diff --git a/dlinfer/framework/lmdeploy_ext/device/__init__.py b/dlinfer/framework/lmdeploy_ext/device/__init__.py index 5981a6a6..6c72e593 100644 --- a/dlinfer/framework/lmdeploy_ext/device/__init__.py +++ b/dlinfer/framework/lmdeploy_ext/device/__init__.py @@ -366,7 +366,7 @@ async def _patched_async_model_forward( draft_extra_inputs = await self._rejection_sampling( model_inputs, extra_inputs, sampling_inputs ) - self._maybe_replay_main_states(draft_extra_inputs) + # self._maybe_replay_main_states(draft_extra_inputs) draft_model_inputs, draft_extra_inputs = self._prepare_inputs_from_main( model_inputs, draft_extra_inputs ) @@ -533,8 +533,8 @@ def patch_gated_delta_net(): ) from dlinfer.vendor.ascend.triton_ops import ( chunk_gated_delta_rule, - fused_recurrent_gated_delta_rule, fused_sigmoid_gating_delta_rule_update, + fused_recurrent_gated_delta_rule, ) class AscendGatedDeltaMeta: @@ -739,6 +739,7 @@ def __call__( state_slots, query_len, ) + state_indices, _ = torch.sort(state_indices, dim=1) core_attn_out, _ = self.fused_recurrent_gated_delta_rule( q=query.contiguous(), k=key.contiguous(), @@ -786,8 +787,7 @@ def __call__( last_recurrent_state = None else: if gated_delta_meta.spec_state_offsets is not None: - read_offsets, write_offsets = gated_delta_meta.spec_state_offsets - initial_state = recurrent_state[gated_delta_meta.state_ids, read_offsets] + initial_state = recurrent_state[gated_delta_meta.state_ids, 0].transpose(-1, -2).contiguous() else: initial_state = recurrent_state[gated_delta_meta.state_ids] initial_state[~gated_delta_meta.has_initial_state, ...] = 0 @@ -804,7 +804,7 @@ def __call__( use_qk_l2norm_in_kernel=self.use_qk_l2norm_in_kernel, ) if gated_delta_meta.spec_state_offsets is not None: - recurrent_state[gated_delta_meta.state_ids, write_offsets] = last_recurrent_state.to( + recurrent_state[gated_delta_meta.state_ids, 0] = last_recurrent_state.transpose(-1, -2).to( recurrent_state.dtype ) else: diff --git a/dlinfer/vendor/ascend/triton_ops/__init__.py b/dlinfer/vendor/ascend/triton_ops/__init__.py index 8ac1a579..f08c5f4b 100644 --- a/dlinfer/vendor/ascend/triton_ops/__init__.py +++ b/dlinfer/vendor/ascend/triton_ops/__init__.py @@ -12,9 +12,7 @@ ] from .fla.chunk import chunk_gated_delta_rule -from .fla.sigmoid_gating import ( - fused_recurrent_gated_delta_rule, - fused_sigmoid_gating_delta_rule_update, -) +from .fla.sigmoid_gating import fused_sigmoid_gating_delta_rule_update +from .fla.fused_recurrent import fused_recurrent_gated_delta_rule from .rms_norm_gated import RMSNormGated from .causal_conv1d import causal_conv1d_fn, causal_conv1d_update_npu diff --git a/dlinfer/vendor/ascend/triton_ops/fla/__init__.py b/dlinfer/vendor/ascend/triton_ops/fla/__init__.py index 94ed346c..e7104dec 100644 --- a/dlinfer/vendor/ascend/triton_ops/fla/__init__.py +++ b/dlinfer/vendor/ascend/triton_ops/fla/__init__.py @@ -1,10 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 from .chunk import chunk_gated_delta_rule -from .sigmoid_gating import ( - fused_recurrent_gated_delta_rule, - fused_sigmoid_gating_delta_rule_update, -) +from .sigmoid_gating import fused_sigmoid_gating_delta_rule_update +from .fused_recurrent import fused_recurrent_gated_delta_rule __all__ = [ "chunk_gated_delta_rule", diff --git a/dlinfer/vendor/ascend/triton_ops/fla/fused_recurrent.py b/dlinfer/vendor/ascend/triton_ops/fla/fused_recurrent.py new file mode 100644 index 00000000..1d05ee4c --- /dev/null +++ b/dlinfer/vendor/ascend/triton_ops/fla/fused_recurrent.py @@ -0,0 +1,401 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 + +import os +import torch +import triton +import triton.language as tl + +if os.environ.get("FLA_USE_FAST_OPS", "0") == "1": + exp = tldevice.fast_expf + log = tldevice.fast_logf + log2 = tldevice.fast_log2f +else: + exp = tl.exp + log = tl.log + log2 = tl.log2 + + +@triton.heuristics( + { + "USE_INITIAL_STATE": lambda args: args["h0"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + "IS_CONTINUOUS_BATCHING": lambda args: args["ssm_state_indices"] is not None, + "IS_SPEC_DECODING": lambda args: args["num_accepted_tokens"] is not None, + } +) +@triton.jit(do_not_specialize=["N", "T"]) +def fused_recurrent_gated_delta_rule_fwd_kernel( + q, + k, + v, + g, + beta, + o, + h0, + ht, + cu_seqlens, + ssm_state_indices, + num_accepted_tokens, + scale, + N: tl.int64, # num of sequences + T: tl.int64, # num of tokens + B: tl.constexpr, + H: tl.constexpr, + HV: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + stride_init_state_token: tl.constexpr, + stride_final_state_token: tl.constexpr, + stride_indices_seq: tl.constexpr, + stride_indices_tok: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace + IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar, + USE_QK_L2NORM_IN_KERNEL: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_CONTINUOUS_BATCHING: tl.constexpr, + IS_SPEC_DECODING: tl.constexpr, + IS_KDA: tl.constexpr, +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_hv = i_nh // HV, i_nh % HV + i_h = i_hv // (HV // H) + if IS_VARLEN: + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int64), + tl.load(cu_seqlens + i_n + 1).to(tl.int64), + ) + all = T + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + all = B * T + + if T == 0: + # no tokens to process for this sequence + return + + o_k = i_k * BK + tl.arange(0, BK) + o_v = i_v * BV + tl.arange(0, BV) + + p_q = q + (bos * H + i_h) * K + o_k + p_k = k + (bos * H + i_h) * K + o_k + p_v = v + (bos * HV + i_hv) * V + o_v + if IS_BETA_HEADWISE: + p_beta = beta + (bos * HV + i_hv) * V + o_v + else: + p_beta = beta + bos * HV + i_hv + + if not IS_KDA: + p_g = g + bos * HV + i_hv + else: + p_gk = g + (bos * HV + i_hv) * K + o_k + + p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v + + mask_k = o_k < K + mask_v = o_v < V + mask_h = mask_v[:, None] & mask_k[None, :] + + b_h = tl.zeros([BV, BK], dtype=tl.float32) + if USE_INITIAL_STATE: + if IS_CONTINUOUS_BATCHING: + if IS_SPEC_DECODING: + i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1 + else: + i_t = 0 + # Load state index and check for PAD_SLOT_ID (-1) + state_idx = tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to( + tl.int64 + ) + # Skip if state index is invalid (PAD_SLOT_ID = -1) + if state_idx < 0: + return + p_h0 = h0 + state_idx * stride_init_state_token + else: + p_h0 = h0 + bos * HV * V * K + p_h0 = p_h0 + i_hv * V * K + o_v[:, None] * K + o_k[None, :] + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for i_t in range(0, T): + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + + if USE_QK_L2NORM_IN_KERNEL: + b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6) + b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6) + b_q = b_q * scale + # [BV, BK] + if not IS_KDA: + b_g = tl.load(p_g).to(tl.float32) + b_h *= exp(b_g) + else: + b_gk = tl.load(p_gk).to(tl.float32) + b_h *= exp(b_gk[None, :]) + # [BV] + b_v -= tl.sum(b_h * b_k[None, :], 1) + if IS_BETA_HEADWISE: + b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32) + else: + b_beta = tl.load(p_beta).to(tl.float32) + b_v *= b_beta + # [BV, BK] + b_h += b_v[:, None] * b_k[None, :] + # [BV] + b_o = tl.sum(b_h * b_q[None, :], 1) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) + + # keep the states for multi-query tokens + if INPLACE_FINAL_STATE: + # Load state index and check for PAD_SLOT_ID (-1) + final_state_idx = tl.load( + ssm_state_indices + i_n * stride_indices_seq + i_t + ).to(tl.int64) + # Only store if state index is valid (not PAD_SLOT_ID) + if final_state_idx >= 0: + p_ht = ht + final_state_idx * stride_final_state_token + p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :] + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + else: + p_ht = ht + (bos + i_t) * stride_final_state_token + p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :] + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + + p_q += H * K + p_k += H * K + p_o += HV * V + p_v += HV * V + if not IS_KDA: + p_g += HV + else: + p_gk += HV * K + p_beta += HV * (V if IS_BETA_HEADWISE else 1) + + +def fused_recurrent_gated_delta_rule_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + inplace_final_state: bool = True, + cu_seqlens: torch.LongTensor | None = None, + ssm_state_indices: torch.Tensor | None = None, + num_accepted_tokens: torch.Tensor | None = None, + use_qk_l2norm_in_kernel: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, v.shape[-1] + HV = v.shape[2] + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 32) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1, "NK > 1 is not supported yet" + num_stages = 3 + num_warps = 1 + + o = q.new_empty(NK, *v.shape) + if inplace_final_state: + final_state = initial_state + else: + final_state = q.new_empty(T, HV, V, K, dtype=initial_state.dtype) + + stride_init_state_token = initial_state.stride(0) + stride_final_state_token = final_state.stride(0) + + if ssm_state_indices is None: + stride_indices_seq, stride_indices_tok = 1, 1 + elif ssm_state_indices.ndim == 1: + stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1 + else: + stride_indices_seq, stride_indices_tok = ssm_state_indices.stride() + + grid = (NK, NV, N * HV) + fused_recurrent_gated_delta_rule_fwd_kernel[grid]( + q=q, + k=k, + v=v, + g=g, + beta=beta, + o=o, + h0=initial_state, + ht=final_state, + cu_seqlens=cu_seqlens, + ssm_state_indices=ssm_state_indices, + num_accepted_tokens=num_accepted_tokens, + scale=scale, + N=N, + T=T, + B=B, + H=H, + HV=HV, + K=K, + V=V, + BK=BK, + BV=BV, + stride_init_state_token=stride_init_state_token, + stride_final_state_token=stride_final_state_token, + stride_indices_seq=stride_indices_seq, + stride_indices_tok=stride_indices_tok, + IS_BETA_HEADWISE=beta.ndim == v.ndim, + USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, + INPLACE_FINAL_STATE=inplace_final_state, + IS_KDA=False, + num_warps=num_warps, + num_stages=num_stages, + ) + o = o.squeeze(0) + return o, final_state + + +class FusedRecurrentFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + inplace_final_state: bool = True, + cu_seqlens: torch.LongTensor | None = None, + ssm_state_indices: torch.Tensor | None = None, + num_accepted_tokens: torch.Tensor | None = None, + use_qk_l2norm_in_kernel: bool = False, + ): + o, final_state = fused_recurrent_gated_delta_rule_fwd( + q=q.contiguous(), + k=k.contiguous(), + v=v.contiguous(), + g=g.contiguous(), + beta=beta.contiguous(), + scale=scale, + initial_state=initial_state, + inplace_final_state=inplace_final_state, + cu_seqlens=cu_seqlens, + ssm_state_indices=ssm_state_indices, + num_accepted_tokens=num_accepted_tokens, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + ) + + return o, final_state + + +def fused_recurrent_gated_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor = None, + scale: float = None, + initial_state: torch.Tensor = None, + inplace_final_state: bool = True, + cu_seqlens: torch.LongTensor | None = None, + ssm_state_indices: torch.Tensor | None = None, + num_accepted_tokens: torch.Tensor | None = None, + use_qk_l2norm_in_kernel: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]`. + v (torch.Tensor): + values of shape `[B, T, HV, V]`. + GVA is applied if `HV > H`. + g (torch.Tensor): + g (decays) of shape `[B, T, HV]`. + beta (torch.Tensor): + betas of shape `[B, T, HV]`. + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, HV, V, K]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + inplace_final_state: bool: + Whether to store the final state in-place to save memory. + Default: `True`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + ssm_state_indices (Optional[torch.Tensor]): + Indices to map the input sequences to the initial/final states. + num_accepted_tokens (Optional[torch.Tensor]): + Number of accepted tokens for each sequence during decoding. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, HV, V]`. + final_state (torch.Tensor): + Final state of shape `[N, HV, V, K]`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule + # inputs with equal lengths + >>> B, T, H, HV, K, V = 4, 2048, 4, 8, 512, 512 + >>> q = torch.randn(B, T, H, K, device='cuda') + >>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1) + >>> v = torch.randn(B, T, HV, V, device='cuda') + >>> g = F.logsigmoid(torch.rand(B, T, HV, device='cuda')) + >>> beta = torch.rand(B, T, HV, device='cuda').sigmoid() + >>> h0 = torch.randn(B, HV, V, K, device='cuda') + >>> o, ht = fused_gated_recurrent_delta_rule( + q, k, v, g, beta, + initial_state=h0, + ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = fused_gated_recurrent_delta_rule( + q, k, v, g, beta, + initial_state=h0, + cu_seqlens=cu_seqlens + ) + """ + if cu_seqlens is not None and q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) + if scale is None: + scale = k.shape[-1] ** -0.5 + else: + assert scale > 0, "scale must be positive" + if beta is None: + beta = torch.ones_like(q[..., 0]) + o, final_state = FusedRecurrentFunction.apply( + q, + k, + v, + g, + beta, + scale, + initial_state, + inplace_final_state, + cu_seqlens, + ssm_state_indices, + num_accepted_tokens, + use_qk_l2norm_in_kernel, + ) + return o, final_state diff --git a/dlinfer/vendor/ascend/triton_ops/fla/sigmoid_gating.py b/dlinfer/vendor/ascend/triton_ops/fla/sigmoid_gating.py index 46e547f8..c7fafa3b 100644 --- a/dlinfer/vendor/ascend/triton_ops/fla/sigmoid_gating.py +++ b/dlinfer/vendor/ascend/triton_ops/fla/sigmoid_gating.py @@ -418,122 +418,3 @@ def fused_sigmoid_gating_delta_rule_update( ) o = o.squeeze(0) return o - - -def fused_recurrent_gated_delta_rule_fwd( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g: torch.Tensor, - beta: torch.Tensor, - scale: float, - initial_state: torch.Tensor, - inplace_final_state: bool = True, - cu_seqlens: torch.Tensor | None = None, - ssm_state_indices: torch.Tensor | None = None, - num_accepted_tokens: torch.Tensor | None = None, - use_qk_l2norm_in_kernel: bool = False, -) -> tuple[torch.Tensor, torch.Tensor]: - """Fused recurrent gated-delta forward with optional spec-decode offsets.""" - B, T, H, K, V = *k.shape, v.shape[-1] - HV = v.shape[2] - N = B if cu_seqlens is None else len(cu_seqlens) - 1 - BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 32) - NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) - assert NK == 1, "NK > 1 is not supported yet" - num_stages = 3 - num_warps = 1 - - o = q.new_empty(NK, *v.shape) - if inplace_final_state: - final_state = initial_state - else: - final_state = q.new_empty(T, HV, V, K, dtype=initial_state.dtype) - - stride_init_state_token = initial_state.stride(0) - stride_final_state_token = final_state.stride(0) - - if ssm_state_indices is None: - stride_indices_seq, stride_indices_tok = 1, 1 - elif ssm_state_indices.ndim == 1: - stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1 - else: - stride_indices_seq, stride_indices_tok = ssm_state_indices.stride() - - grid = (NK, NV, N * HV) - fused_recurrent_gated_delta_rule_fwd_kernel[grid]( - q=q, - k=k, - v=v, - g=g, - beta=beta, - o=o, - h0=initial_state, - ht=final_state, - cu_seqlens=cu_seqlens, - ssm_state_indices=ssm_state_indices, - num_accepted_tokens=num_accepted_tokens, - scale=scale, - N=N, - T=T, - B=B, - H=H, - HV=HV, - K=K, - V=V, - BK=BK, - BV=BV, - stride_init_state_token=stride_init_state_token, - stride_final_state_token=stride_final_state_token, - stride_indices_seq=stride_indices_seq, - stride_indices_tok=stride_indices_tok, - IS_BETA_HEADWISE=beta.ndim == v.ndim, - USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, - INPLACE_FINAL_STATE=inplace_final_state, - IS_KDA=False, - num_warps=num_warps, - num_stages=num_stages, - ) - o = o.squeeze(0) - return o, final_state - - -def fused_recurrent_gated_delta_rule( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g: torch.Tensor, - beta: torch.Tensor | None = None, - scale: float | None = None, - initial_state: torch.Tensor | None = None, - inplace_final_state: bool = True, - cu_seqlens: torch.LongTensor | None = None, - ssm_state_indices: torch.Tensor | None = None, - num_accepted_tokens: torch.Tensor | None = None, - use_qk_l2norm_in_kernel: bool = False, -) -> tuple[torch.Tensor, torch.Tensor]: - """Wrapper for recurrent gated-delta forward.""" - if cu_seqlens is not None and q.shape[0] != 1: - raise ValueError( - f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." - ) - if scale is None: - scale = k.shape[-1] ** -0.5 - else: - assert scale > 0, "scale must be positive" - if beta is None: - beta = torch.ones_like(q[..., 0]) - return fused_recurrent_gated_delta_rule_fwd( - q=q.contiguous(), - k=k.contiguous(), - v=v.contiguous(), - g=g.contiguous(), - beta=beta.contiguous(), - scale=scale, - initial_state=initial_state, - inplace_final_state=inplace_final_state, - cu_seqlens=cu_seqlens, - ssm_state_indices=ssm_state_indices, - num_accepted_tokens=num_accepted_tokens, - use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, - ) From 435eaa1d46e435f9ea11d62bab47a10abab730c1 Mon Sep 17 00:00:00 2001 From: Super User Date: Sun, 10 May 2026 14:06:07 +0000 Subject: [PATCH 05/12] fix graph --- .../cudagraph/ascend_cudagraph.py | 65 ++++++++++++++++++- dlinfer/vendor/ascend/torch_npu_ops.py | 59 ----------------- 2 files changed, 64 insertions(+), 60 deletions(-) diff --git a/dlinfer/framework/lmdeploy_ext/cudagraph/ascend_cudagraph.py b/dlinfer/framework/lmdeploy_ext/cudagraph/ascend_cudagraph.py index 51042569..7afc4fcd 100644 --- a/dlinfer/framework/lmdeploy_ext/cudagraph/ascend_cudagraph.py +++ b/dlinfer/framework/lmdeploy_ext/cudagraph/ascend_cudagraph.py @@ -107,6 +107,16 @@ def AscendCudaGraphMixin_make_buffers_cudagraph( max_batches, dtype=torch.int32 ) + # attention mask buffer for multi-token decode (kept on-device so graph + # replay can see in-place updates via the same data pointer captured at + # graph-capture time). Shape matches the fixed 2048×2048 mask used by + # npu_fused_infer_attention_score. Initialised to all-False (no masking) + # so a stale buffer value is permissive rather than destructive. + _ATTN_MASK_WIDTH = 2048 + input_buffers["attention_mask_buf"] = torch.zeros( + _ATTN_MASK_WIDTH, _ATTN_MASK_WIDTH, dtype=torch.bool, device=device + ) + # ssm if graph_meta.is_ssm: input_buffers["state_ids"] = torch.full( @@ -193,8 +203,41 @@ def AscendCudaGraphMixin_fill_buffers_cudagraph( last_q = last_q + pad_query_len input_buffers["actual_seq_lengths_q"][idx] = last_q input_buffers["q_seqlens"].copy_(input_buffers["actual_seq_lengths_q"]) + + # Fix: phantom sequences (indices num_seqs:max_batches) were padded + # with kv_seqlens=0 above. A kv_len of 0 triggers the + # paged_prefill_attention fallback with history_lens = 0 - q_per_seq + # < 0, producing negative values in expanded_kv_seq_len that cause + # undefined NPU kernel behaviour. Set phantom kv_seqlens equal to + # the per-sequence query length so history_lens = 0 (no prior KV + # beyond the phantom tokens themselves). Their attention outputs are + # discarded by get_outputs_cudagraph so the wrong KV data read from + # block 0 has no effect on real-sequence outputs. + per_seq_q_len = int(pad_query_len.item()) + if num_seqs < input_buffers["kv_seqlens"].size(0): + input_buffers["kv_seqlens"][num_seqs:] = per_seq_q_len + input_buffers["kv_seqlens_device"][num_seqs:] = per_seq_q_len + attn_metadata.actual_seq_lengths_q = input_buffers["actual_seq_lengths_q"] + # Fix: the attention mask is computed by op_backend per-step using the + # current max_kv_seq_len (diagonal = max_kv_seq_len - max_q_seq_len + 1). + # In graph mode the mask tensor is captured by pointer at warmup time + # (when kv_seqlens=[q_len,q_len] -> diagonal=1) and never updated, so + # inference steps with larger kv_seqlens use a far-too-restrictive mask + # (e.g. diagonal=1 instead of 19) causing query tokens to attend only to + # a single KV entry -> near-zero hidden states -> token-0 output. + # + # Fix: pre-allocate a device buffer, copy the fresh per-step mask into + # it in-place, and redirect attn_metadata.attention_mask to the buffer. + # Because the graph captured the buffer's data pointer, in-place updates + # are visible during replay. + fresh_mask = getattr(attn_metadata, 'attention_mask', None) + attn_mask_buf = input_buffers.get("attention_mask_buf") + if fresh_mask and attn_mask_buf is not None: + attn_mask_buf.copy_(fresh_mask[0]) + attn_metadata.attention_mask = [attn_mask_buf] + # ssm if graph_meta.is_ssm: input_buffers["q_start_loc"].fill_(0) @@ -225,6 +268,22 @@ def AscendCudaGraphMixin_fill_buffers_cudagraph( # Keep linear-attention state math on the fixed graph buffer so its # per-sequence cache lengths stay aligned with padded q_start_loc. attn_metadata.kv_seqlens_device = input_buffers["kv_seqlens_device"] + elif actual_seq_lengths_q is not None: + # Fix: non-SSM models (e.g. MTP draft model) also need q_start_loc + # updated during multi-token decode. The is_ssm branch above handles + # SSM models; without this branch the buffer stays at its initial + # torch.arange value [0,1,2,...] causing paged_prefill_attention to + # slice the wrong query tokens from the TND-layout Q tensor. + input_buffers["q_start_loc"].fill_(0) + input_buffers["q_start_loc"][:q_start_loc.size(0)] = q_start_loc + if q_start_loc.numel() > 1: + pad_query_len = q_start_loc[-1] - q_start_loc[-2] + last_q = q_start_loc[-1] + for idx in range(q_start_loc.size(0), input_buffers["q_start_loc"].size(0)): + last_q = last_q + pad_query_len + input_buffers["q_start_loc"][idx] = last_q + else: + input_buffers["q_start_loc"][q_start_loc.size(0):] = q_start_loc[-1] if inputs_embeds is not None: emb_size = inputs_embeds.size(-1) @@ -432,7 +491,6 @@ def forward(self, **kwargs): context = self.ctx_mgr.current_context() self.model.update_context_cudagraph(self.meta, context) if aclgraph_use_torch_npu_update(): - self._graph.replay() update_dict = { "actual_seq_lengths_kv": self.meta.input_buffers["kv_seqlens"], } @@ -440,9 +498,14 @@ def forward(self, **kwargs): actual_seq_lengths_q = self.meta.input_buffers.get("actual_seq_lengths_q") if actual_seq_lengths_q is not None and actual_seq_lengths_q.any(): update_dict["actual_seq_lengths"] = actual_seq_lengths_q + # Fix: update CPU inputs BEFORE replay so the current step uses the + # fresh seq-length values; the original order replayed first and + # only updated for the next replay, leaving the first iteration + # with stale captured values. self._graph.update( cpu_update_input=[update_dict] ) + self._graph.replay() else: update_attn_params(self.update_stream, self.meta, self.max_batches) self._graph.replay() diff --git a/dlinfer/vendor/ascend/torch_npu_ops.py b/dlinfer/vendor/ascend/torch_npu_ops.py index 48dff7c4..2b9c6814 100644 --- a/dlinfer/vendor/ascend/torch_npu_ops.py +++ b/dlinfer/vendor/ascend/torch_npu_ops.py @@ -460,65 +460,6 @@ def paged_prefill_attention( raise ValueError("TND paged prefill expects 1D actual_seq_lengths tensors.") if block_table.size(0) != q_seq_len.numel() or kv_seq_len.numel() != q_seq_len.numel(): raise ValueError("TND paged prefill expects per-sequence block_table and kv_seq_len.") - q_seq_len_cpu = get_cpu_seq_len(q_seq_len) - kv_seq_len_cpu = get_cpu_seq_len(kv_seq_len) - if ( - q_seq_len_cpu.numel() > 0 - and int(q_seq_len_cpu.max().item()) > 1 - and torch.any(kv_seq_len_cpu > q_seq_len_cpu) - ): - # Ascend TND fused infer attention is still unstable for speculative - # multi-token verify. Fall back to token-wise paged decode semantics so - # each speculative token only attends to history plus accepted prefix. - q_seq_len_per_seq = torch.diff( - q_seq_len_cpu, - prepend=q_seq_len_cpu.new_zeros(1), - ) - history_lens = kv_seq_len_cpu - q_seq_len_per_seq - expanded_kv_seq_len = torch.cat([ - torch.arange( - int(history_len.item()) + 1, - int(final_len.item()) + 1, - dtype=kv_seq_len_cpu.dtype, - ) - for history_len, final_len in zip(history_lens, kv_seq_len_cpu) - ]) - expanded_q_seq_len = torch.arange( - 1, - expanded_kv_seq_len.numel() + 1, - dtype=q_seq_len_cpu.dtype, - ) - expanded_block_table = block_table.repeat_interleave( - q_seq_len_per_seq.to(device=block_table.device, dtype=torch.int64), - dim=0, - ) - key_headsize, value_headsize = key_cache.shape[-1], value_cache.shape[-1] - if key_headsize == value_headsize: - return decode_attention( - query=query, - key_cache=key_cache, - value_cache=value_cache, - num_q_heads=num_q_heads, - num_kv_heads=num_kv_heads, - scale_value=scale_value, - block_table=expanded_block_table, - block_size=block_size, - q_seq_len=expanded_q_seq_len, - kv_seq_len=expanded_kv_seq_len, - softmax_scale=softmax_scale, - attn_output=attn_output, - ) - return decode_attention_mla( - query=query, - key_cache=key_cache, - num_kv_heads=num_kv_heads, - num_q_heads=num_q_heads, - scale_value=scale_value, - block_table=expanded_block_table, - kv_seq_len=expanded_kv_seq_len, - mla_vheadsize=value_cache.shape[-1], - attn_output=attn_output, - ) block_num = key_cache.size(0) key_cache = key_cache.view(block_num, block_size, -1) value_cache = value_cache.view(block_num, block_size, -1) From 24a7b46639a34b4b2feda9b6e02dcdba954a37bc Mon Sep 17 00:00:00 2001 From: WangQing <2917021186@qq.com> Date: Fri, 15 May 2026 03:23:01 +0000 Subject: [PATCH 06/12] fix: ensure state is contiguous --- dlinfer/framework/lmdeploy_ext/device/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dlinfer/framework/lmdeploy_ext/device/__init__.py b/dlinfer/framework/lmdeploy_ext/device/__init__.py index 6c72e593..210bc260 100644 --- a/dlinfer/framework/lmdeploy_ext/device/__init__.py +++ b/dlinfer/framework/lmdeploy_ext/device/__init__.py @@ -804,7 +804,7 @@ def __call__( use_qk_l2norm_in_kernel=self.use_qk_l2norm_in_kernel, ) if gated_delta_meta.spec_state_offsets is not None: - recurrent_state[gated_delta_meta.state_ids, 0] = last_recurrent_state.transpose(-1, -2).to( + recurrent_state[gated_delta_meta.state_ids, 0] = last_recurrent_state.transpose(-1, -2).contiguous().to( recurrent_state.dtype ) else: From c53ea6f540674bd87d5713a0a9a6ea3c643a5ac1 Mon Sep 17 00:00:00 2001 From: WangQing <2917021186@qq.com> Date: Fri, 15 May 2026 03:23:38 +0000 Subject: [PATCH 07/12] refactor fill buffers --- .../cudagraph/ascend_cudagraph.py | 114 +++++++----------- 1 file changed, 43 insertions(+), 71 deletions(-) diff --git a/dlinfer/framework/lmdeploy_ext/cudagraph/ascend_cudagraph.py b/dlinfer/framework/lmdeploy_ext/cudagraph/ascend_cudagraph.py index 7afc4fcd..8738bf3d 100644 --- a/dlinfer/framework/lmdeploy_ext/cudagraph/ascend_cudagraph.py +++ b/dlinfer/framework/lmdeploy_ext/cudagraph/ascend_cudagraph.py @@ -169,8 +169,6 @@ def AscendCudaGraphMixin_fill_buffers_cudagraph( input_buffers["input_ids"][:, :num_tokens] = input_ids input_buffers["position_ids"].zero_() input_buffers["position_ids"][:, :num_tokens] = position_ids - input_buffers["q_seqlens"].fill_(1) - input_buffers["q_seqlens"][: q_seqlens.size(0)] = q_seqlens input_buffers["block_offsets"].zero_() input_buffers["block_offsets"][:num_seqs, :num_blocks] = block_offsets input_buffers["kv_seqlens"].fill_(0) @@ -185,53 +183,43 @@ def AscendCudaGraphMixin_fill_buffers_cudagraph( input_buffers["x_active_mask"].fill_(0) input_buffers["x_active_mask"][:x_active_mask.size(0)] = x_active_mask - # multi-token decode: fill actual_seq_lengths_q actual_seq_lengths_q = getattr(attn_metadata, 'actual_seq_lengths_q', None) if actual_seq_lengths_q is not None: - input_buffers["actual_seq_lengths_q"].zero_() + # multi-token decode: fill actual_seq_lengths_q + bs = input_buffers["actual_seq_lengths_q"].size(0) + pad_query_len = 5 + padding_tensor = torch.arange(1, bs + 1) * pad_query_len + input_buffers["actual_seq_lengths_q"].copy_(padding_tensor) input_buffers["actual_seq_lengths_q"][:actual_seq_lengths_q.size(0)] = actual_seq_lengths_q - # TND graph replay uses fixed-size query buffers. Pad the cumulative - # query lengths to the compatible graph batch so the final element - # still matches the captured query token count. - if actual_seq_lengths_q.numel() > 0: - pad_query_len = torch.diff( - actual_seq_lengths_q, - prepend=actual_seq_lengths_q.new_zeros(1), - )[-1] - last_q = input_buffers["actual_seq_lengths_q"][actual_seq_lengths_q.size(0) - 1] - for idx in range(actual_seq_lengths_q.size(0), input_buffers["actual_seq_lengths_q"].size(0)): - last_q = last_q + pad_query_len - input_buffers["actual_seq_lengths_q"][idx] = last_q - input_buffers["q_seqlens"].copy_(input_buffers["actual_seq_lengths_q"]) - - # Fix: phantom sequences (indices num_seqs:max_batches) were padded - # with kv_seqlens=0 above. A kv_len of 0 triggers the - # paged_prefill_attention fallback with history_lens = 0 - q_per_seq - # < 0, producing negative values in expanded_kv_seq_len that cause - # undefined NPU kernel behaviour. Set phantom kv_seqlens equal to - # the per-sequence query length so history_lens = 0 (no prior KV - # beyond the phantom tokens themselves). Their attention outputs are - # discarded by get_outputs_cudagraph so the wrong KV data read from - # block 0 has no effect on real-sequence outputs. - per_seq_q_len = int(pad_query_len.item()) - if num_seqs < input_buffers["kv_seqlens"].size(0): - input_buffers["kv_seqlens"][num_seqs:] = per_seq_q_len - input_buffers["kv_seqlens_device"][num_seqs:] = per_seq_q_len - - attn_metadata.actual_seq_lengths_q = input_buffers["actual_seq_lengths_q"] - - # Fix: the attention mask is computed by op_backend per-step using the - # current max_kv_seq_len (diagonal = max_kv_seq_len - max_q_seq_len + 1). - # In graph mode the mask tensor is captured by pointer at warmup time - # (when kv_seqlens=[q_len,q_len] -> diagonal=1) and never updated, so - # inference steps with larger kv_seqlens use a far-too-restrictive mask - # (e.g. diagonal=1 instead of 19) causing query tokens to attend only to - # a single KV entry -> near-zero hidden states -> token-0 output. - # - # Fix: pre-allocate a device buffer, copy the fresh per-step mask into - # it in-place, and redirect attn_metadata.attention_mask to the buffer. - # Because the graph captured the buffer's data pointer, in-place updates - # are visible during replay. + + # actual_seq_lengths_q通过q_seqlens传入进去 + input_buffers["q_seqlens"].copy_(input_buffers["actual_seq_lengths_q"]) + else: + # single-token decode: fill q_seqlens + bs = input_buffers["q_seqlens"].size(0) + padding_tensor = torch.arange(1, bs + 1) + input_buffers["q_seqlens"].copy_(padding_tensor) + input_buffers["q_seqlens"][: q_seqlens.size(0)] = q_seqlens + + + if actual_seq_lengths_q is not None: + # multi-token decode + input_buffers["kv_seqlens"].fill_(5) + input_buffers["kv_seqlens"][:num_seqs] = kv_seqlens + input_buffers["kv_seqlens_device"].fill_(5) + input_buffers["kv_seqlens_device"][:num_seqs].copy_( + kv_seqlens.to(device=input_buffers["kv_seqlens_device"].device) + ) + else: + # single-token decode + input_buffers["kv_seqlens"].fill_(1) + input_buffers["kv_seqlens"][:num_seqs] = kv_seqlens + input_buffers["kv_seqlens_device"].fill_(1) + input_buffers["kv_seqlens_device"][:num_seqs].copy_( + kv_seqlens.to(device=input_buffers["kv_seqlens_device"].device) + ) + + if actual_seq_lengths_q is not None: fresh_mask = getattr(attn_metadata, 'attention_mask', None) attn_mask_buf = input_buffers.get("attention_mask_buf") if fresh_mask and attn_mask_buf is not None: @@ -240,19 +228,19 @@ def AscendCudaGraphMixin_fill_buffers_cudagraph( # ssm if graph_meta.is_ssm: - input_buffers["q_start_loc"].fill_(0) - input_buffers["q_start_loc"][: q_start_loc.size(0)] = q_start_loc + # main model verify if actual_seq_lengths_q is not None and q_start_loc.numel() > 1: - pad_query_len = q_start_loc[-1] - q_start_loc[-2] - last_q = q_start_loc[-1] - for idx in range(q_start_loc.size(0), input_buffers["q_start_loc"].size(0)): - last_q = last_q + pad_query_len - input_buffers["q_start_loc"][idx] = last_q + # multi-token decode + bs = input_buffers["q_start_loc"].size(0) + pad_query_len = 5 + padding_tensor = torch.arange(0, bs) * pad_query_len + input_buffers["q_start_loc"].copy_(padding_tensor) + input_buffers["q_start_loc"][:q_start_loc.size(0)] = q_start_loc else: + # single-token decode input_buffers["q_start_loc"][q_start_loc.size(0):] = q_start_loc[-1] - state_ids = kwargs["state_ids"] - input_buffers["state_ids"].fill_(-1) + input_buffers["state_ids"].fill_(0) input_buffers["state_ids"][: state_ids.size(0)].copy_(state_ids) num_accepted_tokens = getattr(attn_metadata, "num_accepted_tokens", None) @@ -268,22 +256,6 @@ def AscendCudaGraphMixin_fill_buffers_cudagraph( # Keep linear-attention state math on the fixed graph buffer so its # per-sequence cache lengths stay aligned with padded q_start_loc. attn_metadata.kv_seqlens_device = input_buffers["kv_seqlens_device"] - elif actual_seq_lengths_q is not None: - # Fix: non-SSM models (e.g. MTP draft model) also need q_start_loc - # updated during multi-token decode. The is_ssm branch above handles - # SSM models; without this branch the buffer stays at its initial - # torch.arange value [0,1,2,...] causing paged_prefill_attention to - # slice the wrong query tokens from the TND-layout Q tensor. - input_buffers["q_start_loc"].fill_(0) - input_buffers["q_start_loc"][:q_start_loc.size(0)] = q_start_loc - if q_start_loc.numel() > 1: - pad_query_len = q_start_loc[-1] - q_start_loc[-2] - last_q = q_start_loc[-1] - for idx in range(q_start_loc.size(0), input_buffers["q_start_loc"].size(0)): - last_q = last_q + pad_query_len - input_buffers["q_start_loc"][idx] = last_q - else: - input_buffers["q_start_loc"][q_start_loc.size(0):] = q_start_loc[-1] if inputs_embeds is not None: emb_size = inputs_embeds.size(-1) From bf7913edf776fa74170f1406bc6ea735dd4a8e59 Mon Sep 17 00:00:00 2001 From: WangQing <2917021186@qq.com> Date: Fri, 15 May 2026 09:49:38 +0000 Subject: [PATCH 08/12] impl ring buffer for gdn state --- .../framework/lmdeploy_ext/device/__init__.py | 23 ++++--- .../ascend/triton_ops/fla/fused_recurrent.py | 63 ++++++++++++++++--- 2 files changed, 65 insertions(+), 21 deletions(-) diff --git a/dlinfer/framework/lmdeploy_ext/device/__init__.py b/dlinfer/framework/lmdeploy_ext/device/__init__.py index 210bc260..00c3b122 100644 --- a/dlinfer/framework/lmdeploy_ext/device/__init__.py +++ b/dlinfer/framework/lmdeploy_ext/device/__init__.py @@ -730,16 +730,8 @@ def __call__( indices = gated_delta_meta.state_ids cu_seqlens = gated_delta_meta.cu_seqlens if is_multi_token_decode: - query_len = gated_delta_meta.max_query_len state_slots = recurrent_state.size(1) flat_recurrent_state = recurrent_state.view(-1, *recurrent_state.shape[2:]) - state_indices = self._get_decode_state_indices( - indices, - gated_delta_meta.cache_seqlens, - state_slots, - query_len, - ) - state_indices, _ = torch.sort(state_indices, dim=1) core_attn_out, _ = self.fused_recurrent_gated_delta_rule( q=query.contiguous(), k=key.contiguous(), @@ -749,8 +741,9 @@ def __call__( initial_state=flat_recurrent_state, inplace_final_state=True, cu_seqlens=cu_seqlens, - ssm_state_indices=state_indices, - num_accepted_tokens=gated_delta_meta.num_accepted_tokens, + cache_seqlens_rb=gated_delta_meta.cache_seqlens, + state_ids_rb=indices, + num_state=state_slots, use_qk_l2norm_in_kernel=self.use_qk_l2norm_in_kernel, ) return core_attn_out, None @@ -787,7 +780,10 @@ def __call__( last_recurrent_state = None else: if gated_delta_meta.spec_state_offsets is not None: - initial_state = recurrent_state[gated_delta_meta.state_ids, 0].transpose(-1, -2).contiguous() + state_ids = gated_delta_meta.state_ids + # Circular-buffer read slot: history_len % NUM_STATE + read_slots = gated_delta_meta.spec_state_offsets[0] + initial_state = recurrent_state[state_ids, read_slots].transpose(-1, -2).contiguous() else: initial_state = recurrent_state[gated_delta_meta.state_ids] initial_state[~gated_delta_meta.has_initial_state, ...] = 0 @@ -804,7 +800,10 @@ def __call__( use_qk_l2norm_in_kernel=self.use_qk_l2norm_in_kernel, ) if gated_delta_meta.spec_state_offsets is not None: - recurrent_state[gated_delta_meta.state_ids, 0] = last_recurrent_state.transpose(-1, -2).contiguous().to( + state_ids = gated_delta_meta.state_ids + # Circular-buffer write slot: (history_len + query_len) % NUM_STATE + write_slots = gated_delta_meta.spec_state_offsets[1] + recurrent_state[state_ids, write_slots] = last_recurrent_state.transpose(-1, -2).to( recurrent_state.dtype ) else: diff --git a/dlinfer/vendor/ascend/triton_ops/fla/fused_recurrent.py b/dlinfer/vendor/ascend/triton_ops/fla/fused_recurrent.py index 1d05ee4c..6709bcad 100644 --- a/dlinfer/vendor/ascend/triton_ops/fla/fused_recurrent.py +++ b/dlinfer/vendor/ascend/triton_ops/fla/fused_recurrent.py @@ -29,6 +29,7 @@ "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, "IS_CONTINUOUS_BATCHING": lambda args: args["ssm_state_indices"] is not None, "IS_SPEC_DECODING": lambda args: args["num_accepted_tokens"] is not None, + "IS_CIRCULAR_BUFFER": lambda args: args["cache_seqlens_rb"] is not None, } ) @triton.jit(do_not_specialize=["N", "T"]) @@ -44,6 +45,8 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( cu_seqlens, ssm_state_indices, num_accepted_tokens, + cache_seqlens_rb, # [N] history lengths for circular-buffer read/write + state_ids_rb, # [N] per-sequence base slot index (= state_id) scale, N: tl.int64, # num of sequences T: tl.int64, # num of tokens @@ -54,6 +57,7 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( V: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, + NUM_STATE: tl.constexpr, # circular buffer size (= state_slots = 1 + num_spec_tokens) stride_init_state_token: tl.constexpr, stride_final_state_token: tl.constexpr, stride_indices_seq: tl.constexpr, @@ -65,6 +69,7 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( IS_VARLEN: tl.constexpr, IS_CONTINUOUS_BATCHING: tl.constexpr, IS_SPEC_DECODING: tl.constexpr, + IS_CIRCULAR_BUFFER: tl.constexpr, # use NVIDIA-style circular-buffer state indexing IS_KDA: tl.constexpr, ): i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) @@ -108,8 +113,23 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( mask_h = mask_v[:, None] & mask_k[None, :] b_h = tl.zeros([BV, BK], dtype=tl.float32) + + # Pre-load circular-buffer addressing params once (before the token loop). + # IS_CIRCULAR_BUFFER is constexpr so the dead branch is eliminated at compile time. + if IS_CIRCULAR_BUFFER: + h_rb = tl.load(cache_seqlens_rb + i_n).to(tl.int64) + s_id_rb = tl.load(state_ids_rb + i_n).to(tl.int64) + # Skip padding sequences: invalid state_ids are clamped to 0 on the host side. + # Letting them proceed would corrupt slot-0 state and cause NaN in graph mode. + if s_id_rb <= 0: + return + if USE_INITIAL_STATE: - if IS_CONTINUOUS_BATCHING: + if IS_CIRCULAR_BUFFER: + # Circular-buffer read: slot = cache_seqlens % NUM_STATE + read_slot = s_id_rb * NUM_STATE + h_rb % NUM_STATE + p_h0 = h0 + read_slot * stride_init_state_token + elif IS_CONTINUOUS_BATCHING: if IS_SPEC_DECODING: i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1 else: @@ -119,7 +139,7 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( tl.int64 ) # Skip if state index is invalid (PAD_SLOT_ID = -1) - if state_idx < 0: + if state_idx <= 0: return p_h0 = h0 + state_idx * stride_init_state_token else: @@ -158,15 +178,22 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( # keep the states for multi-query tokens if INPLACE_FINAL_STATE: - # Load state index and check for PAD_SLOT_ID (-1) - final_state_idx = tl.load( - ssm_state_indices + i_n * stride_indices_seq + i_t - ).to(tl.int64) - # Only store if state index is valid (not PAD_SLOT_ID) - if final_state_idx >= 0: - p_ht = ht + final_state_idx * stride_final_state_token + if IS_CIRCULAR_BUFFER: + # Circular-buffer write: slot = (cache_seqlens + i_t + 1) % NUM_STATE + write_slot = s_id_rb * NUM_STATE + (h_rb + i_t + 1) % NUM_STATE + p_ht = ht + write_slot * stride_final_state_token p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :] tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + else: + # Load state index and check for PAD_SLOT_ID (-1) + final_state_idx = tl.load( + ssm_state_indices + i_n * stride_indices_seq + i_t + ).to(tl.int64) + # Only store if state index is valid (not PAD_SLOT_ID) + if final_state_idx > 0: + p_ht = ht + final_state_idx * stride_final_state_token + p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :] + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) else: p_ht = ht + (bos + i_t) * stride_final_state_token p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :] @@ -195,6 +222,9 @@ def fused_recurrent_gated_delta_rule_fwd( cu_seqlens: torch.LongTensor | None = None, ssm_state_indices: torch.Tensor | None = None, num_accepted_tokens: torch.Tensor | None = None, + cache_seqlens_rb: torch.Tensor | None = None, + state_ids_rb: torch.Tensor | None = None, + num_state: int = 1, use_qk_l2norm_in_kernel: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: B, T, H, K, V = *k.shape, v.shape[-1] @@ -235,6 +265,8 @@ def fused_recurrent_gated_delta_rule_fwd( cu_seqlens=cu_seqlens, ssm_state_indices=ssm_state_indices, num_accepted_tokens=num_accepted_tokens, + cache_seqlens_rb=cache_seqlens_rb, + state_ids_rb=state_ids_rb, scale=scale, N=N, T=T, @@ -245,6 +277,7 @@ def fused_recurrent_gated_delta_rule_fwd( V=V, BK=BK, BV=BV, + NUM_STATE=num_state, stride_init_state_token=stride_init_state_token, stride_final_state_token=stride_final_state_token, stride_indices_seq=stride_indices_seq, @@ -275,6 +308,9 @@ def forward( cu_seqlens: torch.LongTensor | None = None, ssm_state_indices: torch.Tensor | None = None, num_accepted_tokens: torch.Tensor | None = None, + cache_seqlens_rb: torch.Tensor | None = None, + state_ids_rb: torch.Tensor | None = None, + num_state: int = 1, use_qk_l2norm_in_kernel: bool = False, ): o, final_state = fused_recurrent_gated_delta_rule_fwd( @@ -289,6 +325,9 @@ def forward( cu_seqlens=cu_seqlens, ssm_state_indices=ssm_state_indices, num_accepted_tokens=num_accepted_tokens, + cache_seqlens_rb=cache_seqlens_rb, + state_ids_rb=state_ids_rb, + num_state=num_state, use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, ) @@ -307,6 +346,9 @@ def fused_recurrent_gated_delta_rule( cu_seqlens: torch.LongTensor | None = None, ssm_state_indices: torch.Tensor | None = None, num_accepted_tokens: torch.Tensor | None = None, + cache_seqlens_rb: torch.Tensor | None = None, + state_ids_rb: torch.Tensor | None = None, + num_state: int = 1, use_qk_l2norm_in_kernel: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: r""" @@ -396,6 +438,9 @@ def fused_recurrent_gated_delta_rule( cu_seqlens, ssm_state_indices, num_accepted_tokens, + cache_seqlens_rb, + state_ids_rb, + num_state, use_qk_l2norm_in_kernel, ) return o, final_state From 88152ac97cde990abbfa3f5b267d3a4814671177 Mon Sep 17 00:00:00 2001 From: WangQing <2917021186@qq.com> Date: Mon, 18 May 2026 02:29:36 +0000 Subject: [PATCH 09/12] impl ring buffer for conv1d state --- .../framework/lmdeploy_ext/device/__init__.py | 44 ++- .../vendor/ascend/triton_ops/causal_conv1d.py | 340 ++++++++++++------ 2 files changed, 272 insertions(+), 112 deletions(-) diff --git a/dlinfer/framework/lmdeploy_ext/device/__init__.py b/dlinfer/framework/lmdeploy_ext/device/__init__.py index 00c3b122..334e264f 100644 --- a/dlinfer/framework/lmdeploy_ext/device/__init__.py +++ b/dlinfer/framework/lmdeploy_ext/device/__init__.py @@ -561,6 +561,8 @@ def __init__( self.max_query_len = max(num_tokens // num_seqs, 1) self.cache_seqlens = None self.spec_state_offsets = None + self.spec_conv_offsets = None + self.conv_kernel_size = conv_kernel_size kv_seqlens_device = getattr(attn_metadata, 'kv_seqlens_device', None) if query_lens is not None and kv_seqlens_device is not None: kv_seqlens = kv_seqlens_device.to(dtype=torch.int32) @@ -571,6 +573,28 @@ def __init__( torch.remainder(self.cache_seqlens, state_slots), torch.remainder(kv_seqlens, state_slots), ) + # Conv ring buffer: state_len = linear_conv_kernel_dim + num_spec_tokens. + # `conv_kernel_size` here is the conv width (linear_conv_kernel_dim). + state_len = conv_kernel_size + self.num_spec_tokens + range_idx = torch.arange( + -conv_kernel_size, + 0, + device=self.cache_seqlens.device, + dtype=torch.int32, + ) + # Read the (conv_kernel_size - 1) tokens preceding the current write + # window from the circular buffer. + read_conv_offsets = torch.remainder( + self.cache_seqlens[:, None] + range_idx[1:][None], + state_len, + ).to(torch.int64) + # Write the last conv_kernel_size tokens of this prefill batch into + # circular-buffer slots so the next decode read aligns naturally. + write_conv_offsets = torch.remainder( + kv_seqlens[:, None] + range_idx[None], + state_len, + ).to(torch.int64) + self.spec_conv_offsets = (read_conv_offsets, write_conv_offsets) self.num_accepted_tokens = getattr(attn_metadata, 'num_accepted_tokens', None) if self.num_accepted_tokens is None and self.is_multi_token_decoding and query_lens is not None: self.num_accepted_tokens = torch.ones(query_lens.size(0), dtype=torch.int32, device=self.cu_seqlens.device) @@ -610,6 +634,12 @@ def conv1d_func( out: (b, seqlen, dim) conv_state: (b, dim, kernel_size) """ + spec_conv_offsets = getattr(gated_delta_meta, "spec_conv_offsets", None) + if spec_conv_offsets is not None: + read_conv_offsets, write_conv_offsets = spec_conv_offsets + else: + read_conv_offsets, write_conv_offsets = None, None + out = self.causal_conv1d_fn( x.t(), weight, @@ -619,6 +649,8 @@ def conv1d_func( has_initial_state=gated_delta_meta.has_initial_state, cache_indices=gated_delta_meta.conv_state_indices, query_start_loc=gated_delta_meta.cu_seqlens, + read_conv_offsets=read_conv_offsets, + write_conv_offsets=write_conv_offsets, ) out = out.t().unsqueeze(0) @@ -637,7 +669,17 @@ def conv1d_update( ): update_kwargs = {} validate_data = True - if getattr(gated_delta_meta, 'is_multi_token_decoding', False): + if getattr(gated_delta_meta, 'cache_seqlens', None) is not None and gated_delta_meta.is_decoding: + # Ring-buffer decode path: positions are derived from cache_seqlens. + update_kwargs['cache_seqlens'] = gated_delta_meta.cache_seqlens + if getattr(gated_delta_meta, 'is_multi_token_decoding', False): + # Multi-token decode uses varlen format (2-D x tensor); must keep + # IS_VARLEN=True by passing query_start_loc, otherwise x gets incorrectly + # unsqueezed and cache_seqlens is accessed out-of-bounds. + update_kwargs['query_start_loc'] = gated_delta_meta.cu_seqlens + update_kwargs['max_query_len'] = gated_delta_meta.max_query_len + validate_data = False + elif getattr(gated_delta_meta, 'is_multi_token_decoding', False): update_kwargs.update( num_accepted_tokens=gated_delta_meta.num_accepted_tokens, query_start_loc=gated_delta_meta.cu_seqlens, diff --git a/dlinfer/vendor/ascend/triton_ops/causal_conv1d.py b/dlinfer/vendor/ascend/triton_ops/causal_conv1d.py index 00c919e7..e249bfae 100644 --- a/dlinfer/vendor/ascend/triton_ops/causal_conv1d.py +++ b/dlinfer/vendor/ascend/triton_ops/causal_conv1d.py @@ -70,6 +70,8 @@ def causal_conv1d_fn( query_start_loc: Optional[torch.Tensor] = None, metadata: Optional[Any] = None, pad_slot_id: int = PAD_SLOT_ID, + read_conv_offsets: Optional[torch.Tensor] = None, + write_conv_offsets: Optional[torch.Tensor] = None, ): """ Prefill-phase varlen causal conv1d using PyTorch reference implementation. @@ -80,7 +82,17 @@ def causal_conv1d_fn( query_start_loc: (batch + 1) int32 cache_indices: (batch) int32 has_initial_state: (batch) bool - conv_states: (..., dim, width - 1) + conv_states: (..., dim, state_len) — state_len == width-1 in legacy linear + layout; state_len == width + num_spec_tokens in ring layout. + + Ring-buffer mode (used by MTP / speculative decoding) is enabled when + `write_conv_offsets` is provided. Then `conv_states` is treated as a ring + of size `state_len = conv_states.shape[-1]`: + * initial state (when has_initial_state[i]) is gathered from + read_conv_offsets[i] (shape (width-1,)); + * the last `width` tokens of each sequence are scattered to + write_conv_offsets[i] (shape (width,)). + When `write_conv_offsets` is None the legacy semantics apply unchanged. """ if activation not in [None, "silu", "swish"]: raise NotImplementedError("activation must be None, silu, or swish") @@ -91,6 +103,7 @@ def causal_conv1d_fn( if query_start_loc is None: raise ValueError("query_start_loc is required for prefill mode") + is_ring = write_conv_offsets is not None seqlens = query_start_loc[1:] - query_start_loc[:-1] seqlens = seqlens.tolist() splits = torch.split(x, seqlens, dim=-1) @@ -100,23 +113,58 @@ def causal_conv1d_fn( x_s = splits[i] if cache_indices[i] == PAD_SLOT_ID: continue + + if has_initial_state[i]: + if is_ring: + slot_idx = read_conv_offsets[i] + init_state = ( + conv_states[cache_indices[i]] + .index_select(-1, slot_idx) + .unsqueeze(0) + ) + else: + init_state = conv_states[cache_indices[i]][..., : (width - 1)] + else: + init_state = None + out_ref_b = causal_conv1d_ref( x_s, weight, bias, activation=activation, - return_final_states=True, - final_states_out=conv_states[cache_indices[i]][ - ..., : (width - 1) - ].unsqueeze(0), - initial_states=( - conv_states[cache_indices[i]][..., : (width - 1)] - if has_initial_state[i] - else None + return_final_states=not is_ring, + final_states_out=( + None + if is_ring + else conv_states[cache_indices[i]][..., : (width - 1)].unsqueeze(0) ), + initial_states=init_state, ) out_chunks.append(out_ref_b[0]) out = torch.cat(out_chunks, dim=-1) + + if is_ring: + # Vectorised ring write: place each sequence's trailing `width` tokens + # at their circular slots in conv_states. Replaces the post-hoc overwrite + # the caller used to do, and removes the redundant linear-slot copy_(). + K = write_conv_offsets.size(1) + tok_offsets = torch.arange( + -K, 0, device=query_start_loc.device, dtype=query_start_loc.dtype + ) + token_idx = ( + (query_start_loc[1:, None] + tok_offsets[None]).clamp_min(0).to(torch.int64) + ) + x_gather = x.index_select(1, token_idx.reshape(-1)).reshape( + x.size(0), token_idx.size(0), token_idx.size(1) + ) + cache_idx = cache_indices.to(torch.int64) + dim_idx = torch.arange(conv_states.size(1), device=conv_states.device) + conv_states[ + cache_idx[:, None, None], + dim_idx[None, :, None], + write_conv_offsets[:, None, :], + ] = x_gather.permute(1, 0, 2).to(conv_states.dtype) + return out @@ -131,6 +179,7 @@ def _causal_conv1d_update_kernel_npu_tiled( query_start_loc_ptr, block_idx_last_scheduled_token, initial_state_idx, + cache_seqlens_ptr, o_ptr, batch: tl.int32, dim: tl.constexpr, @@ -158,6 +207,7 @@ def _causal_conv1d_update_kernel_npu_tiled( IS_SPEC_DECODING: tl.constexpr, NP2_STATELEN: tl.constexpr, USE_PAD_SLOT: tl.constexpr, + IS_CIRCULAR_BUFFER: tl.constexpr, BLOCK_N: tl.constexpr, B_TILE: tl.constexpr, T_CHUNK: tl.constexpr, @@ -249,17 +299,26 @@ def _causal_conv1d_update_kernel_npu_tiled( lane_active = lane_active & (seqlen_run > 0) - if IS_SPEC_DECODING: - conv_state_token_offset = ( - tl.load(num_accepted_tokens_ptr + b, mask=lane_active, other=1).to( - tl.int64 - ) - - 1 - ) - shift = tl.full((), 1, tl.int32) - else: + if IS_CIRCULAR_BUFFER: + cb_raw = tl.load(cache_seqlens_ptr + b, mask=lane_active, other=0).to(tl.int32) + cb_pos = cb_raw % state_len # write-start: wrapped current position + cb_read_start = (cb_pos - (KERNEL_WIDTH - 1) + state_len) % state_len conv_state_token_offset = tl.full((), 0, tl.int64) - shift = seqlen_run + shift = tl.full((), 0, tl.int32) + else: + cb_pos = tl.full((), 0, tl.int32) + cb_read_start = tl.full((), 0, tl.int32) + if IS_SPEC_DECODING: + conv_state_token_offset = ( + tl.load(num_accepted_tokens_ptr + b, mask=lane_active, other=1).to( + tl.int64 + ) + - 1 + ) + shift = tl.full((), 1, tl.int32) + else: + conv_state_token_offset = tl.full((), 0, tl.int64) + shift = seqlen_run conv_states_base = ( conv_state_ptr @@ -275,42 +334,77 @@ def _causal_conv1d_update_kernel_npu_tiled( col2 = tl.zeros((BLOCK_N,), dtype=tl.float16) col3 = tl.zeros((BLOCK_N,), dtype=tl.float16) col4 = tl.zeros((BLOCK_N,), dtype=tl.float16) - if KERNEL_WIDTH >= 2: - col0 = tl.load( - prior_tokens + 0 * stride_conv_state_tok, - mask=lane_active & mask_w, - other=0.0, - ).to(tl.float16) - if KERNEL_WIDTH >= 3: - col1 = tl.load( - prior_tokens + 1 * stride_conv_state_tok, - mask=lane_active & mask_w, - other=0.0, - ).to(tl.float16) - if KERNEL_WIDTH >= 4: - col2 = tl.load( - prior_tokens + 2 * stride_conv_state_tok, - mask=lane_active & mask_w, - other=0.0, - ).to(tl.float16) - if KERNEL_WIDTH >= 5: - col3 = tl.load( - prior_tokens + 3 * stride_conv_state_tok, - mask=lane_active & mask_w, - other=0.0, - ).to(tl.float16) - if KERNEL_WIDTH >= 6: - col4 = tl.load( - prior_tokens + 4 * stride_conv_state_tok, - mask=lane_active & mask_w, - other=0.0, - ).to(tl.float16) - - conv_states_offset = tl.load( - conv_state_indices_ptr + b * stride_state_indices + current_last_index, - mask=lane_active, - other=0, - ).to(tl.int64) + if IS_CIRCULAR_BUFFER: + if KERNEL_WIDTH >= 2: + col0 = tl.load( + conv_states_base + cb_read_start * stride_conv_state_tok, + mask=lane_active & mask_w, + other=0.0, + ).to(tl.float16) + if KERNEL_WIDTH >= 3: + col1 = tl.load( + conv_states_base + ((cb_read_start + 1) % state_len) * stride_conv_state_tok, + mask=lane_active & mask_w, + other=0.0, + ).to(tl.float16) + if KERNEL_WIDTH >= 4: + col2 = tl.load( + conv_states_base + ((cb_read_start + 2) % state_len) * stride_conv_state_tok, + mask=lane_active & mask_w, + other=0.0, + ).to(tl.float16) + if KERNEL_WIDTH >= 5: + col3 = tl.load( + conv_states_base + ((cb_read_start + 3) % state_len) * stride_conv_state_tok, + mask=lane_active & mask_w, + other=0.0, + ).to(tl.float16) + if KERNEL_WIDTH >= 6: + col4 = tl.load( + conv_states_base + ((cb_read_start + 4) % state_len) * stride_conv_state_tok, + mask=lane_active & mask_w, + other=0.0, + ).to(tl.float16) + else: + if KERNEL_WIDTH >= 2: + col0 = tl.load( + prior_tokens + 0 * stride_conv_state_tok, + mask=lane_active & mask_w, + other=0.0, + ).to(tl.float16) + if KERNEL_WIDTH >= 3: + col1 = tl.load( + prior_tokens + 1 * stride_conv_state_tok, + mask=lane_active & mask_w, + other=0.0, + ).to(tl.float16) + if KERNEL_WIDTH >= 4: + col2 = tl.load( + prior_tokens + 2 * stride_conv_state_tok, + mask=lane_active & mask_w, + other=0.0, + ).to(tl.float16) + if KERNEL_WIDTH >= 5: + col3 = tl.load( + prior_tokens + 3 * stride_conv_state_tok, + mask=lane_active & mask_w, + other=0.0, + ).to(tl.float16) + if KERNEL_WIDTH >= 6: + col4 = tl.load( + prior_tokens + 4 * stride_conv_state_tok, + mask=lane_active & mask_w, + other=0.0, + ).to(tl.float16) + + if not IS_CIRCULAR_BUFFER: + conv_states_offset = tl.load( + conv_state_indices_ptr + b * stride_state_indices + current_last_index, + mask=lane_active, + other=0, + ).to(tl.int64) + else: + conv_states_offset = tl.full((), 0, tl.int64) use_shift = seqlen_run < state_len_run use_tail = seqlen_run >= state_len_run @@ -335,61 +429,79 @@ def _causal_conv1d_update_kernel_npu_tiled( ) x_base = x_ptr + x_offset + idx_feats * stride_x_dim - for t0 in tl.static_range(0, NP2_STATELEN, T_CHUNK): - dst_tok = (t0 + tok_vec).to(tl.int32) - src_tok = (dst_tok + shift).to(tl.int32) - m_tok = ( - use_shift - & (dst_tok < keep_shift) - & (src_tok < state_len_run) - & (dst_tok < state_len_run) - ) - m = ( - (lane_active & m_tok)[:, None] - & mask_w[None, :] - & (conv_states_input_coord < num_cache_lines) - & (conv_states_offset < num_cache_lines) - ) - src_ptrs = ( - state_src_base[None, :] + src_tok[:, None] * stride_conv_state_tok - ) - dst_ptrs = ( - state_dst_base[None, :] + dst_tok[:, None] * stride_conv_state_tok - ) - vals = tl.load(src_ptrs, mask=m, other=0.0) - tl.store(dst_ptrs, vals, mask=m) - - for t0 in tl.static_range(0, seqlen, T_CHUNK): - x_tok = (t0 + tok_vec).to(tl.int32) - dst_tok = (keep_shift + x_tok).to(tl.int32) - m_tok = use_shift & (x_tok < seqlen_run) & (dst_tok < state_len_run) - m = ( - (lane_active & m_tok)[:, None] - & mask_w[None, :] - & (conv_states_offset < num_cache_lines) - ) - x_ptrs = x_base[None, :] + x_tok[:, None] * stride_x_token - dst_ptrs = ( - state_dst_base[None, :] + dst_tok[:, None] * stride_conv_state_tok - ) - x_vals = tl.load(x_ptrs, mask=m, other=0.0) - tl.store(dst_ptrs, x_vals, mask=m) - - for t0 in tl.static_range(0, NP2_STATELEN, T_CHUNK): - dst_tok = (t0 + tok_vec).to(tl.int32) - x_tok = (tail_start + dst_tok).to(tl.int32) - m_tok = use_tail & (dst_tok < state_len_run) & (x_tok < seqlen_run) - m = ( - (lane_active & m_tok)[:, None] - & mask_w[None, :] - & (conv_states_offset < num_cache_lines) - ) - x_ptrs = x_base[None, :] + x_tok[:, None] * stride_x_token - dst_ptrs = ( - state_dst_base[None, :] + dst_tok[:, None] * stride_conv_state_tok - ) - x_vals = tl.load(x_ptrs, mask=m, other=0.0) - tl.store(dst_ptrs, x_vals, mask=m) + if IS_CIRCULAR_BUFFER: + # Circular write: write each input token to its circular position. + # No shift needed — positions are derived from cache_seqlens. + for t0 in tl.static_range(0, seqlen, T_CHUNK): + x_tok = (t0 + tok_vec).to(tl.int32) + write_tok = (cb_pos + x_tok) % state_len + m = ( + (lane_active & (x_tok < seqlen_run))[:, None] + & mask_w[None, :] + & (conv_states_input_coord < num_cache_lines) + ) + x_ptrs = x_base[None, :] + x_tok[:, None] * stride_x_token + dst_ptrs = ( + conv_states_base[None, :] + write_tok[:, None] * stride_conv_state_tok + ) + x_vals = tl.load(x_ptrs, mask=m, other=0.0) + tl.store(dst_ptrs, x_vals, mask=m) + else: + for t0 in tl.static_range(0, NP2_STATELEN, T_CHUNK): + dst_tok = (t0 + tok_vec).to(tl.int32) + src_tok = (dst_tok + shift).to(tl.int32) + m_tok = ( + use_shift + & (dst_tok < keep_shift) + & (src_tok < state_len_run) + & (dst_tok < state_len_run) + ) + m = ( + (lane_active & m_tok)[:, None] + & mask_w[None, :] + & (conv_states_input_coord < num_cache_lines) + & (conv_states_offset < num_cache_lines) + ) + src_ptrs = ( + state_src_base[None, :] + src_tok[:, None] * stride_conv_state_tok + ) + dst_ptrs = ( + state_dst_base[None, :] + dst_tok[:, None] * stride_conv_state_tok + ) + vals = tl.load(src_ptrs, mask=m, other=0.0) + tl.store(dst_ptrs, vals, mask=m) + + for t0 in tl.static_range(0, seqlen, T_CHUNK): + x_tok = (t0 + tok_vec).to(tl.int32) + dst_tok = (keep_shift + x_tok).to(tl.int32) + m_tok = use_shift & (x_tok < seqlen_run) & (dst_tok < state_len_run) + m = ( + (lane_active & m_tok)[:, None] + & mask_w[None, :] + & (conv_states_offset < num_cache_lines) + ) + x_ptrs = x_base[None, :] + x_tok[:, None] * stride_x_token + dst_ptrs = ( + state_dst_base[None, :] + dst_tok[:, None] * stride_conv_state_tok + ) + x_vals = tl.load(x_ptrs, mask=m, other=0.0) + tl.store(dst_ptrs, x_vals, mask=m) + + for t0 in tl.static_range(0, NP2_STATELEN, T_CHUNK): + dst_tok = (t0 + tok_vec).to(tl.int32) + x_tok = (tail_start + dst_tok).to(tl.int32) + m_tok = use_tail & (dst_tok < state_len_run) & (x_tok < seqlen_run) + m = ( + (lane_active & m_tok)[:, None] + & mask_w[None, :] + & (conv_states_offset < num_cache_lines) + ) + x_ptrs = x_base[None, :] + x_tok[:, None] * stride_x_token + dst_ptrs = ( + state_dst_base[None, :] + dst_tok[:, None] * stride_conv_state_tok + ) + x_vals = tl.load(x_ptrs, mask=m, other=0.0) + tl.store(dst_ptrs, x_vals, mask=m) x_base_1d = x_base o_base_1d = o_ptr + o_offset + idx_feats * stride_o_dim @@ -513,6 +625,7 @@ def causal_conv1d_update_npu( pad_slot_id: int = PAD_SLOT_ID, block_idx_last_scheduled_token: Optional[torch.Tensor] = None, initial_state_idx: Optional[torch.Tensor] = None, + cache_seqlens: Optional[torch.Tensor] = None, validate_data=False, ): if validate_data: @@ -557,7 +670,10 @@ def causal_conv1d_update_npu( conv_state_indices.stride(0) if conv_state_indices is not None else 0 ) - if num_accepted_tokens is not None: + if cache_seqlens is not None: + # Circular buffer: use the full allocated state size as the modulus. + eff_state_len = state_len_total + elif num_accepted_tokens is not None: eff_state_len = width - 1 + (seqlen - 1) else: eff_state_len = width - 1 @@ -591,6 +707,7 @@ def grid(META): query_start_loc, block_idx_last_scheduled_token, initial_state_idx, + cache_seqlens, out, batch, dim, @@ -618,6 +735,7 @@ def grid(META): IS_SPEC_DECODING=num_accepted_tokens is not None, NP2_STATELEN=np2_statelen, USE_PAD_SLOT=pad_slot_id is not None, + IS_CIRCULAR_BUFFER=cache_seqlens is not None, BLOCK_N=block_n, B_TILE=b_tile, T_CHUNK=t_chunk, From 0345241970c44dd0ef6bb3875b63ab40bbf92ad2 Mon Sep 17 00:00:00 2001 From: WangQing <2917021186@qq.com> Date: Mon, 18 May 2026 08:07:03 +0000 Subject: [PATCH 10/12] Refactor GDN and conv1d computation flow --- .../framework/lmdeploy_ext/device/__init__.py | 167 +++++------------- 1 file changed, 49 insertions(+), 118 deletions(-) diff --git a/dlinfer/framework/lmdeploy_ext/device/__init__.py b/dlinfer/framework/lmdeploy_ext/device/__init__.py index 334e264f..c746ada6 100644 --- a/dlinfer/framework/lmdeploy_ext/device/__init__.py +++ b/dlinfer/framework/lmdeploy_ext/device/__init__.py @@ -546,65 +546,16 @@ def __init__( state_ids: torch.Tensor, attn_metadata: Any, ): - self.is_multi_token_decoding = getattr(attn_metadata, 'is_multi_token_decoding', False) - # Keep decode semantics for linear-attention state updates even when - # full attention uses a prefill-style TND verify path. - self.is_decoding = attn_metadata.is_decoding or self.is_multi_token_decoding + self.is_decoding = attn_metadata.is_decoding self.cu_seqlens = attn_metadata.q_start_loc - self.num_spec_tokens = get_step_ctx_manager().build_ctx.num_spec_tokens - - query_lens = None - num_seqs = 1 - if self.cu_seqlens is not None: - query_lens = torch.diff(self.cu_seqlens).to(torch.int32) - num_seqs = max(int(self.cu_seqlens.numel()) - 1, 1) - self.max_query_len = max(num_tokens // num_seqs, 1) - self.cache_seqlens = None - self.spec_state_offsets = None - self.spec_conv_offsets = None - self.conv_kernel_size = conv_kernel_size - kv_seqlens_device = getattr(attn_metadata, 'kv_seqlens_device', None) - if query_lens is not None and kv_seqlens_device is not None: - kv_seqlens = kv_seqlens_device.to(dtype=torch.int32) - self.cache_seqlens = (kv_seqlens - query_lens).contiguous() - if self.num_spec_tokens > 0 and not self.is_decoding: - state_slots = 1 + self.num_spec_tokens - self.spec_state_offsets = ( - torch.remainder(self.cache_seqlens, state_slots), - torch.remainder(kv_seqlens, state_slots), - ) - # Conv ring buffer: state_len = linear_conv_kernel_dim + num_spec_tokens. - # `conv_kernel_size` here is the conv width (linear_conv_kernel_dim). - state_len = conv_kernel_size + self.num_spec_tokens - range_idx = torch.arange( - -conv_kernel_size, - 0, - device=self.cache_seqlens.device, - dtype=torch.int32, - ) - # Read the (conv_kernel_size - 1) tokens preceding the current write - # window from the circular buffer. - read_conv_offsets = torch.remainder( - self.cache_seqlens[:, None] + range_idx[1:][None], - state_len, - ).to(torch.int64) - # Write the last conv_kernel_size tokens of this prefill batch into - # circular-buffer slots so the next decode read aligns naturally. - write_conv_offsets = torch.remainder( - kv_seqlens[:, None] + range_idx[None], - state_len, - ).to(torch.int64) - self.spec_conv_offsets = (read_conv_offsets, write_conv_offsets) - self.num_accepted_tokens = getattr(attn_metadata, 'num_accepted_tokens', None) - if self.num_accepted_tokens is None and self.is_multi_token_decoding and query_lens is not None: - self.num_accepted_tokens = torch.ones(query_lens.size(0), dtype=torch.int32, device=self.cu_seqlens.device) - elif self.num_accepted_tokens is not None: - self.num_accepted_tokens = self.num_accepted_tokens.to( - device=self.cu_seqlens.device if self.cu_seqlens is not None else state_ids.device, - dtype=torch.int32, - ).contiguous() - - # state_ids, fill invalid state with 0 + self.is_multi_token_decoding = attn_metadata.is_multi_token_decoding + self.max_q_seq_len = attn_metadata.max_q_seq_len + + self.num_spec_tokens = get_step_ctx_manager().build_ctx.num_spec_tokens + self.cache_seqlens = getattr(attn_metadata, 'cache_seqlens', None) + self.spec_state_offsets = getattr(attn_metadata, 'spec_state_offsets', None) + self.spec_conv_offsets = getattr(attn_metadata, 'spec_conv_offsets', None) + self.state_ids = state_ids.clamp(0) self.has_initial_state = attn_metadata.has_initial_state self.conv_state_indices = self.state_ids.to(torch.int32) @@ -639,7 +590,7 @@ def conv1d_func( read_conv_offsets, write_conv_offsets = spec_conv_offsets else: read_conv_offsets, write_conv_offsets = None, None - + out = self.causal_conv1d_fn( x.t(), weight, @@ -668,24 +619,21 @@ def conv1d_update( gated_delta_meta: GatedDeltaMeta, ): update_kwargs = {} - validate_data = True - if getattr(gated_delta_meta, 'cache_seqlens', None) is not None and gated_delta_meta.is_decoding: - # Ring-buffer decode path: positions are derived from cache_seqlens. - update_kwargs['cache_seqlens'] = gated_delta_meta.cache_seqlens - if getattr(gated_delta_meta, 'is_multi_token_decoding', False): - # Multi-token decode uses varlen format (2-D x tensor); must keep - # IS_VARLEN=True by passing query_start_loc, otherwise x gets incorrectly - # unsqueezed and cache_seqlens is accessed out-of-bounds. - update_kwargs['query_start_loc'] = gated_delta_meta.cu_seqlens - update_kwargs['max_query_len'] = gated_delta_meta.max_query_len - validate_data = False - elif getattr(gated_delta_meta, 'is_multi_token_decoding', False): - update_kwargs.update( - num_accepted_tokens=gated_delta_meta.num_accepted_tokens, - query_start_loc=gated_delta_meta.cu_seqlens, - max_query_len=gated_delta_meta.max_query_len, - ) - validate_data = False + validate_data = False + + cache_seqlens = gated_delta_meta.cache_seqlens + is_multi_token_decoding = gated_delta_meta.is_multi_token_decoding + + # Ring-buffer decode path: positions are derived from cache_seqlens. + update_kwargs['cache_seqlens'] = gated_delta_meta.cache_seqlens + + if is_multi_token_decoding: + # Multi-token decode uses varlen format (2-D x tensor); must keep + # IS_VARLEN=True by passing query_start_loc, otherwise x gets incorrectly + # unsqueezed and cache_seqlens is accessed out-of-bounds. + update_kwargs['query_start_loc'] = gated_delta_meta.cu_seqlens + update_kwargs['max_query_len'] = gated_delta_meta.max_q_seq_len + out = self.causal_conv1d_update( x, conv_state, @@ -710,7 +658,7 @@ def __call__( weight_reshaped = weight.squeeze(1) x = x.squeeze(0) - if gated_delta_meta.is_decoding: + if gated_delta_meta.is_decoding or gated_delta_meta.is_multi_token_decoding: conv_state_indices = gated_delta_meta.conv_state_indices return self.conv1d_update( x, weight_reshaped, bias, conv_state, conv_state_indices, gated_delta_meta @@ -763,15 +711,30 @@ def __call__( """call.""" is_decoding = gated_delta_meta.is_decoding - is_multi_token_decode = getattr(gated_delta_meta, 'is_multi_token_decoding', False) + is_multi_token_decoding = gated_delta_meta.is_multi_token_decoding + beta = b.sigmoid() # If the model is loaded in fp16, without the .float() here, A might be -inf g = (-A_log.float().exp()) * F.softplus(a.float() + dt_bias) - + if is_decoding: - indices = gated_delta_meta.state_ids - cu_seqlens = gated_delta_meta.cu_seqlens - if is_multi_token_decode: + core_attn_out = self.fused_sigmoid_gating_delta_rule_update( + A_log=A_log, + dt_bias=dt_bias, + q=query, + k=key, + v=value.contiguous(), + a=a.contiguous(), + b=b.contiguous(), + initial_state_source=recurrent_state, + initial_state_indices=gated_delta_meta.state_ids, + cu_seqlens=gated_delta_meta.cu_seqlens, + use_qk_l2norm_in_kernel=True, + softplus_beta=1.0, + softplus_threshold=20.0, + ) + return core_attn_out, None + elif is_multi_token_decoding: state_slots = recurrent_state.size(1) flat_recurrent_state = recurrent_state.view(-1, *recurrent_state.shape[2:]) core_attn_out, _ = self.fused_recurrent_gated_delta_rule( @@ -782,44 +745,13 @@ def __call__( beta=beta.contiguous(), initial_state=flat_recurrent_state, inplace_final_state=True, - cu_seqlens=cu_seqlens, + cu_seqlens=gated_delta_meta.cu_seqlens, cache_seqlens_rb=gated_delta_meta.cache_seqlens, - state_ids_rb=indices, + state_ids_rb=gated_delta_meta.state_ids, num_state=state_slots, use_qk_l2norm_in_kernel=self.use_qk_l2norm_in_kernel, ) return core_attn_out, None - - # Single-token decode: use the optimized update kernel - initial_state_source = recurrent_state - initial_state_indices = indices - if recurrent_state.dim() == 5: - state_slots = recurrent_state.size(1) - flat_recurrent_state = recurrent_state.view(-1, *recurrent_state.shape[2:]) - slot_offsets = torch.remainder( - gated_delta_meta.cache_seqlens.to(torch.int64), - state_slots, - ) - initial_state_source = flat_recurrent_state - initial_state_indices = ( - indices.to(torch.int64) * state_slots + slot_offsets - ).contiguous() - core_attn_out = self.fused_sigmoid_gating_delta_rule_update( - A_log=A_log, - dt_bias=dt_bias, - q=query, - k=key, - v=value.contiguous(), - a=a.contiguous(), - b=b.contiguous(), - initial_state_source=initial_state_source, - initial_state_indices=initial_state_indices, - cu_seqlens=cu_seqlens, - use_qk_l2norm_in_kernel=True, - softplus_beta=1.0, - softplus_threshold=20.0, - ) - last_recurrent_state = None else: if gated_delta_meta.spec_state_offsets is not None: state_ids = gated_delta_meta.state_ids @@ -852,8 +784,7 @@ def __call__( recurrent_state[gated_delta_meta.state_ids] = last_recurrent_state.to( recurrent_state.dtype ) - - return core_attn_out, last_recurrent_state + return core_attn_out, last_recurrent_state gated_delta.GatedDeltaMeta = AscendGatedDeltaMeta gated_delta.CausalConv1dFunc = AscendCausalConv1dFunc From 9946730c5f1472c5ca0cb0071f9474df176ef86e Mon Sep 17 00:00:00 2001 From: WangQing <2917021186@qq.com> Date: Mon, 18 May 2026 11:52:43 +0000 Subject: [PATCH 11/12] refactor gdn buffer --- .../cudagraph/ascend_cudagraph.py | 129 +++--------------- dlinfer/vendor/ascend/torch_npu_ops.py | 4 - 2 files changed, 22 insertions(+), 111 deletions(-) diff --git a/dlinfer/framework/lmdeploy_ext/cudagraph/ascend_cudagraph.py b/dlinfer/framework/lmdeploy_ext/cudagraph/ascend_cudagraph.py index 8738bf3d..6c5ba95e 100644 --- a/dlinfer/framework/lmdeploy_ext/cudagraph/ascend_cudagraph.py +++ b/dlinfer/framework/lmdeploy_ext/cudagraph/ascend_cudagraph.py @@ -83,11 +83,7 @@ def AscendCudaGraphMixin_make_buffers_cudagraph( input_buffers["q_seqlens"] = torch.ones(max_batches, dtype=torch.int32) - # actual_seq_lengths_kv is also tracked per sequence in the TND path. input_buffers["kv_seqlens"] = torch.ones(max_batches, dtype=torch.int32) - input_buffers["kv_seqlens_device"] = torch.ones( - max_batches, dtype=torch.int32, device=device - ) input_buffers["q_start_loc"] = torch.arange( max_batches + 1, dtype=torch.int32, device=device @@ -102,29 +98,16 @@ def AscendCudaGraphMixin_make_buffers_cudagraph( (max_tokens), dtype=torch.bool, device=device ) - # actual_seq_lengths_q for multi-token decode (CPU tensor, cumulative) - input_buffers["actual_seq_lengths_q"] = torch.zeros( - max_batches, dtype=torch.int32 - ) - - # attention mask buffer for multi-token decode (kept on-device so graph - # replay can see in-place updates via the same data pointer captured at - # graph-capture time). Shape matches the fixed 2048×2048 mask used by - # npu_fused_infer_attention_score. Initialised to all-False (no masking) - # so a stale buffer value is permissive rather than destructive. - _ATTN_MASK_WIDTH = 2048 - input_buffers["attention_mask_buf"] = torch.zeros( - _ATTN_MASK_WIDTH, _ATTN_MASK_WIDTH, dtype=torch.bool, device=device - ) + input_buffers["attention_mask"] = torch.triu(torch.ones(2048, 2048, dtype=torch.bool, device=device), diagonal=1) # ssm if graph_meta.is_ssm: input_buffers["state_ids"] = torch.full( (max_batches,), -1, dtype=torch.int64, device=device ) - input_buffers["num_accepted_tokens"] = torch.ones( + input_buffers["cache_seqlens"] = torch.zeros( max_batches, dtype=torch.int32, device=device - ) + ) # mrope if graph_meta.use_mrope: @@ -152,6 +135,7 @@ def AscendCudaGraphMixin_fill_buffers_cudagraph( moe_metadata = get_step_ctx_manager().current_context().moe_metadata x_active_mask: Tensor = moe_metadata.x_active_mask q_start_loc: Tensor = attn_metadata.q_start_loc + cache_seqlens: Tensor = attn_metadata.cache_seqlens input_buffers: BuffType = graph_meta.input_buffers @@ -171,91 +155,32 @@ def AscendCudaGraphMixin_fill_buffers_cudagraph( input_buffers["position_ids"][:, :num_tokens] = position_ids input_buffers["block_offsets"].zero_() input_buffers["block_offsets"][:num_seqs, :num_blocks] = block_offsets + input_buffers["q_seqlens"].fill_(0) + input_buffers["q_seqlens"][: num_seqs] = q_seqlens input_buffers["kv_seqlens"].fill_(0) input_buffers["kv_seqlens"][:num_seqs] = kv_seqlens - input_buffers["kv_seqlens_device"].fill_(0) - input_buffers["kv_seqlens_device"][:num_seqs].copy_( - kv_seqlens.to(device=input_buffers["kv_seqlens_device"].device) - ) input_buffers["kv_start_indices"].fill_(-1) input_buffers["kv_start_indices"][:kv_start_indices.size(0)] = kv_start_indices if x_active_mask is not None: input_buffers["x_active_mask"].fill_(0) input_buffers["x_active_mask"][:x_active_mask.size(0)] = x_active_mask - actual_seq_lengths_q = getattr(attn_metadata, 'actual_seq_lengths_q', None) - if actual_seq_lengths_q is not None: - # multi-token decode: fill actual_seq_lengths_q - bs = input_buffers["actual_seq_lengths_q"].size(0) - pad_query_len = 5 - padding_tensor = torch.arange(1, bs + 1) * pad_query_len - input_buffers["actual_seq_lengths_q"].copy_(padding_tensor) - input_buffers["actual_seq_lengths_q"][:actual_seq_lengths_q.size(0)] = actual_seq_lengths_q - - # actual_seq_lengths_q通过q_seqlens传入进去 - input_buffers["q_seqlens"].copy_(input_buffers["actual_seq_lengths_q"]) - else: - # single-token decode: fill q_seqlens - bs = input_buffers["q_seqlens"].size(0) - padding_tensor = torch.arange(1, bs + 1) - input_buffers["q_seqlens"].copy_(padding_tensor) - input_buffers["q_seqlens"][: q_seqlens.size(0)] = q_seqlens - - - if actual_seq_lengths_q is not None: - # multi-token decode - input_buffers["kv_seqlens"].fill_(5) - input_buffers["kv_seqlens"][:num_seqs] = kv_seqlens - input_buffers["kv_seqlens_device"].fill_(5) - input_buffers["kv_seqlens_device"][:num_seqs].copy_( - kv_seqlens.to(device=input_buffers["kv_seqlens_device"].device) - ) - else: - # single-token decode - input_buffers["kv_seqlens"].fill_(1) - input_buffers["kv_seqlens"][:num_seqs] = kv_seqlens - input_buffers["kv_seqlens_device"].fill_(1) - input_buffers["kv_seqlens_device"][:num_seqs].copy_( - kv_seqlens.to(device=input_buffers["kv_seqlens_device"].device) - ) - - if actual_seq_lengths_q is not None: - fresh_mask = getattr(attn_metadata, 'attention_mask', None) - attn_mask_buf = input_buffers.get("attention_mask_buf") - if fresh_mask and attn_mask_buf is not None: - attn_mask_buf.copy_(fresh_mask[0]) - attn_metadata.attention_mask = [attn_mask_buf] - - # ssm if graph_meta.is_ssm: - # main model verify - if actual_seq_lengths_q is not None and q_start_loc.numel() > 1: - # multi-token decode - bs = input_buffers["q_start_loc"].size(0) - pad_query_len = 5 - padding_tensor = torch.arange(0, bs) * pad_query_len - input_buffers["q_start_loc"].copy_(padding_tensor) - input_buffers["q_start_loc"][:q_start_loc.size(0)] = q_start_loc - else: - # single-token decode - input_buffers["q_start_loc"][q_start_loc.size(0):] = q_start_loc[-1] + bs = input_buffers["q_start_loc"].size(0) + max_q_seq_len = attn_metadata.max_q_seq_len + padding_tensor = torch.arange(0, bs) * max_q_seq_len + input_buffers["q_start_loc"].copy_(padding_tensor) + input_buffers["q_start_loc"][:q_start_loc.size(0)] = q_start_loc + state_ids = kwargs["state_ids"] input_buffers["state_ids"].fill_(0) input_buffers["state_ids"][: state_ids.size(0)].copy_(state_ids) - - num_accepted_tokens = getattr(attn_metadata, "num_accepted_tokens", None) - input_buffers["num_accepted_tokens"].fill_(1) - if num_accepted_tokens is not None: - input_buffers["num_accepted_tokens"][: num_accepted_tokens.size(0)].copy_( - num_accepted_tokens.to( - device=input_buffers["num_accepted_tokens"].device, - dtype=input_buffers["num_accepted_tokens"].dtype, - ) - ) - attn_metadata.num_accepted_tokens = input_buffers["num_accepted_tokens"] - # Keep linear-attention state math on the fixed graph buffer so its - # per-sequence cache lengths stay aligned with padded q_start_loc. - attn_metadata.kv_seqlens_device = input_buffers["kv_seqlens_device"] + + input_buffers["cache_seqlens"].fill_(0) + input_buffers["cache_seqlens"][: num_seqs].copy_(cache_seqlens) + + attn_metadata.cache_seqlens = input_buffers["cache_seqlens"] + attn_metadata.attention_mask = [input_buffers["attention_mask"]] if inputs_embeds is not None: emb_size = inputs_embeds.size(-1) @@ -286,7 +211,6 @@ def AscendCudaGraphMixin_fill_buffers_cudagraph( new_inputs.update(kwargs) - # ssm: override kwargs' variable-length state_ids with the fixed-size buffer if graph_meta.is_ssm: new_inputs["state_ids"] = input_buffers["state_ids"] @@ -463,21 +387,12 @@ def forward(self, **kwargs): context = self.ctx_mgr.current_context() self.model.update_context_cudagraph(self.meta, context) if aclgraph_use_torch_npu_update(): - update_dict = { - "actual_seq_lengths_kv": self.meta.input_buffers["kv_seqlens"], - } - # multi-token decode also needs actual_seq_lengths updated - actual_seq_lengths_q = self.meta.input_buffers.get("actual_seq_lengths_q") - if actual_seq_lengths_q is not None and actual_seq_lengths_q.any(): - update_dict["actual_seq_lengths"] = actual_seq_lengths_q - # Fix: update CPU inputs BEFORE replay so the current step uses the - # fresh seq-length values; the original order replayed first and - # only updated for the next replay, leaving the first iteration - # with stale captured values. + self._graph.replay() self._graph.update( - cpu_update_input=[update_dict] + cpu_update_input=[ + {"actual_seq_lengths_kv": self.meta.input_buffers["kv_seqlens"]} + ] ) - self._graph.replay() else: update_attn_params(self.update_stream, self.meta, self.max_batches) self._graph.replay() diff --git a/dlinfer/vendor/ascend/torch_npu_ops.py b/dlinfer/vendor/ascend/torch_npu_ops.py index 2b9c6814..62734102 100644 --- a/dlinfer/vendor/ascend/torch_npu_ops.py +++ b/dlinfer/vendor/ascend/torch_npu_ops.py @@ -456,10 +456,6 @@ def paged_prefill_attention( scale_value = softmax_scale if softmax_scale else 1.0 / math.sqrt(query.shape[-1]) query = query.contiguous() - if q_seq_len.dim() != 1 or kv_seq_len.dim() != 1: - raise ValueError("TND paged prefill expects 1D actual_seq_lengths tensors.") - if block_table.size(0) != q_seq_len.numel() or kv_seq_len.numel() != q_seq_len.numel(): - raise ValueError("TND paged prefill expects per-sequence block_table and kv_seq_len.") block_num = key_cache.size(0) key_cache = key_cache.view(block_num, block_size, -1) value_cache = value_cache.view(block_num, block_size, -1) From 1fb586bb3690c3d8d27f02feb136053a5b2f7505 Mon Sep 17 00:00:00 2001 From: WangQing <2917021186@qq.com> Date: Tue, 19 May 2026 03:33:54 +0000 Subject: [PATCH 12/12] fix ascend graph --- .../cudagraph/ascend_cudagraph.py | 68 +++++++++---------- 1 file changed, 31 insertions(+), 37 deletions(-) diff --git a/dlinfer/framework/lmdeploy_ext/cudagraph/ascend_cudagraph.py b/dlinfer/framework/lmdeploy_ext/cudagraph/ascend_cudagraph.py index 6c5ba95e..69443b0a 100644 --- a/dlinfer/framework/lmdeploy_ext/cudagraph/ascend_cudagraph.py +++ b/dlinfer/framework/lmdeploy_ext/cudagraph/ascend_cudagraph.py @@ -74,9 +74,6 @@ def AscendCudaGraphMixin_make_buffers_cudagraph( (1, max_tokens), dtype=torch.int32, device=device ) - # TND paged attention consumes block tables per sequence. Keep the graph - # buffer batch-shaped even when one decode step contains multiple query - # tokens per sequence. input_buffers["block_offsets"] = torch.zeros( (max_batches, num_blocks), dtype=torch.int32, device=device ) @@ -93,7 +90,6 @@ def AscendCudaGraphMixin_make_buffers_cudagraph( (max_tokens), dtype=torch.int32, device=device ) - # MoE routing still reasons in token space for multi-token verify. input_buffers["x_active_mask"] = torch.zeros( (max_tokens), dtype=torch.bool, device=device ) @@ -139,7 +135,7 @@ def AscendCudaGraphMixin_fill_buffers_cudagraph( input_buffers: BuffType = graph_meta.input_buffers - num_seqs, num_blocks = block_offsets.size() + batch_size, num_blocks = block_offsets.size() num_tokens = input_ids.size(-1) q_seqlens: Tensor = attn_metadata.q_seqlens @@ -154,11 +150,11 @@ def AscendCudaGraphMixin_fill_buffers_cudagraph( input_buffers["position_ids"].zero_() input_buffers["position_ids"][:, :num_tokens] = position_ids input_buffers["block_offsets"].zero_() - input_buffers["block_offsets"][:num_seqs, :num_blocks] = block_offsets + input_buffers["block_offsets"][:batch_size, :num_blocks] = block_offsets input_buffers["q_seqlens"].fill_(0) - input_buffers["q_seqlens"][: num_seqs] = q_seqlens + input_buffers["q_seqlens"][: batch_size] = q_seqlens input_buffers["kv_seqlens"].fill_(0) - input_buffers["kv_seqlens"][:num_seqs] = kv_seqlens + input_buffers["kv_seqlens"][:batch_size] = kv_seqlens input_buffers["kv_start_indices"].fill_(-1) input_buffers["kv_start_indices"][:kv_start_indices.size(0)] = kv_start_indices if x_active_mask is not None: @@ -174,10 +170,10 @@ def AscendCudaGraphMixin_fill_buffers_cudagraph( state_ids = kwargs["state_ids"] input_buffers["state_ids"].fill_(0) - input_buffers["state_ids"][: state_ids.size(0)].copy_(state_ids) + input_buffers["state_ids"][: batch_size].copy_(state_ids) input_buffers["cache_seqlens"].fill_(0) - input_buffers["cache_seqlens"][: num_seqs].copy_(cache_seqlens) + input_buffers["cache_seqlens"][: batch_size].copy_(cache_seqlens) attn_metadata.cache_seqlens = input_buffers["cache_seqlens"] attn_metadata.attention_mask = [input_buffers["attention_mask"]] @@ -467,32 +463,22 @@ def get_graph_key( **kwargs, ): """Get graph key.""" - context = self.ctx_mgr.current_context() - is_decoding = context.is_decoding + is_decoding = attn_metadata.is_decoding + is_multi_token_decoding = attn_metadata.is_multi_token_decoding meta = self.get_meta() enable_microbatch = get_step_ctx_manager().current_context().enable_microbatch - if is_decoding: - batch_size = None - q_seqlens = None - if attn_metadata is not None: - q_seqlens = getattr(attn_metadata, "q_seqlens", None) - if q_seqlens is None: - q_seqlens = getattr(context, "q_seqlens", None) - if q_seqlens is not None: - batch_size = q_seqlens.size(0) - elif kwargs.get("state_ids", None) is not None: - batch_size = kwargs["state_ids"].size(0) - - if batch_size is not None and batch_size > 0 and input_ids.size(-1) % batch_size == 0: - query_len = input_ids.size(-1) // batch_size - if meta.padding_batch_size is None: - new_batch_size = self._get_capture_tokens(batch_size) - else: - padding_num_tokens = meta.padding_batch_size - padding_batch_size = (padding_num_tokens + query_len - 1) // query_len - new_batch_size = self._get_capture_tokens(padding_batch_size) - return (new_batch_size, is_decoding, enable_microbatch, query_len) + if is_multi_token_decoding: + q_seqlens = attn_metadata.q_seqlens + max_q_seq_len = attn_metadata.max_q_seq_len + batch_size = q_seqlens.size(0) + if meta.padding_batch_size is None: + new_batch_size = self._get_capture_tokens(batch_size) + else: + padding_num_tokens = meta.padding_batch_size + padding_batch_size = (padding_num_tokens + max_q_seq_len - 1) // max_q_seq_len + new_batch_size = self._get_capture_tokens(padding_batch_size) + return (new_batch_size, is_multi_token_decoding, enable_microbatch, max_q_seq_len) num_tokens = input_ids.numel() if meta.padding_batch_size is None: @@ -512,10 +498,10 @@ def __call__(self, **kwargs): graph_key = self.get_graph_key(**kwargs) max_batches = graph_key[0] - is_decoding = graph_key[1] - decode_query_len = graph_key[3] - if is_decoding: - max_tokens = max_batches * decode_query_len + is_decoding_or_multi_token_decoding = graph_key[1] + max_q_seq_len = graph_key[3] + if is_decoding_or_multi_token_decoding: + max_tokens = max_batches * max_q_seq_len else: max_tokens = max_batches max_batches = self.max_batches @@ -599,10 +585,12 @@ class GraphParams: _graph_params: Optional[GraphParams] = None +_graph_capture_sizes: set[int] = None def set_graph_params(aclgraph_capture_sizes: set[int]): global _graph_params + global _graph_capture_sizes if _graph_params is not None: raise ValueError("Graph parameters have already been set!") _graph_params = GraphParams( @@ -612,6 +600,7 @@ def set_graph_params(aclgraph_capture_sizes: set[int]): attn_params={size: [] for size in aclgraph_capture_sizes}, is_mla=False, ) + _graph_capture_sizes = aclgraph_capture_sizes def get_graph_params(): @@ -621,6 +610,7 @@ def get_graph_params(): def clear_graph_params(): """Clear global graph params and release references to KV cache tensors.""" global _graph_params + global _graph_capture_sizes if _graph_params is None: return @@ -636,6 +626,10 @@ def clear_graph_params(): _graph_params.workspaces.clear() finally: _graph_params = None + _graph_capture_sizes = None + # 清除 lru_cache,使下次推理时 _get_capture_batch_size_impl + # 重新执行并调用 set_graph_params 干净重建 + _get_capture_batch_size_impl.cache_clear() def update_attn_params(update_stream, forward_meta, runtime_size):