From 64eb5bbb2cff5114046a1e24bbe5c8c5a2f980d1 Mon Sep 17 00:00:00 2001 From: drewjin Date: Mon, 16 Mar 2026 12:53:31 +0000 Subject: [PATCH 1/5] feat: Add validation mechanism for attention in Diffulex --- .gitignore | 2 + .../python/chunked_prefill_cutedsl.py | 0 .../config/validation_bench_sdar_bufsz1.yml | 36 +++ .../config/validation_bench_sdar_bufsz2.yml | 36 +++ .../config/validation_bench_sdar_bufsz4.yml | 36 +++ .../engine/dummy_attn_with_validation.py | 120 +++++++++ test/python/engine/test_diffulex_bench.py | 240 ------------------ test/python/engine/test_monkey_patch.py | 19 ++ test/python/engine/test_validation_bench.py | 77 ++++++ 9 files changed, 326 insertions(+), 240 deletions(-) create mode 100644 diffulex_kernel/python/chunked_prefill_cutedsl.py create mode 100644 test/python/engine/config/validation_bench_sdar_bufsz1.yml create mode 100644 test/python/engine/config/validation_bench_sdar_bufsz2.yml create mode 100644 test/python/engine/config/validation_bench_sdar_bufsz4.yml create mode 100644 test/python/engine/dummy_attn_with_validation.py delete mode 100644 test/python/engine/test_diffulex_bench.py create mode 100644 test/python/engine/test_monkey_patch.py create mode 100644 test/python/engine/test_validation_bench.py diff --git a/.gitignore b/.gitignore index 2713547..57dda48 100755 --- a/.gitignore +++ b/.gitignore @@ -60,3 +60,5 @@ benchmark_results_tmp/ # Cursor IDE files .cursor/ drewjin-diffulex +cutlass +output/ \ No newline at end of file diff --git a/diffulex_kernel/python/chunked_prefill_cutedsl.py b/diffulex_kernel/python/chunked_prefill_cutedsl.py new file mode 100644 index 0000000..e69de29 diff --git a/test/python/engine/config/validation_bench_sdar_bufsz1.yml b/test/python/engine/config/validation_bench_sdar_bufsz1.yml new file mode 100644 index 0000000..ad5c338 --- /dev/null +++ b/test/python/engine/config/validation_bench_sdar_bufsz1.yml @@ -0,0 +1,36 @@ +engine: + model_path: "/data1/ckpts/JetLM/SDAR-1.7B-Chat-b32" + tokenizer_path: null + model_name: "sdar" + decoding_strategy: "multi_bd" + mask_token_id: 151669 + + use_lora: false + lora_path: "" + + tensor_parallel_size: 1 + data_parallel_size: 1 + gpu_memory_utilization: 0.3 + max_model_len: 1024 + max_num_batched_tokens: 4096 + max_num_reqs: 1 + + enforce_eager: true + kv_cache_layout: "unified" + + decoding_thresholds: + add_block_threshold: 0.1 + semi_complete_threshold: 0.9 + decoding_threshold: 0.95 + block_size: 32 + buffer_size: 4 + +eval: + dataset_name: "gsm8k" + dataset_split: "test" + dataset_limit: null + temperature: 0.0 + max_tokens: 256 + add_bos_token: true + output_dir: "benchmark_results" + save_results: false diff --git a/test/python/engine/config/validation_bench_sdar_bufsz2.yml b/test/python/engine/config/validation_bench_sdar_bufsz2.yml new file mode 100644 index 0000000..a3eb3a7 --- /dev/null +++ b/test/python/engine/config/validation_bench_sdar_bufsz2.yml @@ -0,0 +1,36 @@ +engine: + model_path: "/data1/ckpts/JetLM/SDAR-1.7B-Chat-b32" + tokenizer_path: null + model_name: "sdar" + decoding_strategy: "multi_bd" + mask_token_id: 151669 + + use_lora: false + lora_path: "" + + tensor_parallel_size: 1 + data_parallel_size: 1 + gpu_memory_utilization: 0.3 + max_model_len: 1024 + max_num_batched_tokens: 4096 + max_num_reqs: 1 + + enforce_eager: true + kv_cache_layout: "unified" + + decoding_thresholds: + add_block_threshold: 0.1 + semi_complete_threshold: 0.9 + decoding_threshold: 0.95 + block_size: 32 + buffer_size: 2 + +eval: + dataset_name: "gsm8k" + dataset_split: "test" + dataset_limit: 5 + temperature: 0.0 + max_tokens: 256 + add_bos_token: true + output_dir: "benchmark_results" + save_results: false diff --git a/test/python/engine/config/validation_bench_sdar_bufsz4.yml b/test/python/engine/config/validation_bench_sdar_bufsz4.yml new file mode 100644 index 0000000..ad5c338 --- /dev/null +++ b/test/python/engine/config/validation_bench_sdar_bufsz4.yml @@ -0,0 +1,36 @@ +engine: + model_path: "/data1/ckpts/JetLM/SDAR-1.7B-Chat-b32" + tokenizer_path: null + model_name: "sdar" + decoding_strategy: "multi_bd" + mask_token_id: 151669 + + use_lora: false + lora_path: "" + + tensor_parallel_size: 1 + data_parallel_size: 1 + gpu_memory_utilization: 0.3 + max_model_len: 1024 + max_num_batched_tokens: 4096 + max_num_reqs: 1 + + enforce_eager: true + kv_cache_layout: "unified" + + decoding_thresholds: + add_block_threshold: 0.1 + semi_complete_threshold: 0.9 + decoding_threshold: 0.95 + block_size: 32 + buffer_size: 4 + +eval: + dataset_name: "gsm8k" + dataset_split: "test" + dataset_limit: null + temperature: 0.0 + max_tokens: 256 + add_bos_token: true + output_dir: "benchmark_results" + save_results: false diff --git a/test/python/engine/dummy_attn_with_validation.py b/test/python/engine/dummy_attn_with_validation.py new file mode 100644 index 0000000..244c3a1 --- /dev/null +++ b/test/python/engine/dummy_attn_with_validation.py @@ -0,0 +1,120 @@ +import torch +import torch.nn.functional as F +from einops import rearrange + +from diffulex.attention.attn_impl import Attention as OriginalAttention + + +class AttentionWithValidation(OriginalAttention): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.validation_enabled = True + self.atol = 1e-2 + self.rtol = 1e-2 + self.error_log = [] + + def forward(self, q, k, v, mask=None): + if not self.validation_enabled: + return super().forward(q, k, v, mask) + + # Get metadata + attn_metadata = self.fetch_attn_metadata() + + # Run original kernel + output = super().forward(q, k, v, mask) + + # Run reference implementation + try: + ref_output = self._compute_reference(q, k, v, attn_metadata) + self._validate_output(output, ref_output, attn_metadata) + except Exception as e: + self.error_log.append(f"Validation failed: {e}") + + return output + + def _compute_reference(self, q, k, v, metadata): + q_reshaped = rearrange(q, "s (nh hd) -> s nh hd", **self.q_shape) + k_reshaped = rearrange(k, "s (nkvh hd) -> s nkvh hd", **self.kv_shape) + v_reshaped = rearrange(v, "s (nkvh hd) -> s nkvh hd", **self.kv_shape) + + scale = self.scale + k_cache, v_cache = self.k_cache, self.v_cache + page_tables = metadata.page_tables + context_lens = metadata.context_lens + cu_seqlens_q = metadata.cu_seqlens_q + valid_slices = getattr(metadata, 'valid_slices', None) + page_size = metadata.page_size + + num_seqs = len(cu_seqlens_q) - 1 + output = torch.zeros_like(q_reshaped) + + for seq_id in range(num_seqs): + q_start = int(cu_seqlens_q[seq_id].item()) + if valid_slices is not None: + valid_end = int(valid_slices[seq_id].item()) + valid_q_len = valid_end - q_start + else: + q_end = int(cu_seqlens_q[seq_id + 1].item()) + valid_q_len = q_end - q_start + + ctx_len = int(context_lens[seq_id].item()) + + if valid_q_len <= 0: + continue + + q_seq = q_reshaped[q_start:q_start + valid_q_len] + + # Reconstruct cache KV + k_parts, v_parts = [], [] + if k_cache.numel() > 0 and ctx_len > 0: + for rel_page_id in range(page_tables.shape[1]): + abs_page_id = int(page_tables[seq_id, rel_page_id].item()) + if abs_page_id < 0: + continue + page_start = rel_page_id * page_size + if page_start >= ctx_len: + break + n = min(page_start + page_size, ctx_len) - page_start + k_parts.append(k_cache[abs_page_id, :n]) + v_parts.append(v_cache[abs_page_id, :n]) + + k_new = k_reshaped[q_start:q_start + valid_q_len] + v_new = v_reshaped[q_start:q_start + valid_q_len] + + if k_parts: + k_full = torch.cat(k_parts + [k_new], dim=0) + v_full = torch.cat(v_parts + [v_new], dim=0) + else: + k_full = k_new + v_full = v_new + + q_sdpa = rearrange(q_seq, "s h d -> 1 h s d") + k_sdpa = rearrange(k_full, "s h d -> 1 h s d") + v_sdpa = rearrange(v_full, "s h d -> 1 h s d") + + attn_out = F.scaled_dot_product_attention( + q_sdpa, k_sdpa, v_sdpa, dropout_p=0.0, is_causal=False, scale=scale, enable_gqa=True + ) + output[q_start:q_start + valid_q_len] = rearrange(attn_out, "1 h s d -> s h d") + + return rearrange(output, "s nh hd -> s (nh hd)").contiguous() + + def _validate_output(self, output, ref_output, metadata): + try: + torch.testing.assert_close(output, ref_output, atol=self.atol, rtol=self.rtol) + except AssertionError as e: + diff = (output - ref_output).abs() + max_diff = diff.max().item() + mean_diff = diff.mean().item() + error_msg = f"Validation failed - max_diff: {max_diff:.6f}, mean_diff: {mean_diff:.6f}, buffer_size: {metadata.buffer_size}" + self.error_log.append(error_msg) + print(f"[ATTN VALIDATION ERROR] {error_msg}") + raise + + +def install_validation_hook(): + """Monkey patch to replace Attention with validation version""" + import diffulex.attention.attn_impl + diffulex.attention.attn_impl.Attention = AttentionWithValidation + print("[VALIDATION] Attention class replaced with validation wrapper") + diff --git a/test/python/engine/test_diffulex_bench.py b/test/python/engine/test_diffulex_bench.py deleted file mode 100644 index d85f5af..0000000 --- a/test/python/engine/test_diffulex_bench.py +++ /dev/null @@ -1,240 +0,0 @@ -import json -import os -import shutil -from datetime import datetime -from pathlib import Path - -import pytest -import torch -import yaml - -from diffulex import SamplingParams - -# Run each test in forked process to avoid torch.distributed double-init / leak between tests. -# Skip in CI by default (no GPU/checkpoints); run locally with: pytest --forked -pytestmark = [ - pytest.mark.forked, - pytest.mark.diffulex_dry_run, - pytest.mark.skipif( - os.environ.get("CI") == "true", - reason="Skip diffulex dry-run in CI (GPU + checkpoints required)", - ), -] - - -_CONFIG_PATH = Path(__file__).resolve().parent / "config" / "test_diffulex_dry_run.yaml" -OUTPUT_BASE = Path(__file__).resolve().parent / "output" / "test_diffulex_dry_run" -MAX_RUNS_RETAINED = 10 - -with open(_CONFIG_PATH) as f: - CONFIG = yaml.safe_load(f) - -CKPT = Path(CONFIG["ckpt"]) -GSM8K_NUM_SAMPLES = CONFIG["gsm8k_num_samples"] -CHECKPOINT_RELS = {k: tuple(v) for k, v in CONFIG["checkpoint_rels"].items()} -STRATEGY_CONFIG = {k: tuple(v) for k, v in CONFIG["strategy_config"].items()} -DECODING_THRESHOLDS = CONFIG["decoding_thresholds"] -ENGINE_KWARGS = CONFIG["engine"] -SAMPLING_PARAMS = CONFIG["sampling"] -FEW_SHOT_BASE = CONFIG["few_shot_base"] -FEW_SHOT_INSTRUCT = CONFIG["few_shot_instruct"] - - -def _ensure_output_run_dir() -> Path: - """Create timestamped run dir, prune old runs, return path.""" - OUTPUT_BASE.mkdir(parents=True, exist_ok=True) - run_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - run_dir = OUTPUT_BASE / run_name - run_dir.mkdir(parents=True, exist_ok=True) - dirs = sorted(OUTPUT_BASE.iterdir(), key=lambda p: p.stat().st_mtime, reverse=True) - for d in dirs[MAX_RUNS_RETAINED:]: - if d.is_dir(): - shutil.rmtree(d) - return run_dir - - -def _save_outputs(output_path: Path, test_name: str, strategy: str, outputs: list) -> None: - payload = {"test": test_name, "strategy": strategy, "outputs": outputs} - output_path = Path(output_path) - output_path.parent.mkdir(parents=True, exist_ok=True) - with open(output_path, "w", encoding="utf-8") as f: - json.dump(payload, f, ensure_ascii=False, indent=2) - - -def _ckpt(rel: str | None) -> str | None: - """Resolve relative path under CKPT.""" - return str(CKPT / rel) if rel else None - - -def _get_model_config(name: str): - """(model_path, lora_path, model_name, decoding_strategy, buffer_size, few_shot, few_shot_type).""" - m_rel, l_rel = CHECKPOINT_RELS[name] - cfg = STRATEGY_CONFIG[name] - dec, buf = cfg[0], cfg[1] - few_shot_type = cfg[2] if len(cfg) > 2 else "instruct" - few_shot = FEW_SHOT_BASE if few_shot_type == "base" else FEW_SHOT_INSTRUCT - return _ckpt(m_rel), _ckpt(l_rel), name, dec, buf, few_shot, few_shot_type - - -def build_prompts(questions, prefix="", few_shot=None, few_shot_type="instruct"): - """Build prompts from questions, prefix, few_shot text, and format (base|instruct).""" - few_shot = few_shot or FEW_SHOT_INSTRUCT - if few_shot_type == "base": - suffix = lambda q: f"\n\nQ: {q}\nA: " - else: - suffix = lambda q: f"<|im_start|>user\nQuestion: {q}\nAnswer:<|im_end|>\n<|im_start|>assistant\n" - return [prefix + few_shot + suffix(q) for q in questions] - - -def _run_diffulex_test( - model, - model_name, - decoding_strategy, - prompts, - sampling_params, - use_lora=False, - lora_path=None, - buffer_size=4, - save_output_path: Path | None = None, - test_name: str = "", - **kwargs, -): - """Shared runner for Diffulex dry-run tests.""" - from diffulex import Diffulex - - common_kwargs = dict(ENGINE_KWARGS, buffer_size=buffer_size, decoding_thresholds=DECODING_THRESHOLDS) - common_kwargs.update(kwargs) - - llm_kwargs = dict( - model_name=model_name, - decoding_strategy=decoding_strategy, - **common_kwargs, - ) - if use_lora and lora_path: - llm_kwargs["use_lora"] = True - llm_kwargs["lora_path"] = lora_path - - llm = Diffulex(model, **llm_kwargs) - outputs = llm.generate(prompts, sampling_params) - assert len(outputs) == len(prompts) - if save_output_path is not None: - _save_outputs(save_output_path, test_name, model_name, outputs) - return outputs - - -def get_gsm8k_prompts(strategy_name: str): - """Build GSM8K prompts for the given strategy (uses its few_shot_type and tokenizer).""" - datasets = pytest.importorskip("datasets") - transformers = pytest.importorskip("transformers") - dataset = datasets.load_dataset("gsm8k", "main")["test"]["question"][:GSM8K_NUM_SAMPLES] - model_path = _ckpt(CHECKPOINT_RELS[strategy_name][0]) - tokenizer = transformers.AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) - prefix = tokenizer.bos_token or "" - _, _, _, _, _, few_shot, few_shot_type = _get_model_config(strategy_name) - return build_prompts(dataset, prefix, few_shot, few_shot_type) - - -@pytest.fixture -def sampling_params(): - return SamplingParams(**SAMPLING_PARAMS) - - -@pytest.fixture(scope="session") -def dry_run_output_dir(): - return _ensure_output_run_dir() - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") -def test_d2f_llada(sampling_params, dry_run_output_dir, request): - name = "llada" - model, lora, _, dec, buf, _, _ = _get_model_config(name) - prompts = get_gsm8k_prompts(name) - _run_diffulex_test( - model, - model_name=name, - decoding_strategy=dec, - prompts=prompts, - sampling_params=sampling_params, - use_lora=True, - lora_path=lora, - buffer_size=buf, - save_output_path=dry_run_output_dir / f"{request.node.name}.json", - test_name=request.node.name, - ) - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") -def test_d2f_dream(sampling_params, dry_run_output_dir, request): - name = "dream" - model, lora, _, dec, buf, _, _ = _get_model_config(name) - prompts = get_gsm8k_prompts(name) - _run_diffulex_test( - model, - model_name=name, - decoding_strategy=dec, - prompts=prompts, - sampling_params=sampling_params, - use_lora=True, - lora_path=lora, - buffer_size=buf, - save_output_path=dry_run_output_dir / f"{request.node.name}.json", - test_name=request.node.name, - ) - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") -def test_sdar(sampling_params, dry_run_output_dir, request): - name = "sdar" - model, lora, _, dec, buf, _, _ = _get_model_config(name) - prompts = get_gsm8k_prompts(name) - _run_diffulex_test( - model, - model_name=name, - decoding_strategy=dec, - prompts=prompts, - sampling_params=sampling_params, - buffer_size=buf, - save_output_path=dry_run_output_dir / f"{request.node.name}.json", - test_name=request.node.name, - ) - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") -def test_fastdllmv2(sampling_params, dry_run_output_dir, request): - name = "fast_dllm_v2" - model, lora, _, dec, buf, _, _ = _get_model_config(name) - prompts = get_gsm8k_prompts(name) - _run_diffulex_test( - model, - model_name=name, - decoding_strategy=dec, - prompts=prompts, - sampling_params=sampling_params, - buffer_size=buf, - save_output_path=dry_run_output_dir / f"{request.node.name}.json", - test_name=request.node.name, - ) - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") -@pytest.mark.parametrize("name", list(CHECKPOINT_RELS)) -def test_diffulex_strategies_parametrized(name, sampling_params, dry_run_output_dir, request): - """Parametrized coverage over all strategies.""" - model, lora, model_name, dec, buf, _, _ = _get_model_config(name) - prompts = get_gsm8k_prompts(name) - _run_diffulex_test( - model, - model_name=model_name, - decoding_strategy=dec, - prompts=prompts, - sampling_params=sampling_params, - use_lora=(lora is not None), - lora_path=lora, - buffer_size=buf, - save_output_path=dry_run_output_dir / f"{request.node.name}.json", - test_name=request.node.name, - ) - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) diff --git a/test/python/engine/test_monkey_patch.py b/test/python/engine/test_monkey_patch.py new file mode 100644 index 0000000..a3f2b7f --- /dev/null +++ b/test/python/engine/test_monkey_patch.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python3 +"""Quick test to verify monkey patch works""" + +from test.python.engine.dummy_attn_with_validation import install_validation_hook + +print("Before patch:") +from diffulex.attention.attn_impl import Attention +print(f" Attention class: {Attention}") + +install_validation_hook() + +print("\nAfter patch:") +# Force reimport +import sys +if 'diffulex.attention.attn_impl' in sys.modules: + del sys.modules['diffulex.attention.attn_impl'] +from diffulex.attention.attn_impl import Attention +print(f" Attention class: {Attention}") +print(f" Has validation: {hasattr(Attention, 'validation_enabled')}") diff --git a/test/python/engine/test_validation_bench.py b/test/python/engine/test_validation_bench.py new file mode 100644 index 0000000..df71068 --- /dev/null +++ b/test/python/engine/test_validation_bench.py @@ -0,0 +1,77 @@ +""" +Validation test runner for diffulex_bench with accuracy checking. +Runs benchmark in subprocess with monkey-patched Attention for validation. +""" + +import json +import os +import subprocess +import sys +from datetime import datetime +from pathlib import Path + +import pytest +import torch + +pytestmark = [ + pytest.mark.forked, + pytest.mark.skipif( + os.environ.get("CI") == "true", + reason="Skip validation bench in CI (GPU required)", + ), +] + +OUTPUT_BASE = Path(__file__).resolve().parent / "output" / "validation_bench" + +CONFIGS = [ + "validation_bench_sdar_bufsz1.yml", + "validation_bench_sdar_bufsz2.yml", + "validation_bench_sdar_bufsz4.yml", +] + + +def _ensure_output_dir() -> Path: + OUTPUT_BASE.mkdir(parents=True, exist_ok=True) + run_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + run_dir = OUTPUT_BASE / run_name + run_dir.mkdir(parents=True, exist_ok=True) + return run_dir + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +@pytest.mark.parametrize("config_name", CONFIGS) +def test_validation_bench_gsm8k(config_name): + """Run GSM8K benchmark with validation enabled.""" + output_dir = _ensure_output_dir() + + wrapper_script = """ +import sys +sys.path.insert(0, 'test/python/engine') +from dummy_attn_with_validation import install_validation_hook +install_validation_hook() + +from diffulex_bench.main import main +main() +""" + + cmd = [ + sys.executable, "-c", wrapper_script, + "--config", f"test/python/engine/config/{config_name}", + "--output-dir", str(output_dir / config_name.replace('.yml', '')), + "--save-results", + ] + + result = subprocess.run(cmd, capture_output=True, text=True, cwd=Path(__file__).resolve().parent.parent.parent.parent) + + print(f"\n{'='*80}") + print(f"Config: {config_name}") + print(f"{'='*80}") + print(f"STDOUT:\n{result.stdout}") + if result.stderr: + print(f"STDERR:\n{result.stderr}") + + assert result.returncode == 0, f"Benchmark failed with code {result.returncode}" + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) From dc6634e5460296a6e48a0e294bcc14897225e2ac Mon Sep 17 00:00:00 2001 From: drewjin Date: Tue, 17 Mar 2026 05:36:00 +0000 Subject: [PATCH 2/5] fix: fix infinite iteration error caused by missing eos token --- .../mixin/multi_block/engine/model_runner.py | 18 +++++++++--------- diffulex/mixin/multi_block/engine/request.py | 3 ++- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/diffulex/mixin/multi_block/engine/model_runner.py b/diffulex/mixin/multi_block/engine/model_runner.py index 509db8a..629e6e7 100644 --- a/diffulex/mixin/multi_block/engine/model_runner.py +++ b/diffulex/mixin/multi_block/engine/model_runner.py @@ -189,18 +189,18 @@ def run_model_multi_block( if key != "outputs": value.zero_() - num_seqs = len(attn_metadata.context_lens) + num_reqs = attn_metadata.num_reqs graph_vars["input_ids"][:num_tokens] = input_ids graph_vars["positions"][:num_tokens] = positions graph_vars["slot_mapping"][:num_tokens] = attn_metadata.slot_mapping - graph_vars["context_lens"][:num_seqs] = attn_metadata.context_lens - graph_vars["cu_seqlens_q"][: num_seqs + 1] = attn_metadata.cu_seqlens_q - graph_vars["cu_seqlens_k"][: num_seqs + 1] = attn_metadata.cu_seqlens_k - graph_vars["valid_slices"][:num_seqs] = attn_metadata.valid_slices - graph_vars["status_table"][:num_seqs] = attn_metadata.status_table - graph_vars["prefix_lens"][:num_seqs] = attn_metadata.prefix_lens - graph_vars["padded_prefix_lens"][:num_seqs] = attn_metadata.padded_prefix_lens - graph_vars["page_tables"][:num_seqs, : attn_metadata.page_tables.size(1)] = attn_metadata.page_tables + graph_vars["context_lens"][:num_reqs] = attn_metadata.context_lens + graph_vars["cu_seqlens_q"][: num_reqs + 1] = attn_metadata.cu_seqlens_q + graph_vars["cu_seqlens_k"][: num_reqs + 1] = attn_metadata.cu_seqlens_k + graph_vars["valid_slices"][:num_reqs] = attn_metadata.valid_slices + graph_vars["status_table"][:num_reqs] = attn_metadata.status_table + graph_vars["prefix_lens"][:num_reqs] = attn_metadata.prefix_lens + graph_vars["padded_prefix_lens"][:num_reqs] = attn_metadata.padded_prefix_lens + graph_vars["page_tables"][:num_reqs, : attn_metadata.page_tables.size(1)] = attn_metadata.page_tables # Update attn_metadata to use graph_vars tensors attn_metadata.slot_mapping = graph_vars["slot_mapping"] diff --git a/diffulex/mixin/multi_block/engine/request.py b/diffulex/mixin/multi_block/engine/request.py index 67f4bb8..b089bb3 100644 --- a/diffulex/mixin/multi_block/engine/request.py +++ b/diffulex/mixin/multi_block/engine/request.py @@ -90,7 +90,7 @@ def eos_token_generated(self) -> bool: last_in_cache_block = self.dllm_block_buffer.first_running_block.prev_block return eos_detect_fn(seq) or ( last_in_cache_block.is_last_in_context and eos_detect_fn(last_in_cache_block.token_ids) - ) + ) or eos_detect_fn(self.token_ids) @property def num_prefix_blocks(self) -> int: @@ -335,6 +335,7 @@ def deactivate(self): def step(self): self.lazy_activate() + # Condition to activate the next block, when buffer contains active blocks activate_cond = self.dllm_block_buffer.should_add_block and not self.dllm_block_buffer.is_overflow From d775c0269501bf9a6a0a11854901417357661fc8 Mon Sep 17 00:00:00 2001 From: drewjin Date: Tue, 17 Mar 2026 05:45:40 +0000 Subject: [PATCH 3/5] refactor: Simplify validation logic and enhance attention mask handling in AttentionWithValidation --- .../engine/dummy_attn_with_validation.py | 26 ++++++++++++++----- 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/test/python/engine/dummy_attn_with_validation.py b/test/python/engine/dummy_attn_with_validation.py index 244c3a1..4e91241 100644 --- a/test/python/engine/dummy_attn_with_validation.py +++ b/test/python/engine/dummy_attn_with_validation.py @@ -24,11 +24,8 @@ def forward(self, q, k, v, mask=None): output = super().forward(q, k, v, mask) # Run reference implementation - try: - ref_output = self._compute_reference(q, k, v, attn_metadata) - self._validate_output(output, ref_output, attn_metadata) - except Exception as e: - self.error_log.append(f"Validation failed: {e}") + ref_output = self._compute_reference(q, k, v, attn_metadata) + self._validate_output(output, ref_output, attn_metadata) return output @@ -44,6 +41,7 @@ def _compute_reference(self, q, k, v, metadata): cu_seqlens_q = metadata.cu_seqlens_q valid_slices = getattr(metadata, 'valid_slices', None) page_size = metadata.page_size + block_size = metadata.block_size num_seqs = len(cu_seqlens_q) - 1 output = torch.zeros_like(q_reshaped) @@ -88,12 +86,28 @@ def _compute_reference(self, q, k, v, metadata): k_full = k_new v_full = v_new + # Build block-causal mask (aligned with kernel line 179-181) + mask = None + if block_size > 0: + qi = torch.arange(valid_q_len, device=q.device) + kj = torch.arange(valid_q_len, device=q.device) + # Kernel: ((offs_q_block // DLLM_BLOCK_SIZE + 1) * DLLM_BLOCK_SIZE)[:, None] > offs_kv_block[None, :] + block_ends = ((qi // block_size) + 1) * block_size + new_kv_mask = block_ends[:, None] > kj[None, :] + + if ctx_len > 0: + cache_mask = torch.ones(valid_q_len, ctx_len, dtype=torch.bool, device=q.device) + mask = torch.cat([cache_mask, new_kv_mask], dim=1) + else: + mask = new_kv_mask + mask = mask.unsqueeze(0).unsqueeze(0) + q_sdpa = rearrange(q_seq, "s h d -> 1 h s d") k_sdpa = rearrange(k_full, "s h d -> 1 h s d") v_sdpa = rearrange(v_full, "s h d -> 1 h s d") attn_out = F.scaled_dot_product_attention( - q_sdpa, k_sdpa, v_sdpa, dropout_p=0.0, is_causal=False, scale=scale, enable_gqa=True + q_sdpa, k_sdpa, v_sdpa, attn_mask=mask, dropout_p=0.0, is_causal=False, scale=scale, enable_gqa=True ) output[q_start:q_start + valid_q_len] = rearrange(attn_out, "1 h s d -> s h d") From bfc8223fcc6ca28b935673cdae99f3749f701c75 Mon Sep 17 00:00:00 2001 From: drewjin Date: Tue, 17 Mar 2026 07:02:33 +0000 Subject: [PATCH 4/5] feat: Add mask visualization helper and integrate it into chunked prefill reference for debugging --- .../python/chunked_prefill_triton.py | 1 + ...ash_attn_chunked_prefill_unified_kernel.py | 107 ++++++++++++++++-- 2 files changed, 99 insertions(+), 9 deletions(-) diff --git a/diffulex_kernel/python/chunked_prefill_triton.py b/diffulex_kernel/python/chunked_prefill_triton.py index 7fa4596..3f1c60e 100644 --- a/diffulex_kernel/python/chunked_prefill_triton.py +++ b/diffulex_kernel/python/chunked_prefill_triton.py @@ -6,6 +6,7 @@ from diffulex_kernel.python.auto_tuner import build_chunked_prefill_configs +# NOTE: While doing test, comment auto-tuner to avoid slowing down the test. @triton.autotune( configs=[ triton.Config(c, num_warps=c.pop("num_warps"), num_stages=c.pop("num_stages")) diff --git a/test/python/kernel/test_dllm_flash_attn_chunked_prefill_unified_kernel.py b/test/python/kernel/test_dllm_flash_attn_chunked_prefill_unified_kernel.py index 56fb7c5..2c350e8 100644 --- a/test/python/kernel/test_dllm_flash_attn_chunked_prefill_unified_kernel.py +++ b/test/python/kernel/test_dllm_flash_attn_chunked_prefill_unified_kernel.py @@ -9,6 +9,68 @@ ) +# --------------------------------------------------------------------------- +# Mask visualization helper +# --------------------------------------------------------------------------- + + +def _visualize_mask(mask, seq_id, ctx_len, valid_q_len, block_size, label): + """Visualize attention mask with clear structure.""" + print(f"\n{'='*80}") + print(f"Seq {seq_id} | {label} | ctx_len={ctx_len}, valid_q_len={valid_q_len}, block_size={block_size}") + print(f"{'='*80}") + + mask_np = mask.cpu().numpy() + total_kv = mask_np.shape[1] + + # Print header + print(f"Mask shape: Q={mask_np.shape[0]} x KV={total_kv} (cache={ctx_len}, new={valid_q_len})") + + # Compact visualization for large masks + if valid_q_len > 64 or total_kv > 64: + print("\n[Compact view - showing block boundaries]") + step = max(1, block_size // 4) + sample_q = list(range(0, valid_q_len, step)) + sample_kv = list(range(0, total_kv, step)) + + print(f"\n KV→", end="") + for kv_idx in sample_kv: + if kv_idx < ctx_len: + print(f" C{kv_idx:3d}", end="") + else: + print(f" N{kv_idx-ctx_len:3d}", end="") + print() + + for q_idx in sample_q: + block_id = q_idx // block_size + print(f"Q{q_idx:3d}(B{block_id})", end="") + for kv_idx in sample_kv: + print(f" {'█' if mask_np[q_idx, kv_idx] else '·'} ", end="") + print() + else: + # Full visualization for small masks + print(f"\n KV→", end="") + for kv_idx in range(total_kv): + if kv_idx < ctx_len: + print(f"C{kv_idx%10}", end="") + else: + print(f"N{(kv_idx-ctx_len)%10}", end="") + print() + + for q_idx in range(valid_q_len): + block_id = q_idx // block_size + print(f"Q{q_idx:2d}(B{block_id})", end="") + for kv_idx in range(total_kv): + print("█" if mask_np[q_idx, kv_idx] else "·", end="") + print() + + # Statistics + visible_per_q = mask_np.sum(axis=1) + print(f"\nStats: min_visible={visible_per_q.min():.0f}, max_visible={visible_per_q.max():.0f}, " + f"avg_visible={visible_per_q.mean():.1f}") + print(f"{'='*80}\n") + + # --------------------------------------------------------------------------- # Kernel invocation (direct call, bypasses the incomplete Python wrapper) # --------------------------------------------------------------------------- @@ -31,7 +93,7 @@ def call_chunked_prefill_kernel( dllm_block_size, is_block_causal, is_prefix_full=False, - BLOCK_M=128, + BLOCK_M=64, BLOCK_N=64, ): o = torch.zeros_like(q) @@ -118,6 +180,7 @@ def naive_chunked_prefill_ref( dllm_block_size, is_block_causal, is_prefix_full=False, + visualize_mask=False, ): """ Per-request reference: @@ -187,6 +250,10 @@ def naive_chunked_prefill_ref( ) mask = torch.cat([cache_mask, new_kv_mask], dim=1) mask = mask.unsqueeze(0).unsqueeze(0) + + if visualize_mask: + _visualize_mask(mask[0, 0], seq_id, ctx_len, valid_q_len, dllm_block_size, + f"prefix_full_status{status}_P{P}_Pp{P_prime}") elif is_block_causal: qi = torch.arange(valid_q_len, device=q.device) kj = torch.arange(valid_q_len, device=q.device) @@ -201,6 +268,9 @@ def naive_chunked_prefill_ref( mask = torch.cat([cache_mask, new_kv_mask], dim=1) mask = mask.unsqueeze(0).unsqueeze(0) + if visualize_mask: + _visualize_mask(mask[0, 0], seq_id, ctx_len, valid_q_len, dllm_block_size, "block_causal") + q_sdpa = rearrange(q_seq, "s h d -> 1 h s d") k_sdpa = rearrange(k_full, "s h d -> 1 h s d") v_sdpa = rearrange(v_full, "s h d -> 1 h s d") @@ -330,6 +400,7 @@ def _run_test( seed=42, atol=1e-2, rtol=1e-2, + visualize_mask=False, ): torch.manual_seed(seed) num_seqs = len(q_lens) @@ -390,17 +461,18 @@ def _run_test( dllm_block_size, is_block_causal=is_block_causal, is_prefix_full=is_prefix_full, + visualize_mask=visualize_mask, ) - for i, vql in enumerate(valid_q_lens): + for i in range(num_seqs): q_start = int(cu[i].item()) - if vql > 0: - torch.testing.assert_close( - out[q_start : q_start + vql], - ref[q_start : q_start + vql], - atol=atol, - rtol=rtol, - ) + q_end = int(cu[i + 1].item()) + torch.testing.assert_close( + out[q_start:q_end], + ref[q_start:q_end], + atol=atol, + rtol=rtol, + ) # ========================= Case 1: Pure Prefill ========================== @@ -856,5 +928,22 @@ def test_prefix_full_extended_strategies( ) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +def test_visualize_mask_example(): + """Example: visualize mask for debugging.""" + _run_test( + q_lens=[128, 96], + valid_q_lens=[64, 32], + ctx_lens=[64, 128], + num_heads=32, + num_kv_heads=8, + head_dim=128, + page_size=32, + dllm_block_size=32, + is_block_causal=True, + visualize_mask=True, + ) + + if __name__ == "__main__": pytest.main([__file__, "-v"]) From a2d93b005b1304f72eaafce1bf6b095111b5d64b Mon Sep 17 00:00:00 2001 From: drewjin Date: Tue, 17 Mar 2026 12:12:16 +0000 Subject: [PATCH 5/5] feat: Add standalone debug script for chunked prefill kernel with comparison output --- test/python/kernel/debug_chunked_prefill.py | 66 +++++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 test/python/kernel/debug_chunked_prefill.py diff --git a/test/python/kernel/debug_chunked_prefill.py b/test/python/kernel/debug_chunked_prefill.py new file mode 100644 index 0000000..1f17230 --- /dev/null +++ b/test/python/kernel/debug_chunked_prefill.py @@ -0,0 +1,66 @@ +"""Standalone debug script for chunked prefill kernel - no pytest dependency.""" + +import torch +from test.python.kernel.test_dllm_flash_attn_chunked_prefill_unified_kernel import ( + build_test_data, + call_chunked_prefill_kernel, + naive_chunked_prefill_ref, +) + + +def debug_single_case(): + """Mixed batch: prefill + decode with various contexts.""" + torch.manual_seed(42) + + # Config: mixed batch (prefill + decode) + q_lens = [192, 128, 128, 128] # seq0: prefill, seq1: prefill, seq2-3: decode + valid_q_lens = [192, 128, 96, 64] # seq0-1: full, seq2-3: partial + ctx_lens = [0, 64, 256, 128] # seq0: no cache, seq1-3: with cache + num_heads = 32 + num_kv_heads = 8 + head_dim = 128 + page_size = 32 + dllm_block_size = 32 + + # Build data + data = build_test_data( + q_lens, + valid_q_lens, + ctx_lens, + num_heads, + num_kv_heads, + head_dim, + page_size, + dllm_block_size=dllm_block_size, + ) + q, k, v, k_cache, v_cache, pt, st, cl, cu, vs, pl, ppl = data + scale = 1.0 / head_dim**0.5 + + # Kernel output + out = call_chunked_prefill_kernel( + q, k, v, k_cache, v_cache, pt, st, cl, cu, vs, pl, ppl, + scale, dllm_block_size, is_block_causal=True + ) + + # Reference output (with mask visualization) + ref = naive_chunked_prefill_ref( + q, k, v, k_cache, v_cache, pt, st, cl, cu, vs, + [0], [0], scale, page_size, dllm_block_size, + is_block_causal=True, visualize_mask=True + ) + + # Compare + diff = (out - ref).abs() + print(f"\n{'='*80}") + print("COMPARISON RESULTS") + print(f"{'='*80}") + print(f"Output shape: {out.shape}") + print(f"Max diff: {diff.max().item():.6f}") + print(f"Mean diff: {diff.mean().item():.6f}") + print(f"{'='*80}\n") + + return out, ref + + +if __name__ == "__main__": + debug_single_case()