Implement Bits Per Byte Metric Computation#61
Conversation
| if remainder > 0: | ||
| rank = self.accelerator.process_index | ||
| return max(0, min(nominal_count, remainder - rank * nominal_count)) | ||
| return nominal_count |
There was a problem hiding this comment.
Distributed last-batch trimming wrong for non-dispatch mode
High Severity
_real_batch_count uses remainder - rank * nominal_count to compute how many real samples each rank has in the last batch. This formula is only correct when dispatch_batches=True (where remainder is the global count of leftover samples). In the default dispatch_batches=False mode, each rank's DataLoaderShard sets remainder to the per-rank count of real samples in its own last batch — the same value on every rank. The subtraction of rank * nominal_count causes all ranks except rank 0 to return 0, silently discarding their last-batch contributions to accuracy, BPB, generation predictions, and sample counts.
📝 WalkthroughWalkthroughAdds a bits-per-byte metric and wiring: new Changes
Sequence DiagramsequenceDiagram
participant DL as DataLoader
participant Trainer
participant Model
participant Decoder
participant Metrics
participant Dist as DistributedSync
participant Logger
DL->>Trainer: yield eval batch (input_ids, labels)
Trainer->>Model: forward(inputs)
Model-->>Trainer: logits, loss
Trainer->>Trainer: _accumulate_accuracy_and_bpb(logits, labels)
Trainer->>Decoder: decode(input_ids)
Decoder-->>Metrics: decoded_text -> compute num_bytes,num_tokens
Trainer->>Metrics: compute_bits_per_byte(eval_loss, num_tokens, num_bytes)
Metrics-->>Trainer: eval_bits_per_byte
Trainer->>Dist: _sync_eval_state() (reduce counters, gather preds)
Dist-->>Trainer: synchronized accumulators
Trainer->>Logger: log eval_bits_per_byte, accuracies, gen metrics
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
🚥 Pre-merge checks | ✅ 3 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment Tip You can enable review details to help with troubleshooting, context usage and more.Enable the |
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/test_metrics.py`:
- Around line 128-130: The test in tests/test_metrics.py instantiates a real
HuggingFace tokenizer (tokenizer = AutoTokenizer.from_pretrained("gpt2")) which
requires network access and can flake in CI; mark the test as non-hermetic by
adding `@pytest.mark.integration` to the test function or ensure the tokenizer is
pre-cached by committing the gpt2 tokenizer files and changing the loader to
point to the local directory (or add setup code to download/cache the tokenizer
before tests run). Update the test signature or test setup to reference the
tokenizer variable/AutoTokenizer.from_pretrained invocation accordingly.
In `@welt_training/experiments/machine-translation/run_clm.py`:
- Around line 760-771: The computed num_eval_bytes is biased by tokenizer.decode
defaults; change the decoding call used to compute num_eval_bytes so it
preserves special tokens and original spacing by calling
tokenizer.decode(example["input_ids"][1:], skip_special_tokens=False,
clean_up_tokenization_spaces=False) (or an equivalent tokenizer.batch_decode
with those flags) when iterating over eval_dataset; keep the surrounding logic
(data_args.streaming check, num_eval_tokens, compute_bits_per_byte and
metrics["eval_bits_per_byte"]) unchanged so eval_bits_per_byte uses a byte
denominator independent of tokenizer cleanup defaults.
In `@welt_training/trainer.py`:
- Around line 509-511: When computing batch_sample_count before updating
self._eval_sample_count, fall back to counting samples from the batch payload
when prefixes is None instead of zero: if prefixes is not None use
len(prefixes), otherwise derive the count from another batch element (e.g., the
first non-scalar entry such as inputs/targets/labels) and then call
self._real_batch_count(...) as before; ensure this change is applied where
_eval_sample_count is updated and that _add_custom_metrics() will then report
correct eval_samples for datasets without __len__.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: d4e4b075-cfeb-4318-a409-50d98d1f3e9c
📒 Files selected for processing (7)
tests/test_metrics.pytests/test_train.pytests/test_trainer.pywelt/processor.pywelt_training/experiments/machine-translation/run_clm.pywelt_training/metrics.pywelt_training/trainer.py
| # Compute bits per byte from the evaluated subset. | ||
| # Decode tokens back to text to count the corresponding UTF-8 bytes. | ||
| if not data_args.streaming: | ||
| num_eval_tokens = len(eval_dataset) * (block_size - 1) | ||
| num_eval_bytes = sum( | ||
| len(tokenizer.decode(example["input_ids"][1:]).encode("utf-8")) | ||
| for example in eval_dataset | ||
| ) | ||
| if num_eval_bytes > 0: | ||
| metrics["eval_bits_per_byte"] = compute_bits_per_byte( | ||
| metrics["eval_loss"], num_eval_tokens, num_eval_bytes | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
find . -name "run_clm.py" -type fRepository: sign/WeLT
Length of output: 111
🏁 Script executed:
git ls-files | grep -E "(run_clm|machine.*translation)" | head -20Repository: sign/WeLT
Length of output: 321
🏁 Script executed:
wc -l ./welt_training/experiments/machine-translation/run_clm.pyRepository: sign/WeLT
Length of output: 115
🏁 Script executed:
sed -n '760,771p' ./welt_training/experiments/machine-translation/run_clm.pyRepository: sign/WeLT
Length of output: 661
🏁 Script executed:
sed -n '1,100p' ./welt_training/experiments/machine-translation/run_clm.py | grep -E "(import|from).*tokenizer|compute_bits_per_byte"Repository: sign/WeLT
Length of output: 108
🏁 Script executed:
rg "compute_bits_per_byte" ./welt_training/experiments/machine-translation/run_clm.py -A 5 -B 2Repository: sign/WeLT
Length of output: 606
🏁 Script executed:
rg "def compute_bits_per_byte" . -A 10Repository: sign/WeLT
Length of output: 807
🏁 Script executed:
rg "def compute_bits_per_byte" . -A 20Repository: sign/WeLT
Length of output: 1292
🏁 Script executed:
sed -n '1,50p' ./welt_training/experiments/machine-translation/run_clm.py | head -40Repository: sign/WeLT
Length of output: 1427
🏁 Script executed:
sed -n '750,780p' ./welt_training/experiments/machine-translation/run_clm.pyRepository: sign/WeLT
Length of output: 1395
🏁 Script executed:
sed -n '1,200p' ./welt_training/experiments/machine-translation/run_clm.py | grep -E "tokenizer.*=" | head -10Repository: sign/WeLT
Length of output: 133
🏁 Script executed:
rg "skip_special_tokens|clean_up_tokenization" ./welt_training/experiments/machine-translation/run_clm.pyRepository: sign/WeLT
Length of output: 35
🏁 Script executed:
sed -n '600,700p' ./welt_training/experiments/machine-translation/run_clm.py | grep -A 20 "tokenizer"Repository: sign/WeLT
Length of output: 2188
🌐 Web query:
huggingface transformers tokenizer.decode() default parameters skip_special_tokens clean_up_tokenization_spaces
💡 Result:
In 🤗 Transformers, the method signature (and defaults) are:
tokenizer.decode(
token_ids,
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool | None = None,
**kwargs
) -> str[1][3]
So by default:
skip_special_tokensdefaults toFalse(special tokens are kept). [1][3]clean_up_tokenization_spacesdefaults toNone, which means it will fall back totokenizer.clean_up_tokenization_spaces. [1][3]- The tokenizer attribute
clean_up_tokenization_spacesitself defaults toTrueunless overridden when creating/loading the tokenizer. [2]
- The tokenizer attribute
(The same defaults apply to batch_decode(...).) [1]
🏁 Script executed:
sed -n '200,400p' ./welt_training/experiments/machine-translation/run_clm.py | grep -B 5 -A 10 "load_model_and_tokenizer\|AutoTokenizer"Repository: sign/WeLT
Length of output: 35
🏁 Script executed:
rg "tokenizer.*=" ./welt_training/experiments/machine-translation/run_clm.py -B 2 -A 2 | head -40Repository: sign/WeLT
Length of output: 1433
🏁 Script executed:
sed -n '750,780p' ./welt_training/experiments/machine-translation/run_clm.py | cat -nRepository: sign/WeLT
Length of output: 1612
Make the byte denominator independent of tokenizer decode defaults.
num_eval_bytes is currently derived via tokenizer.decode(...) with the tokenizer's default cleanup behavior (skip_special_tokens=False, clean_up_tokenization_spaces=True). For tokenizers that remove special tokens or normalize spaces, the decoded text no longer matches the original input bytes, causing eval_bits_per_byte to be biased.
Proposed fix
- num_eval_bytes = sum(
- len(tokenizer.decode(example["input_ids"][1:]).encode("utf-8"))
- for example in eval_dataset
- )
+ num_eval_bytes = sum(
+ len(
+ tokenizer.decode(
+ example["input_ids"][1:],
+ skip_special_tokens=True,
+ clean_up_tokenization_spaces=False,
+ ).encode("utf-8")
+ )
+ for example in eval_dataset
+ )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@welt_training/experiments/machine-translation/run_clm.py` around lines 760 -
771, The computed num_eval_bytes is biased by tokenizer.decode defaults; change
the decoding call used to compute num_eval_bytes so it preserves special tokens
and original spacing by calling tokenizer.decode(example["input_ids"][1:],
skip_special_tokens=False, clean_up_tokenization_spaces=False) (or an equivalent
tokenizer.batch_decode with those flags) when iterating over eval_dataset; keep
the surrounding logic (data_args.streaming check, num_eval_tokens,
compute_bits_per_byte and metrics["eval_bits_per_byte"]) unchanged so
eval_bits_per_byte uses a byte denominator independent of tokenizer cleanup
defaults.
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 1 potential issue.
Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
welt_training/trainer.py
Outdated
| batch_loss = model.bytes_decoder.compute_loss(flat_logits, flat_labels) | ||
|
|
||
| if torch.isfinite(batch_loss): | ||
| self._eval_total_nats += batch_loss.item() * batch_non_pad_bytes |
There was a problem hiding this comment.
Nats accumulator multiplied by spurious bytes_per_token factor
Medium Severity
_eval_total_nats accumulates batch_loss * batch_non_pad_bytes where batch_non_pad_bytes = batch_non_pad_tokens * bytes_per_token. Since batch_loss is a per-token average, the correct total nats is batch_loss * batch_non_pad_tokens, not batch_loss * batch_non_pad_bytes. The extra bytes_per_token factor means _eval_total_nats stores a value that is bytes_per_token× the actual total nats. For UTF-8 (bytes_per_token=1) the result is correct, and for UTF-16/UTF-32 the factor happens to cancel with the same inflation in _eval_total_content_bytes—but only when the model is byte-level. If a character-level model is ever used with UTF-16/UTF-32, the BPB will be wrong by a factor of bytes_per_token.
Additional Locations (1)
AmitMY
left a comment
There was a problem hiding this comment.
looks good! just one note regarding the streaming dataset
| self._dataset = hf_dataset | ||
|
|
||
| def __iter__(self): | ||
| yield from self._dataset |
There was a problem hiding this comment.
i feel like this should be in streaming.py
|
and tests fail |
…hings on no-internet-access clusters.
There was a problem hiding this comment.
Actionable comments posted: 5
🧹 Nitpick comments (1)
welt/model.py (1)
675-800: Address pipeline lint failure:generateis too complex (12 > 10).The
generatemethod exceeds the complexity threshold due to the added entropy computation logic. Consider extracting the entropy-related code (lines 782-798) into a helper method like_compute_entropy_for_generationto reduce cyclomatic complexity.♻️ Suggested refactoring approach
+ def _compute_entropy_for_generation( + self, + prefill_logits: torch.Tensor, + initial_num_words: torch.Tensor, + word_latents: list[torch.Tensor], + all_generated_words: list[list[str]], + prompt_words: list[str] | None, + tokenizer, + device: torch.device, + ) -> tuple[list[str], list[float], list[str], int]: + """Extract entropy computation from generate() to reduce complexity.""" + prompt_entropies, prompt_byte_labels = [], [] + if prompt_words is not None and len(prompt_words) > 1: + num_prompt = min(initial_num_words[0].item(), len(prompt_words)) + prompt_latents = [prefill_logits[:, i:i+1, :] for i in range(num_prompt - 1)] + prompt_target_words = prompt_words[1:num_prompt] + prompt_entropies, prompt_byte_labels = self._compute_generation_entropy( + prompt_latents, prompt_target_words, tokenizer, device) + + gen_entropies, gen_byte_labels = self._compute_generation_entropy( + word_latents, all_generated_words[0], tokenizer, device) + + texts = ["".join(words) for words in all_generated_words] + return (texts, + prompt_entropies + gen_entropies, + prompt_byte_labels + gen_byte_labels, + len(prompt_entropies))Then call this helper from
generate()whenreturn_entropy=True.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@welt/model.py` around lines 675 - 800, The generate method is over-complex due to inline entropy post-processing; extract the block that computes prompt_entropies/prompt_byte_labels and gen_entropies/gen_byte_labels (the code that calls _compute_generation_entropy using prefill_logits, prompt_words, word_latents, all_generated_words and tokenizer) into a new helper method named _compute_entropy_for_generation that accepts (prefill_logits, initial_num_words, prompt_words, word_latents, all_generated_words, tokenizer, device) and returns (combined_entropies, combined_byte_labels, num_prompt_entropies). Replace the extracted block in generate() with a single call to _compute_entropy_for_generation and use its returned tuple to build the final return; ensure behavior (including handling when prompt_words is None or short) is preserved and that types/ordering of returned values match the original (prompt_entropies + gen_entropies, prompt_byte_labels + gen_byte_labels, len(prompt_entropies)).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@welt_training/demo_clm.py`:
- Around line 29-34: The tokenizer is loaded from model_path while the model is
loaded from checkpoint_path, which can cause tokenizer/model mismatch; update
the AutoTokenizer.from_pretrained call to use checkpoint_path (or otherwise
verify and document that tokenizer and checkpoint share the same vocab/config)
so tokenizer and model are loaded from the same source; locate the lines using
AutoTokenizer.from_pretrained (tokenizer) and
AutoModelForCausalLM.from_pretrained (model) and make the tokenizer use
checkpoint_path or add an explicit check that model_path and checkpoint_path are
compatible.
- Line 64: The function signature for generate(...) is too long and also the
call at the second occurrence (the generate( invocation around line 96) exceeds
120 chars; split the long function definition and any long call into multi-line
form to satisfy line-length linting—break the parameter list of generate into
multiple lines (one parameter per line or logical groups) and update the
corresponding generate(...) invocation to pass arguments each on its own line or
use keyword args on separate lines so both the def generate and its call are
under 120 characters; ensure you preserve parameter order and names
(max_new_tokens, strategy, num_beams, top_k, top_p, temperature,
repetition_penalty, model, tokenizer).
In `@welt_training/demo.py`:
- Around line 89-90: The loop currently does "for i, (xi, h, c, a) in
enumerate(zip(x, entropies, colors, alphas)):" which leaves the index i unused
and calls zip without strict, triggering Ruff errors; change the loop to drop
the unused enumerate/index and call zip with strict=True so it becomes a direct
unpack: iterate over (xi, h, c, a) from zip(x, entropies, colors, alphas,
strict=True), keeping the alphas construction (alphas = [0.4] *
prompt_byte_count + [1.0] * (len(entropies) - prompt_byte_count)) intact.
- Around line 18-20: The code defines DEVICE and AUTOCAST_DTYPE but the model is
still loaded with a hard-coded torch.bfloat16; update AUTOCAST_DTYPE to choose
bfloat16 for CUDA, float16 for MPS, and float32 for CPU (i.e., torch.bfloat16 if
DEVICE == "cuda" else torch.float16 if DEVICE == "mps" else torch.float32), then
change the model-loading call (the from_pretrained/load_model or similar call
that currently passes torch.bfloat16) to use AUTOCAST_DTYPE so the loaded model
dtype matches the device-aware autocast logic.
In `@welt_training/trainer.py`:
- Around line 208-252: The outer isinstance currently accepts
torch.utils.data.IterableDataset but the code only shards
datasets.IterableDataset and CustomIterableDataset and expects
with_transform()/set_transform() later in _prepare_eval_dataset; fix by either
narrowing the isinstance check to only HF/custom iterables
(datasets.IterableDataset or CustomIterableDataset) or add proper handling for
plain torch IterableDataset: when eval_dataset is a torch IterableDataset, shard
it across ranks using split_dataset_by_node (wrap the dataset in an adapter that
exposes set_transform/with_transform or reuse _TorchIterableAdapter extended to
implement those methods), and ensure the sharded wrapper is used so later
_prepare_eval_dataset() can call with_transform() without AttributeError;
reference symbols: eval_dataset, _TorchIterableAdapter, CustomIterableDataset,
split_dataset_by_node, and _prepare_eval_dataset/with_transform.
---
Nitpick comments:
In `@welt/model.py`:
- Around line 675-800: The generate method is over-complex due to inline entropy
post-processing; extract the block that computes
prompt_entropies/prompt_byte_labels and gen_entropies/gen_byte_labels (the code
that calls _compute_generation_entropy using prefill_logits, prompt_words,
word_latents, all_generated_words and tokenizer) into a new helper method named
_compute_entropy_for_generation that accepts (prefill_logits, initial_num_words,
prompt_words, word_latents, all_generated_words, tokenizer, device) and returns
(combined_entropies, combined_byte_labels, num_prompt_entropies). Replace the
extracted block in generate() with a single call to
_compute_entropy_for_generation and use its returned tuple to build the final
return; ensure behavior (including handling when prompt_words is None or short)
is preserved and that types/ordering of returned values match the original
(prompt_entropies + gen_entropies, prompt_byte_labels + gen_byte_labels,
len(prompt_entropies)).
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 4953cffd-9986-41cf-8da4-50bf3c577704
📒 Files selected for processing (11)
tests/test_train.pytests/test_trainer.pywelt/model.pywelt/model_utils.pywelt/vision/navit.pywelt_training/args_data.pywelt_training/demo.pywelt_training/demo_clm.pywelt_training/experiments/machine-translation/run_clm.pywelt_training/train.pywelt_training/trainer.py
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/test_trainer.py
| tokenizer = AutoTokenizer.from_pretrained(model_path) | ||
| model = AutoModelForCausalLM.from_pretrained( | ||
| checkpoint_path, | ||
| torch_dtype=torch.bfloat16, | ||
| device_map=DEVICE, | ||
| ) |
There was a problem hiding this comment.
Potential inconsistency between tokenizer and model paths.
The tokenizer is loaded from model_path (line 29) while the model is loaded from checkpoint_path (lines 30-34). If checkpoint_path is a later checkpoint with a different tokenizer configuration, this could cause mismatches.
Consider loading the tokenizer from checkpoint_path as well, or verify that checkpoints always share the same tokenizer as the base model path.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@welt_training/demo_clm.py` around lines 29 - 34, The tokenizer is loaded from
model_path while the model is loaded from checkpoint_path, which can cause
tokenizer/model mismatch; update the AutoTokenizer.from_pretrained call to use
checkpoint_path (or otherwise verify and document that tokenizer and checkpoint
share the same vocab/config) so tokenizer and model are loaded from the same
source; locate the lines using AutoTokenizer.from_pretrained (tokenizer) and
AutoModelForCausalLM.from_pretrained (model) and make the tokenizer use
checkpoint_path or add an explicit check that model_path and checkpoint_path are
compatible.
|
|
||
| @torch.inference_mode() | ||
| @torch.autocast(device_type=DEVICE, dtype=AUTOCAST_DTYPE, enabled=DEVICE != "cpu") | ||
| def generate(prompt, max_new_tokens, strategy, num_beams, top_k, top_p, temperature, repetition_penalty, model, tokenizer): |
There was a problem hiding this comment.
Fix line length lint errors.
Lines 64 and 96 exceed the 120 character limit.
🔧 Proposed fix
`@torch.inference_mode`()
`@torch.autocast`(device_type=DEVICE, dtype=AUTOCAST_DTYPE, enabled=DEVICE != "cpu")
-def generate(prompt, max_new_tokens, strategy, num_beams, top_k, top_p, temperature, repetition_penalty, model, tokenizer):
+def generate(
+ prompt, max_new_tokens, strategy, num_beams, top_k, top_p,
+ temperature, repetition_penalty, model, tokenizer
+):
if not prompt.strip(): def on_generate(prompt, max_new_tokens, strategy, num_beams, top_k, top_p, temperature, repetition_penalty):
- return generate(
- prompt, max_new_tokens, strategy, num_beams, top_k, top_p, temperature, repetition_penalty, model, tokenizer)
+ return generate(
+ prompt, max_new_tokens, strategy, num_beams, top_k, top_p,
+ temperature, repetition_penalty, model, tokenizer
+ )Also applies to: 96-96
🧰 Tools
🪛 GitHub Actions: Lint
[error] 64-64: ruff check (E501): Line too long (123 > 120).
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@welt_training/demo_clm.py` at line 64, The function signature for
generate(...) is too long and also the call at the second occurrence (the
generate( invocation around line 96) exceeds 120 chars; split the long function
definition and any long call into multi-line form to satisfy line-length
linting—break the parameter list of generate into multiple lines (one parameter
per line or logical groups) and update the corresponding generate(...)
invocation to pass arguments each on its own line or use keyword args on
separate lines so both the def generate and its call are under 120 characters;
ensure you preserve parameter order and names (max_new_tokens, strategy,
num_beams, top_k, top_p, temperature, repetition_penalty, model, tokenizer).
| DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" | ||
| AUTOCAST_DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float16 | ||
|
|
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
fd -t f "demo.py" | head -20Repository: sign/WeLT
Length of output: 74
🏁 Script executed:
cat -n welt_training/demo.pyRepository: sign/WeLT
Length of output: 9624
🌐 Web query:
PyTorch bfloat16 support CPU compatibility
💡 Result:
PyTorch supports torch.bfloat16 (bfloat16) on CPU for all x86_64 processors. The dtype is fully functional universally, allowing tensor creation, model conversion (model.to(torch.bfloat16)), and operations via torch.autocast(device_type="cpu", dtype=torch.bfloat16) for mixed precision. Native hardware acceleration requires Intel CPUs with AVX512_BF16 (3rd Gen Xeon Scalable "Cooper Lake" and later) or AMX for optimal performance on compute-intensive ops like convolution/linear via oneDNN. On older CPUs lacking these ISAs, PyTorch falls back to software emulation: BF16 tensors convert to FP32 for compute, then back, resulting in correct results but reduced performance (often slower than FP32). No explicit torch.cpu.is_bfloat16_supported exists (unlike CUDA), and dtype checks return true for CPU. Compatibility issues (e.g., hangs in old versions) are resolved in modern PyTorch (2.x). ARM CPUs work similarly with emulation, though less optimized.
Citations:
- 1: https://pytorch.org/blog/empowering-pytorch-on-intel-xeon-scalable-processors-with-bfloat16/
- 2: https://www.intel.com/content/www/us/en/developer/articles/technical/pytorch-on-xeon-processors-with-bfloat16.html
- 3: Allow BF16 dtype support on CPU meta-pytorch/torchtune#1218
- 4:
torch.bmmis slow on non-contiguous BF16 CPU tensors pytorch/pytorch#151934 - 5: https://intel.github.io/neural-compressor/latest/docs/source/3x/PT_MixedPrecision.html
🌐 Web query:
PyTorch MPS Metal Performance Shaders bfloat16 support
💡 Result:
PyTorch's MPS (Metal Performance Shaders) backend supports bfloat16 on Apple Silicon (arm64) Macs running macOS 14.0 (Sonoma) or later, starting from PyTorch nightly builds around version 2.6 (late 2024). Full stable support is available in recent versions like 2.10+ as of 2026. Apple Metal added bfloat16 via MPSDataType.bFloat16 in macOS 14+. On older M1/M2 chips, it is emulated in software (using float32 internally), providing functional support but no hardware acceleration speedups—use float16 for better performance there. Intel/AMD (x86_64) Macs lack support due to hardware limitations. Some operations may still fallback to CPU or require recent PyTorch nightlies/PyTorch_ENABLE_MPS_FALLBACK=1 for full coverage. Official PyTorch docs do not explicitly list supported dtypes, but GitHub issues/PRs confirm ongoing kernel compilation and fixes for bfloat16 in MPS (e.g., arange_mps, shader codegen). Errors like "BFloat16 is not supported on MPS" occur on unsupported hardware/OS or older PyTorch versions—update PyTorch and verify torch.backends.mps.is_available.
Citations:
- 1: BFloat16 Unsupported on MPS (macOS Sequoia 15.1.1, Intel, AMD) pytorch/pytorch#141864
- 2: https://stackoverflow.com/questions/77359161/bfloat16-is-not-supported-on-mps-macos
- 3: RuntimeError: "arange_mps" not implemented for 'BFloat16' pytorch/pytorch#136624
- 4: [MPS] Compile kernels into Metallib pytorch/pytorch#138636
- 5: [Inductor][MPS] Fix half-precision type mismatches in Metal shader codegen pytorch/pytorch#176436
- 6: Add bfloat16 support detection for MPS in is_torch_bf16_gpu_available() huggingface/transformers#40458
- 7: https://pytorch.org/docs/stable/_modules/torch/backends/mps.html
- 8: https://docs.pytorch.org/docs/stable/backends.html
Use device-aware dtype for model loading to match AUTOCAST_DTYPE logic.
Line 38 hard-codes torch.bfloat16 despite the script explicitly defining AUTOCAST_DTYPE with device-specific values at line 19 and using it in the autocast decorator (line 109). This inconsistency means the model is loaded with a dtype that may not be optimal for the selected backend: bfloat16 on MPS requires macOS 14.0+ and newer PyTorch versions, and fp32 is more appropriate for CPU inference.
Apply the suggested fix to align model loading dtype with the device selection logic:
Suggested fix
+ load_dtype = torch.bfloat16 if DEVICE == "cuda" else torch.float16 if DEVICE == "mps" else torch.float32
model: WordLatentTransformerForCausalLM = (
WordLatentTransformerForCausalLM.from_pretrained(
checkpoint_path,
- dtype=torch.bfloat16,
+ dtype=load_dtype,
device_map=DEVICE,
)
)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@welt_training/demo.py` around lines 18 - 20, The code defines DEVICE and
AUTOCAST_DTYPE but the model is still loaded with a hard-coded torch.bfloat16;
update AUTOCAST_DTYPE to choose bfloat16 for CUDA, float16 for MPS, and float32
for CPU (i.e., torch.bfloat16 if DEVICE == "cuda" else torch.float16 if DEVICE
== "mps" else torch.float32), then change the model-loading call (the
from_pretrained/load_model or similar call that currently passes torch.bfloat16)
to use AUTOCAST_DTYPE so the loaded model dtype matches the device-aware
autocast logic.
| alphas = [0.4] * prompt_byte_count + [1.0] * (len(entropies) - prompt_byte_count) | ||
| for i, (xi, h, c, a) in enumerate(zip(x, entropies, colors, alphas)): |
There was a problem hiding this comment.
Fix the CI-blocking Ruff errors in the bar loop.
GitHub Actions is already red here: i is unused, and zip() needs an explicit strict=.
Suggested fix
- for i, (xi, h, c, a) in enumerate(zip(x, entropies, colors, alphas)):
+ for xi, h, c, a in zip(x, entropies, colors, alphas, strict=True):
ax.bar(xi, h, color=c, alpha=a, edgecolor="none", width=0.8)🧰 Tools
🪛 GitHub Actions: Lint
[error] 90-90: ruff check (B007): Loop control variable i not used within loop body. Help: Rename unused i to _i.
[error] 90-90: ruff check (B905): zip() without an explicit strict= parameter. Help: Add explicit value for parameter strict=.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@welt_training/demo.py` around lines 89 - 90, The loop currently does "for i,
(xi, h, c, a) in enumerate(zip(x, entropies, colors, alphas)):" which leaves the
index i unused and calls zip without strict, triggering Ruff errors; change the
loop to drop the unused enumerate/index and call zip with strict=True so it
becomes a direct unpack: iterate over (xi, h, c, a) from zip(x, entropies,
colors, alphas, strict=True), keeping the alphas construction (alphas = [0.4] *
prompt_byte_count + [1.0] * (len(entropies) - prompt_byte_count)) intact.
| For IterableDataset (both torch and HuggingFace datasets variants): | ||
| - Shards across distributed ranks via split_dataset_by_node | ||
| - Wraps HF IterableDataset in a torch-compatible adapter so PyTorch's | ||
| DataLoader treats it as iterable (not map-style) | ||
| - Creates a DataLoader without accelerate's prepare() to avoid | ||
| string field concatenation errors | ||
|
|
||
| For regular Dataset: | ||
| - Falls through to the base Trainer (handles DistributedSampler) | ||
| """ | ||
| from torch.utils.data import DataLoader, IterableDataset | ||
| import datasets | ||
| from torch.utils.data import DataLoader | ||
| from torch.utils.data import IterableDataset as TorchIterableDataset | ||
|
|
||
| eval_dataset = eval_dataset or self.eval_dataset | ||
|
|
||
| # For IterableDataset, create dataloader without accelerate's prepare | ||
| # to avoid string field concatenation errors | ||
| if isinstance(eval_dataset, IterableDataset): | ||
| # Check both torch and HF IterableDataset (they are unrelated classes; | ||
| # CustomIterableDataset inherits from datasets.IterableDataset only) | ||
| if isinstance(eval_dataset, TorchIterableDataset | datasets.IterableDataset): | ||
| # Shard across ranks for distributed evaluation | ||
| if self.accelerator.num_processes > 1: | ||
| from datasets.distributed import split_dataset_by_node | ||
|
|
||
| from welt_training.streaming import CustomIterableDataset | ||
|
|
||
| rank = self.accelerator.process_index | ||
| world_size = self.accelerator.num_processes | ||
|
|
||
| if isinstance(eval_dataset, CustomIterableDataset): | ||
| old_transform = eval_dataset._transform | ||
| sharded_inner = split_dataset_by_node( | ||
| eval_dataset._dataset, rank=rank, world_size=world_size) | ||
| eval_dataset = CustomIterableDataset(sharded_inner) | ||
| if old_transform is not None: | ||
| eval_dataset.set_transform(old_transform) | ||
| elif isinstance(eval_dataset, datasets.IterableDataset): | ||
| eval_dataset = split_dataset_by_node( | ||
| eval_dataset, rank=rank, world_size=world_size) | ||
|
|
||
| # datasets.IterableDataset does NOT inherit from | ||
| # torch.utils.data.IterableDataset, so PyTorch's DataLoader would | ||
| # treat it as map-style and try len()/__getitem__, which fails. | ||
| # Wrap in a thin torch-compatible adapter. | ||
| if not isinstance(eval_dataset, TorchIterableDataset): | ||
| eval_dataset = _TorchIterableAdapter(eval_dataset) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, locate and examine the trainer.py file
find . -type f -name "trainer.py" | grep -E "welt_training|trainer"Repository: sign/WeLT
Length of output: 79
🏁 Script executed:
# Let's also search for the file in the expected location
cat -n welt_training/trainer.py | head -260 | tail -60Repository: sign/WeLT
Length of output: 3476
🏁 Script executed:
# Also check around line 605-606 mentioned in the review
cat -n welt_training/trainer.py | sed -n '600,610p'Repository: sign/WeLT
Length of output: 724
🏁 Script executed:
# Check if _TorchIterableAdapter is defined and how it works
rg -n "_TorchIterableAdapter" welt_training/Repository: sign/WeLT
Length of output: 239
🏁 Script executed:
# Also search for with_transform usage in the file
rg -n "with_transform" welt_training/trainer.pyRepository: sign/WeLT
Length of output: 131
🏁 Script executed:
# Let me check the _TorchIterableAdapter implementation
cat -n welt_training/trainer.py | sed -n '24,100p'Repository: sign/WeLT
Length of output: 3866
🏁 Script executed:
# Now let's understand the call sequence - search for where _prepare_eval_dataset is called
rg -n "_prepare_eval_dataset" welt_training/trainer.py -A 3 -B 1Repository: sign/WeLT
Length of output: 545
🏁 Script executed:
# Also check the full method signatures around the dataset handling
rg -n "def get_eval_dataloader\|def _prepare_eval_dataset\|def evaluate" welt_training/trainer.pyRepository: sign/WeLT
Length of output: 35
🏁 Script executed:
# Let's check if there are other calls to with_transform on eval_dataset
rg -n "\.with_transform\(" welt_training/trainer.py -B 3 -A 1Repository: sign/WeLT
Length of output: 372
🏁 Script executed:
# Let me verify that with_transform is indeed a HF datasets method and not available on torch iterables
rg -n "def with_transform" --type py 2>/dev/null | head -5Repository: sign/WeLT
Length of output: 123
🏁 Script executed:
# Check if there's any handling for plain torch iterables in the custom code
rg -n "torch\.utils\.data\.IterableDataset" welt_training/ -A 5 -B 2Repository: sign/WeLT
Length of output: 1429
Plain torch.utils.data.IterableDataset is not properly handled and will fail at runtime.
The code catches arbitrary torch.utils.data.IterableDataset in the outer isinstance check (line 226), but the implementation only handles CustomIterableDataset and datasets.IterableDataset. A plain torch iterable will:
- Skip the sharding block (lines 236-245) entirely—so it won't be sharded in distributed evaluation
- Fail at line 606 when
_prepare_eval_dataset()callswith_transform(), sincetorch.utils.data.IterableDatasethas no such method (AttributeError)
The docstring claims to shard "both torch and HuggingFace datasets variants", but that's only true for HF/custom variants. Either narrow the isinstance check to only HF/custom iterables, or add proper sharding and transform handling for plain torch iterables.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@welt_training/trainer.py` around lines 208 - 252, The outer isinstance
currently accepts torch.utils.data.IterableDataset but the code only shards
datasets.IterableDataset and CustomIterableDataset and expects
with_transform()/set_transform() later in _prepare_eval_dataset; fix by either
narrowing the isinstance check to only HF/custom iterables
(datasets.IterableDataset or CustomIterableDataset) or add proper handling for
plain torch IterableDataset: when eval_dataset is a torch IterableDataset, shard
it across ranks using split_dataset_by_node (wrap the dataset in an adapter that
exposes set_transform/with_transform or reuse _TorchIterableAdapter extended to
implement those methods), and ensure the sharded wrapper is used so later
_prepare_eval_dataset() can call with_transform() without AttributeError;
reference symbols: eval_dataset, _TorchIterableAdapter, CustomIterableDataset,
split_dataset_by_node, and _prepare_eval_dataset/with_transform.


Fixes
Fixes #60 by @ilkerkesen
Description
Add bits-per-byte (BPB) metric to WeLT training and evaluation.
welt_training/metrics.pywith acompute_bits_per_byte(loss, num_tokens, num_bytes)utilityWeLTTrainerto accumulate per-batch nats and content byte counts during evaluation, computing BPB from exact token/byte counts (supports UTF-8, UTF-16, and UTF-32 encodings)prediction_step, reducing memory overhead_sync_eval_state()all-reduces scalar counters and gathers string predictions across ranks; handle last-batch padding via_real_batch_count()_TorchIterableAdapterto correctly wrap HFIterableDatasetfor PyTorchDataLoadereval_bits_per_byteinrun_clm.pyfor non-streaming CLM evaluationeval_bits_per_bytealongside existingeval_loss,perplexity,eval_byte_accuracy, andeval_word_accuracy.Tests
pytest tests/test_metrics.py -vpytest tests/test_trainer.py -vpytest tests/test_train.py -vpytest tests/test_metrics.py tests/test_trainer.py tests/test_train.py -vChecklist
Update index.md).mainormaster).visible errors.
Note
Medium Risk
Touches
WeLTTrainerevaluation internals (accuracy/BPB accumulation, distributed syncing, iterable dataloading), which could subtly change reported metrics or behavior under multi-GPU/streaming setups. Functionality is well-covered by new unit/integration tests, lowering regression risk.Overview
Adds a new
bits per byte(BPB) evaluation metric viacompute_bits_per_byte()and reports it aseval_bits_per_bytealongside loss/perplexity.Refactors
WeLTTrainerevaluation to incrementally accumulate byte/word accuracy and exact BPB numerators/denominators duringprediction_step(including UTF-8/16/32 handling), then synchronizes counters and generated strings across distributed ranks before computing final metrics; also improves iterable/streaming eval dataloader handling (HFIterableDatasetadapter + sharding) and fixeseval_samplescounting.Updates
run_clm.pyto computebits_per_byte(and perplexity) during eval using per-token losses plus precomputed token/byte counts, and extends training/trainer tests to assert BPB presence/behavior and evaluation consistency.Written by Cursor Bugbot for commit 518d5bc. This will update automatically on new commits. Configure here.
Summary by CodeRabbit
New Features
Refactor
Tests