Add rope/mx_block_quant vmi/mi code#491
Open
Zhendong404 wants to merge 1 commit into
Open
Conversation
Comment on lines
+143
to
+215
| } else { | ||
| // ======================================================================== | ||
| // Interleave 模式(奇偶交织布局)— MI 版本 | ||
| // 公式: y = x*cos + rot(x)*sin, rot(x) = [-x_odd, x_even] | ||
| // | ||
| // MI: 需要两种 mask——pair mask(32宽) 用于奇偶半区的乘加, | ||
| // full mask(64宽) 用于最终写回。 | ||
| // 解交织/交织使用 vdintlv/vintlv 硬件指令——开发者需理解 lane 映射。 | ||
| // ======================================================================== | ||
|
|
||
| %mask16_pair = pto.pge_b16 "PAT_VL32" : !pto.mask<b16> | ||
| %mask16_full = pto.pge_b16 "PAT_VL64" : !pto.mask<b16> | ||
| scf.for %s = %c0 to %s_count step %c1 { | ||
| %x_s_off = arith.muli %s, %x_s_step : index | ||
| %cs_off = arith.muli %s, %cs_s_step : index | ||
| %y_s_off = arith.muli %s, %y_s_step : index | ||
|
|
||
| %cos = pto.vlds %cos_ub[%cs_off] : !pto.ptr<f16, ub> -> !pto.vreg<128xf16> | ||
| %sin = pto.vlds %sin_ub[%cs_off] : !pto.ptr<f16, ub> -> !pto.vreg<128xf16> | ||
| // building the rotated partner [-x1, x0, -x3, x2, ...] and then | ||
| // evaluating y = x * cos + rotated(x) * sin. | ||
| // This PTO form keeps the same RoPE pairing but writes it explicitly | ||
| // in terms of even/odd streams. | ||
| // MI: vdintlv 硬件解交织——将 [e0,o0, e1,o1, ...] 拆分为 | ||
| // even=[e0,e1,...] 和 odd=[o0,o1,...] 两个独立寄存器。 | ||
| // 输出两个 vreg<128xf16>(各只有半宽有效),需要知道 lane 映射关系。 | ||
| // VMI 对应: channel_split(语义化奇偶分解,直接产出 vreg<64>)。 | ||
|
|
||
| %cos_even, %cos_odd = pto.vdintlv %cos, %cos : !pto.vreg<128xf16>, !pto.vreg<128xf16> -> !pto.vreg<128xf16>, !pto.vreg<128xf16> | ||
| %sin_even, %sin_odd = pto.vdintlv %sin, %sin : !pto.vreg<128xf16>, !pto.vreg<128xf16> -> !pto.vreg<128xf16>, !pto.vreg<128xf16> | ||
|
|
||
| scf.for %n = %c0 to %n_count step %c1 { | ||
| %x_n_off = arith.muli %n, %x_n_step : index | ||
| %y_n_off = arith.muli %n, %y_n_step : index | ||
| %x_off = arith.addi %x_s_off, %x_n_off : index | ||
| %y_off = arith.addi %y_s_off, %y_n_off : index | ||
|
|
||
| %x = pto.vlds %x_ub[%x_off] : !pto.ptr<f16, ub> -> !pto.vreg<128xf16> | ||
| // MI: vdintlv 解交织 x,拆出 even/odd 两个半宽寄存器。 | ||
| // VMI 对应: channel_split。 | ||
|
|
||
| %x_even, %x_odd = pto.vdintlv %x, %x : !pto.vreg<128xf16>, !pto.vreg<128xf16> -> !pto.vreg<128xf16>, !pto.vreg<128xf16> | ||
|
|
||
| // Explicit even/odd RoPE equations for the interleaved layout. | ||
| // ======================================================================== | ||
| // MI: 成对 RoPE 计算——8条算术指令,每条绑 mask<b16>。 | ||
| // 4条 vmul + 2条 vsub + 2条 vadd,mask 参数反复出现。 | ||
| // VMI 对应: 6条 vmi.mulf/subf/addf(无 mask 参数)。 | ||
| // ======================================================================== | ||
|
|
||
| %x_even_cos = pto.vmul %x_even, %cos_even, %mask16_pair : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask<b16> -> !pto.vreg<128xf16> | ||
| %x_odd_sin = pto.vmul %x_odd, %sin_even, %mask16_pair : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask<b16> -> !pto.vreg<128xf16> | ||
| %y_even = pto.vsub %x_even_cos, %x_odd_sin, %mask16_pair : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask<b16> -> !pto.vreg<128xf16> | ||
|
|
||
| %x_odd_cos = pto.vmul %x_odd, %cos_odd, %mask16_pair : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask<b16> -> !pto.vreg<128xf16> | ||
| %x_even_sin = pto.vmul %x_even, %sin_odd, %mask16_pair : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask<b16> -> !pto.vreg<128xf16> | ||
| %y_odd = pto.vadd %x_odd_cos, %x_even_sin, %mask16_pair : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask<b16> -> !pto.vreg<128xf16> | ||
|
|
||
| // MI: vintlv 硬件交织——将独立的 even/odd 寄存器合并回 | ||
| // 交织格式 [e0,o0, e1,o1, ...]。产生两个输出(低半/高半), | ||
| // 需要分别 store 或只 store 低半(取决于数据量)。 | ||
| // VMI 对应: channel_merge(产出单个 vreg<128>,逻辑更简洁)。 | ||
|
|
||
| %y_pack, %y_pack_hi = pto.vintlv %y_even, %y_odd : !pto.vreg<128xf16>, !pto.vreg<128xf16> -> !pto.vreg<128xf16>, !pto.vreg<128xf16> | ||
| // MI: vsts 写回,绑定 mask<b16> full mask(64宽)。 | ||
| // vintlv 产出的 y_pack_hi 在当前数据量(64元素)下不需要 store。 | ||
| // VMI 对应: vmi.masked_store(单个 128xf16 向量一次写回)。 | ||
|
|
||
| pto.vsts %y_pack, %y_ub[%y_off], %mask16_full : !pto.vreg<128xf16>, !pto.ptr<f16, ub>, !pto.mask<b16> | ||
| } | ||
| } | ||
| } | ||
| } |
There was a problem hiding this comment.
This interleave branch has a different implementation logic than my intrinsics code. The math is equivalent, but the cycles might differ.. The version below mirrors the original intrinsics more closely:
} else {
// === INTERLEAVE mode (GPT-J layout) — v2(code here) vs rope_f16.mi.pto ===
//
// Same interleaved RoPE semantics; two equivalent spellings of complex rotation:
//
// rope_f16.mi.pto — Cartesian expansion form (even/odd deinterleaved compute):
// Treat (x_even, x_odd) as (Re, Im) per 2D block; expand (a+ib)(c+id):
// y_even = x_even*cos_even - x_odd*sin_even
// y_odd = x_odd*cos_odd + x_even*sin_odd
// vdintlv cos, sin, AND x → 8 arith ops → vintlv → store
// Fixed pge_b16 masks; no blk loop; no vbr.
//
// v2(code here) — complex-multiply form (rope_cce_compute.h ComputeF16 INTERLEAVE):
// y = x*cos + (i*x)*sin, (i*x) = [-x1,x0,-x3,x2,...] in interleaved layout
// cos/sin dense load (no vdintlv on tables)
// Build i*x on x only: vdintlv(x) → vmul(odd,-1) → vintlv
// 3 arith ops → single vsts; blk loop (64 f16/block); plt_b16; vbr(-1)
%neg_one = pto.vbr %neg_one_f16 : f16 -> !pto.vreg<128xf16>
scf.for %s = %c0 to %s_count step %c1 {
%x_s_off = arith.muli %s, %x_s_step : index
%cs_off = arith.muli %s, %cs_s_step : index
%y_s_off = arith.muli %s, %y_s_step : index
scf.for %blk = %c0 to %d_blocks step %c1 {
// rope_f16.mi.pto: no blk loop — one cos/sin load at cs_off per s.
%off = arith.muli %blk, %c64 : index
%off_i32 = arith.index_cast %off : index to i32
%remaining_i32 = arith.subi %c64_i32, %off_i32 : i32
%lt_blk = arith.cmpi slt, %remaining_i32, %c64_i32 : i32
%cnt_i32 = arith.select %lt_blk, %remaining_i32, %c64_i32 : i32
%pair_cnt_plus = arith.addi %cnt_i32, %c1_i32 : i32
%pair_cnt_i32 = arith.divui %pair_cnt_plus, %c2_i32 : i32
%mask16, %mask16_next = pto.plt_b16 %cnt_i32 : i32 -> !pto.mask<b16>, i32
%mask_pair, %mask_pair_next = pto.plt_b16 %pair_cnt_i32 : i32 -> !pto.mask<b16>, i32
%cs_elem_off = arith.addi %cs_off, %off : index
// rope_f16.mi.pto deinterleaves cos/sin here; v2 keeps tables dense.
%cos = pto.vlds %cos_ub[%cs_elem_off] : !pto.ptr<f16, ub> -> !pto.vreg<128xf16>
%sin = pto.vlds %sin_ub[%cs_elem_off] : !pto.ptr<f16, ub> -> !pto.vreg<128xf16>
scf.for %n = %c0 to %n_count step %c1 {
%x_n_off = arith.muli %n, %x_n_step : index
%y_n_off = arith.muli %n, %y_n_step : index
%x_off = arith.addi %x_s_off, %x_n_off : index
%y_off = arith.addi %y_s_off, %y_n_off : index
%x_elem_off = arith.addi %x_off, %off : index
%y_elem_off = arith.addi %y_off, %off : index
%x = pto.vlds %x_ub[%x_elem_off] : !pto.ptr<f16, ub> -> !pto.vreg<128xf16>
// Build (i*x) for complex-multiply form — v1 builds i*x implicitly
// via separate y_even/y_odd updates on deinterleaved streams.
%x_even, %x_odd = pto.vdintlv %x, %x : !pto.vreg<128xf16>, !pto.vreg<128xf16> -> !pto.vreg<128xf16>, !pto.vreg<128xf16>
%x_neg_odd = pto.vmul %x_odd, %neg_one, %mask_pair : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask<b16> -> !pto.vreg<128xf16>
%x_rot, %x_rot_hi = pto.vintlv %x_neg_odd, %x_even : !pto.vreg<128xf16>, !pto.vreg<128xf16> -> !pto.vreg<128xf16>, !pto.vreg<128xf16>
// complex-multiply: y = x*cos + (i*x)*sin (3 ops)
// vs Cartesian expansion: separate y_even/y_odd updates (8 ops)
%x_cos = pto.vmul %x, %cos, %mask16 : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask<b16> -> !pto.vreg<128xf16>
%rot_sin = pto.vmul %x_rot, %sin, %mask16 : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask<b16> -> !pto.vreg<128xf16>
%y = pto.vadd %x_cos, %rot_sin, %mask16 : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask<b16> -> !pto.vreg<128xf16>
pto.vsts %y, %y_ub[%y_elem_off], %mask16 : !pto.vreg<128xf16>, !pto.ptr<f16, ub>, !pto.mask<b16>
}
}
}
}
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
No description provided.