Skip to content

[WIP] Add per-token weight support to DataStream and weighted_causal_lm_ce#177

Open
luciaquirke wants to merge 1 commit intomagic-dtensor-patchfrom
magic-per-token
Open

[WIP] Add per-token weight support to DataStream and weighted_causal_lm_ce#177
luciaquirke wants to merge 1 commit intomagic-dtensor-patchfrom
magic-per-token

Conversation

@luciaquirke
Copy link
Collaborator

@luciaquirke luciaquirke commented Mar 7, 2026

  • Add per_token parameter to DataStream for [n_examples, max_length] weight tensors
  • Support 2D [B, T] per-token weights in weighted_causal_lm_ce

More:

  • Per-token weights: DataStream gains a per_token flag for 2D [n_examples, max_length] weight tensors, and weighted_causal_lm_ce now accepts [B, T] weights alongside the existing [B] shape.

- Add per_token parameter to DataStream for [n_examples, max_length]
  weight tensors
- Support 2D [B, T] per-token weights in weighted_causal_lm_ce

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant