Skip to content

Implement Bits Per Byte Metric Computation#61

Open
ilkerkesen wants to merge 12 commits intosign:mainfrom
ilkerkesen:bits-per-byte
Open

Implement Bits Per Byte Metric Computation#61
ilkerkesen wants to merge 12 commits intosign:mainfrom
ilkerkesen:bits-per-byte

Conversation

@ilkerkesen
Copy link
Copy Markdown
Contributor

@ilkerkesen ilkerkesen commented Mar 6, 2026

Fixes

Fixes #60 by @ilkerkesen

Description

Add bits-per-byte (BPB) metric to WeLT training and evaluation.

  • Introduce welt_training/metrics.py with a compute_bits_per_byte(loss, num_tokens, num_bytes) utility
  • Extend WeLTTrainer to accumulate per-batch nats and content byte counts during evaluation, computing BPB from exact token/byte counts (supports UTF-8, UTF-16, and UTF-32 encodings)
  • Refactor accuracy computation from post-hoc batch-level reprocessing to incremental on-device accumulation during prediction_step, reducing memory overhead
  • Add distributed evaluation support: _sync_eval_state() all-reduces scalar counters and gathers string predictions across ranks; handle last-batch padding via _real_batch_count()
  • Add _TorchIterableAdapter to correctly wrap HF IterableDataset for PyTorch DataLoader
  • Compute eval_bits_per_byte in run_clm.py for non-streaming CLM evaluation
  • Report eval_bits_per_byte alongside existing eval_loss, perplexity, eval_byte_accuracy, and eval_word_accuracy.

Tests

  1. Unit tests for the BPB metric utility: pytest tests/test_metrics.py -v
  2. Trainer integration tests (accuracy, BPB, distributed padding): pytest tests/test_trainer.py -v
  3. Existing training tests (regression check): pytest tests/test_train.py -v
  4. Run all tests together: pytest tests/test_metrics.py tests/test_trainer.py tests/test_train.py -v

Checklist

  • My pull request has a descriptive title (not a vague title like Update index.md).
  • My pull request targets the default branch of the repository (main or master).
  • My commit messages follow the contribution guidelines.
  • My code follows the established code style of the repository.
  • I added or updated tests for the changes I made (if applicable).
  • I added or updated documentation (if applicable).
  • I tried running the project locally and verified that there are no
    visible errors.

Note

Medium Risk
Touches WeLTTrainer evaluation internals (accuracy/BPB accumulation, distributed syncing, iterable dataloading), which could subtly change reported metrics or behavior under multi-GPU/streaming setups. Functionality is well-covered by new unit/integration tests, lowering regression risk.

Overview
Adds a new bits per byte (BPB) evaluation metric via compute_bits_per_byte() and reports it as eval_bits_per_byte alongside loss/perplexity.

Refactors WeLTTrainer evaluation to incrementally accumulate byte/word accuracy and exact BPB numerators/denominators during prediction_step (including UTF-8/16/32 handling), then synchronizes counters and generated strings across distributed ranks before computing final metrics; also improves iterable/streaming eval dataloader handling (HF IterableDataset adapter + sharding) and fixes eval_samples counting.

Updates run_clm.py to compute bits_per_byte (and perplexity) during eval using per-token losses plus precomputed token/byte counts, and extends training/trainer tests to assert BPB presence/behavior and evaluation consistency.

Written by Cursor Bugbot for commit 518d5bc. This will update automatically on new commits. Configure here.

Summary by CodeRabbit

  • New Features

    • Added a bits-per-byte evaluation metric, byte- and word-level accuracy, and an option to report BPB alongside loss/perplexity.
    • New Gradio demo apps for interactive generation and entropy visualization; generation can optionally return per-byte entropy for plotting.
    • Evaluation: optional packed-eval mode and a preserve-document-boundaries option for chunked evaluation; improved iterable-dataset and distributed eval handling.
  • Refactor

    • Updated internal type-check syntax.
  • Tests

    • Extensive tests covering BPB math, tokenization/encoding effects, edge cases, and distributed consistency.

if remainder > 0:
rank = self.accelerator.process_index
return max(0, min(nominal_count, remainder - rank * nominal_count))
return nominal_count
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Distributed last-batch trimming wrong for non-dispatch mode

High Severity

_real_batch_count uses remainder - rank * nominal_count to compute how many real samples each rank has in the last batch. This formula is only correct when dispatch_batches=True (where remainder is the global count of leftover samples). In the default dispatch_batches=False mode, each rank's DataLoaderShard sets remainder to the per-rank count of real samples in its own last batch — the same value on every rank. The subtraction of rank * nominal_count causes all ranks except rank 0 to return 0, silently discarding their last-batch contributions to accuracy, BPB, generation predictions, and sample counts.

Fix in Cursor Fix in Web

@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Mar 6, 2026

📝 Walkthrough

Walkthrough

Adds a bits-per-byte metric and wiring: new compute_bits_per_byte, trainer-side accumulators and distributed sync for BPB and byte/word accuracy, IterableDataset adapter, evaluation packing options, CLM experiment/eval integration, model entropy-returning generation, demos, processor and vision config tweaks, and comprehensive tests validating BPB behavior.

