diff --git a/dlinfer/framework/lmdeploy_ext/cudagraph/ascend_cudagraph.py b/dlinfer/framework/lmdeploy_ext/cudagraph/ascend_cudagraph.py index 314e4b0c..69443b0a 100644 --- a/dlinfer/framework/lmdeploy_ext/cudagraph/ascend_cudagraph.py +++ b/dlinfer/framework/lmdeploy_ext/cudagraph/ascend_cudagraph.py @@ -37,6 +37,26 @@ def aclgraph_use_torch_npu_update(): # AscendCudaGraphMixin methods for cudagraph buffer management. +def AscendCudaGraphMixin_support_cuda_graph( + self, + input_ids: Tensor, + position_ids: Tensor, + past_key_values: List[List[Tensor]], + attn_metadata: Any = None, + inputs_embeds: Tensor = None, + **kwargs, +): + """Allow multi-token decode graph only when runtime length updates exist.""" + if attn_metadata is None: + return False + + is_decoding = getattr(attn_metadata, "is_decoding", False) + is_multi_token = getattr(attn_metadata, "is_multi_token_decoding", False) + if is_multi_token and not aclgraph_use_torch_npu_update(): + return False + return is_decoding or is_multi_token + + def AscendCudaGraphMixin_make_buffers_cudagraph( self, graph_meta: CudaGraphMeta, *args, **kwargs ) -> BuffType: @@ -58,9 +78,7 @@ def AscendCudaGraphMixin_make_buffers_cudagraph( (max_batches, num_blocks), dtype=torch.int32, device=device ) - input_buffers["q_seqlens"] = torch.ones( - max_batches, dtype=torch.int32, device=device - ) + input_buffers["q_seqlens"] = torch.ones(max_batches, dtype=torch.int32) input_buffers["kv_seqlens"] = torch.ones(max_batches, dtype=torch.int32) @@ -69,18 +87,23 @@ def AscendCudaGraphMixin_make_buffers_cudagraph( ) input_buffers["kv_start_indices"] = -torch.ones( - (max_batches), dtype=torch.int32, device=device + (max_tokens), dtype=torch.int32, device=device ) input_buffers["x_active_mask"] = torch.zeros( - (max_batches), dtype=torch.bool, device=device + (max_tokens), dtype=torch.bool, device=device ) + input_buffers["attention_mask"] = torch.triu(torch.ones(2048, 2048, dtype=torch.bool, device=device), diagonal=1) + # ssm if graph_meta.is_ssm: input_buffers["state_ids"] = torch.full( (max_batches,), -1, dtype=torch.int64, device=device ) + input_buffers["cache_seqlens"] = torch.zeros( + max_batches, dtype=torch.int32, device=device + ) # mrope if graph_meta.use_mrope: @@ -108,11 +131,13 @@ def AscendCudaGraphMixin_fill_buffers_cudagraph( moe_metadata = get_step_ctx_manager().current_context().moe_metadata x_active_mask: Tensor = moe_metadata.x_active_mask q_start_loc: Tensor = attn_metadata.q_start_loc + cache_seqlens: Tensor = attn_metadata.cache_seqlens input_buffers: BuffType = graph_meta.input_buffers batch_size, num_blocks = block_offsets.size() num_tokens = input_ids.size(-1) + q_seqlens: Tensor = attn_metadata.q_seqlens # fill buffer max_num_tokens = input_buffers["input_ids"].size(-1) @@ -126,22 +151,32 @@ def AscendCudaGraphMixin_fill_buffers_cudagraph( input_buffers["position_ids"][:, :num_tokens] = position_ids input_buffers["block_offsets"].zero_() input_buffers["block_offsets"][:batch_size, :num_blocks] = block_offsets + input_buffers["q_seqlens"].fill_(0) + input_buffers["q_seqlens"][: batch_size] = q_seqlens input_buffers["kv_seqlens"].fill_(0) input_buffers["kv_seqlens"][:batch_size] = kv_seqlens input_buffers["kv_start_indices"].fill_(-1) - input_buffers["kv_start_indices"][:batch_size] = kv_start_indices + input_buffers["kv_start_indices"][:kv_start_indices.size(0)] = kv_start_indices if x_active_mask is not None: input_buffers["x_active_mask"].fill_(0) - input_buffers["x_active_mask"][:batch_size] = x_active_mask + input_buffers["x_active_mask"][:x_active_mask.size(0)] = x_active_mask - # ssm if graph_meta.is_ssm: - input_buffers["q_start_loc"][: batch_size + 1] = q_start_loc - input_buffers["q_start_loc"][batch_size + 1 :] = q_start_loc[-1] + bs = input_buffers["q_start_loc"].size(0) + max_q_seq_len = attn_metadata.max_q_seq_len + padding_tensor = torch.arange(0, bs) * max_q_seq_len + input_buffers["q_start_loc"].copy_(padding_tensor) + input_buffers["q_start_loc"][:q_start_loc.size(0)] = q_start_loc state_ids = kwargs["state_ids"] - input_buffers["state_ids"].fill_(-1) - input_buffers["state_ids"][: state_ids.size(0)].copy_(state_ids) + input_buffers["state_ids"].fill_(0) + input_buffers["state_ids"][: batch_size].copy_(state_ids) + + input_buffers["cache_seqlens"].fill_(0) + input_buffers["cache_seqlens"][: batch_size].copy_(cache_seqlens) + + attn_metadata.cache_seqlens = input_buffers["cache_seqlens"] + attn_metadata.attention_mask = [input_buffers["attention_mask"]] if inputs_embeds is not None: emb_size = inputs_embeds.size(-1) @@ -151,10 +186,7 @@ def AscendCudaGraphMixin_fill_buffers_cudagraph( 1, max_num_tokens, emb_size ) input_buffers["inputs_embeds"][:, :num_tokens] = inputs_embeds - # create inputs - # Use compatible size but cap at graph's max_batchs to avoid buffer overflow - new_batch_size = min(get_ascend_compatible_size(batch_size), graph_meta.max_batchs) - + attn_metadata.q_seqlens = input_buffers["q_seqlens"] attn_metadata.block_offsets = input_buffers["block_offsets"] attn_metadata.kv_seqlens = input_buffers["kv_seqlens"] attn_metadata.kv_start_indices = input_buffers["kv_start_indices"] @@ -175,7 +207,6 @@ def AscendCudaGraphMixin_fill_buffers_cudagraph( new_inputs.update(kwargs) - # ssm: override kwargs' variable-length state_ids with the fixed-size buffer if graph_meta.is_ssm: new_inputs["state_ids"] = input_buffers["state_ids"] @@ -209,6 +240,7 @@ def AscendCudaGraphMixin_update_context_cudagraph(self, graph_meta, context): context.mrope_position_ids = input_buffers["mrope_position_ids"] +CudaGraphMixin.support_cuda_graph = AscendCudaGraphMixin_support_cuda_graph CudaGraphMixin.make_buffers_cudagraph = AscendCudaGraphMixin_make_buffers_cudagraph CudaGraphMixin.fill_buffers_cudagraph = AscendCudaGraphMixin_fill_buffers_cudagraph CudaGraphMixin.update_context_cudagraph = AscendCudaGraphMixin_update_context_cudagraph @@ -358,7 +390,7 @@ def forward(self, **kwargs): ] ) else: - update_attn_params(self.update_stream, self.meta, self.max_tokens) + update_attn_params(self.update_stream, self.meta, self.max_batches) self._graph.replay() output_buffers = self.meta.output_buffers output = self.model.get_outputs_cudagraph(output_buffers, **kwargs) @@ -427,19 +459,33 @@ def _get_capture_tokens(self, batch_size: int): def get_graph_key( self, input_ids: torch.Tensor, + attn_metadata: Any, **kwargs, ): """Get graph key.""" - context = self.ctx_mgr.current_context() - is_decoding = context.is_decoding - num_tokens = input_ids.numel() + is_decoding = attn_metadata.is_decoding + is_multi_token_decoding = attn_metadata.is_multi_token_decoding meta = self.get_meta() enable_microbatch = get_step_ctx_manager().current_context().enable_microbatch + + if is_multi_token_decoding: + q_seqlens = attn_metadata.q_seqlens + max_q_seq_len = attn_metadata.max_q_seq_len + batch_size = q_seqlens.size(0) + if meta.padding_batch_size is None: + new_batch_size = self._get_capture_tokens(batch_size) + else: + padding_num_tokens = meta.padding_batch_size + padding_batch_size = (padding_num_tokens + max_q_seq_len - 1) // max_q_seq_len + new_batch_size = self._get_capture_tokens(padding_batch_size) + return (new_batch_size, is_multi_token_decoding, enable_microbatch, max_q_seq_len) + + num_tokens = input_ids.numel() if meta.padding_batch_size is None: new_num_tokens = self._get_capture_tokens(num_tokens) else: new_num_tokens = self._get_capture_tokens(meta.padding_batch_size) - return (new_num_tokens, is_decoding, enable_microbatch) + return (new_num_tokens, is_decoding, enable_microbatch, 1) def __call__(self, **kwargs): """call.""" @@ -451,10 +497,15 @@ def __call__(self, **kwargs): return self.model.make_output_buffers(ret) graph_key = self.get_graph_key(**kwargs) - max_tokens = graph_key[0] - is_decoding = graph_key[1] + max_batches = graph_key[0] + is_decoding_or_multi_token_decoding = graph_key[1] + max_q_seq_len = graph_key[3] + if is_decoding_or_multi_token_decoding: + max_tokens = max_batches * max_q_seq_len + else: + max_tokens = max_batches + max_batches = self.max_batches if graph_key not in self._runner_map: - max_batches = max_tokens if is_decoding else self.max_batches runner = AscendSingleGraphRunner( self.model, max_batches=max_batches, diff --git a/dlinfer/framework/lmdeploy_ext/device/__init__.py b/dlinfer/framework/lmdeploy_ext/device/__init__.py index 75a45b36..c746ada6 100644 --- a/dlinfer/framework/lmdeploy_ext/device/__init__.py +++ b/dlinfer/framework/lmdeploy_ext/device/__init__.py @@ -100,6 +100,289 @@ async def async_sampling_logits( BaseModelAgent.async_sampling_logits = async_sampling_logits +def patch_rejection_sampler(): + from lmdeploy.pytorch.spec_decode import reject_sampler as _reject_sampler_mod + _orig_rejection_sample = _reject_sampler_mod.rejection_sample + + def _patched_rejection_sample( + target_logits, + draft_token_ids, + bonus_token_ids, + sampling_inputs, + draft_probs=None, + ): + if sampling_inputs.max_top_k == 1: + return _orig_rejection_sample( + target_logits, + draft_token_ids, + bonus_token_ids, + sampling_inputs, + draft_probs=draft_probs, + ) + + assert draft_probs is None or draft_probs.is_contiguous() + if not draft_token_ids.is_contiguous(): + draft_token_ids = draft_token_ids.contiguous() + + if not target_logits.is_contiguous(): + target_logits = target_logits.contiguous() + + batch_size, num_spec_tokens = draft_token_ids.shape + device = target_logits.device + + output_token_ids = torch.full( + (batch_size, num_spec_tokens + 1), + _reject_sampler_mod.PLACEHOLDER_TOKEN_ID, + dtype=torch.long, + device=device, + ) + + target_probs = target_logits.softmax(dim=-1, dtype=torch.float32) + if sampling_inputs.top_k is not None: + is_greedy = (sampling_inputs.top_k == 1) + if not torch.is_tensor(is_greedy): + is_greedy = torch.full( + (batch_size,), bool(is_greedy), dtype=torch.bool, device=device + ) + else: + is_greedy = is_greedy.to(device=device, dtype=torch.bool) + else: + is_greedy = torch.zeros(batch_size, dtype=torch.bool, device=device) + + target_argmax = target_probs.argmax(dim=-1) + uniform_probs = torch.rand( + (batch_size, num_spec_tokens), dtype=torch.float64, device=device + ) + inv_q = torch.empty( + (batch_size, target_probs.shape[-1]), dtype=torch.float32, device=device + ) + inv_q.exponential_() + inv_q = inv_q.reciprocal() + + recovered_token_ids = torch.empty( + (batch_size, num_spec_tokens), dtype=torch.long, device=device + ) + zero = target_probs.new_tensor(0.0) + for batch_idx in range(batch_size): + if bool(is_greedy[batch_idx].item()): + continue + batch_inv_q = inv_q[batch_idx] + for pos in range(num_spec_tokens): + draft_token_id = draft_token_ids[batch_idx, pos] + if draft_probs is None: + prob = target_probs[batch_idx, pos].clone() + prob[draft_token_id] = 0.0 + else: + prob = torch.maximum( + target_probs[batch_idx, pos] - draft_probs[batch_idx, pos], + zero, + ) + recovered_token_ids[batch_idx, pos] = torch.argmax(prob * batch_inv_q) + + for batch_idx in range(batch_size): + rejected = False + if bool(is_greedy[batch_idx].item()): + for pos in range(num_spec_tokens): + token_id = target_argmax[batch_idx, pos] + output_token_ids[batch_idx, pos] = token_id + if draft_token_ids[batch_idx, pos] != token_id: + rejected = True + break + else: + for pos in range(num_spec_tokens): + draft_token_id = draft_token_ids[batch_idx, pos] + if draft_probs is None: + draft_prob = 1.0 + else: + draft_prob = float( + draft_probs[batch_idx, pos, draft_token_id].item() + ) + target_prob = float( + target_probs[batch_idx, pos, draft_token_id].item() + ) + uniform_prob = float(uniform_probs[batch_idx, pos].item()) + if draft_prob > 0 and target_prob / draft_prob >= uniform_prob: + token_id = draft_token_id + else: + token_id = recovered_token_ids[batch_idx, pos] + rejected = True + output_token_ids[batch_idx, pos] = token_id + if rejected: + break + + if not rejected: + output_token_ids[batch_idx, num_spec_tokens] = bonus_token_ids[batch_idx] + + return _reject_sampler_mod._extract_outputs(output_token_ids, num_spec_tokens) + + _reject_sampler_mod.rejection_sample = _patched_rejection_sample + + +def patch_spec_decode_runtime(): + from torch.profiler import record_function + from lmdeploy.pytorch.engine.model_agent import BaseModelAgent + from lmdeploy.pytorch.spec_decode.spec_agent import SpecModelAgent + + _orig_build_cache_engine = BaseModelAgent.build_cache_engine + _orig_forward_impl = BaseModelAgent._forward_impl + + def _is_ascend_agent(agent): + return getattr(getattr(agent, "backend_config", None), "device_type", None) == "ascend" + + def _set_main_runtime(self, model, cache_engine, state_cache_engine, stream): + self.main_model = model + self.main_cache_engine = cache_engine + self.main_state_cache_engine = state_cache_engine + self.main_stream = stream + + def _maybe_snapshot_main_states(self, model_inputs): + self._main_state_snapshot = None + self._main_state_ids = None + self._main_replay_inputs = None + state_cache_engine = getattr(self, "main_state_cache_engine", None) + main_model = getattr(self, "main_model", None) + if state_cache_engine is None or main_model is None: + return + if not model_inputs.is_decoding or model_inputs.max_q_seqlen <= 1: + return + mem_pool = getattr(state_cache_engine, "mem_pool", None) + if mem_pool is None or mem_pool.numel() == 0: + return + state_offsets = getattr(model_inputs, "state_offsets", None) + if state_offsets is None: + return + active_state_ids = state_offsets[state_offsets >= 0] + if active_state_ids.numel() == 0: + return + active_state_ids = active_state_ids.to(device=mem_pool.device, dtype=torch.long) + # Only snapshot rows touched by the current batch. + self._main_state_ids = active_state_ids + self._main_state_snapshot = mem_pool.index_select(0, active_state_ids).clone() + self._main_replay_inputs = model_inputs.clone() + + def _build_replay_inputs(self, model_inputs, output_token_ids): + valid_mask = output_token_ids.ge(0) + seq_length = valid_mask.sum(dim=-1).to(model_inputs.seq_length.dtype) + if torch.any(seq_length <= 0): + return None + + replay_ids = [row[mask] for row, mask in zip(output_token_ids, valid_mask)] + input_ids = torch.cat(replay_ids, dim=0).unsqueeze(0) + + mrope_pos_ids = model_inputs.mrope_pos_ids + if mrope_pos_ids is not None: + mrope_chunks = [] + reshaped = mrope_pos_ids.unflatten(1, (-1, model_inputs.max_q_seqlen)) + for batch_idx, replay_len in enumerate(seq_length.tolist()): + mrope_chunks.append(reshaped[:, batch_idx, :replay_len]) + mrope_pos_ids = torch.cat(mrope_chunks, dim=1) + + max_q_seqlen = int(seq_length.max().item()) + max_kv_seqlen = int((model_inputs.history_lengths + seq_length).max().item()) + sum_kv_seqlen = int((model_inputs.history_lengths + seq_length).sum().item()) + return model_inputs.clone( + input_ids=input_ids, + seq_length=seq_length, + max_q_seqlen=max_q_seqlen, + max_kv_seqlen=max_kv_seqlen, + sum_kv_seqlen=sum_kv_seqlen, + target_hidden_states=None, + target_position_ids=None, + target_inputs_embeds=None, + mrope_pos_ids=mrope_pos_ids, + ) + + def _maybe_replay_main_states(self, extra_inputs): + state_snapshot = getattr(self, "_main_state_snapshot", None) + state_ids = getattr(self, "_main_state_ids", None) + replay_template = getattr(self, "_main_replay_inputs", None) + if state_snapshot is None or state_ids is None or replay_template is None: + return + try: + if extra_inputs.num_rejected_tokens is None or not torch.any(extra_inputs.num_rejected_tokens > 0): + return + replay_inputs = self._build_replay_inputs(replay_template, extra_inputs.output_token_ids) + if replay_inputs is None: + return + + main_model = getattr(self, "main_model", None) + main_cache_engine = getattr(self, "main_cache_engine", None) + state_cache_engine = getattr(self, "main_state_cache_engine", None) + if main_model is None or main_cache_engine is None or state_cache_engine is None: + return + + state_cache_engine.mem_pool.index_copy_(0, state_ids, state_snapshot) + from lmdeploy.pytorch.engine.model_agent.agent import model_forward as _main_model_forward + + _main_model_forward( + main_model, + replay_inputs, + main_cache_engine, + state_cache_engine, + stream=getattr(self, "main_stream", None), + ) + finally: + self._main_state_snapshot = None + self._main_state_ids = None + self._main_replay_inputs = None + + def _patched_build_cache_engine(self): + if _is_ascend_agent(self): + state_shapes = getattr(self.model_config, "states_shapes", []) + self.cache_config.states_shapes = state_shapes + if self.cache_config.num_state_caches is None and len(state_shapes) > 0: + self.cache_config.num_state_caches = int(self.cache_config.max_batches + 1) + + _orig_build_cache_engine(self) + + if ( + _is_ascend_agent(self) + and self.spec_agent is not None + and self.spec_agent.is_enabled() + ): + self.spec_agent.set_main_runtime( + self.patched_model, + self.cache_engine, + self.state_cache_engine, + self.stream, + ) + + def _patched_forward_impl(self, inputs): + if ( + _is_ascend_agent(self) + and self.spec_agent is not None + and self.spec_agent.is_enabled() + ): + self.spec_agent.maybe_snapshot_main_states(inputs) + return _orig_forward_impl(self, inputs) + + async def _patched_async_model_forward( + self, + model_inputs, + extra_inputs, + sampling_inputs, + ): + with record_function("spec_rejection_sampling"): + draft_extra_inputs = await self._rejection_sampling( + model_inputs, extra_inputs, sampling_inputs + ) + # self._maybe_replay_main_states(draft_extra_inputs) + draft_model_inputs, draft_extra_inputs = self._prepare_inputs_from_main( + model_inputs, draft_extra_inputs + ) + return await self._async_model_forward( + draft_model_inputs, draft_extra_inputs, sampling_inputs + ) + + BaseModelAgent.build_cache_engine = _patched_build_cache_engine + BaseModelAgent._forward_impl = _patched_forward_impl + SpecModelAgent.set_main_runtime = _set_main_runtime + SpecModelAgent.maybe_snapshot_main_states = _maybe_snapshot_main_states + SpecModelAgent._build_replay_inputs = _build_replay_inputs + SpecModelAgent._maybe_replay_main_states = _maybe_replay_main_states + SpecModelAgent.async_model_forward = _patched_async_model_forward + + ##### patch cache engine ##### def patch_contiguous_cache_engine(): from lmdeploy.pytorch.config import CacheConfig, ModelConfig @@ -241,6 +524,7 @@ def patch_gated_delta_net(): from lmdeploy.pytorch.nn import gated_delta from lmdeploy.pytorch.nn.gated_delta import GatedDeltaMeta + from lmdeploy.pytorch.model_inputs import get_step_ctx_manager from dlinfer.vendor.ascend.triton_ops import RMSNormGated from dlinfer.vendor.ascend.triton_ops import ( @@ -250,6 +534,7 @@ def patch_gated_delta_net(): from dlinfer.vendor.ascend.triton_ops import ( chunk_gated_delta_rule, fused_sigmoid_gating_delta_rule_update, + fused_recurrent_gated_delta_rule, ) class AscendGatedDeltaMeta: @@ -263,11 +548,17 @@ def __init__( ): self.is_decoding = attn_metadata.is_decoding self.cu_seqlens = attn_metadata.q_start_loc - - # state_ids, fill invalid state with 0 + self.is_multi_token_decoding = attn_metadata.is_multi_token_decoding + self.max_q_seq_len = attn_metadata.max_q_seq_len + + self.num_spec_tokens = get_step_ctx_manager().build_ctx.num_spec_tokens + self.cache_seqlens = getattr(attn_metadata, 'cache_seqlens', None) + self.spec_state_offsets = getattr(attn_metadata, 'spec_state_offsets', None) + self.spec_conv_offsets = getattr(attn_metadata, 'spec_conv_offsets', None) + self.state_ids = state_ids.clamp(0) self.has_initial_state = attn_metadata.has_initial_state - self.conv_state_indices = self.state_ids + self.conv_state_indices = self.state_ids.to(torch.int32) def build_rmsnorm_gated(hidden_size: int, eps=1e-6, **kwargs): device = kwargs["device"] @@ -294,6 +585,12 @@ def conv1d_func( out: (b, seqlen, dim) conv_state: (b, dim, kernel_size) """ + spec_conv_offsets = getattr(gated_delta_meta, "spec_conv_offsets", None) + if spec_conv_offsets is not None: + read_conv_offsets, write_conv_offsets = spec_conv_offsets + else: + read_conv_offsets, write_conv_offsets = None, None + out = self.causal_conv1d_fn( x.t(), weight, @@ -303,6 +600,8 @@ def conv1d_func( has_initial_state=gated_delta_meta.has_initial_state, cache_indices=gated_delta_meta.conv_state_indices, query_start_loc=gated_delta_meta.cu_seqlens, + read_conv_offsets=read_conv_offsets, + write_conv_offsets=write_conv_offsets, ) out = out.t().unsqueeze(0) @@ -317,7 +616,24 @@ def conv1d_update( bias: torch.Tensor, conv_state: torch.Tensor, conv_state_indices: torch.Tensor, + gated_delta_meta: GatedDeltaMeta, ): + update_kwargs = {} + validate_data = False + + cache_seqlens = gated_delta_meta.cache_seqlens + is_multi_token_decoding = gated_delta_meta.is_multi_token_decoding + + # Ring-buffer decode path: positions are derived from cache_seqlens. + update_kwargs['cache_seqlens'] = gated_delta_meta.cache_seqlens + + if is_multi_token_decoding: + # Multi-token decode uses varlen format (2-D x tensor); must keep + # IS_VARLEN=True by passing query_start_loc, otherwise x gets incorrectly + # unsqueezed and cache_seqlens is accessed out-of-bounds. + update_kwargs['query_start_loc'] = gated_delta_meta.cu_seqlens + update_kwargs['max_query_len'] = gated_delta_meta.max_q_seq_len + out = self.causal_conv1d_update( x, conv_state, @@ -325,7 +641,8 @@ def conv1d_update( bias, self.activation, conv_state_indices=conv_state_indices, - validate_data=True, + validate_data=validate_data, + **update_kwargs, ) return out.unsqueeze(0), conv_state @@ -341,10 +658,10 @@ def __call__( weight_reshaped = weight.squeeze(1) x = x.squeeze(0) - if gated_delta_meta.is_decoding: + if gated_delta_meta.is_decoding or gated_delta_meta.is_multi_token_decoding: conv_state_indices = gated_delta_meta.conv_state_indices return self.conv1d_update( - x, weight_reshaped, bias, conv_state, conv_state_indices + x, weight_reshaped, bias, conv_state, conv_state_indices, gated_delta_meta ) return self.conv1d_func( x, weight_reshaped, bias, conv_state, gated_delta_meta=gated_delta_meta @@ -353,12 +670,32 @@ def __call__( class AscendGatedDelta: def __init__(self, use_qk_l2norm_in_kernel: bool = True): + self.fused_recurrent_gated_delta_rule = fused_recurrent_gated_delta_rule self.fused_sigmoid_gating_delta_rule_update = ( fused_sigmoid_gating_delta_rule_update ) self.chunk_gated_delta_rule = chunk_gated_delta_rule self.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel + @staticmethod + def _get_decode_state_indices( + state_ids: torch.Tensor, + cache_seqlens: torch.Tensor, + num_slots: int, + query_len: int, + ) -> torch.Tensor: + """Map LMDeploy's per-sequence slot layout to per-token state ids.""" + token_offsets = torch.arange( + query_len, + device=cache_seqlens.device, + dtype=torch.int64, + ) + slot_offsets = torch.remainder( + cache_seqlens.to(torch.int64)[:, None] + token_offsets[None], + num_slots, + ) + return (state_ids.to(torch.int64)[:, None] * num_slots + slot_offsets).contiguous() + def __call__( self, query: torch.Tensor, @@ -374,10 +711,13 @@ def __call__( """call.""" is_decoding = gated_delta_meta.is_decoding - + is_multi_token_decoding = gated_delta_meta.is_multi_token_decoding + + beta = b.sigmoid() + # If the model is loaded in fp16, without the .float() here, A might be -inf + g = (-A_log.float().exp()) * F.softplus(a.float() + dt_bias) + if is_decoding: - indices = gated_delta_meta.state_ids - cu_seqlens = gated_delta_meta.cu_seqlens core_attn_out = self.fused_sigmoid_gating_delta_rule_update( A_log=A_log, dt_bias=dt_bias, @@ -387,19 +727,39 @@ def __call__( a=a.contiguous(), b=b.contiguous(), initial_state_source=recurrent_state, - initial_state_indices=indices, - cu_seqlens=cu_seqlens, + initial_state_indices=gated_delta_meta.state_ids, + cu_seqlens=gated_delta_meta.cu_seqlens, use_qk_l2norm_in_kernel=True, softplus_beta=1.0, softplus_threshold=20.0, ) - last_recurrent_state = None + return core_attn_out, None + elif is_multi_token_decoding: + state_slots = recurrent_state.size(1) + flat_recurrent_state = recurrent_state.view(-1, *recurrent_state.shape[2:]) + core_attn_out, _ = self.fused_recurrent_gated_delta_rule( + q=query.contiguous(), + k=key.contiguous(), + v=value.contiguous(), + g=g.contiguous(), + beta=beta.contiguous(), + initial_state=flat_recurrent_state, + inplace_final_state=True, + cu_seqlens=gated_delta_meta.cu_seqlens, + cache_seqlens_rb=gated_delta_meta.cache_seqlens, + state_ids_rb=gated_delta_meta.state_ids, + num_state=state_slots, + use_qk_l2norm_in_kernel=self.use_qk_l2norm_in_kernel, + ) + return core_attn_out, None else: - beta = b.sigmoid() - # If the model is loaded in fp16, without the .float() here, A might be -inf - g = (-A_log.float().exp()) * F.softplus(a.float() + dt_bias) - - initial_state = recurrent_state[gated_delta_meta.state_ids] + if gated_delta_meta.spec_state_offsets is not None: + state_ids = gated_delta_meta.state_ids + # Circular-buffer read slot: history_len % NUM_STATE + read_slots = gated_delta_meta.spec_state_offsets[0] + initial_state = recurrent_state[state_ids, read_slots].transpose(-1, -2).contiguous() + else: + initial_state = recurrent_state[gated_delta_meta.state_ids] initial_state[~gated_delta_meta.has_initial_state, ...] = 0 core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule( q=query, @@ -413,11 +773,18 @@ def __call__( head_first=False, use_qk_l2norm_in_kernel=self.use_qk_l2norm_in_kernel, ) - recurrent_state[gated_delta_meta.state_ids] = last_recurrent_state.to( - recurrent_state.dtype - ) - - return core_attn_out, last_recurrent_state + if gated_delta_meta.spec_state_offsets is not None: + state_ids = gated_delta_meta.state_ids + # Circular-buffer write slot: (history_len + query_len) % NUM_STATE + write_slots = gated_delta_meta.spec_state_offsets[1] + recurrent_state[state_ids, write_slots] = last_recurrent_state.transpose(-1, -2).to( + recurrent_state.dtype + ) + else: + recurrent_state[gated_delta_meta.state_ids] = last_recurrent_state.to( + recurrent_state.dtype + ) + return core_attn_out, last_recurrent_state gated_delta.GatedDeltaMeta = AscendGatedDeltaMeta gated_delta.CausalConv1dFunc = AscendCausalConv1dFunc @@ -431,6 +798,36 @@ def import_vendor_module(vendor_name_str): importlib.import_module(f".{vendor_name_str}", __package__) +def patch_attention_is_tp(): + """Monkey-patch Qwen3_5Attention to skip TP head division for draft model. + + The MTP draft model uses is_tp=False to keep full head counts on each + rank. Qwen3_5Attention already passes is_tp to build_qkv_proj and + build_o_proj, but Attention.__init__ always calls _update_num_heads. + We temporarily replace _update_num_heads with an identity function + during Qwen3_5Attention.__init__ when is_tp=False. + """ + from lmdeploy.pytorch.nn import attention as _attn_mod + from lmdeploy.pytorch.models import qwen3_5 + + _orig_update = _attn_mod._update_num_heads + _identity_update = lambda nh, nkv: (nh, nkv) + _orig_init = qwen3_5.Qwen3_5Attention.__init__ + + def _patched_init(self, config, layer_idx, dtype=None, device=None, + prefix='', is_tp=True): + if not is_tp: + _attn_mod._update_num_heads = _identity_update + try: + _orig_init(self, config, layer_idx, dtype=dtype, device=device, + prefix=prefix, is_tp=is_tp) + finally: + if not is_tp: + _attn_mod._update_num_heads = _orig_update + + qwen3_5.Qwen3_5Attention.__init__ = _patched_init + + def patch_qwen3_5(): import torch from typing import List @@ -451,6 +848,9 @@ def patch_qwen3_5(): @classmethod def custom_build(cls, hf_config, model_path: str = None, tp: int = 1, **kwargs): """build.""" + is_draft_model = kwargs.get("is_draft_model", False) + spec_method = kwargs.get("spec_method", None) + num_spec_tokens = kwargs.get("num_spec_tokens", 0) text_config = hf_config.text_config # propagate quantization_config from top-level hf_config into text_config quantization_config = getattr(hf_config, "quantization_config", None) @@ -475,11 +875,14 @@ def custom_build(cls, hf_config, model_path: str = None, tp: int = 1, **kwargs): key_dim = head_k_dim * num_k_heads value_dim = head_v_dim * num_v_heads conv_dim = key_dim * 2 + value_dim - conv_kernel_size = text_config.linear_conv_kernel_dim + conv_kernel_size = text_config.linear_conv_kernel_dim + num_spec_tokens # Ascend Patch conv_state_shape = (conv_kernel_size, conv_dim) - recurrent_state_shape = (num_v_heads, head_k_dim, head_v_dim) + if num_spec_tokens > 0: + recurrent_state_shape = (1 + num_spec_tokens, num_v_heads, head_k_dim, head_v_dim) + else: + recurrent_state_shape = (num_v_heads, head_k_dim, head_v_dim) device_type = kwargs.get("device_type", "cuda") if is_bf16_supported(device_type): @@ -499,6 +902,23 @@ def custom_build(cls, hf_config, model_path: str = None, tp: int = 1, **kwargs): cfg.check_env_func = _check_env_qwen3_next cfg.use_mrope = True + + # Speculative decoding support + if spec_method is not None: + assert spec_method == "qwen3_5_mtp" + cfg.model_paradigm = "ar_spec" + + if is_draft_model: + hf_config.architectures = ["Qwen3_5MTPModel"] + if getattr(hf_config, "auto_map", None): + hf_config.auto_map = {} + cfg.model_paradigm = "ar_spec" + cfg.num_layers = text_config.mtp_num_hidden_layers + cfg.states_shapes = [] + # Draft model uses is_tp=False — each rank runs the full model + # independently, so keep the replicated KV head count for correct + # cache allocation. + return cfg def custom_prepare_inputs_for_generation( @@ -671,6 +1091,107 @@ def custom_forward( ) +def patch_ray_init(): + """Monkey-patch lmdeploy's init_ray_cluster to register custom NPU resources. + + Ray does not auto-detect Ascend NPUs; without registering custom resources + at ray.init() time, placement groups requesting ``{'NPU': 1}`` never schedule + on a fresh local cluster. + """ + import os + import logging + import lmdeploy.pytorch.ray as _ray_mod + + logger = logging.getLogger('dlinfer.ray') + _orig_init_ray_cluster = _ray_mod.init_ray_cluster + + def _infer_local_ray_custom_resources(device_type, world_size): + if device_type == 'ascend': + n = None + try: + npu_mod = getattr(torch, 'npu', None) + if npu_mod is not None and callable(getattr(npu_mod, 'device_count', None)): + n = int(npu_mod.device_count()) + if n <= 0: + n = None + except Exception: + n = None + if n is None: + vis = os.environ.get('ASCEND_RT_VISIBLE_DEVICES', '').strip() + if vis: + n = len([x for x in vis.split(',') if x.strip() != '']) + if n is None or n <= 0: + n = int(world_size) + logger.warning( + 'Could not detect NPU count; registering Ray resource NPU=%d ' + 'from world_size.', n) + return {'NPU': float(n)} + if device_type == 'camb': + n = None + try: + mlu = getattr(torch, 'mlu', None) + if mlu is not None and callable(getattr(mlu, 'device_count', None)): + n = int(mlu.device_count()) + if n <= 0: + n = None + except Exception: + n = None + if n is None or n <= 0: + n = int(world_size) + logger.warning('Could not detect MLU count; registering MLU=%d.', n) + return {'MLU': float(n)} + return None + + def _patched_init_ray_cluster(world_size, ray_address=None, dp=1, device_type='cuda'): + """Same as original but registers custom resources at ray.init() for local clusters.""" + import ray + if not ray.is_initialized(): + num_cpus = world_size + object_store_memory = _ray_mod._get_obj_store_memory(dp=dp) + init_kwargs = dict( + ignore_reinit_error=True, + num_cpus=num_cpus, + object_store_memory=object_store_memory, + ) + if ray_address is not None: + init_kwargs['address'] = ray_address + if ray_address is None: + custom_res = _infer_local_ray_custom_resources(device_type, world_size) + if custom_res: + init_kwargs['resources'] = custom_res + try: + ray.init(**init_kwargs) + except ValueError as e: + if e.args is not None and len(e.args) >= 1 and e.args[ + 0] == 'When connecting to an existing cluster, num_cpus and num_gpus must not be provided.': + ray.init(address=ray_address, ignore_reinit_error=True) + else: + raise + + # Remaining logic unchanged from original init_ray_cluster + device_str = _ray_mod.get_device_str(device_type) + current_placement_group = ray.util.get_current_placement_group() + owned_pg = False + if not current_placement_group: + num_devices_in_cluster = ray.cluster_resources().get(device_str, 0) + if world_size > num_devices_in_cluster: + _ray_mod.logger.warning( + 'The number of required %ss exceeds the total ' + 'number of available %ss in the placement group.', device_str, device_str) + placement_group_specs = [{device_str: 1.0} for _ in range(world_size)] + current_ip = ray.util.get_node_ip_address() + placement_group_specs[0][f'node:{current_ip}'] = 0.001 + current_placement_group = ray.util.placement_group(placement_group_specs, strategy='PACK') + _ray_mod._wait_until_pg_ready(current_placement_group) + owned_pg = True + + assert current_placement_group is not None + placement_group = current_placement_group + return placement_group, owned_pg + + _ray_mod.init_ray_cluster = _patched_init_ray_cluster + + def vendor_device_init(): import_vendor_module(vendor_name) patch_compiled_func() @@ -678,9 +1199,13 @@ def vendor_device_init(): if vendor_name in ["camb", "ascend"]: patch_contiguous_cache_engine() if vendor_name == "ascend": + patch_rejection_sampler() + patch_spec_decode_runtime() patch_state_cache_engine() - patch_gated_delta_net() + patch_gated_delta_net() # MUST be before patch_attention_is_tp + patch_attention_is_tp() patch_qwen3_5() + patch_ray_init() vendor_device_init() diff --git a/dlinfer/vendor/ascend/torch_npu_ops.py b/dlinfer/vendor/ascend/torch_npu_ops.py index 2a12aabe..62734102 100644 --- a/dlinfer/vendor/ascend/torch_npu_ops.py +++ b/dlinfer/vendor/ascend/torch_npu_ops.py @@ -1,5 +1,4 @@ # Copyright (c) 2024, DeepLink. All rights reserved. -import os import math import torch import torch.distributed as dist diff --git a/dlinfer/vendor/ascend/triton_ops/__init__.py b/dlinfer/vendor/ascend/triton_ops/__init__.py index 49a22221..f08c5f4b 100644 --- a/dlinfer/vendor/ascend/triton_ops/__init__.py +++ b/dlinfer/vendor/ascend/triton_ops/__init__.py @@ -6,11 +6,13 @@ "causal_conv1d_fn", "causal_conv1d_update_npu", "chunk_gated_delta_rule", + "fused_recurrent_gated_delta_rule", "fused_sigmoid_gating_delta_rule_update", "RMSNormGated", ] from .fla.chunk import chunk_gated_delta_rule from .fla.sigmoid_gating import fused_sigmoid_gating_delta_rule_update +from .fla.fused_recurrent import fused_recurrent_gated_delta_rule from .rms_norm_gated import RMSNormGated from .causal_conv1d import causal_conv1d_fn, causal_conv1d_update_npu diff --git a/dlinfer/vendor/ascend/triton_ops/causal_conv1d.py b/dlinfer/vendor/ascend/triton_ops/causal_conv1d.py index 00c919e7..e249bfae 100644 --- a/dlinfer/vendor/ascend/triton_ops/causal_conv1d.py +++ b/dlinfer/vendor/ascend/triton_ops/causal_conv1d.py @@ -70,6 +70,8 @@ def causal_conv1d_fn( query_start_loc: Optional[torch.Tensor] = None, metadata: Optional[Any] = None, pad_slot_id: int = PAD_SLOT_ID, + read_conv_offsets: Optional[torch.Tensor] = None, + write_conv_offsets: Optional[torch.Tensor] = None, ): """ Prefill-phase varlen causal conv1d using PyTorch reference implementation. @@ -80,7 +82,17 @@ def causal_conv1d_fn( query_start_loc: (batch + 1) int32 cache_indices: (batch) int32 has_initial_state: (batch) bool - conv_states: (..., dim, width - 1) + conv_states: (..., dim, state_len) — state_len == width-1 in legacy linear + layout; state_len == width + num_spec_tokens in ring layout. + + Ring-buffer mode (used by MTP / speculative decoding) is enabled when + `write_conv_offsets` is provided. Then `conv_states` is treated as a ring + of size `state_len = conv_states.shape[-1]`: + * initial state (when has_initial_state[i]) is gathered from + read_conv_offsets[i] (shape (width-1,)); + * the last `width` tokens of each sequence are scattered to + write_conv_offsets[i] (shape (width,)). + When `write_conv_offsets` is None the legacy semantics apply unchanged. """ if activation not in [None, "silu", "swish"]: raise NotImplementedError("activation must be None, silu, or swish") @@ -91,6 +103,7 @@ def causal_conv1d_fn( if query_start_loc is None: raise ValueError("query_start_loc is required for prefill mode") + is_ring = write_conv_offsets is not None seqlens = query_start_loc[1:] - query_start_loc[:-1] seqlens = seqlens.tolist() splits = torch.split(x, seqlens, dim=-1) @@ -100,23 +113,58 @@ def causal_conv1d_fn( x_s = splits[i] if cache_indices[i] == PAD_SLOT_ID: continue + + if has_initial_state[i]: + if is_ring: + slot_idx = read_conv_offsets[i] + init_state = ( + conv_states[cache_indices[i]] + .index_select(-1, slot_idx) + .unsqueeze(0) + ) + else: + init_state = conv_states[cache_indices[i]][..., : (width - 1)] + else: + init_state = None + out_ref_b = causal_conv1d_ref( x_s, weight, bias, activation=activation, - return_final_states=True, - final_states_out=conv_states[cache_indices[i]][ - ..., : (width - 1) - ].unsqueeze(0), - initial_states=( - conv_states[cache_indices[i]][..., : (width - 1)] - if has_initial_state[i] - else None + return_final_states=not is_ring, + final_states_out=( + None + if is_ring + else conv_states[cache_indices[i]][..., : (width - 1)].unsqueeze(0) ), + initial_states=init_state, ) out_chunks.append(out_ref_b[0]) out = torch.cat(out_chunks, dim=-1) + + if is_ring: + # Vectorised ring write: place each sequence's trailing `width` tokens + # at their circular slots in conv_states. Replaces the post-hoc overwrite + # the caller used to do, and removes the redundant linear-slot copy_(). + K = write_conv_offsets.size(1) + tok_offsets = torch.arange( + -K, 0, device=query_start_loc.device, dtype=query_start_loc.dtype + ) + token_idx = ( + (query_start_loc[1:, None] + tok_offsets[None]).clamp_min(0).to(torch.int64) + ) + x_gather = x.index_select(1, token_idx.reshape(-1)).reshape( + x.size(0), token_idx.size(0), token_idx.size(1) + ) + cache_idx = cache_indices.to(torch.int64) + dim_idx = torch.arange(conv_states.size(1), device=conv_states.device) + conv_states[ + cache_idx[:, None, None], + dim_idx[None, :, None], + write_conv_offsets[:, None, :], + ] = x_gather.permute(1, 0, 2).to(conv_states.dtype) + return out @@ -131,6 +179,7 @@ def _causal_conv1d_update_kernel_npu_tiled( query_start_loc_ptr, block_idx_last_scheduled_token, initial_state_idx, + cache_seqlens_ptr, o_ptr, batch: tl.int32, dim: tl.constexpr, @@ -158,6 +207,7 @@ def _causal_conv1d_update_kernel_npu_tiled( IS_SPEC_DECODING: tl.constexpr, NP2_STATELEN: tl.constexpr, USE_PAD_SLOT: tl.constexpr, + IS_CIRCULAR_BUFFER: tl.constexpr, BLOCK_N: tl.constexpr, B_TILE: tl.constexpr, T_CHUNK: tl.constexpr, @@ -249,17 +299,26 @@ def _causal_conv1d_update_kernel_npu_tiled( lane_active = lane_active & (seqlen_run > 0) - if IS_SPEC_DECODING: - conv_state_token_offset = ( - tl.load(num_accepted_tokens_ptr + b, mask=lane_active, other=1).to( - tl.int64 - ) - - 1 - ) - shift = tl.full((), 1, tl.int32) - else: + if IS_CIRCULAR_BUFFER: + cb_raw = tl.load(cache_seqlens_ptr + b, mask=lane_active, other=0).to(tl.int32) + cb_pos = cb_raw % state_len # write-start: wrapped current position + cb_read_start = (cb_pos - (KERNEL_WIDTH - 1) + state_len) % state_len conv_state_token_offset = tl.full((), 0, tl.int64) - shift = seqlen_run + shift = tl.full((), 0, tl.int32) + else: + cb_pos = tl.full((), 0, tl.int32) + cb_read_start = tl.full((), 0, tl.int32) + if IS_SPEC_DECODING: + conv_state_token_offset = ( + tl.load(num_accepted_tokens_ptr + b, mask=lane_active, other=1).to( + tl.int64 + ) + - 1 + ) + shift = tl.full((), 1, tl.int32) + else: + conv_state_token_offset = tl.full((), 0, tl.int64) + shift = seqlen_run conv_states_base = ( conv_state_ptr @@ -275,42 +334,77 @@ def _causal_conv1d_update_kernel_npu_tiled( col2 = tl.zeros((BLOCK_N,), dtype=tl.float16) col3 = tl.zeros((BLOCK_N,), dtype=tl.float16) col4 = tl.zeros((BLOCK_N,), dtype=tl.float16) - if KERNEL_WIDTH >= 2: - col0 = tl.load( - prior_tokens + 0 * stride_conv_state_tok, - mask=lane_active & mask_w, - other=0.0, - ).to(tl.float16) - if KERNEL_WIDTH >= 3: - col1 = tl.load( - prior_tokens + 1 * stride_conv_state_tok, - mask=lane_active & mask_w, - other=0.0, - ).to(tl.float16) - if KERNEL_WIDTH >= 4: - col2 = tl.load( - prior_tokens + 2 * stride_conv_state_tok, - mask=lane_active & mask_w, - other=0.0, - ).to(tl.float16) - if KERNEL_WIDTH >= 5: - col3 = tl.load( - prior_tokens + 3 * stride_conv_state_tok, - mask=lane_active & mask_w, - other=0.0, - ).to(tl.float16) - if KERNEL_WIDTH >= 6: - col4 = tl.load( - prior_tokens + 4 * stride_conv_state_tok, - mask=lane_active & mask_w, - other=0.0, - ).to(tl.float16) - - conv_states_offset = tl.load( - conv_state_indices_ptr + b * stride_state_indices + current_last_index, - mask=lane_active, - other=0, - ).to(tl.int64) + if IS_CIRCULAR_BUFFER: + if KERNEL_WIDTH >= 2: + col0 = tl.load( + conv_states_base + cb_read_start * stride_conv_state_tok, + mask=lane_active & mask_w, + other=0.0, + ).to(tl.float16) + if KERNEL_WIDTH >= 3: + col1 = tl.load( + conv_states_base + ((cb_read_start + 1) % state_len) * stride_conv_state_tok, + mask=lane_active & mask_w, + other=0.0, + ).to(tl.float16) + if KERNEL_WIDTH >= 4: + col2 = tl.load( + conv_states_base + ((cb_read_start + 2) % state_len) * stride_conv_state_tok, + mask=lane_active & mask_w, + other=0.0, + ).to(tl.float16) + if KERNEL_WIDTH >= 5: + col3 = tl.load( + conv_states_base + ((cb_read_start + 3) % state_len) * stride_conv_state_tok, + mask=lane_active & mask_w, + other=0.0, + ).to(tl.float16) + if KERNEL_WIDTH >= 6: + col4 = tl.load( + conv_states_base + ((cb_read_start + 4) % state_len) * stride_conv_state_tok, + mask=lane_active & mask_w, + other=0.0, + ).to(tl.float16) + else: + if KERNEL_WIDTH >= 2: + col0 = tl.load( + prior_tokens + 0 * stride_conv_state_tok, + mask=lane_active & mask_w, + other=0.0, + ).to(tl.float16) + if KERNEL_WIDTH >= 3: + col1 = tl.load( + prior_tokens + 1 * stride_conv_state_tok, + mask=lane_active & mask_w, + other=0.0, + ).to(tl.float16) + if KERNEL_WIDTH >= 4: + col2 = tl.load( + prior_tokens + 2 * stride_conv_state_tok, + mask=lane_active & mask_w, + other=0.0, + ).to(tl.float16) + if KERNEL_WIDTH >= 5: + col3 = tl.load( + prior_tokens + 3 * stride_conv_state_tok, + mask=lane_active & mask_w, + other=0.0, + ).to(tl.float16) + if KERNEL_WIDTH >= 6: + col4 = tl.load( + prior_tokens + 4 * stride_conv_state_tok, + mask=lane_active & mask_w, + other=0.0, + ).to(tl.float16) + + if not IS_CIRCULAR_BUFFER: + conv_states_offset = tl.load( + conv_state_indices_ptr + b * stride_state_indices + current_last_index, + mask=lane_active, + other=0, + ).to(tl.int64) + else: + conv_states_offset = tl.full((), 0, tl.int64) use_shift = seqlen_run < state_len_run use_tail = seqlen_run >= state_len_run @@ -335,61 +429,79 @@ def _causal_conv1d_update_kernel_npu_tiled( ) x_base = x_ptr + x_offset + idx_feats * stride_x_dim - for t0 in tl.static_range(0, NP2_STATELEN, T_CHUNK): - dst_tok = (t0 + tok_vec).to(tl.int32) - src_tok = (dst_tok + shift).to(tl.int32) - m_tok = ( - use_shift - & (dst_tok < keep_shift) - & (src_tok < state_len_run) - & (dst_tok < state_len_run) - ) - m = ( - (lane_active & m_tok)[:, None] - & mask_w[None, :] - & (conv_states_input_coord < num_cache_lines) - & (conv_states_offset < num_cache_lines) - ) - src_ptrs = ( - state_src_base[None, :] + src_tok[:, None] * stride_conv_state_tok - ) - dst_ptrs = ( - state_dst_base[None, :] + dst_tok[:, None] * stride_conv_state_tok - ) - vals = tl.load(src_ptrs, mask=m, other=0.0) - tl.store(dst_ptrs, vals, mask=m) - - for t0 in tl.static_range(0, seqlen, T_CHUNK): - x_tok = (t0 + tok_vec).to(tl.int32) - dst_tok = (keep_shift + x_tok).to(tl.int32) - m_tok = use_shift & (x_tok < seqlen_run) & (dst_tok < state_len_run) - m = ( - (lane_active & m_tok)[:, None] - & mask_w[None, :] - & (conv_states_offset < num_cache_lines) - ) - x_ptrs = x_base[None, :] + x_tok[:, None] * stride_x_token - dst_ptrs = ( - state_dst_base[None, :] + dst_tok[:, None] * stride_conv_state_tok - ) - x_vals = tl.load(x_ptrs, mask=m, other=0.0) - tl.store(dst_ptrs, x_vals, mask=m) - - for t0 in tl.static_range(0, NP2_STATELEN, T_CHUNK): - dst_tok = (t0 + tok_vec).to(tl.int32) - x_tok = (tail_start + dst_tok).to(tl.int32) - m_tok = use_tail & (dst_tok < state_len_run) & (x_tok < seqlen_run) - m = ( - (lane_active & m_tok)[:, None] - & mask_w[None, :] - & (conv_states_offset < num_cache_lines) - ) - x_ptrs = x_base[None, :] + x_tok[:, None] * stride_x_token - dst_ptrs = ( - state_dst_base[None, :] + dst_tok[:, None] * stride_conv_state_tok - ) - x_vals = tl.load(x_ptrs, mask=m, other=0.0) - tl.store(dst_ptrs, x_vals, mask=m) + if IS_CIRCULAR_BUFFER: + # Circular write: write each input token to its circular position. + # No shift needed — positions are derived from cache_seqlens. + for t0 in tl.static_range(0, seqlen, T_CHUNK): + x_tok = (t0 + tok_vec).to(tl.int32) + write_tok = (cb_pos + x_tok) % state_len + m = ( + (lane_active & (x_tok < seqlen_run))[:, None] + & mask_w[None, :] + & (conv_states_input_coord < num_cache_lines) + ) + x_ptrs = x_base[None, :] + x_tok[:, None] * stride_x_token + dst_ptrs = ( + conv_states_base[None, :] + write_tok[:, None] * stride_conv_state_tok + ) + x_vals = tl.load(x_ptrs, mask=m, other=0.0) + tl.store(dst_ptrs, x_vals, mask=m) + else: + for t0 in tl.static_range(0, NP2_STATELEN, T_CHUNK): + dst_tok = (t0 + tok_vec).to(tl.int32) + src_tok = (dst_tok + shift).to(tl.int32) + m_tok = ( + use_shift + & (dst_tok < keep_shift) + & (src_tok < state_len_run) + & (dst_tok < state_len_run) + ) + m = ( + (lane_active & m_tok)[:, None] + & mask_w[None, :] + & (conv_states_input_coord < num_cache_lines) + & (conv_states_offset < num_cache_lines) + ) + src_ptrs = ( + state_src_base[None, :] + src_tok[:, None] * stride_conv_state_tok + ) + dst_ptrs = ( + state_dst_base[None, :] + dst_tok[:, None] * stride_conv_state_tok + ) + vals = tl.load(src_ptrs, mask=m, other=0.0) + tl.store(dst_ptrs, vals, mask=m) + + for t0 in tl.static_range(0, seqlen, T_CHUNK): + x_tok = (t0 + tok_vec).to(tl.int32) + dst_tok = (keep_shift + x_tok).to(tl.int32) + m_tok = use_shift & (x_tok < seqlen_run) & (dst_tok < state_len_run) + m = ( + (lane_active & m_tok)[:, None] + & mask_w[None, :] + & (conv_states_offset < num_cache_lines) + ) + x_ptrs = x_base[None, :] + x_tok[:, None] * stride_x_token + dst_ptrs = ( + state_dst_base[None, :] + dst_tok[:, None] * stride_conv_state_tok + ) + x_vals = tl.load(x_ptrs, mask=m, other=0.0) + tl.store(dst_ptrs, x_vals, mask=m) + + for t0 in tl.static_range(0, NP2_STATELEN, T_CHUNK): + dst_tok = (t0 + tok_vec).to(tl.int32) + x_tok = (tail_start + dst_tok).to(tl.int32) + m_tok = use_tail & (dst_tok < state_len_run) & (x_tok < seqlen_run) + m = ( + (lane_active & m_tok)[:, None] + & mask_w[None, :] + & (conv_states_offset < num_cache_lines) + ) + x_ptrs = x_base[None, :] + x_tok[:, None] * stride_x_token + dst_ptrs = ( + state_dst_base[None, :] + dst_tok[:, None] * stride_conv_state_tok + ) + x_vals = tl.load(x_ptrs, mask=m, other=0.0) + tl.store(dst_ptrs, x_vals, mask=m) x_base_1d = x_base o_base_1d = o_ptr + o_offset + idx_feats * stride_o_dim @@ -513,6 +625,7 @@ def causal_conv1d_update_npu( pad_slot_id: int = PAD_SLOT_ID, block_idx_last_scheduled_token: Optional[torch.Tensor] = None, initial_state_idx: Optional[torch.Tensor] = None, + cache_seqlens: Optional[torch.Tensor] = None, validate_data=False, ): if validate_data: @@ -557,7 +670,10 @@ def causal_conv1d_update_npu( conv_state_indices.stride(0) if conv_state_indices is not None else 0 ) - if num_accepted_tokens is not None: + if cache_seqlens is not None: + # Circular buffer: use the full allocated state size as the modulus. + eff_state_len = state_len_total + elif num_accepted_tokens is not None: eff_state_len = width - 1 + (seqlen - 1) else: eff_state_len = width - 1 @@ -591,6 +707,7 @@ def grid(META): query_start_loc, block_idx_last_scheduled_token, initial_state_idx, + cache_seqlens, out, batch, dim, @@ -618,6 +735,7 @@ def grid(META): IS_SPEC_DECODING=num_accepted_tokens is not None, NP2_STATELEN=np2_statelen, USE_PAD_SLOT=pad_slot_id is not None, + IS_CIRCULAR_BUFFER=cache_seqlens is not None, BLOCK_N=block_n, B_TILE=b_tile, T_CHUNK=t_chunk, diff --git a/dlinfer/vendor/ascend/triton_ops/fla/__init__.py b/dlinfer/vendor/ascend/triton_ops/fla/__init__.py index e2eea080..e7104dec 100644 --- a/dlinfer/vendor/ascend/triton_ops/fla/__init__.py +++ b/dlinfer/vendor/ascend/triton_ops/fla/__init__.py @@ -2,8 +2,10 @@ from .chunk import chunk_gated_delta_rule from .sigmoid_gating import fused_sigmoid_gating_delta_rule_update +from .fused_recurrent import fused_recurrent_gated_delta_rule __all__ = [ "chunk_gated_delta_rule", + "fused_recurrent_gated_delta_rule", "fused_sigmoid_gating_delta_rule_update", ] diff --git a/dlinfer/vendor/ascend/triton_ops/fla/fused_recurrent.py b/dlinfer/vendor/ascend/triton_ops/fla/fused_recurrent.py new file mode 100644 index 00000000..6709bcad --- /dev/null +++ b/dlinfer/vendor/ascend/triton_ops/fla/fused_recurrent.py @@ -0,0 +1,446 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 + +import os +import torch +import triton +import triton.language as tl + +if os.environ.get("FLA_USE_FAST_OPS", "0") == "1": + exp = tldevice.fast_expf + log = tldevice.fast_logf + log2 = tldevice.fast_log2f +else: + exp = tl.exp + log = tl.log + log2 = tl.log2 + + +@triton.heuristics( + { + "USE_INITIAL_STATE": lambda args: args["h0"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + "IS_CONTINUOUS_BATCHING": lambda args: args["ssm_state_indices"] is not None, + "IS_SPEC_DECODING": lambda args: args["num_accepted_tokens"] is not None, + "IS_CIRCULAR_BUFFER": lambda args: args["cache_seqlens_rb"] is not None, + } +) +@triton.jit(do_not_specialize=["N", "T"]) +def fused_recurrent_gated_delta_rule_fwd_kernel( + q, + k, + v, + g, + beta, + o, + h0, + ht, + cu_seqlens, + ssm_state_indices, + num_accepted_tokens, + cache_seqlens_rb, # [N] history lengths for circular-buffer read/write + state_ids_rb, # [N] per-sequence base slot index (= state_id) + scale, + N: tl.int64, # num of sequences + T: tl.int64, # num of tokens + B: tl.constexpr, + H: tl.constexpr, + HV: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NUM_STATE: tl.constexpr, # circular buffer size (= state_slots = 1 + num_spec_tokens) + stride_init_state_token: tl.constexpr, + stride_final_state_token: tl.constexpr, + stride_indices_seq: tl.constexpr, + stride_indices_tok: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace + IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar, + USE_QK_L2NORM_IN_KERNEL: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_CONTINUOUS_BATCHING: tl.constexpr, + IS_SPEC_DECODING: tl.constexpr, + IS_CIRCULAR_BUFFER: tl.constexpr, # use NVIDIA-style circular-buffer state indexing + IS_KDA: tl.constexpr, +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_hv = i_nh // HV, i_nh % HV + i_h = i_hv // (HV // H) + if IS_VARLEN: + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int64), + tl.load(cu_seqlens + i_n + 1).to(tl.int64), + ) + all = T + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + all = B * T + + if T == 0: + # no tokens to process for this sequence + return + + o_k = i_k * BK + tl.arange(0, BK) + o_v = i_v * BV + tl.arange(0, BV) + + p_q = q + (bos * H + i_h) * K + o_k + p_k = k + (bos * H + i_h) * K + o_k + p_v = v + (bos * HV + i_hv) * V + o_v + if IS_BETA_HEADWISE: + p_beta = beta + (bos * HV + i_hv) * V + o_v + else: + p_beta = beta + bos * HV + i_hv + + if not IS_KDA: + p_g = g + bos * HV + i_hv + else: + p_gk = g + (bos * HV + i_hv) * K + o_k + + p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v + + mask_k = o_k < K + mask_v = o_v < V + mask_h = mask_v[:, None] & mask_k[None, :] + + b_h = tl.zeros([BV, BK], dtype=tl.float32) + + # Pre-load circular-buffer addressing params once (before the token loop). + # IS_CIRCULAR_BUFFER is constexpr so the dead branch is eliminated at compile time. + if IS_CIRCULAR_BUFFER: + h_rb = tl.load(cache_seqlens_rb + i_n).to(tl.int64) + s_id_rb = tl.load(state_ids_rb + i_n).to(tl.int64) + # Skip padding sequences: invalid state_ids are clamped to 0 on the host side. + # Letting them proceed would corrupt slot-0 state and cause NaN in graph mode. + if s_id_rb <= 0: + return + + if USE_INITIAL_STATE: + if IS_CIRCULAR_BUFFER: + # Circular-buffer read: slot = cache_seqlens % NUM_STATE + read_slot = s_id_rb * NUM_STATE + h_rb % NUM_STATE + p_h0 = h0 + read_slot * stride_init_state_token + elif IS_CONTINUOUS_BATCHING: + if IS_SPEC_DECODING: + i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1 + else: + i_t = 0 + # Load state index and check for PAD_SLOT_ID (-1) + state_idx = tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to( + tl.int64 + ) + # Skip if state index is invalid (PAD_SLOT_ID = -1) + if state_idx <= 0: + return + p_h0 = h0 + state_idx * stride_init_state_token + else: + p_h0 = h0 + bos * HV * V * K + p_h0 = p_h0 + i_hv * V * K + o_v[:, None] * K + o_k[None, :] + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for i_t in range(0, T): + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + + if USE_QK_L2NORM_IN_KERNEL: + b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6) + b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6) + b_q = b_q * scale + # [BV, BK] + if not IS_KDA: + b_g = tl.load(p_g).to(tl.float32) + b_h *= exp(b_g) + else: + b_gk = tl.load(p_gk).to(tl.float32) + b_h *= exp(b_gk[None, :]) + # [BV] + b_v -= tl.sum(b_h * b_k[None, :], 1) + if IS_BETA_HEADWISE: + b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32) + else: + b_beta = tl.load(p_beta).to(tl.float32) + b_v *= b_beta + # [BV, BK] + b_h += b_v[:, None] * b_k[None, :] + # [BV] + b_o = tl.sum(b_h * b_q[None, :], 1) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) + + # keep the states for multi-query tokens + if INPLACE_FINAL_STATE: + if IS_CIRCULAR_BUFFER: + # Circular-buffer write: slot = (cache_seqlens + i_t + 1) % NUM_STATE + write_slot = s_id_rb * NUM_STATE + (h_rb + i_t + 1) % NUM_STATE + p_ht = ht + write_slot * stride_final_state_token + p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :] + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + else: + # Load state index and check for PAD_SLOT_ID (-1) + final_state_idx = tl.load( + ssm_state_indices + i_n * stride_indices_seq + i_t + ).to(tl.int64) + # Only store if state index is valid (not PAD_SLOT_ID) + if final_state_idx > 0: + p_ht = ht + final_state_idx * stride_final_state_token + p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :] + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + else: + p_ht = ht + (bos + i_t) * stride_final_state_token + p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :] + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + + p_q += H * K + p_k += H * K + p_o += HV * V + p_v += HV * V + if not IS_KDA: + p_g += HV + else: + p_gk += HV * K + p_beta += HV * (V if IS_BETA_HEADWISE else 1) + + +def fused_recurrent_gated_delta_rule_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + inplace_final_state: bool = True, + cu_seqlens: torch.LongTensor | None = None, + ssm_state_indices: torch.Tensor | None = None, + num_accepted_tokens: torch.Tensor | None = None, + cache_seqlens_rb: torch.Tensor | None = None, + state_ids_rb: torch.Tensor | None = None, + num_state: int = 1, + use_qk_l2norm_in_kernel: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, v.shape[-1] + HV = v.shape[2] + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 32) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1, "NK > 1 is not supported yet" + num_stages = 3 + num_warps = 1 + + o = q.new_empty(NK, *v.shape) + if inplace_final_state: + final_state = initial_state + else: + final_state = q.new_empty(T, HV, V, K, dtype=initial_state.dtype) + + stride_init_state_token = initial_state.stride(0) + stride_final_state_token = final_state.stride(0) + + if ssm_state_indices is None: + stride_indices_seq, stride_indices_tok = 1, 1 + elif ssm_state_indices.ndim == 1: + stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1 + else: + stride_indices_seq, stride_indices_tok = ssm_state_indices.stride() + + grid = (NK, NV, N * HV) + fused_recurrent_gated_delta_rule_fwd_kernel[grid]( + q=q, + k=k, + v=v, + g=g, + beta=beta, + o=o, + h0=initial_state, + ht=final_state, + cu_seqlens=cu_seqlens, + ssm_state_indices=ssm_state_indices, + num_accepted_tokens=num_accepted_tokens, + cache_seqlens_rb=cache_seqlens_rb, + state_ids_rb=state_ids_rb, + scale=scale, + N=N, + T=T, + B=B, + H=H, + HV=HV, + K=K, + V=V, + BK=BK, + BV=BV, + NUM_STATE=num_state, + stride_init_state_token=stride_init_state_token, + stride_final_state_token=stride_final_state_token, + stride_indices_seq=stride_indices_seq, + stride_indices_tok=stride_indices_tok, + IS_BETA_HEADWISE=beta.ndim == v.ndim, + USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, + INPLACE_FINAL_STATE=inplace_final_state, + IS_KDA=False, + num_warps=num_warps, + num_stages=num_stages, + ) + o = o.squeeze(0) + return o, final_state + + +class FusedRecurrentFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + inplace_final_state: bool = True, + cu_seqlens: torch.LongTensor | None = None, + ssm_state_indices: torch.Tensor | None = None, + num_accepted_tokens: torch.Tensor | None = None, + cache_seqlens_rb: torch.Tensor | None = None, + state_ids_rb: torch.Tensor | None = None, + num_state: int = 1, + use_qk_l2norm_in_kernel: bool = False, + ): + o, final_state = fused_recurrent_gated_delta_rule_fwd( + q=q.contiguous(), + k=k.contiguous(), + v=v.contiguous(), + g=g.contiguous(), + beta=beta.contiguous(), + scale=scale, + initial_state=initial_state, + inplace_final_state=inplace_final_state, + cu_seqlens=cu_seqlens, + ssm_state_indices=ssm_state_indices, + num_accepted_tokens=num_accepted_tokens, + cache_seqlens_rb=cache_seqlens_rb, + state_ids_rb=state_ids_rb, + num_state=num_state, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + ) + + return o, final_state + + +def fused_recurrent_gated_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor = None, + scale: float = None, + initial_state: torch.Tensor = None, + inplace_final_state: bool = True, + cu_seqlens: torch.LongTensor | None = None, + ssm_state_indices: torch.Tensor | None = None, + num_accepted_tokens: torch.Tensor | None = None, + cache_seqlens_rb: torch.Tensor | None = None, + state_ids_rb: torch.Tensor | None = None, + num_state: int = 1, + use_qk_l2norm_in_kernel: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]`. + v (torch.Tensor): + values of shape `[B, T, HV, V]`. + GVA is applied if `HV > H`. + g (torch.Tensor): + g (decays) of shape `[B, T, HV]`. + beta (torch.Tensor): + betas of shape `[B, T, HV]`. + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, HV, V, K]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + inplace_final_state: bool: + Whether to store the final state in-place to save memory. + Default: `True`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + ssm_state_indices (Optional[torch.Tensor]): + Indices to map the input sequences to the initial/final states. + num_accepted_tokens (Optional[torch.Tensor]): + Number of accepted tokens for each sequence during decoding. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, HV, V]`. + final_state (torch.Tensor): + Final state of shape `[N, HV, V, K]`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule + # inputs with equal lengths + >>> B, T, H, HV, K, V = 4, 2048, 4, 8, 512, 512 + >>> q = torch.randn(B, T, H, K, device='cuda') + >>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1) + >>> v = torch.randn(B, T, HV, V, device='cuda') + >>> g = F.logsigmoid(torch.rand(B, T, HV, device='cuda')) + >>> beta = torch.rand(B, T, HV, device='cuda').sigmoid() + >>> h0 = torch.randn(B, HV, V, K, device='cuda') + >>> o, ht = fused_gated_recurrent_delta_rule( + q, k, v, g, beta, + initial_state=h0, + ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = fused_gated_recurrent_delta_rule( + q, k, v, g, beta, + initial_state=h0, + cu_seqlens=cu_seqlens + ) + """ + if cu_seqlens is not None and q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) + if scale is None: + scale = k.shape[-1] ** -0.5 + else: + assert scale > 0, "scale must be positive" + if beta is None: + beta = torch.ones_like(q[..., 0]) + o, final_state = FusedRecurrentFunction.apply( + q, + k, + v, + g, + beta, + scale, + initial_state, + inplace_final_state, + cu_seqlens, + ssm_state_indices, + num_accepted_tokens, + cache_seqlens_rb, + state_ids_rb, + num_state, + use_qk_l2norm_in_kernel, + ) + return o, final_state