Skip to content

Conversation

@RissyRan
Copy link
Collaborator

@RissyRan RissyRan commented Jan 13, 2026

Description

  1. Implement DeepSeek Sparse Attention (DSA)
  • DSA: inside MLA, qk product to get key logits via indexer, select top-k key for each query
  • top-k selection is implemented for dot product attention: qk product -> add index mask (similar to regular attention mask) -> mult value
  • training only (no prefill / decode)
  • attention_mla.py, attention_op.py
  1. Onboard deepseek3.2 config: deepseek3.2-671b.yml
  • compare to deepseek3, the architecture difference is DSA
  • deepseek v3.2 vs. v3 config diff:
"index_head_dim": 128, "index_n_heads": 64, "index_topk": 2048,
  1. Add unit test against torch code for indexer and MLA: check_deepseek32_vs_reference.py

Future work: verify end-to-end training logits

Reference

Tests

Unit test against torch code (adapted from reference): indexer, MLA

python3 -m pytest -v --pyargs tests.check_deepseek32_vs_reference -rP -s

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link

codecov bot commented Jan 13, 2026

Codecov Report

❌ Patch coverage is 0% with 78 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/MaxText/layers/attention_mla.py 0.00% 72 Missing ⚠️
src/MaxText/layers/attention_op.py 0.00% 6 Missing ⚠️

📢 Thoughts on this report? Let us know!

Copy link
Collaborator Author

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the change! I took a look at indexer part, and overall it looks good for functionality. It also has indexer logit kernel for performance, I will take a look there.

I will take a look at MLA part shortly.

in_features=self.q_lora_rank,
out_features=self.n_heads * self.head_dim,
use_bias=False,
dtype=jnp.bfloat16,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's use self.dtype, similar for others. May have exception for self.weights_proj

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I replaced these nnx.Linear with MaxText's DenseGeneral, and passed in config.dtype/weight/matmul_precision. self.weights_proj is kept as float32.

self.n_heads = config.index_n_heads
self.head_dim = config.index_head_dim
self.index_topk = config.index_topk
self.dim = config.emb_dim
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep alignment with MaxText term? I know reference called dim

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

renamed to emb_dim


# Internal Indexer Cache (distinct from main MLA KV Cache)
# Shape: [Batch, MaxLen, HeadDim]
self.k_cache = nnx.Variable(jnp.zeros((config.max_target_length, self.head_dim), dtype=jnp.bfloat16))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this is 2D, while note with [Batch, MaxLen, HeadDim]? We could remove this if not needed for training so far.

k = self._apply_partial_rope(k, positions)
k = k.squeeze(2) # Back to [B, S, D]

# 3. Cache Update (Functional NNX update)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we also apply rotate_activation from reference? It seems related to quantization strategy, and zero accuracy change without it.

self.k_cache.value = updated_cache

# Active Keys: [B, TotalLen, D]
k_active = jax.lax.dynamic_slice(updated_cache, (0, 0, 0), (bsz, end_pos, self.head_dim))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this same with start_pos:end_pos for end_pos?

seq_idx = jnp.arange(seqlen)[None, :, None]

# JAX scatter update
bias_mask = bias_mask.at[batch_idx, seq_idx, topk_indices].set(0.0)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recall this set will be very inefficient. Will need to consider other operations later, leveraging matmul.


# [CHANGE 1] Initialize Indexer
# We check a config flag to see if we are in Sparse/DeepSeek3.2 mode
self.use_sparse_indexer = getattr(config, "use_sparse_indexer", False)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it's better we pass this value with default as False?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I set it as use_sparse_indexer: False in base.yml, remove this line.

# We check a config flag to see if we are in Sparse/DeepSeek3.2 mode
self.use_sparse_indexer = getattr(config, "use_sparse_indexer", False)
if self.use_sparse_indexer:
indexer_rope = copy.copy(self.rotary_embedding)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need copy of rotary_embedding?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the MLA, yarn is used with interleave=true. For indexer, it is with interleave=false. Making a copy to keep the two process isolated.

start_pos = inputs_positions[0, 0] # Assuming [B, L] or similar
# Run Indexer
# inputs_q is 'x', low_rank_q is 'qr'
sparse_bias, _, _ = self.indexer(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's rename it to index_mask instead of bias? bias is like adjustment values, and this is mask for sparsity.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. Renamed to index_mask.

out = unnormalized_out / (exp_sum + 1e-9) if exp_sum is not None else unnormalized_out
else:
out = self.attention_op(query, key, value, decoder_segment_ids, model_mode, cached_values)
# ds3.2, MHA mode for train / prefill, TODO: MQA model for decode (mathematically equivalent but speed faster)?
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In MQA, the size of the Key-Value (KV) cache is reduced by a factor equal to the number of heads.

You may have this diff already, but just FYI. V3 MLA vs. V3.2: https://diff.googleplex.com/#key=3JSmf20vQG8U

batch_idx = jnp.arange(bsz)[:, None, None]
seq_idx = jnp.arange(seqlen)[None, :, None]
# JAX scatter update
index_mask = index_mask.at[batch_idx, seq_idx, topk_indices].set(0.0)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Later - We could update it to jnp.where in this case. set() is a scatter operation, often less efficient.

# Assuming:
# topk_indices shape: [B, S, K]
# T is the target dimension size (index_score.shape[-1])

def get_mask_efficient(topk_indices, T, default_value):
    # 1. Create a range [0, 1, ..., T-1]
    # 2. Broadcast compare against [B, S, K] to get [B, S, K, T]
    # 3. Use .any() to see if a T-index is present in any of the K slots
    is_topk = (jnp.arange(T) == topk_indices[..., None]).any(axis=-2)
    
    # 4. Use where to select between 0.0 and the mask value
    return jnp.where(is_topk, 0.0, default_value)
    ```

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants