-
Notifications
You must be signed in to change notification settings - Fork 448
[DO NO MERGE] Draft for sparse #2933
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
RissyRan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the change! I took a look at indexer part, and overall it looks good for functionality. It also has indexer logit kernel for performance, I will take a look there.
I will take a look at MLA part shortly.
src/MaxText/layers/attention_mla.py
Outdated
| in_features=self.q_lora_rank, | ||
| out_features=self.n_heads * self.head_dim, | ||
| use_bias=False, | ||
| dtype=jnp.bfloat16, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's use self.dtype, similar for others. May have exception for self.weights_proj
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I replaced these nnx.Linear with MaxText's DenseGeneral, and passed in config.dtype/weight/matmul_precision. self.weights_proj is kept as float32.
src/MaxText/layers/attention_mla.py
Outdated
| self.n_heads = config.index_n_heads | ||
| self.head_dim = config.index_head_dim | ||
| self.index_topk = config.index_topk | ||
| self.dim = config.emb_dim |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's keep alignment with MaxText term? I know reference called dim
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
renamed to emb_dim
src/MaxText/layers/attention_mla.py
Outdated
|
|
||
| # Internal Indexer Cache (distinct from main MLA KV Cache) | ||
| # Shape: [Batch, MaxLen, HeadDim] | ||
| self.k_cache = nnx.Variable(jnp.zeros((config.max_target_length, self.head_dim), dtype=jnp.bfloat16)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this is 2D, while note with [Batch, MaxLen, HeadDim]? We could remove this if not needed for training so far.
src/MaxText/layers/attention_mla.py
Outdated
| k = self._apply_partial_rope(k, positions) | ||
| k = k.squeeze(2) # Back to [B, S, D] | ||
|
|
||
| # 3. Cache Update (Functional NNX update) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we also apply rotate_activation from reference? It seems related to quantization strategy, and zero accuracy change without it.
src/MaxText/layers/attention_mla.py
Outdated
| self.k_cache.value = updated_cache | ||
|
|
||
| # Active Keys: [B, TotalLen, D] | ||
| k_active = jax.lax.dynamic_slice(updated_cache, (0, 0, 0), (bsz, end_pos, self.head_dim)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this same with start_pos:end_pos for end_pos?
src/MaxText/layers/attention_mla.py
Outdated
| seq_idx = jnp.arange(seqlen)[None, :, None] | ||
|
|
||
| # JAX scatter update | ||
| bias_mask = bias_mask.at[batch_idx, seq_idx, topk_indices].set(0.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I recall this set will be very inefficient. Will need to consider other operations later, leveraging matmul.
src/MaxText/layers/attention_mla.py
Outdated
|
|
||
| # [CHANGE 1] Initialize Indexer | ||
| # We check a config flag to see if we are in Sparse/DeepSeek3.2 mode | ||
| self.use_sparse_indexer = getattr(config, "use_sparse_indexer", False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think it's better we pass this value with default as False?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I set it as use_sparse_indexer: False in base.yml, remove this line.
| # We check a config flag to see if we are in Sparse/DeepSeek3.2 mode | ||
| self.use_sparse_indexer = getattr(config, "use_sparse_indexer", False) | ||
| if self.use_sparse_indexer: | ||
| indexer_rope = copy.copy(self.rotary_embedding) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need copy of rotary_embedding?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the MLA, yarn is used with interleave=true. For indexer, it is with interleave=false. Making a copy to keep the two process isolated.
src/MaxText/layers/attention_mla.py
Outdated
| start_pos = inputs_positions[0, 0] # Assuming [B, L] or similar | ||
| # Run Indexer | ||
| # inputs_q is 'x', low_rank_q is 'qr' | ||
| sparse_bias, _, _ = self.indexer( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's rename it to index_mask instead of bias? bias is like adjustment values, and this is mask for sparsity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. Renamed to index_mask.
src/MaxText/layers/attention_mla.py
Outdated
| out = unnormalized_out / (exp_sum + 1e-9) if exp_sum is not None else unnormalized_out | ||
| else: | ||
| out = self.attention_op(query, key, value, decoder_segment_ids, model_mode, cached_values) | ||
| # ds3.2, MHA mode for train / prefill, TODO: MQA model for decode (mathematically equivalent but speed faster)? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In MQA, the size of the Key-Value (KV) cache is reduced by a factor equal to the number of heads.
You may have this diff already, but just FYI. V3 MLA vs. V3.2: https://diff.googleplex.com/#key=3JSmf20vQG8U
| batch_idx = jnp.arange(bsz)[:, None, None] | ||
| seq_idx = jnp.arange(seqlen)[None, :, None] | ||
| # JAX scatter update | ||
| index_mask = index_mask.at[batch_idx, seq_idx, topk_indices].set(0.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Later - We could update it to jnp.where in this case. set() is a scatter operation, often less efficient.
# Assuming:
# topk_indices shape: [B, S, K]
# T is the target dimension size (index_score.shape[-1])
def get_mask_efficient(topk_indices, T, default_value):
# 1. Create a range [0, 1, ..., T-1]
# 2. Broadcast compare against [B, S, K] to get [B, S, K, T]
# 3. Use .any() to see if a T-index is present in any of the K slots
is_topk = (jnp.arange(T) == topk_indices[..., None]).any(axis=-2)
# 4. Use where to select between 0.0 and the mask value
return jnp.where(is_topk, 0.0, default_value)
```
Description
attention_mla.py,attention_op.pydeepseek3.2-671b.ymlcheck_deepseek32_vs_reference.pyFuture work: verify end-to-end training logits
Reference
Tests
Unit test against torch code (adapted from reference): indexer, MLA
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.