Changes

Cohort / File(s) Summary
Metric core & Trainer
welt_training/metrics.py, welt_training/trainer.py
Added compute_bits_per_byte(loss,num_tokens,num_bytes). Trainer now accumulates accuracy/BPB counters, syncs across ranks (_sync_eval_state), computes/exports eval_bits_per_byte, adds _accumulate_accuracy_and_bpb, _real_batch_count, _generate_predictions, and _TorchIterableAdapter. Removed legacy logits storage.
Model generation / entropy
welt/model.py
_prefill now returns prefill_logits; generate gains return_entropy and prompt_words params and can return per-byte entropies, byte labels, and prompt-offset when enabled.
Experiment & CLM eval
welt_training/experiments/machine-translation/run_clm.py, welt_training/train.py
Wired BPB into CLM eval (compute_metrics/preprocess_logits_for_metrics) and changed truncation/packing semantics: optional eval_preserve_document_boundaries, pre-tokenization truncation, and shared pretokenize_and_pack helper; added pack_eval_dataset handling.
Args / Config
welt_training/args_data.py
Added pack_eval_dataset: bool to DataTrainingArguments with compatibility checks vs generation metrics.
Processor & model utils
welt/processor.py, welt/model_utils.py
Minor type-check syntax update to PEP 604 union; image processor selection simplified to instantiate ViTImageProcessorFast() for custom models and expanded CUSTOM_MODELS configs.
Vision NaViT
welt/vision/navit.py
NaViTConfig.image_size accepts `int
Demos (new)
welt_training/demo.py, welt_training/demo_clm.py
Added two Gradio demo scripts for WELT word-latent models and CLM models, including device/autocast selection, generation config builders, and UI wiring.
Tests
tests/test_metrics.py, tests/test_train.py, tests/test_trainer.py
New tests/test_metrics.py with extensive unit/integration tests for BPB; updated training/trainer tests to assert eval_bits_per_byte presence and correctness across encodings, batch sizes, iterable datasets, and sync behavior.

Sequence Diagram

sequenceDiagram
    participant DL as DataLoader
    participant Trainer
    participant Model
    participant Decoder
    participant Metrics
    participant Dist as DistributedSync
    participant Logger

    DL->>Trainer: yield eval batch (input_ids, labels)
    Trainer->>Model: forward(inputs)
    Model-->>Trainer: logits, loss
    Trainer->>Trainer: _accumulate_accuracy_and_bpb(logits, labels)
    Trainer->>Decoder: decode(input_ids) 
    Decoder-->>Metrics: decoded_text -> compute num_bytes,num_tokens
    Trainer->>Metrics: compute_bits_per_byte(eval_loss, num_tokens, num_bytes)
    Metrics-->>Trainer: eval_bits_per_byte
    Trainer->>Dist: _sync_eval_state() (reduce counters, gather preds)
    Dist-->>Trainer: synchronized accumulators
    Trainer->>Logger: log eval_bits_per_byte, accuracies, gen metrics
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

🐇✨
I counted bits from byte to byte,
Across the ranks I hopped in flight.
Loss to BPB, numbers bright,
Tests hum softly through the night.

🚥 Pre-merge checks | ✅ 3 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 63.29% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Out of Scope Changes check ❓ Inconclusive Most changes are scope-aligned with BPB implementation, but some changes appear tangential: demo scripts, image size config adjustments, and trainer refactors for IterableDataset handling extend beyond BPB core requirements. Clarify whether demo scripts, NaViT config updates, and IterableDataset adapter are essential dependencies for BPB or should be separated into follow-up PRs.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'Implement Bits Per Byte Metric Computation' directly and specifically describes the main change: adding BPB metric infrastructure to the codebase.
Linked Issues check ✅ Passed The PR comprehensively implements the BPB metric requirement from #60: adds compute_bits_per_byte utility, integrates BPB computation in trainer evaluation, supports multiple encodings (UTF-8/16/32), and surfaces eval_bits_per_byte in logs.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Tip

You can enable review details to help with troubleshooting, context usage and more.

Enable the reviews.review_details setting to include review details such as the model used, the time taken for each step and more in the review comments.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/test_metrics.py`:
- Around line 128-130: The test in tests/test_metrics.py instantiates a real
HuggingFace tokenizer (tokenizer = AutoTokenizer.from_pretrained("gpt2")) which
requires network access and can flake in CI; mark the test as non-hermetic by
adding `@pytest.mark.integration` to the test function or ensure the tokenizer is
pre-cached by committing the gpt2 tokenizer files and changing the loader to
point to the local directory (or add setup code to download/cache the tokenizer
before tests run). Update the test signature or test setup to reference the
tokenizer variable/AutoTokenizer.from_pretrained invocation accordingly.

In `@welt_training/experiments/machine-translation/run_clm.py`:
- Around line 760-771: The computed num_eval_bytes is biased by tokenizer.decode
defaults; change the decoding call used to compute num_eval_bytes so it
preserves special tokens and original spacing by calling
tokenizer.decode(example["input_ids"][1:], skip_special_tokens=False,
clean_up_tokenization_spaces=False) (or an equivalent tokenizer.batch_decode
with those flags) when iterating over eval_dataset; keep the surrounding logic
(data_args.streaming check, num_eval_tokens, compute_bits_per_byte and
metrics["eval_bits_per_byte"]) unchanged so eval_bits_per_byte uses a byte
denominator independent of tokenizer cleanup defaults.

