Skip to content
101 changes: 76 additions & 25 deletions dlinfer/framework/lmdeploy_ext/cudagraph/ascend_cudagraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,26 @@ def aclgraph_use_torch_npu_update():


# AscendCudaGraphMixin methods for cudagraph buffer management.
def AscendCudaGraphMixin_support_cuda_graph(
self,
input_ids: Tensor,
position_ids: Tensor,
past_key_values: List[List[Tensor]],
attn_metadata: Any = None,
inputs_embeds: Tensor = None,
**kwargs,
):
"""Allow multi-token decode graph only when runtime length updates exist."""
if attn_metadata is None:
return False

is_decoding = getattr(attn_metadata, "is_decoding", False)
is_multi_token = getattr(attn_metadata, "is_multi_token_decoding", False)
if is_multi_token and not aclgraph_use_torch_npu_update():
return False
return is_decoding or is_multi_token


def AscendCudaGraphMixin_make_buffers_cudagraph(
self, graph_meta: CudaGraphMeta, *args, **kwargs
) -> BuffType:
Expand All @@ -58,9 +78,7 @@ def AscendCudaGraphMixin_make_buffers_cudagraph(
(max_batches, num_blocks), dtype=torch.int32, device=device
)

input_buffers["q_seqlens"] = torch.ones(
max_batches, dtype=torch.int32, device=device
)
input_buffers["q_seqlens"] = torch.ones(max_batches, dtype=torch.int32)

input_buffers["kv_seqlens"] = torch.ones(max_batches, dtype=torch.int32)

Expand All @@ -69,18 +87,23 @@ def AscendCudaGraphMixin_make_buffers_cudagraph(
)

input_buffers["kv_start_indices"] = -torch.ones(
(max_batches), dtype=torch.int32, device=device
(max_tokens), dtype=torch.int32, device=device
)

input_buffers["x_active_mask"] = torch.zeros(
(max_batches), dtype=torch.bool, device=device
(max_tokens), dtype=torch.bool, device=device
)

input_buffers["attention_mask"] = torch.triu(torch.ones(2048, 2048, dtype=torch.bool, device=device), diagonal=1)

# ssm
if graph_meta.is_ssm:
input_buffers["state_ids"] = torch.full(
(max_batches,), -1, dtype=torch.int64, device=device
)
input_buffers["cache_seqlens"] = torch.zeros(
max_batches, dtype=torch.int32, device=device
)

# mrope
if graph_meta.use_mrope:
Expand Down Expand Up @@ -108,11 +131,13 @@ def AscendCudaGraphMixin_fill_buffers_cudagraph(
moe_metadata = get_step_ctx_manager().current_context().moe_metadata
x_active_mask: Tensor = moe_metadata.x_active_mask
q_start_loc: Tensor = attn_metadata.q_start_loc
cache_seqlens: Tensor = attn_metadata.cache_seqlens

input_buffers: BuffType = graph_meta.input_buffers

batch_size, num_blocks = block_offsets.size()
num_tokens = input_ids.size(-1)
q_seqlens: Tensor = attn_metadata.q_seqlens

# fill buffer
max_num_tokens = input_buffers["input_ids"].size(-1)
Expand All @@ -126,22 +151,32 @@ def AscendCudaGraphMixin_fill_buffers_cudagraph(
input_buffers["position_ids"][:, :num_tokens] = position_ids
input_buffers["block_offsets"].zero_()
input_buffers["block_offsets"][:batch_size, :num_blocks] = block_offsets
input_buffers["q_seqlens"].fill_(0)
input_buffers["q_seqlens"][: batch_size] = q_seqlens
input_buffers["kv_seqlens"].fill_(0)
input_buffers["kv_seqlens"][:batch_size] = kv_seqlens
input_buffers["kv_start_indices"].fill_(-1)
input_buffers["kv_start_indices"][:batch_size] = kv_start_indices
input_buffers["kv_start_indices"][:kv_start_indices.size(0)] = kv_start_indices
if x_active_mask is not None:
input_buffers["x_active_mask"].fill_(0)
input_buffers["x_active_mask"][:batch_size] = x_active_mask
input_buffers["x_active_mask"][:x_active_mask.size(0)] = x_active_mask

# ssm
if graph_meta.is_ssm:
input_buffers["q_start_loc"][: batch_size + 1] = q_start_loc
input_buffers["q_start_loc"][batch_size + 1 :] = q_start_loc[-1]
bs = input_buffers["q_start_loc"].size(0)
max_q_seq_len = attn_metadata.max_q_seq_len
padding_tensor = torch.arange(0, bs) * max_q_seq_len
input_buffers["q_start_loc"].copy_(padding_tensor)
input_buffers["q_start_loc"][:q_start_loc.size(0)] = q_start_loc

state_ids = kwargs["state_ids"]
input_buffers["state_ids"].fill_(-1)
input_buffers["state_ids"][: state_ids.size(0)].copy_(state_ids)
input_buffers["state_ids"].fill_(0)
input_buffers["state_ids"][: batch_size].copy_(state_ids)

input_buffers["cache_seqlens"].fill_(0)
input_buffers["cache_seqlens"][: batch_size].copy_(cache_seqlens)

attn_metadata.cache_seqlens = input_buffers["cache_seqlens"]
attn_metadata.attention_mask = [input_buffers["attention_mask"]]

if inputs_embeds is not None:
emb_size = inputs_embeds.size(-1)
Expand All @@ -151,10 +186,7 @@ def AscendCudaGraphMixin_fill_buffers_cudagraph(
1, max_num_tokens, emb_size
)
input_buffers["inputs_embeds"][:, :num_tokens] = inputs_embeds
# create inputs
# Use compatible size but cap at graph's max_batchs to avoid buffer overflow
new_batch_size = min(get_ascend_compatible_size(batch_size), graph_meta.max_batchs)

attn_metadata.q_seqlens = input_buffers["q_seqlens"]
attn_metadata.block_offsets = input_buffers["block_offsets"]
attn_metadata.kv_seqlens = input_buffers["kv_seqlens"]
attn_metadata.kv_start_indices = input_buffers["kv_start_indices"]
Expand All @@ -175,7 +207,6 @@ def AscendCudaGraphMixin_fill_buffers_cudagraph(

new_inputs.update(kwargs)

# ssm: override kwargs' variable-length state_ids with the fixed-size buffer
if graph_meta.is_ssm:
new_inputs["state_ids"] = input_buffers["state_ids"]

Expand Down Expand Up @@ -209,6 +240,7 @@ def AscendCudaGraphMixin_update_context_cudagraph(self, graph_meta, context):
context.mrope_position_ids = input_buffers["mrope_position_ids"]


CudaGraphMixin.support_cuda_graph = AscendCudaGraphMixin_support_cuda_graph
CudaGraphMixin.make_buffers_cudagraph = AscendCudaGraphMixin_make_buffers_cudagraph
CudaGraphMixin.fill_buffers_cudagraph = AscendCudaGraphMixin_fill_buffers_cudagraph
CudaGraphMixin.update_context_cudagraph = AscendCudaGraphMixin_update_context_cudagraph
Expand Down Expand Up @@ -358,7 +390,7 @@ def forward(self, **kwargs):
]
)
else:
update_attn_params(self.update_stream, self.meta, self.max_tokens)
update_attn_params(self.update_stream, self.meta, self.max_batches)
self._graph.replay()
output_buffers = self.meta.output_buffers
output = self.model.get_outputs_cudagraph(output_buffers, **kwargs)
Expand Down Expand Up @@ -427,19 +459,33 @@ def _get_capture_tokens(self, batch_size: int):
def get_graph_key(
self,
input_ids: torch.Tensor,
attn_metadata: Any,
**kwargs,
):
"""Get graph key."""
context = self.ctx_mgr.current_context()
is_decoding = context.is_decoding
num_tokens = input_ids.numel()
is_decoding = attn_metadata.is_decoding
is_multi_token_decoding = attn_metadata.is_multi_token_decoding
meta = self.get_meta()
enable_microbatch = get_step_ctx_manager().current_context().enable_microbatch

if is_multi_token_decoding:
q_seqlens = attn_metadata.q_seqlens
max_q_seq_len = attn_metadata.max_q_seq_len
batch_size = q_seqlens.size(0)
if meta.padding_batch_size is None:
new_batch_size = self._get_capture_tokens(batch_size)
else:
padding_num_tokens = meta.padding_batch_size
padding_batch_size = (padding_num_tokens + max_q_seq_len - 1) // max_q_seq_len
new_batch_size = self._get_capture_tokens(padding_batch_size)
return (new_batch_size, is_multi_token_decoding, enable_microbatch, max_q_seq_len)

num_tokens = input_ids.numel()
if meta.padding_batch_size is None:
new_num_tokens = self._get_capture_tokens(num_tokens)
else:
new_num_tokens = self._get_capture_tokens(meta.padding_batch_size)
return (new_num_tokens, is_decoding, enable_microbatch)
return (new_num_tokens, is_decoding, enable_microbatch, 1)

def __call__(self, **kwargs):
"""call."""
Expand All @@ -451,10 +497,15 @@ def __call__(self, **kwargs):
return self.model.make_output_buffers(ret)

graph_key = self.get_graph_key(**kwargs)
max_tokens = graph_key[0]
is_decoding = graph_key[1]
max_batches = graph_key[0]
is_decoding_or_multi_token_decoding = graph_key[1]
max_q_seq_len = graph_key[3]
if is_decoding_or_multi_token_decoding:
max_tokens = max_batches * max_q_seq_len
else:
max_tokens = max_batches
max_batches = self.max_batches
if graph_key not in self._runner_map:
max_batches = max_tokens if is_decoding else self.max_batches
runner = AscendSingleGraphRunner(
self.model,
max_batches=max_batches,
Expand Down
Loading