feat(world_model): Add T5, CLIP, Vae, DiT pypto3.0 kernels#520
feat(world_model): Add T5, CLIP, Vae, DiT pypto3.0 kernels#520xzhxzhxzh123 wants to merge 6 commits into
Conversation
…ontrol 1.3B - t5_encoder.py: RMSNorm → Self-Attention (T5 relative pos bias) → Residual → RMSNorm → Gated GELU-tanh FFN → Residual × L → final RMSNorm - clip_encoder.py: PatchConv im2col → CLS+Pos → PreLN → [LN → MHA → Proj → Res → LN → FFN(GELU) → Res] × L - config.py: shared golden/real case constants for all world model sub-networks - Both use run_jit + TensorSpec + argparse pattern for uniform test harness
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Organization UI Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughThis PR establishes a video generation pipeline framework by introducing centralized model hyperparameters and two PyPTO JIT-compiled neural network encoders: a CLIP image encoder for visual input and a T5 text encoder for prompt conditioning. Each encoder includes a hardware-optimized kernel, golden PyTorch reference, and automated test infrastructure. ChangesVideo Pipeline Encoders and Configuration
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces the PyPTO-based implementations of the CLIP image encoder and T5 text encoder, along with their model configurations. The review feedback highlights several critical improvements: replacing a numerically unsafe multiplication by zero with pl.full in the CLIP encoder's LayerNorm to prevent NaN propagation; replacing a hardcoded absolute path in the T5 encoder with a relative path to ensure portability; unrolling the T5 encoder's layer loop to avoid memory overhead from dynamic tensor allocation; and lazily importing torch in both encoder files to adhere to project conventions and reduce import overhead.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| for r in pl.range(CLIP_TOKENS, T): | ||
| zero_row = pl.mul(pl.slice(biased, [1, CLIP_DIM], [r, 0]), 0.0) | ||
| biased = pl.assemble(biased, zero_row, [r, 0]) |
There was a problem hiding this comment.
In _layernorm, zeroing out the padding rows is done using pl.mul(pl.slice(biased, [1, CLIP_DIM], [r, 0]), 0.0). This is numerically unsafe. If the uninitialized memory in the padding rows of biased contains NaN or Inf values, multiplying them by 0.0 will result in NaN (per IEEE 754 standard), failing to zero them out and potentially corrupting downstream attention computation. Instead, use pl.full to generate a clean zero row and assemble it, similar to the pattern used in the rest of the kernel.
| for r in pl.range(CLIP_TOKENS, T): | |
| zero_row = pl.mul(pl.slice(biased, [1, CLIP_DIM], [r, 0]), 0.0) | |
| biased = pl.assemble(biased, zero_row, [r, 0]) | |
| for r in pl.range(CLIP_TOKENS, T): | |
| zero_row = pl.full([1, CLIP_DIM], dtype=pl.FP32, value=0.0) | |
| biased = pl.assemble(biased, zero_row, [r, 0]) |
| ) | ||
|
|
||
| # ── Golden reference (shared classes with infer_fun_control_1_3b_text.py) ── | ||
| sys.path.insert(0, '/data/x00952168/pypto3.0/cann-recipes-embodied-ai/world_model/agibot-arm-world-model/infer_with_torch') |
There was a problem hiding this comment.
The absolute path /data/x00952168/pypto3.0/... is hardcoded into sys.path. This makes the code non-portable and will cause import failures for other users or in CI/CD environments where this specific directory structure does not exist. Please use relative paths to locate the required module.
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "world_model" / "agibot-arm-world-model" / "infer_with_torch"))| for layer_idx in pl.range(T5_LAYERS): | ||
| next_hidden = pl.create_tensor([T5_SEQ, T5_DIM], dtype=pl.FP32) | ||
| cur = t5_encoder_layer( | ||
| cur, norm1_w, q_w, k_w, v_w, o_w, pos_bias, norm2_w, | ||
| wi_0, wi_1, wo, next_hidden, layer_idx | ||
| ) |
There was a problem hiding this comment.
The layer loop in t5_encoder uses pl.range(T5_LAYERS) which is a runtime loop. Inside this loop, next_hidden is dynamically created via pl.create_tensor, and the inlined t5_encoder_layer also creates multiple intermediate tensors. In PyPTO, allocating tensors inside a runtime loop can cause significant memory overhead, GM buffer dependency issues, or compilation failures. To align with the optimized design pattern used in clip_encoder.py (which explicitly unrolls the loop to avoid GM buffer dependency issues with pl.range), please use pl.unroll(T5_LAYERS) instead of pl.range(T5_LAYERS).
| for layer_idx in pl.range(T5_LAYERS): | |
| next_hidden = pl.create_tensor([T5_SEQ, T5_DIM], dtype=pl.FP32) | |
| cur = t5_encoder_layer( | |
| cur, norm1_w, q_w, k_w, v_w, o_w, pos_bias, norm2_w, | |
| wi_0, wi_1, wo, next_hidden, layer_idx | |
| ) | |
| for layer_idx in pl.unroll(T5_LAYERS): | |
| next_hidden = pl.create_tensor([T5_SEQ, T5_DIM], dtype=pl.FP32) | |
| cur = t5_encoder_layer( | |
| cur, norm1_w, q_w, k_w, v_w, o_w, pos_bias, norm2_w, | |
| wi_0, wi_1, wo, next_hidden, layer_idx | |
| ) |
| import torch | ||
| import torch.nn.functional as F |
There was a problem hiding this comment.
Importing torch and torch.nn.functional at the top level of the kernel file violates the project convention of lazy importing torch to avoid import overhead when only builder functions or JIT kernels are imported. Since torch is only used in the golden reference and test harness functions, these imports should be moved inside those functions or under the if __name__ == "__main__": block.
| import torch | |
| import torch.nn.functional as F | |
| # torch and F are imported lazily inside golden/test functions to avoid import overhead |
References
- Lazy imports for 'torch' and 'pypto.runtime' are a project convention to avoid import overhead when only builder functions are used.
| import sys | ||
|
|
||
| import pypto.language as pl | ||
| import torch |
There was a problem hiding this comment.
Similar to clip_encoder.py, torch is imported at the top level of t5_encoder.py but is only used in the test harness and golden reference. To adhere to the project convention of lazy importing torch to avoid import overhead, please remove the top-level import and import torch locally where needed.
| import torch | |
| # torch is imported lazily inside golden/test functions to avoid import overhead |
References
- Lazy imports for 'torch' and 'pypto.runtime' are a project convention to avoid import overhead when only builder functions are used.
There was a problem hiding this comment.
Actionable comments posted: 4
🧹 Nitpick comments (1)
models/world_model/config.py (1)
172-174: ⚡ Quick winDerive the latent dimensions instead of hardcoding them.
These are documented as derived values, but they will silently drift if
H,W,CHUNK_SIZE, or the VAE factors change.♻️ Suggested change
-LAT_F = 2 # (CHUNK_SIZE - 1) // VAE_TEMPORAL_FACTOR + 1 -LAT_H = 8 # H // VAE_SPATIAL_FACTOR -LAT_W = 8 # W // VAE_SPATIAL_FACTOR +LAT_F = (CHUNK_SIZE - 1) // VAE_TEMPORAL_FACTOR + 1 +LAT_H = H // VAE_SPATIAL_FACTOR +LAT_W = W // VAE_SPATIAL_FACTOR🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@models/world_model/config.py` around lines 172 - 174, LAT_F, LAT_H and LAT_W are hardcoded but documented as derived; change their assignments to compute from CHUNK_SIZE, H, W, VAE_TEMPORAL_FACTOR and VAE_SPATIAL_FACTOR instead of constants so they stay correct when dimensions or factors change. Locate LAT_F, LAT_H, LAT_W in models.world_model.config and replace the literals with expressions: LAT_F = (CHUNK_SIZE - 1) // VAE_TEMPORAL_FACTOR + 1, LAT_H = H // VAE_SPATIAL_FACTOR, LAT_W = W // VAE_SPATIAL_FACTOR (use the existing variable names exactly) and ensure any dependent code still imports these computed values.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@models/world_model/clip_encoder.py`:
- Around line 229-247: The softmax currently includes padded key columns so
padded positions still contribute to the denominator; before computing exp() you
must mask out columns r >= CLIP_TOKENS by forcing their scores to a large
negative value (e.g. -1e9) so they get zero weight. Concretely, in the MHA block
around variables scores, shifted, exp_s, sm (operating on qkv_gm, T,
CLIP_TOKENS), build a boolean column mask for valid keys (length T with True for
indices < CLIP_TOKENS), expand it to match scores' shape and apply: scores =
pl.where(mask, scores, scores - LARGE_NEG) or scores = pl.select(mask, scores,
LARGE_NEG) (using PL ops) before computing shifted/exp_s/sm; apply the same
masking in the corresponding later block (lines ~251-253) as well.
- Around line 44-47: The imports in models/world_model/clip_encoder.py and
models/world_model/t5_encoder.py currently use a top-level from config import
... which can resolve the wrong config module; fix by making models/ and
models/world_model/ proper packages (add __init__.py) and change the imports in
clip_encoder.py (symbols: CLIP_DIM, CLIP_HEADS, CLIP_HEAD_DIM, CLIP_LAYERS,
CLIP_PATCH, CLIP_IMG, CLIP_TOKENS, CLIP_FFN, CLIP_NORM_EPS) and t5_encoder.py to
use package-aware imports (e.g., from .config import ... or from
models.world_model.config import ...) so the correct
models/world_model/config.py is always resolved, or alternatively implement an
explicit path-based resolver that ensures models/world_model/config.py is
imported before falling back.
In `@models/world_model/t5_encoder.py`:
- Around line 419-444: golden_t5_encoder() currently always uses layer 0 weights
and a single T5SelfAttention, so it only validates when T5_LAYERS == 1; update
it to loop over all encoder layers: for each layer index i, extract per-layer
tensors (e.g., tensors["q_w"][i], tensors["k_w"][i], tensors["v_w"][i],
tensors["o_w"][i], tensors["wi_0"][i], tensors["wi_1"][i], tensors["wo"][i],
tensors["norm1_w"][i], tensors["norm2_w"][i]) build a T5SelfAttention/T5 FFN
block with those layer_weights and apply it sequentially to x (passing pos_bias
appropriately), and only after the loop apply the final layer norm using
tensors["final_norm_w"]; finally write the resulting x to tensors["out"] as
before so the harness validates multi-layer encoder stacks.
- Around line 35-38: Remove the hardcoded sys.path insertion and the top-level
import of test_golden_fun_control_full; instead import or reference T5LayerNorm,
T5RelativeEmbedding, T5SelfAttention inside the golden_t5_encoder function (or
provide local equivalents) so the module no longer depends on an absolute
machine-local path at import time. In golden_t5_encoder(), stop only taking
tensors["..."][0,...] for layer weights; iterate over all layer indices up to
T5_LAYERS (use the same loop used by t5_encoder: for layer_idx in
pl.range(T5_LAYERS)) and slice the tensors per-layer (e.g.,
tensors["..."][layer_idx,...]) and run a T5SelfAttention call per layer (or
apply the same sequence of T5LayerNorm/T5RelativeEmbedding/T5SelfAttention for
each layer) so the golden comparator matches multi-layer behavior when T5_LAYERS
> 1.
---
Nitpick comments:
In `@models/world_model/config.py`:
- Around line 172-174: LAT_F, LAT_H and LAT_W are hardcoded but documented as
derived; change their assignments to compute from CHUNK_SIZE, H, W,
VAE_TEMPORAL_FACTOR and VAE_SPATIAL_FACTOR instead of constants so they stay
correct when dimensions or factors change. Locate LAT_F, LAT_H, LAT_W in
models.world_model.config and replace the literals with expressions: LAT_F =
(CHUNK_SIZE - 1) // VAE_TEMPORAL_FACTOR + 1, LAT_H = H // VAE_SPATIAL_FACTOR,
LAT_W = W // VAE_SPATIAL_FACTOR (use the existing variable names exactly) and
ensure any dependent code still imports these computed values.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: ee2a7728-7aa5-46a2-ac73-0aca0602f3a6
📒 Files selected for processing (3)
models/world_model/clip_encoder.pymodels/world_model/config.pymodels/world_model/t5_encoder.py
| from config import ( | ||
| CLIP_DIM, CLIP_HEADS, CLIP_HEAD_DIM, CLIP_LAYERS, CLIP_PATCH, CLIP_IMG, | ||
| CLIP_TOKENS, CLIP_FFN, CLIP_NORM_EPS, | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
rg -n 'from config import' models/world_model/clip_encoder.py models/world_model/t5_encoder.py
fd -a '^config\.py$'
fd -a '^__init__\.py$' modelsRepository: hw-native-sys/pypto-lib
Length of output: 325
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Confirm package layout
echo "== __init__.py files under models =="
fd -a '^__init__\.py$' models || true
# Confirm config.py locations
echo "== config.py files under models =="
fd -a '^config\.py$' models/world_model models || true
# Show the relevant import blocks
echo "== clip_encoder.py import context =="
nl -ba models/world_model/clip_encoder.py | sed -n '35,60p'
echo "== t5_encoder.py import context =="
nl -ba models/world_model/t5_encoder.py | sed -n '1,60p'Repository: hw-native-sys/pypto-lib
Length of output: 428
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "== __init__.py under models =="
fd -a '^__init__\.py$' models || true
echo "== __init__.py under models/world_model =="
fd -a '^__init__\.py$' models/world_model || true
echo "== config.py candidates (under models) =="
fd -a '^config\.py$' models || true
echo "== clip_encoder.py around import =="
cat -n models/world_model/clip_encoder.py | sed -n '35,70p'
echo "== t5_encoder.py around import =="
cat -n models/world_model/t5_encoder.py | sed -n '1,70p'Repository: hw-native-sys/pypto-lib
Length of output: 5406
Use deterministic, package-aware config imports for models/world_model/clip_encoder.py and models/world_model/t5_encoder.py.
Both files use from config import ... (clip_encoder.py:44-47, t5_encoder.py:30-32), which is fragile because the repo contains multiple config.py files (including models/deepseek/v4/config.py and models/qwen3/14b/config.py) and models/ / models/world_model/ currently have no __init__.py (so they aren’t a package today). This import can therefore resolve the wrong module depending on sys.path/execution context.
Replace with a deterministic import strategy:
- either make
models/andmodels/world_model/proper packages (add__init__.py) and usefrom .config import .../from models.world_model.config import ..., or - add an explicit, path-based fallback so
config.pyalways resolves tomodels/world_model/config.py.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@models/world_model/clip_encoder.py` around lines 44 - 47, The imports in
models/world_model/clip_encoder.py and models/world_model/t5_encoder.py
currently use a top-level from config import ... which can resolve the wrong
config module; fix by making models/ and models/world_model/ proper packages
(add __init__.py) and change the imports in clip_encoder.py (symbols: CLIP_DIM,
CLIP_HEADS, CLIP_HEAD_DIM, CLIP_LAYERS, CLIP_PATCH, CLIP_IMG, CLIP_TOKENS,
CLIP_FFN, CLIP_NORM_EPS) and t5_encoder.py to use package-aware imports (e.g.,
from .config import ... or from models.world_model.config import ...) so the
correct models/world_model/config.py is always resolved, or alternatively
implement an explicit path-based resolver that ensures
models/world_model/config.py is imported before falling back.
| for r in pl.range(CLIP_TOKENS, T): | ||
| zero_row = pl.full([1, 3 * CLIP_DIM], dtype=pl.BF16, value=0.0) | ||
| qkv_gm = pl.assemble(qkv_gm, zero_row, [r, 0]) | ||
|
|
||
| # ── Multi-head self-attention (per-head, with scale). ── | ||
| with pl.at(level=pl.Level.CORE_GROUP, name_hint="mha"): | ||
| for h in pl.range(CLIP_HEADS): | ||
| h_col = h * CLIP_HEAD_DIM | ||
| q_h = pl.slice(qkv_gm, [T, CLIP_HEAD_DIM], [0, h_col]) | ||
| k_h = pl.slice(qkv_gm, [T, CLIP_HEAD_DIM], [0, CLIP_DIM + h_col]) | ||
| v_h = pl.slice(qkv_gm, [T, CLIP_HEAD_DIM], [0, 2 * CLIP_DIM + h_col]) | ||
|
|
||
| scores = pl.mul( | ||
| pl.matmul(q_h, k_h, b_trans=True, out_dtype=pl.FP32), ATTN_SCALE, | ||
| ) | ||
| shifted = pl.row_expand_sub(scores, pl.row_max(scores)) | ||
| exp_s = pl.exp(shifted) | ||
| sm = pl.row_expand_div(exp_s, pl.row_sum(exp_s)) | ||
| ctx = pl.matmul(pl.cast(sm, target_type=pl.BF16), v_h, out_dtype=pl.FP32) |
There was a problem hiding this comment.
Mask padded attention columns before softmax.
Zeroing the padded Q/K/V rows is not enough here. The softmax still runs over all T keys, so every padded position contributes to the denominator for valid queries and dilutes the real-token attention weights. The golden path only attends over CLIP_TOKENS, so this kernel cannot be numerically equivalent until the padded score columns are excluded or forced to a large negative value before exp.
Also applies to: 251-253
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@models/world_model/clip_encoder.py` around lines 229 - 247, The softmax
currently includes padded key columns so padded positions still contribute to
the denominator; before computing exp() you must mask out columns r >=
CLIP_TOKENS by forcing their scores to a large negative value (e.g. -1e9) so
they get zero weight. Concretely, in the MHA block around variables scores,
shifted, exp_s, sm (operating on qkv_gm, T, CLIP_TOKENS), build a boolean column
mask for valid keys (length T with True for indices < CLIP_TOKENS), expand it to
match scores' shape and apply: scores = pl.where(mask, scores, scores -
LARGE_NEG) or scores = pl.select(mask, scores, LARGE_NEG) (using PL ops) before
computing shifted/exp_s/sm; apply the same masking in the corresponding later
block (lines ~251-253) as well.
| sys.path.insert(0, '/data/x00952168/pypto3.0/cann-recipes-embodied-ai/world_model/agibot-arm-world-model/infer_with_torch') | ||
| from test_golden_fun_control_full import ( # noqa: E402 | ||
| T5LayerNorm, T5RelativeEmbedding, T5SelfAttention, | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
rg -n "sys\.path\.insert|test_golden_fun_control_full" models/world_model/t5_encoder.py
fd -a '^test_golden_fun_control_full\.py$'Repository: hw-native-sys/pypto-lib
Length of output: 395
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# 1) Locate the golden dependency anywhere in the repo
echo "== fd results for test_golden_fun_control_full.py =="
fd -a --hidden '^test_golden_fun_control_full\.py$' . || true
# 2) Inspect t5_encoder.py around the sys.path insert/import
echo
echo "== Context around lines 25-60 in models/world_model/t5_encoder.py =="
sed -n '20,60p' models/world_model/t5_encoder.py
# 3) Find golden_t5_encoder and related harness code
echo
echo "== golden_t5_encoder() definition and any layer loop =="
rg -n "def golden_t5_encoder|golden_t5_encoder|T5_LAYERS|layer|layers" models/world_model/t5_encoder.py
echo
echo "== Show golden_t5_encoder() body (best-effort bounded range) =="
# attempt to print a bounded window around the function; fall back if line numbers unknown
python3 - <<'PY'
import re, itertools
path="models/world_model/t5_encoder.py"
with open(path,'r',encoding='utf-8') as f:
lines=f.readlines()
# locate golden_t5_encoder
start=None
for i,l in enumerate(lines):
if re.search(r'\bdef\s+golden_t5_encoder\s*\(', l):
start=i
break
if start is None:
print("golden_t5_encoder not found")
raise SystemExit(0)
lo=max(0,start-5); hi=min(len(lines), start+220)
for j in range(lo,hi):
print(f"{j+1}:{lines[j].rstrip()}")
PY
# 4) Confirm how many T5 blocks/layers the golden harness uses (search for T5SelfAttention usage and indexing)
echo
echo "== T5SelfAttention usage sites =="
rg -n "T5SelfAttention" models/world_model/t5_encoder.pyRepository: hw-native-sys/pypto-lib
Length of output: 11556
Remove the hardcoded /data/... golden import and fix the single-layer golden reference.
models/world_model/t5_encoder.pyinserts a machine-local absolute path intosys.pathand importstest_golden_fun_control_fullat module import time;test_golden_fun_control_full.pyisn’t present in this repo, so importing the module depends on an external checkout existing at that exact path.golden_t5_encoder()extracts only layer-0 weights (tensors["..."][0, ...]) and runs a singleT5SelfAttentionblock, butt5_encoderappliesfor layer_idx in pl.range(T5_LAYERS), so the golden comparison becomes incorrect whenT5_LAYERS > 1.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@models/world_model/t5_encoder.py` around lines 35 - 38, Remove the hardcoded
sys.path insertion and the top-level import of test_golden_fun_control_full;
instead import or reference T5LayerNorm, T5RelativeEmbedding, T5SelfAttention
inside the golden_t5_encoder function (or provide local equivalents) so the
module no longer depends on an absolute machine-local path at import time. In
golden_t5_encoder(), stop only taking tensors["..."][0,...] for layer weights;
iterate over all layer indices up to T5_LAYERS (use the same loop used by
t5_encoder: for layer_idx in pl.range(T5_LAYERS)) and slice the tensors
per-layer (e.g., tensors["..."][layer_idx,...]) and run a T5SelfAttention call
per layer (or apply the same sequence of
T5LayerNorm/T5RelativeEmbedding/T5SelfAttention for each layer) so the golden
comparator matches multi-layer behavior when T5_LAYERS > 1.
| pos_bias_4d = tensors["pos_bias"][0].unsqueeze(0) # [1, H, S, S] | ||
|
|
||
| layer_weights = { | ||
| 'norm1_w': tensors["norm1_w"][0, 0, :], | ||
| 'norm2_w': tensors["norm2_w"][0, 0, :], | ||
| 'attn': { | ||
| 'q': tensors["q_w"][0].float(), | ||
| 'k': tensors["k_w"][0].float(), | ||
| 'v': tensors["v_w"][0].float(), | ||
| 'o': tensors["o_w"][0].float(), | ||
| }, | ||
| 'ffn': { | ||
| 'wi_0': tensors["wi_0"][0].float(), | ||
| 'wi_1': tensors["wi_1"][0].float(), | ||
| 'wo': tensors["wo"][0].float(), | ||
| }, | ||
| } | ||
|
|
||
| block = T5SelfAttention( | ||
| layer_weights, T5_DIM, T5_DIM, T5_FFN, T5_HEADS, T5_NUM_BUCKETS, | ||
| shared_pos=True, dropout=0.1, | ||
| ) | ||
| x = block(x, mask=None, pos_bias=pos_bias_4d) | ||
| x = T5LayerNorm(tensors["final_norm_w"].squeeze(0), T5_DIM)(x) | ||
|
|
||
| tensors["out"][:] = x.squeeze(0).bfloat16() |
There was a problem hiding this comment.
The golden harness only validates a single encoder block.
golden_t5_encoder() always pulls layer 0 weights and applies one T5SelfAttention block, so it only matches the kernel while T5_LAYERS == 1. models/world_model/config.py already documents a 24-layer real-case config, so run_jit() stops validating the actual encoder stack as soon as that mode is enabled.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@models/world_model/t5_encoder.py` around lines 419 - 444, golden_t5_encoder()
currently always uses layer 0 weights and a single T5SelfAttention, so it only
validates when T5_LAYERS == 1; update it to loop over all encoder layers: for
each layer index i, extract per-layer tensors (e.g., tensors["q_w"][i],
tensors["k_w"][i], tensors["v_w"][i], tensors["o_w"][i], tensors["wi_0"][i],
tensors["wi_1"][i], tensors["wo"][i], tensors["norm1_w"][i],
tensors["norm2_w"][i]) build a T5SelfAttention/T5 FFN block with those
layer_weights and apply it sequentially to x (passing pos_bias appropriately),
and only after the loop apply the final layer norm using
tensors["final_norm_w"]; finally write the resulting x to tensors["out"] as
before so the harness validates multi-layer encoder stacks.
…ontrol 1.3B delivery - Add dit_forward.py: fused DiT with timestep embedding, AdaLN, RoPE, self/cross-attn, FFN, head - Update t5_encoder.py: refactor with @pl.jit.inline _rmsnorm, argparse CLI - Update clip_encoder.py: switch to QuickGELU (matching golden), add geometry assertions - Update config.py: add DiT/LAT constants
…ation bug - Add vae_encoder.py: VAE encoder for Fun-Control 1.3B with multi-kernel design - Add vae_decoder.py: VAE decoder for Fun-Control 1.3B with upsample3d support - Fix dit_forward.py timestep modulation vector slicing bug (line 268-271) - Remove unused l_gate_attn/l_gate_ffn parameters from DiT kernel
Add T5 and CLIP encoder pypto3.0 kernels.