Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
273 changes: 203 additions & 70 deletions transformer_engine/jax/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from jax.ad_checkpoint import checkpoint_name
import jax
import jax.numpy as jnp
from flax.linen import make_attention_mask

from transformer_engine_jax import NVTE_Bias_Type
from transformer_engine_jax import NVTE_Mask_Type
Expand Down Expand Up @@ -541,6 +540,149 @@ def run_length_fill(segment_ids) -> jnp.ndarray:
return run_length_segment_id_shape.reshape(orig_shape)


def _get_seqlens_offsets_thd(
segment_ids_q,
segment_ids_kv,
segment_pos_q,
segment_pos_kv,
attn_mask_type,
max_segments_per_seq,
):
"""O(T * max_segments_per_seq) replacement for the older O(T^2) mask-based slow path.
Returns (q_seqlen, kv_seqlen, q_offset, kv_offset) values to match the reference older mask-based path:
segment_mask = make_attention_mask(q_ids, kv_ids, equal)
segment_mask_with_id = make_attention_mask(q_ids, kv_ids, equal * q_id)
attn_mask = segment_mask AND (causal_or_brcm_or_none)
attn_mask_with_id = where(attn_mask, segment_mask_with_id, 0)
row_ids = reduce_max(attn_mask_with_id, axis=kv) # [B, T_q]
col_ids = reduce_max(attn_mask_with_id, axis=q) # [B, T_kv]
seqlens/offsets = bincount(...) / find_offsets(...)
The two reductions are expressed equivalently as per-segment aggregates:
- causal: row_ids[q] = q_seg_id iff seg_pos_q[q] >= min(seg_pos_kv over same-seg KV)
- brcm: row_ids[q] = q_seg_id iff (run_len_q - seg_pos_q) >=
min(run_len_kv - seg_pos_kv over same-seg KV)
- padding: row_ids[q] = q_seg_id iff q_seg_id appears in KV
(and symmetrically for col_ids with max/<=).
"""

# Example: For striping P2P causal attention (but this logic also applies for non-CP fused attn)
# pre-striping and sharding: segment_ids = [[1 1 1 1 2 2 2 2]], segment_pos = [[0 1 2 3 0 1 2 3]]
# post-striping and sharding (striped CP=2, Q from rank 0 × KV from rank 1, max_segments_per_seq=2):
# segment_ids_q = [1 1 2 2] segment_pos_q = [0 2 0 2] → q_key = [0 2 0 2]
# segment_ids_kv = [1 1 2 2] segment_pos_kv = [1 3 1 3] → kv_key = [1 3 1 3]
# Q-side — kv_agg[s] = min(kv_key over same-seg KV), fill = max_fill_val = 5 (assumed to be large enough):
# scatter (rows = kv tokens, cols = segs):
# [5 1 5 / 5 3 5 / 5 5 1 / 5 5 3] → reduce min → kv_agg = [5 1 1]
# q_ok = q_key >= kv_agg[seg_ids_q] = [0 2 0 2] >= [1 1 1 1] = [F T F T]
# KV-side — q_agg[s] = max(q_key over same-seg Q), fill = neg_fill_val = -1 (assumed to be small enough):
# scatter: [-1 0 -1 / -1 2 -1 / -1 -1 0 / -1 -1 2] → reduce max → q_agg = [-1 2 2]
# kv_ok = kv_key <= q_agg[seg_ids_kv] = [1 3 1 3] <= [2 2 2 2] = [T F T F]
# Outer combiner:
# row_ids = [0 1 0 2] col_ids = [1 0 2 0]
# q_seqlen = [1 1] kv_seqlen = [1 1]
# q_offset = [1 3 -1] kv_offset = [0 2 -1]
def _row_and_col_ids():
if attn_mask_type.is_bottom_right():
# BRCM: mask[q][kv] = (same seg) AND (q_key <= kv_key).
rl_q = run_length_fill(segment_ids_q)
rl_kv = run_length_fill(segment_ids_kv)
q_key = (rl_q - segment_pos_q).astype(jnp.int32)
kv_key = (rl_kv - segment_pos_kv).astype(jnp.int32)

