feat: Mamba-3 MLX port + pipeline/benchmark correctness fixes#14
feat: Mamba-3 MLX port + pipeline/benchmark correctness fixes#14skyoo2003 wants to merge 6 commits into
Conversation
- Mamba-3 SISO+MIMO block (arXiv:2603.15569) in pure MLX; replaces AxonSSM in sandwich architecture. - Pipeline merge bug: save trained full params instead of only LoRA (PPL ~150K -> ~12K). - SWA prefill/decode cache bug: create_caches on prefill; split prefill/decode paths in SWA. - Qwen2.5-3B embedding port (std-rescaled, d_source>=512 only). - Benchmark harness: relax MC extractor, add stop_strings for GSM8K, guard math.exp overflow. - Pipeline UX: per-50-step SFT/ORPO progress, per-10% benchmark progress via print(flush=True). - Config/hparams: LR 1e-3->1e-4, LoRA scale 20->8, vocab_size promotion, large() preset. - ORPO dataset rename: ultrafeedback_binarized_cleaned -> ultrafeedback_binarized. - .gitignore: pipeline_output/, reference_impl/.
Reviewer's GuideIntroduces a new pure-MLX Mamba-3 sequence block and wires it into BitAxon blocks, adds a Qwen2.5-3B embedding port for better initialization, fixes several pipeline and sliding-window attention cache correctness bugs, improves the evaluation harness and perplexity handling, tweaks training hyperparameters and CLI UX (including large config and benchmark flags), and adds a staging-only Hugging Face upload path plus scripts for baselines and full pipeline runs. File-Level Changes
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
There was a problem hiding this comment.
Hey - I've found 2 issues
Prompt for AI Agents
Please address the comments from this code review:
## Individual Comments
### Comment 1
<location path="src/bit_axon/cli/upload.py" line_range="260" />
<code_context>
+ # Copy LICENSE / NOTICE from the repo root so the HF repo carries the
+ # same legal files the project ships with, not just the license field
+ # declared in the model card frontmatter.
+ repo_root = Path(__file__).resolve().parents[3]
+ for fname in ("LICENSE", "NOTICE"):
+ src = repo_root / fname
</code_context>
<issue_to_address>
**issue (bug_risk):** Repository root discovery likely off by one directory level
Given `__file__` is `.../src/bit_axon/cli/upload.py`, `parents[3]` resolves to `src/`, not the project root. That means LICENSE/NOTICE are searched under `src/` and will never be copied from the actual repo root, defeating the intention of this block. Please update to use `parents[4]` (or another explicit project-root resolution) so the legal files are correctly discovered and copied.
</issue_to_address>
### Comment 2
<location path="src/bit_axon/cli/upload.py" line_range="269-273" />
<code_context>
bench_dict: dict[str, float] | None = None
if benchmark_results is not None:
bench_dict = {}
for pair in benchmark_results.split(","):
+ if "=" not in pair:
+ continue
name, acc = pair.strip().split("=")
</code_context>
<issue_to_address>
**suggestion (bug_risk):** Silently skipping malformed benchmark entries can make CLI misconfiguration harder to notice
Malformed entries in `--benchmark-results` (missing `=`) are currently dropped, which can hide typos like `mmlu-0.45` instead of `mmlu=0.45`. Please either emit a clear warning for skipped entries or fail fast on the first malformed pair so users know their input wasn’t parsed as intended.
```suggestion
bench_dict = {}
for pair in benchmark_results.split(","):
if "=" not in pair:
raise ValueError(
f"Malformed benchmark result entry: {pair!r}. "
"Expected format 'name=score', e.g. 'mmlu=0.45'."
)
name, acc = pair.strip().split("=")
```
</issue_to_address>Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
parents[3] resolved to src/ instead of project root; use parents[4]. Silent skip of malformed --benchmark-results replaced with ValueError.
|
@sourcery-ai review |
There was a problem hiding this comment.
Hey - I've found 1 issue, and left some high level feedback:
- The
_scalar_decay_scanimplementation inmamba3.pyfalls back to a Pythonforloop whenchunk_size >= T, which could become a performance bottleneck if users set a largechunk_size; consider always using a vectorized chunked path (even for a single chunk) to avoid per-token Python overhead. - In
stage_upload_dir, resolving the repo root viaPath(__file__).resolve().parents[4]is fairly brittle to directory layout changes; using a more explicit root discovery mechanism (e.g. anchoring on a known package/module or walking up until finding.git/pyproject.toml) would make this more robust. - The benchmark
stop_stringshandling in_generate_syncdecodes the full generated sequence on every token, which is O(n²); given your own comment about this, you might want to implement the suggested optimization (buffering and decoding every K tokens) before this grows into a noticeable bottleneck for longer completions.
Prompt for AI Agents
Please address the comments from this code review:
## Overall Comments
- The `_scalar_decay_scan` implementation in `mamba3.py` falls back to a Python `for` loop when `chunk_size >= T`, which could become a performance bottleneck if users set a large `chunk_size`; consider always using a vectorized chunked path (even for a single chunk) to avoid per-token Python overhead.
- In `stage_upload_dir`, resolving the repo root via `Path(__file__).resolve().parents[4]` is fairly brittle to directory layout changes; using a more explicit root discovery mechanism (e.g. anchoring on a known package/module or walking up until finding `.git`/`pyproject.toml`) would make this more robust.
- The benchmark `stop_strings` handling in `_generate_sync` decodes the full generated sequence on every token, which is O(n²); given your own comment about this, you might want to implement the suggested optimization (buffering and decoding every K tokens) before this grows into a noticeable bottleneck for longer completions.
## Individual Comments
### Comment 1
<location path="src/bit_axon/cli/upload.py" line_range="208-217" />
<code_context>
-def upload_cmd(
+def stage_upload_dir(
model_path: str,
repo_id: str,
</code_context>
<issue_to_address>
**suggestion (bug_risk):** Staging upload folders ignores missing tokenizer.json, which can silently produce unusable artifacts.
`stage_upload_dir` now catches all exceptions from `_download_tokenizer_files` but still reports success, and `stage-upload` doesn’t re‑validate `tokenizer.json`. Unlike `upload_cmd` (which validates later), this means `stage-upload` can produce a directory that appears ready but is missing a tokenizer. Please either (a) have `stage_upload_dir` fail or emit a clear warning when `tokenizer.json` is absent, or (b) add a `stage-upload` option to treat missing tokenizer files as an error, so local/no-push workflows don’t quietly create broken repos.
</issue_to_address>Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
…tokenizer guard - mamba3: remove Python for-loop fallback, always use vectorized chunked scan - upload: replace brittle parents[4] with _find_repo_root() walking for pyproject.toml/.git - upload: fail fast when tokenizer.json is missing after staging - generate: buffer stop_strings decode check to every 16 tokens (O(n) instead of O(n²))
MLX Metal device crashes with SIGABRT after ~130 tests in a single process on CI runners. Run test_generate.py in a fresh process to avoid Metal resource exhaustion from prior test allocations.
… tests - test_enumerate_target_keys_count: 66→78 to match Mamba3's 13 params/SSM (adds B_bias, C_bias, B_norm, C_norm, mimo_x, mimo_z, mimo_o) - CI: also isolate test_compile.py (full model forward) in separate pytest process to prevent MLX Metal NaN flakiness
…n CI - test_generate.py: replace BitAxonModel with FakeModel that returns deterministic mx.array logits without Metal GPU computation - test_compile.py: skip TestModelCompileNoRegression and TestCompileBitExact on CI (MLX Metal forward passes flaky on runners) - ci.yml: restore single pytest invocation with CI=true env var
|
This PR has had no activity for 7 days and has been marked as stale. If it is still being worked on, please push a commit or add a comment. Otherwise, it will be closed in 14 days. |
|
This PR was closed because it has been inactive for 14 days since being marked as stale. Feel free to reopen if still relevant. |
Summary
src/bit_axon/layers/mamba3.py— Pure MLX implementation of the Exponential-Trapezoidal 3-term recurrence from arXiv:2603.15569. Complex SSM expressed as real tensors + cumulative 2×2 rotations via a RoPE trick (no complex64 needed). Single SISO/MIMO codepath. Chunked parallel scan reuses Mamba-2 segsum (Appendix A.1). Prefill vs per-token decode numerical parity (diff ~1e-8).AxonSSMBlock/AxonSSMMoEBlockupdated to use Mamba3.src/bit_axon/porting/qwen_embedding.py— Truncates Qwen embeddings tod_sourcewith std-rescale to match downstream variance. Skips truncation whend_source_model < 512(over-aggressive truncation caused NaN).save_adapter_only + load_and_mergepath merged only LoRA weights on top of a random base, losing all trained embedding/norm parameters → merged PPL ~exp(log(vocab)) ~150K. Now usesmerge_adapters(model) + save_merged_model(model,...)to preserve the full parameter set. PPL measurement moved before NF4 quantization.BitAxonModel.__call__now calls_create_caches()during prefill;SlidingWindowAttentionprefill/decode paths separated (prefill seeds cache only, attends with full K/V — trimming to window_size was turning entire causal mask rows to -inf, causing softmax NaN regression).extract_answerrelaxed to word-boundary[A-D], GSM8K####format prioritized,GenerateConfig.stop_stringsadded,math.expoverflow guard inperplexity.py. Qwen2.5-3B baseline remeasured: GSM8K 1% → 71.5%, ARC-c 14.3% → 45.3%, ARC-e 16.7% → 62.7%, MMLU 14.2% → 46%.print(flush=True)to bypass rich buffering).ultrafeedback_binarized_cleaned404),BitAxonConfig.large()preset + 13 newmamba3_*fields, auto-promotevocab_size < tok.vocab_size(prevents OOB NaN),stage-uploadCLI + LICENSE/NOTICE copy + param count M/B rendering,.gitignoreadditions forpipeline_output/andreference_impl/.Test plan