In `@welt_training/trainer.py`:
- Around line 509-511: When computing batch_sample_count before updating
self._eval_sample_count, fall back to counting samples from the batch payload
when prefixes is None instead of zero: if prefixes is not None use
len(prefixes), otherwise derive the count from another batch element (e.g., the
first non-scalar entry such as inputs/targets/labels) and then call
self._real_batch_count(...) as before; ensure this change is applied where
_eval_sample_count is updated and that _add_custom_metrics() will then report
correct eval_samples for datasets without __len__.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: d4e4b075-cfeb-4318-a409-50d98d1f3e9c

📥 Commits

Reviewing files that changed from the base of the PR and between 7561381 and 18d2e3e.

📒 Files selected for processing (7)
  • tests/test_metrics.py
  • tests/test_train.py
  • tests/test_trainer.py
  • welt/processor.py
  • welt_training/experiments/machine-translation/run_clm.py
  • welt_training/metrics.py
  • welt_training/trainer.py

Comment on lines +760 to +771
# Compute bits per byte from the evaluated subset.
# Decode tokens back to text to count the corresponding UTF-8 bytes.
if not data_args.streaming:
num_eval_tokens = len(eval_dataset) * (block_size - 1)
num_eval_bytes = sum(
len(tokenizer.decode(example["input_ids"][1:]).encode("utf-8"))
for example in eval_dataset
)
if num_eval_bytes > 0:
metrics["eval_bits_per_byte"] = compute_bits_per_byte(
metrics["eval_loss"], num_eval_tokens, num_eval_bytes
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

find . -name "run_clm.py" -type f

Repository: sign/WeLT

Length of output: 111


🏁 Script executed:

git ls-files | grep -E "(run_clm|machine.*translation)" | head -20

Repository: sign/WeLT

Length of output: 321


🏁 Script executed:

wc -l ./welt_training/experiments/machine-translation/run_clm.py

Repository: sign/WeLT

Length of output: 115


🏁 Script executed:

sed -n '760,771p' ./welt_training/experiments/machine-translation/run_clm.py

Repository: sign/WeLT

Length of output: 661


🏁 Script executed:

sed -n '1,100p' ./welt_training/experiments/machine-translation/run_clm.py | grep -E "(import|from).*tokenizer|compute_bits_per_byte"

Repository: sign/WeLT

Length of output: 108


🏁 Script executed:

rg "compute_bits_per_byte" ./welt_training/experiments/machine-translation/run_clm.py -A 5 -B 2

Repository: sign/WeLT

Length of output: 606


🏁 Script executed:

rg "def compute_bits_per_byte" . -A 10

Repository: sign/WeLT

Length of output: 807


🏁 Script executed:

rg "def compute_bits_per_byte" . -A 20

Repository: sign/WeLT

Length of output: 1292


🏁 Script executed:

sed -n '1,50p' ./welt_training/experiments/machine-translation/run_clm.py | head -40

Repository: sign/WeLT

Length of output: 1427


🏁 Script executed:

sed -n '750,780p' ./welt_training/experiments/machine-translation/run_clm.py

Repository: sign/WeLT

Length of output: 1395


🏁 Script executed:

sed -n '1,200p' ./welt_training/experiments/machine-translation/run_clm.py | grep -E "tokenizer.*=" | head -10

Repository: sign/WeLT

Length of output: 133


🏁 Script executed:

rg "skip_special_tokens|clean_up_tokenization" ./welt_training/experiments/machine-translation/run_clm.py

Repository: sign/WeLT

Length of output: 35


🏁 Script executed:

sed -n '600,700p' ./welt_training/experiments/machine-translation/run_clm.py | grep -A 20 "tokenizer"

Repository: sign/WeLT

Length of output: 2188


🌐 Web query:

huggingface transformers tokenizer.decode() default parameters skip_special_tokens clean_up_tokenization_spaces

💡 Result:

In 🤗 Transformers, the method signature (and defaults) are:

tokenizer.decode(
    token_ids,
    skip_special_tokens: bool = False,
    clean_up_tokenization_spaces: bool | None = None,
    **kwargs
) -> str

[1][3]

So by default:

  • skip_special_tokens defaults to False (special tokens are kept). [1][3]
  • clean_up_tokenization_spaces defaults to None, which means it will fall back to tokenizer.clean_up_tokenization_spaces. [1][3]
    • The tokenizer attribute clean_up_tokenization_spaces itself defaults to True unless overridden when creating/loading the tokenizer. [2]

(The same defaults apply to batch_decode(...).) [1]


🏁 Script executed:

sed -n '200,400p' ./welt_training/experiments/machine-translation/run_clm.py | grep -B 5 -A 10 "load_model_and_tokenizer\|AutoTokenizer"

Repository: sign/WeLT

Length of output: 35


🏁 Script executed:

rg "tokenizer.*=" ./welt_training/experiments/machine-translation/run_clm.py -B 2 -A 2 | head -40

Repository: sign/WeLT

Length of output: 1433


🏁 Script executed:

sed -n '750,780p' ./welt_training/experiments/machine-translation/run_clm.py | cat -n

Repository: sign/WeLT

Length of output: 1612


Make the byte denominator independent of tokenizer decode defaults.

num_eval_bytes is currently derived via tokenizer.decode(...) with the tokenizer's default cleanup behavior (skip_special_tokens=False, clean_up_tokenization_spaces=True). For tokenizers that remove special tokens or normalize spaces, the decoded text no longer matches the original input bytes, causing eval_bits_per_byte to be biased.

Proposed fix
-            num_eval_bytes = sum(
-                len(tokenizer.decode(example["input_ids"][1:]).encode("utf-8"))
-                for example in eval_dataset
-            )
+            num_eval_bytes = sum(
+                len(
+                    tokenizer.decode(
+                        example["input_ids"][1:],
+                        skip_special_tokens=True,
+                        clean_up_tokenization_spaces=False,
+                    ).encode("utf-8")
+                )
+                for example in eval_dataset
+            )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@welt_training/experiments/machine-translation/run_clm.py` around lines 760 -
771, The computed num_eval_bytes is biased by tokenizer.decode defaults; change
the decoding call used to compute num_eval_bytes so it preserves special tokens
and original spacing by calling tokenizer.decode(example["input_ids"][1:],
skip_special_tokens=False, clean_up_tokenization_spaces=False) (or an equivalent
tokenizer.batch_decode with those flags) when iterating over eval_dataset; keep
the surrounding logic (data_args.streaming check, num_eval_tokens,
compute_bits_per_byte and metrics["eval_bits_per_byte"]) unchanged so
eval_bits_per_byte uses a byte denominator independent of tokenizer cleanup
defaults.

Copy link
Copy Markdown

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

batch_loss = model.bytes_decoder.compute_loss(flat_logits, flat_labels)

if torch.isfinite(batch_loss):
self._eval_total_nats += batch_loss.item() * batch_non_pad_bytes
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nats accumulator multiplied by spurious bytes_per_token factor

Medium Severity

_eval_total_nats accumulates batch_loss * batch_non_pad_bytes where batch_non_pad_bytes = batch_non_pad_tokens * bytes_per_token. Since batch_loss is a per-token average, the correct total nats is batch_loss * batch_non_pad_tokens, not batch_loss * batch_non_pad_bytes. The extra bytes_per_token factor means _eval_total_nats stores a value that is bytes_per_token× the actual total nats. For UTF-8 (bytes_per_token=1) the result is correct, and for UTF-16/UTF-32 the factor happens to cancel with the same inflation in _eval_total_content_bytes—but only when the model is byte-level. If a character-level model is ever used with UTF-16/UTF-32, the BPB will be wrong by a factor of bytes_per_token.

Additional Locations (1)

Fix in Cursor Fix in Web

Copy link
Copy Markdown
Contributor

@AmitMY AmitMY left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good! just one note regarding the streaming dataset

self._dataset = hf_dataset

def __iter__(self):
yield from self._dataset
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i feel like this should be in streaming.py

@AmitMY
Copy link
Copy Markdown
Contributor

AmitMY commented Mar 10, 2026

and tests fail

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🧹 Nitpick comments (1)
welt/model.py (1)

675-800: Address pipeline lint failure: generate is too complex (12 > 10).

The generate method exceeds the complexity threshold due to the added entropy computation logic. Consider extracting the entropy-related code (lines 782-798) into a helper method like _compute_entropy_for_generation to reduce cyclomatic complexity.

♻️ Suggested refactoring approach
+    def _compute_entropy_for_generation(
+            self,
+            prefill_logits: torch.Tensor,
+            initial_num_words: torch.Tensor,
+            word_latents: list[torch.Tensor],
+            all_generated_words: list[list[str]],
+            prompt_words: list[str] | None,
+            tokenizer,
+            device: torch.device,
+    ) -> tuple[list[str], list[float], list[str], int]:
+        """Extract entropy computation from generate() to reduce complexity."""
+        prompt_entropies, prompt_byte_labels = [], []
+        if prompt_words is not None and len(prompt_words) > 1:
+            num_prompt = min(initial_num_words[0].item(), len(prompt_words))
+            prompt_latents = [prefill_logits[:, i:i+1, :] for i in range(num_prompt - 1)]
+            prompt_target_words = prompt_words[1:num_prompt]
+            prompt_entropies, prompt_byte_labels = self._compute_generation_entropy(
+                prompt_latents, prompt_target_words, tokenizer, device)
+
+        gen_entropies, gen_byte_labels = self._compute_generation_entropy(
+            word_latents, all_generated_words[0], tokenizer, device)
+
+        texts = ["".join(words) for words in all_generated_words]
+        return (texts,
+                prompt_entropies + gen_entropies,
+                prompt_byte_labels + gen_byte_labels,
+                len(prompt_entropies))

Then call this helper from generate() when return_entropy=True.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@welt/model.py` around lines 675 - 800, The generate method is over-complex
due to inline entropy post-processing; extract the block that computes
prompt_entropies/prompt_byte_labels and gen_entropies/gen_byte_labels (the code
that calls _compute_generation_entropy using prefill_logits, prompt_words,
word_latents, all_generated_words and tokenizer) into a new helper method named
_compute_entropy_for_generation that accepts (prefill_logits, initial_num_words,
prompt_words, word_latents, all_generated_words, tokenizer, device) and returns
(combined_entropies, combined_byte_labels, num_prompt_entropies). Replace the
extracted block in generate() with a single call to
_compute_entropy_for_generation and use its returned tuple to build the final
return; ensure behavior (including handling when prompt_words is None or short)
is preserved and that types/ordering of returned values match the original
(prompt_entropies + gen_entropies, prompt_byte_labels + gen_byte_labels,
len(prompt_entropies)).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@welt_training/demo_clm.py`:
- Around line 29-34: The tokenizer is loaded from model_path while the model is
loaded from checkpoint_path, which can cause tokenizer/model mismatch; update
the AutoTokenizer.from_pretrained call to use checkpoint_path (or otherwise
verify and document that tokenizer and checkpoint share the same vocab/config)
so tokenizer and model are loaded from the same source; locate the lines using
AutoTokenizer.from_pretrained (tokenizer) and
AutoModelForCausalLM.from_pretrained (model) and make the tokenizer use
checkpoint_path or add an explicit check that model_path and checkpoint_path are
compatible.
- Line 64: The function signature for generate(...) is too long and also the
call at the second occurrence (the generate( invocation around line 96) exceeds
120 chars; split the long function definition and any long call into multi-line
form to satisfy line-length linting—break the parameter list of generate into
multiple lines (one parameter per line or logical groups) and update the
corresponding generate(...) invocation to pass arguments each on its own line or
use keyword args on separate lines so both the def generate and its call are
under 120 characters; ensure you preserve parameter order and names
(max_new_tokens, strategy, num_beams, top_k, top_p, temperature,
repetition_penalty, model, tokenizer).

In `@welt_training/demo.py`:
- Around line 89-90: The loop currently does "for i, (xi, h, c, a) in
enumerate(zip(x, entropies, colors, alphas)):" which leaves the index i unused
and calls zip without strict, triggering Ruff errors; change the loop to drop
the unused enumerate/index and call zip with strict=True so it becomes a direct
unpack: iterate over (xi, h, c, a) from zip(x, entropies, colors, alphas,
strict=True), keeping the alphas construction (alphas = [0.4] *
prompt_byte_count + [1.0] * (len(entropies) - prompt_byte_count)) intact.
- Around line 18-20: The code defines DEVICE and AUTOCAST_DTYPE but the model is
still loaded with a hard-coded torch.bfloat16; update AUTOCAST_DTYPE to choose
bfloat16 for CUDA, float16 for MPS, and float32 for CPU (i.e., torch.bfloat16 if
DEVICE == "cuda" else torch.float16 if DEVICE == "mps" else torch.float32), then
change the model-loading call (the from_pretrained/load_model or similar call
that currently passes torch.bfloat16) to use AUTOCAST_DTYPE so the loaded model
dtype matches the device-aware autocast logic.

In `@welt_training/trainer.py`:
- Around line 208-252: The outer isinstance currently accepts
torch.utils.data.IterableDataset but the code only shards
datasets.IterableDataset and CustomIterableDataset and expects
with_transform()/set_transform() later in _prepare_eval_dataset; fix by either
narrowing the isinstance check to only HF/custom iterables
(datasets.IterableDataset or CustomIterableDataset) or add proper handling for
plain torch IterableDataset: when eval_dataset is a torch IterableDataset, shard
it across ranks using split_dataset_by_node (wrap the dataset in an adapter that
exposes set_transform/with_transform or reuse _TorchIterableAdapter extended to
implement those methods), and ensure the sharded wrapper is used so later
_prepare_eval_dataset() can call with_transform() without AttributeError;
reference symbols: eval_dataset, _TorchIterableAdapter, CustomIterableDataset,
split_dataset_by_node, and _prepare_eval_dataset/with_transform.

---

Nitpick comments:
In `@welt/model.py`:
- Around line 675-800: The generate method is over-complex due to inline entropy
post-processing; extract the block that computes
prompt_entropies/prompt_byte_labels and gen_entropies/gen_byte_labels (the code
that calls _compute_generation_entropy using prefill_logits, prompt_words,
word_latents, all_generated_words and tokenizer) into a new helper method named
_compute_entropy_for_generation that accepts (prefill_logits, initial_num_words,
prompt_words, word_latents, all_generated_words, tokenizer, device) and returns
(combined_entropies, combined_byte_labels, num_prompt_entropies). Replace the
extracted block in generate() with a single call to
_compute_entropy_for_generation and use its returned tuple to build the final
return; ensure behavior (including handling when prompt_words is None or short)
is preserved and that types/ordering of returned values match the original
(prompt_entropies + gen_entropies, prompt_byte_labels + gen_byte_labels,
len(prompt_entropies)).

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 4953cffd-9986-41cf-8da4-50bf3c577704

📥 Commits

Reviewing files that changed from the base of the PR and between 405edeb and 22837b7.

📒 Files selected for processing (11)
  • tests/test_train.py
  • tests/test_trainer.py
  • welt/model.py
  • welt/model_utils.py
  • welt/vision/navit.py
  • welt_training/args_data.py
  • welt_training/demo.py
  • welt_training/demo_clm.py
  • welt_training/experiments/machine-translation/run_clm.py
  • welt_training/train.py
  • welt_training/trainer.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/test_trainer.py

Comment on lines +29 to +34
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
checkpoint_path,
torch_dtype=torch.bfloat16,
device_map=DEVICE,
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Potential inconsistency between tokenizer and model paths.

The tokenizer is loaded from model_path (line 29) while the model is loaded from checkpoint_path (lines 30-34). If checkpoint_path is a later checkpoint with a different tokenizer configuration, this could cause mismatches.

Consider loading the tokenizer from checkpoint_path as well, or verify that checkpoints always share the same tokenizer as the base model path.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@welt_training/demo_clm.py` around lines 29 - 34, The tokenizer is loaded from
model_path while the model is loaded from checkpoint_path, which can cause
tokenizer/model mismatch; update the AutoTokenizer.from_pretrained call to use
checkpoint_path (or otherwise verify and document that tokenizer and checkpoint
share the same vocab/config) so tokenizer and model are loaded from the same
source; locate the lines using AutoTokenizer.from_pretrained (tokenizer) and
AutoModelForCausalLM.from_pretrained (model) and make the tokenizer use
checkpoint_path or add an explicit check that model_path and checkpoint_path are
compatible.


@torch.inference_mode()
@torch.autocast(device_type=DEVICE, dtype=AUTOCAST_DTYPE, enabled=DEVICE != "cpu")
def generate(prompt, max_new_tokens, strategy, num_beams, top_k, top_p, temperature, repetition_penalty, model, tokenizer):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix line length lint errors.

Lines 64 and 96 exceed the 120 character limit.

🔧 Proposed fix
 `@torch.inference_mode`()
 `@torch.autocast`(device_type=DEVICE, dtype=AUTOCAST_DTYPE, enabled=DEVICE != "cpu")
-def generate(prompt, max_new_tokens, strategy, num_beams, top_k, top_p, temperature, repetition_penalty, model, tokenizer):
+def generate(
+    prompt, max_new_tokens, strategy, num_beams, top_k, top_p,
+    temperature, repetition_penalty, model, tokenizer
+):
     if not prompt.strip():
     def on_generate(prompt, max_new_tokens, strategy, num_beams, top_k, top_p, temperature, repetition_penalty):
-        return generate(
-            prompt, max_new_tokens, strategy, num_beams, top_k, top_p, temperature, repetition_penalty, model, tokenizer)
+        return generate(
+            prompt, max_new_tokens, strategy, num_beams, top_k, top_p,
+            temperature, repetition_penalty, model, tokenizer
+        )

Also applies to: 96-96

🧰 Tools
🪛 GitHub Actions: Lint

[error] 64-64: ruff check (E501): Line too long (123 > 120).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@welt_training/demo_clm.py` at line 64, The function signature for
generate(...) is too long and also the call at the second occurrence (the
generate( invocation around line 96) exceeds 120 chars; split the long function
definition and any long call into multi-line form to satisfy line-length
linting—break the parameter list of generate into multiple lines (one parameter
per line or logical groups) and update the corresponding generate(...)
invocation to pass arguments each on its own line or use keyword args on
separate lines so both the def generate and its call are under 120 characters;
ensure you preserve parameter order and names (max_new_tokens, strategy,
num_beams, top_k, top_p, temperature, repetition_penalty, model, tokenizer).

Comment on lines +18 to +20
DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
AUTOCAST_DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float16

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

fd -t f "demo.py" | head -20

Repository: sign/WeLT

Length of output: 74


🏁 Script executed:

cat -n welt_training/demo.py

Repository: sign/WeLT

Length of output: 9624


🌐 Web query:

PyTorch bfloat16 support CPU compatibility

💡 Result:

PyTorch supports torch.bfloat16 (bfloat16) on CPU for all x86_64 processors. The dtype is fully functional universally, allowing tensor creation, model conversion (model.to(torch.bfloat16)), and operations via torch.autocast(device_type="cpu", dtype=torch.bfloat16) for mixed precision. Native hardware acceleration requires Intel CPUs with AVX512_BF16 (3rd Gen Xeon Scalable "Cooper Lake" and later) or AMX for optimal performance on compute-intensive ops like convolution/linear via oneDNN. On older CPUs lacking these ISAs, PyTorch falls back to software emulation: BF16 tensors convert to FP32 for compute, then back, resulting in correct results but reduced performance (often slower than FP32). No explicit torch.cpu.is_bfloat16_supported exists (unlike CUDA), and dtype checks return true for CPU. Compatibility issues (e.g., hangs in old versions) are resolved in modern PyTorch (2.x). ARM CPUs work similarly with emulation, though less optimized.

Citations:


🌐 Web query:

PyTorch MPS Metal Performance Shaders bfloat16 support

💡 Result:

PyTorch's MPS (Metal Performance Shaders) backend supports bfloat16 on Apple Silicon (arm64) Macs running macOS 14.0 (Sonoma) or later, starting from PyTorch nightly builds around version 2.6 (late 2024). Full stable support is available in recent versions like 2.10+ as of 2026. Apple Metal added bfloat16 via MPSDataType.bFloat16 in macOS 14+. On older M1/M2 chips, it is emulated in software (using float32 internally), providing functional support but no hardware acceleration speedups—use float16 for better performance there. Intel/AMD (x86_64) Macs lack support due to hardware limitations. Some operations may still fallback to CPU or require recent PyTorch nightlies/PyTorch_ENABLE_MPS_FALLBACK=1 for full coverage. Official PyTorch docs do not explicitly list supported dtypes, but GitHub issues/PRs confirm ongoing kernel compilation and fixes for bfloat16 in MPS (e.g., arange_mps, shader codegen). Errors like "BFloat16 is not supported on MPS" occur on unsupported hardware/OS or older PyTorch versions—update PyTorch and verify torch.backends.mps.is_available.

Citations:


Use device-aware dtype for model loading to match AUTOCAST_DTYPE logic.

Line 38 hard-codes torch.bfloat16 despite the script explicitly defining AUTOCAST_DTYPE with device-specific values at line 19 and using it in the autocast decorator (line 109). This inconsistency means the model is loaded with a dtype that may not be optimal for the selected backend: bfloat16 on MPS requires macOS 14.0+ and newer PyTorch versions, and fp32 is more appropriate for CPU inference.

Apply the suggested fix to align model loading dtype with the device selection logic:

Suggested fix
+    load_dtype = torch.bfloat16 if DEVICE == "cuda" else torch.float16 if DEVICE == "mps" else torch.float32
     model: WordLatentTransformerForCausalLM = (
         WordLatentTransformerForCausalLM.from_pretrained(
             checkpoint_path,
-            dtype=torch.bfloat16,
+            dtype=load_dtype,
             device_map=DEVICE,
         )
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@welt_training/demo.py` around lines 18 - 20, The code defines DEVICE and
AUTOCAST_DTYPE but the model is still loaded with a hard-coded torch.bfloat16;
update AUTOCAST_DTYPE to choose bfloat16 for CUDA, float16 for MPS, and float32
for CPU (i.e., torch.bfloat16 if DEVICE == "cuda" else torch.float16 if DEVICE
== "mps" else torch.float32), then change the model-loading call (the
from_pretrained/load_model or similar call that currently passes torch.bfloat16)
to use AUTOCAST_DTYPE so the loaded model dtype matches the device-aware
autocast logic.

Comment on lines +89 to +90
alphas = [0.4] * prompt_byte_count + [1.0] * (len(entropies) - prompt_byte_count)
for i, (xi, h, c, a) in enumerate(zip(x, entropies, colors, alphas)):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix the CI-blocking Ruff errors in the bar loop.

GitHub Actions is already red here: i is unused, and zip() needs an explicit strict=.

Suggested fix
-        for i, (xi, h, c, a) in enumerate(zip(x, entropies, colors, alphas)):
+        for xi, h, c, a in zip(x, entropies, colors, alphas, strict=True):
             ax.bar(xi, h, color=c, alpha=a, edgecolor="none", width=0.8)
🧰 Tools
🪛 GitHub Actions: Lint

[error] 90-90: ruff check (B007): Loop control variable i not used within loop body. Help: Rename unused i to _i.


[error] 90-90: ruff check (B905): zip() without an explicit strict= parameter. Help: Add explicit value for parameter strict=.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@welt_training/demo.py` around lines 89 - 90, The loop currently does "for i,
(xi, h, c, a) in enumerate(zip(x, entropies, colors, alphas)):" which leaves the
index i unused and calls zip without strict, triggering Ruff errors; change the
loop to drop the unused enumerate/index and call zip with strict=True so it
becomes a direct unpack: iterate over (xi, h, c, a) from zip(x, entropies,
colors, alphas, strict=True), keeping the alphas construction (alphas = [0.4] *
prompt_byte_count + [1.0] * (len(entropies) - prompt_byte_count)) intact.

Comment on lines +208 to +252
For IterableDataset (both torch and HuggingFace datasets variants):
- Shards across distributed ranks via split_dataset_by_node
- Wraps HF IterableDataset in a torch-compatible adapter so PyTorch's
DataLoader treats it as iterable (not map-style)
- Creates a DataLoader without accelerate's prepare() to avoid
string field concatenation errors

For regular Dataset:
- Falls through to the base Trainer (handles DistributedSampler)
"""
from torch.utils.data import DataLoader, IterableDataset
import datasets
from torch.utils.data import DataLoader
from torch.utils.data import IterableDataset as TorchIterableDataset

eval_dataset = eval_dataset or self.eval_dataset

# For IterableDataset, create dataloader without accelerate's prepare
# to avoid string field concatenation errors
if isinstance(eval_dataset, IterableDataset):
# Check both torch and HF IterableDataset (they are unrelated classes;
# CustomIterableDataset inherits from datasets.IterableDataset only)
if isinstance(eval_dataset, TorchIterableDataset | datasets.IterableDataset):
# Shard across ranks for distributed evaluation
if self.accelerator.num_processes > 1:
from datasets.distributed import split_dataset_by_node

from welt_training.streaming import CustomIterableDataset

rank = self.accelerator.process_index
world_size = self.accelerator.num_processes

if isinstance(eval_dataset, CustomIterableDataset):
old_transform = eval_dataset._transform
sharded_inner = split_dataset_by_node(
eval_dataset._dataset, rank=rank, world_size=world_size)
eval_dataset = CustomIterableDataset(sharded_inner)
if old_transform is not None:
eval_dataset.set_transform(old_transform)
elif isinstance(eval_dataset, datasets.IterableDataset):
eval_dataset = split_dataset_by_node(
eval_dataset, rank=rank, world_size=world_size)

# datasets.IterableDataset does NOT inherit from
# torch.utils.data.IterableDataset, so PyTorch's DataLoader would
# treat it as map-style and try len()/__getitem__, which fails.
# Wrap in a thin torch-compatible adapter.
if not isinstance(eval_dataset, TorchIterableDataset):
eval_dataset = _TorchIterableAdapter(eval_dataset)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, locate and examine the trainer.py file
find . -type f -name "trainer.py" | grep -E "welt_training|trainer"

Repository: sign/WeLT

Length of output: 79


🏁 Script executed:

# Let's also search for the file in the expected location
cat -n welt_training/trainer.py | head -260 | tail -60

Repository: sign/WeLT

Length of output: 3476


🏁 Script executed:

# Also check around line 605-606 mentioned in the review
cat -n welt_training/trainer.py | sed -n '600,610p'

Repository: sign/WeLT

Length of output: 724


🏁 Script executed:

# Check if _TorchIterableAdapter is defined and how it works
rg -n "_TorchIterableAdapter" welt_training/

Repository: sign/WeLT

Length of output: 239


🏁 Script executed:

# Also search for with_transform usage in the file
rg -n "with_transform" welt_training/trainer.py

Repository: sign/WeLT

Length of output: 131


🏁 Script executed:

# Let me check the _TorchIterableAdapter implementation
cat -n welt_training/trainer.py | sed -n '24,100p'

Repository: sign/WeLT

Length of output: 3866


🏁 Script executed:

# Now let's understand the call sequence - search for where _prepare_eval_dataset is called
rg -n "_prepare_eval_dataset" welt_training/trainer.py -A 3 -B 1

Repository: sign/WeLT

Length of output: 545


🏁 Script executed:

# Also check the full method signatures around the dataset handling
rg -n "def get_eval_dataloader\|def _prepare_eval_dataset\|def evaluate" welt_training/trainer.py

Repository: sign/WeLT

Length of output: 35


🏁 Script executed:

# Let's check if there are other calls to with_transform on eval_dataset
rg -n "\.with_transform\(" welt_training/trainer.py -B 3 -A 1

Repository: sign/WeLT

Length of output: 372


🏁 Script executed:

# Let me verify that with_transform is indeed a HF datasets method and not available on torch iterables
rg -n "def with_transform" --type py 2>/dev/null | head -5

Repository: sign/WeLT

Length of output: 123


🏁 Script executed:

# Check if there's any handling for plain torch iterables in the custom code
rg -n "torch\.utils\.data\.IterableDataset" welt_training/ -A 5 -B 2

Repository: sign/WeLT

Length of output: 1429


Plain torch.utils.data.IterableDataset is not properly handled and will fail at runtime.

The code catches arbitrary torch.utils.data.IterableDataset in the outer isinstance check (line 226), but the implementation only handles CustomIterableDataset and datasets.IterableDataset. A plain torch iterable will:

  1. Skip the sharding block (lines 236-245) entirely—so it won't be sharded in distributed evaluation
  2. Fail at line 606 when _prepare_eval_dataset() calls with_transform(), since torch.utils.data.IterableDataset has no such method (AttributeError)

The docstring claims to shard "both torch and HuggingFace datasets variants", but that's only true for HF/custom variants. Either narrow the isinstance check to only HF/custom iterables, or add proper sharding and transform handling for plain torch iterables.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@welt_training/trainer.py` around lines 208 - 252, The outer isinstance
currently accepts torch.utils.data.IterableDataset but the code only shards
datasets.IterableDataset and CustomIterableDataset and expects
with_transform()/set_transform() later in _prepare_eval_dataset; fix by either
narrowing the isinstance check to only HF/custom iterables
(datasets.IterableDataset or CustomIterableDataset) or add proper handling for
plain torch IterableDataset: when eval_dataset is a torch IterableDataset, shard
it across ranks using split_dataset_by_node (wrap the dataset in an adapter that
exposes set_transform/with_transform or reuse _TorchIterableAdapter extended to
implement those methods), and ensure the sharded wrapper is used so later
_prepare_eval_dataset() can call with_transform() without AttributeError;
reference symbols: eval_dataset, _TorchIterableAdapter, CustomIterableDataset,
split_dataset_by_node, and _prepare_eval_dataset/with_transform.

Copy link
Copy Markdown
Contributor

@AmitMY AmitMY left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lint and test fail

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature] Bits per bytes metric

2 participants