# Use large positive and negative values as fill values for the KV keys and Q keys respectively
max_fill_val = jnp.asarray(jnp.iinfo(jnp.int32).max, dtype=jnp.int32)
neg_fill_val = jnp.asarray(-1, dtype=jnp.int32)
# Creates a one-hot encoding mask of the KV segment ids (size [B, T_kv, max_segments_per_seq+1])
# i.e. each row has only one True value, which is the segment id of the row.
kv_oh = jax.nn.one_hot(segment_ids_kv, max_segments_per_seq + 1, dtype=jnp.bool_)
# Mask the KV keys with the valid segment ids (size [B, T_kv, 1])
kv_key_masked = jnp.where(segment_ids_kv != 0, kv_key, neg_fill_val)[..., None]
# Scatter each KV key (i.e. seg pos) into it's own segment column
kv_agg = jnp.where(kv_oh, kv_key_masked, neg_fill_val)
kv_agg = jnp.max(kv_agg, axis=-2)
# Define causal relationship: Q is attended iff q_key <= max(kv_key over same-seg KV)
q_has_match = q_key <= jnp.take_along_axis(
kv_agg, segment_ids_q.astype(jnp.int32), axis=-1
)

# Symmetric to the Q case, but with KV and Q swapped
q_oh = jax.nn.one_hot(segment_ids_q, max_segments_per_seq + 1, dtype=jnp.bool_)
q_key_masked = jnp.where(segment_ids_q != 0, q_key, max_fill_val)[..., None]
q_agg = jnp.where(q_oh, q_key_masked, max_fill_val)
q_agg = jnp.min(q_agg, axis=-2)
# Define causal relationship: KV is attended iff kv_key >= min(q_key over same-seg Q)
kv_has_match = kv_key >= jnp.take_along_axis(
q_agg, segment_ids_kv.astype(jnp.int32), axis=-1
)
elif attn_mask_type.is_causal():
# CM: mask[q][kv] = (same_seg) AND (q_pos >= kv_pos).
q_key = segment_pos_q.astype(jnp.int32)
kv_key = segment_pos_kv.astype(jnp.int32)

# Use large positive and negative values as a fill value for the KV keys and Q keys respectively
max_fill_val = jnp.asarray(jnp.iinfo(jnp.int32).max, dtype=jnp.int32)
neg_fill_val = jnp.asarray(-1, dtype=jnp.int32)

# Creates a one-hot encoding mask of the KV segment ids (size [B, T_kv, max_segments_per_seq+1])
# i.e. each row has only one True value, which is the segment id of the row.
kv_oh = jax.nn.one_hot(segment_ids_kv, max_segments_per_seq + 1, dtype=jnp.bool_)
# Mask the KV keys with the valid segment ids (size [B, T_kv, 1])
kv_key_masked = jnp.where(segment_ids_kv != 0, kv_key, max_fill_val)[..., None]
# Scatter each KV key (i.e. seg pos) into it's own segment column
kv_agg = jnp.where(kv_oh, kv_key_masked, max_fill_val)
kv_agg = jnp.min(kv_agg, axis=-2)
# Define causal relationship: Q is attended iff q_key >= min(kv_key over same-seg KV)
q_has_match = q_key >= jnp.take_along_axis(
kv_agg, segment_ids_q.astype(jnp.int32), axis=-1
)

# Symmetric to the Q case, but with KV and Q swapped
q_oh = jax.nn.one_hot(segment_ids_q, max_segments_per_seq + 1, dtype=jnp.bool_)
q_key_masked = jnp.where(segment_ids_q != 0, q_key, neg_fill_val)[..., None]
q_agg = jnp.where(q_oh, q_key_masked, neg_fill_val)
q_agg = jnp.max(q_agg, axis=-2)
# Define causal relationship: KV is attended iff kv_key <= max(q_key over same-seg Q)
kv_has_match = kv_key <= jnp.take_along_axis(
q_agg, segment_ids_kv.astype(jnp.int32), axis=-1
)
else:
# Padding-only: row_ids[q] = q_seg_id iff q_seg_id is present in KV (and q not pad).
kv_seg_ids_present = jax.nn.one_hot(
segment_ids_kv, max_segments_per_seq + 1, dtype=jnp.bool_
).any(axis=-2)
q_seg_ids_present = jax.nn.one_hot(
segment_ids_q, max_segments_per_seq + 1, dtype=jnp.bool_
).any(axis=-2)
q_has_match = jnp.take_along_axis(
kv_seg_ids_present, segment_ids_q.astype(jnp.int32), axis=-1
) & (segment_ids_q != 0)
kv_has_match = jnp.take_along_axis(
q_seg_ids_present, segment_ids_kv.astype(jnp.int32), axis=-1
) & (segment_ids_kv != 0)

row_ids = jnp.where(q_has_match, segment_ids_q, 0).astype(jnp.int32)
col_ids = jnp.where(kv_has_match, segment_ids_kv, 0).astype(jnp.int32)
return row_ids, col_ids

row_ids, col_ids = _row_and_col_ids()

bincount_vmap = jax.vmap(partial(jnp.bincount, length=max_segments_per_seq + 1))
q_seqlen = bincount_vmap(row_ids)[..., 1:]
kv_seqlen = bincount_vmap(col_ids)[..., 1:]

def _find_offsets(x):
same_as_previous = jnp.logical_and(x[..., 1:] != x[..., :-1], x[..., 1:] != 0)
first_column = x[..., :1] != 0
boundaries = jnp.concatenate([first_column, same_as_previous], axis=-1)
return jax.vmap(partial(jnp.argwhere, size=(max_segments_per_seq + 1), fill_value=-1))(
boundaries
).squeeze(-1)

q_offset = _find_offsets(row_ids)
kv_offset = _find_offsets(col_ids)
return q_seqlen, kv_seqlen, q_offset, kv_offset


def _segment_ids_pos_to_seqlens_offsets(
segment_ids_q,
segment_ids_kv,
Expand All @@ -550,9 +692,52 @@ def _segment_ids_pos_to_seqlens_offsets(
window_size,
max_segments_per_seq,
):
"""Compute per-segment seqlens and start offsets(currently only used for THD)
Given segment-id and segment-position tensors for Q and KV,
returns the four metadata tensors cuDNN needed for variable-length attention:
q_seqlen : [..., max_segments_per_seq] # valid Q tokens per segment
kv_seqlen : [..., max_segments_per_seq] # valid KV tokens per segment
q_offset : [..., max_segments_per_seq + 1] # start index of each Q segment
kv_offset : [..., max_segments_per_seq + 1] # start index of each KV segment

Args:
segment_ids_q: int32 [..., T_q] per-token segment id; 0 == padding
segment_ids_kv: int32 [..., T_kv] same convention as segment_ids_q
segment_pos_q: int32 [..., T_q] per-token position inside its segment
segment_pos_kv: int32 [..., T_kv] same convention as segment_pos_q
attn_mask_type: AttnMaskType. Selects the mask predicate used to decide
which positions are valid (top-left causal vs
bottom-right causal vs. padding-only)
window_size: Optional sliding-window tuple ``(left, right)`` or None
Used here only as a fast-path eligibility hint
max_segments_per_seq: maximum number of segments expected per row
Used to size the bincount / argwhere outputs

Routing (only invoked for THD qkv_layout):
1. Fast path -- ``_segment_ids_pos_to_seqlens_offsets_fast_causal_path``.
O(T) per row. Counts all segment tokens via bincount on
segment_ids and trims at most one token per segment at the
boundary. Used for:
- top-left CAUSAL / PADDING_CAUSAL with ``window_size is None``
- SWA with ``window_size == (-1, -1)`` and not bottom-right
Bottom-right causal cross-attention is excluded: the boundary
trim leaves kv_seqlen short by one per active segment, which
shifts the BRCM bottom-right alignment by one KV per Q row.

2. Slow path -- ``_get_seqlens_offsets_thd``.
O(T * max_segments_per_seq) per row. Per-segment min/max
aggregation that is equivalent to the older O(T^2)
mask-based reference for top-left causal, bottom-right causal,
and padding-only masks. Required under ring attention where
``segment_ids_q != segment_ids_kv`` in rotated steps.

Returns:
Tuple ``(q_seqlen, kv_seqlen, q_offset, kv_offset)`` with shapes as
above. Inactive segment slots are filled with 0 in seqlens and -1
in offsets.
"""
# TODO(mgoldfarb-nvidia): Consider an opt-in for arbitrary masking if needed here.
# Computing the full mask is expensive due to quadratic expansion of Q * KV masking.

# Assumptions for cudnn causal mask correctness.
# 1. Segments are monotonic [4 4 4 0 0 5 5 5 6 6 0 0]
# 2. No intra-segment padding, only inter-segment paddding allowed
Expand All @@ -561,82 +746,30 @@ def _segment_ids_pos_to_seqlens_offsets(
# 0 x x
# 4 x x x x x
# 8 x x x x x x x x
#
# This fast path avoids expanding the mask to Q * KV matrix and instead allows us to
# examine only O(Q+KV) elements.

# For seqlens and seqoffsets calculations, the intermediate(temp) attn_mask creation
# using the segment ids and pos along with mask type (causal or brcm) is sufficient.
# It does not need to involve SW for this mask's creation

# Currently, this function is only exercised for THD qkv_layout.

# TODO(KshitijLakhani): Try exercising the fast path for BRCM as well
if (attn_mask_type.is_causal() and window_size is None) or (
window_size == (-1, -1) and not attn_mask_type.is_bottom_right()
):
# The fast causal path encodes TOP-LEFT causal semantics via
# valid[q][kv] = (segment_pos_q >= segment_pos_kv)
# which is only equivalent to BRCM when s_q == s_kv (self-attention). For
# cross-attention (s_q != s_kv), BRCM diverges from top-left causal, so we
# must route bottom-right masks to the slow path.

# Fast path: O(T) per row.
if (
attn_mask_type.is_causal() and not attn_mask_type.is_bottom_right() and window_size is None
) or (window_size == (-1, -1) and not attn_mask_type.is_bottom_right()):
return _segment_ids_pos_to_seqlens_offsets_fast_causal_path(
segment_ids_q, segment_ids_kv, segment_pos_q, segment_pos_kv, max_segments_per_seq
)

# (1 = attend, 0 = masked)
segment_mask = make_attention_mask(
segment_ids_q,
segment_ids_kv,
jnp.equal,
)
segment_mask_with_id = make_attention_mask(
# Slow path: O(T * max_segments_per_seq) per row.
return _get_seqlens_offsets_thd(
segment_ids_q,
segment_ids_kv,
lambda x, y: jnp.equal(x, y) * x,
)
# TE JAX Attn expects the THD segments to have q_token <= kv_tokens so that a correct cross-attn type BRCM can be applied
attn_mask = segment_mask
if attn_mask_type.is_bottom_right():
run_length_out_q = run_length_fill(segment_ids_q)
run_length_out_kv = run_length_fill(segment_ids_kv)
# Example for brcm:
# run_length_out_q: [3 3 3 0 4 4 4 4]
# segment_pos_q: [0 1 2 3 0 1 2 3]
# segment_ids_q: [1 1 1 0 2 2 2 2]
# run_length_out_kv: [4 4 4 4 0 0 10 10 10 10 10 10 10 10 10 10]
# segment_pos_kv: [0 1 2 3 4 5 0 1 2 3 4 5 6 7 8 9]
# segment_ids_kv: [1 1 1 1 0 0 2 2 2 2 2 2 2 2 2 2]
# brcm: [[[1 1 0 0 0 0 1 1 1 1 1 1 1 1 0 0]
# [1 1 1 0 0 0 1 1 1 1 1 1 1 1 1 0]
# [1 1 1 1 0 0 1 1 1 1 1 1 1 1 1 1]
# [1 1 1 1 0 0 1 1 1 1 1 1 1 1 1 1]
# [1 0 0 0 0 0 1 1 1 1 1 1 1 0 0 0]
# [1 1 0 0 0 0 1 1 1 1 1 1 1 1 0 0]
# [1 1 1 0 0 0 1 1 1 1 1 1 1 1 1 0]
# [1 1 1 1 0 0 1 1 1 1 1 1 1 1 1 1]]]
# attn_mask(noswa):[[[1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
# [1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0]
# [1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0]
# [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
# [0 0 0 0 0 0 1 1 1 1 1 1 1 0 0 0]
# [0 0 0 0 0 0 1 1 1 1 1 1 1 1 0 0]
# [0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 0]
# [0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1]]]
bottom_right_causal_mask = make_attention_mask(
run_length_out_q - segment_pos_q,
run_length_out_kv - segment_pos_kv,
jnp.less_equal,
)
attn_mask = jnp.logical_and(segment_mask, bottom_right_causal_mask)
elif attn_mask_type.is_causal():
causal_mask = make_attention_mask(
segment_pos_q,
segment_pos_kv,
jnp.greater_equal,
)
attn_mask = jnp.logical_and(segment_mask, causal_mask)

attn_mask_with_id = jnp.where(attn_mask, segment_mask_with_id, 0)
q_seqlen, q_offset, kv_seqlen, kv_offset = _mask_to_seqlens_offset(
attn_mask_with_id, max_segments_per_seq
segment_pos_q,
segment_pos_kv,
attn_mask_type,
max_segments_per_seq,
)
return q_seqlen, kv_seqlen, q_offset, kv_offset


def _segment_ids_to_seqlens(segment_ids_q, segment_ids_kv, attn_mask_type):
Expand Down
Loading