Skip to content

Add rope/mx_block_quant vmi/mi code#491

Open
Zhendong404 wants to merge 1 commit into
feature-vmifrom
vmi-examples
Open

Add rope/mx_block_quant vmi/mi code#491
Zhendong404 wants to merge 1 commit into
feature-vmifrom
vmi-examples

Conversation

@Zhendong404

Copy link
Copy Markdown
Collaborator

No description provided.

Comment on lines +143 to +215
} else {
// ========================================================================
// Interleave 模式(奇偶交织布局)— MI 版本
// 公式: y = x*cos + rot(x)*sin, rot(x) = [-x_odd, x_even]
//
// MI: 需要两种 mask——pair mask(32宽) 用于奇偶半区的乘加,
// full mask(64宽) 用于最终写回。
// 解交织/交织使用 vdintlv/vintlv 硬件指令——开发者需理解 lane 映射。
// ========================================================================

%mask16_pair = pto.pge_b16 "PAT_VL32" : !pto.mask<b16>
%mask16_full = pto.pge_b16 "PAT_VL64" : !pto.mask<b16>
scf.for %s = %c0 to %s_count step %c1 {
%x_s_off = arith.muli %s, %x_s_step : index
%cs_off = arith.muli %s, %cs_s_step : index
%y_s_off = arith.muli %s, %y_s_step : index

%cos = pto.vlds %cos_ub[%cs_off] : !pto.ptr<f16, ub> -> !pto.vreg<128xf16>
%sin = pto.vlds %sin_ub[%cs_off] : !pto.ptr<f16, ub> -> !pto.vreg<128xf16>
// building the rotated partner [-x1, x0, -x3, x2, ...] and then
// evaluating y = x * cos + rotated(x) * sin.
// This PTO form keeps the same RoPE pairing but writes it explicitly
// in terms of even/odd streams.
// MI: vdintlv 硬件解交织——将 [e0,o0, e1,o1, ...] 拆分为
// even=[e0,e1,...] 和 odd=[o0,o1,...] 两个独立寄存器。
// 输出两个 vreg<128xf16>(各只有半宽有效),需要知道 lane 映射关系。
// VMI 对应: channel_split(语义化奇偶分解,直接产出 vreg<64>)。

%cos_even, %cos_odd = pto.vdintlv %cos, %cos : !pto.vreg<128xf16>, !pto.vreg<128xf16> -> !pto.vreg<128xf16>, !pto.vreg<128xf16>
%sin_even, %sin_odd = pto.vdintlv %sin, %sin : !pto.vreg<128xf16>, !pto.vreg<128xf16> -> !pto.vreg<128xf16>, !pto.vreg<128xf16>

scf.for %n = %c0 to %n_count step %c1 {
%x_n_off = arith.muli %n, %x_n_step : index
%y_n_off = arith.muli %n, %y_n_step : index
%x_off = arith.addi %x_s_off, %x_n_off : index
%y_off = arith.addi %y_s_off, %y_n_off : index

%x = pto.vlds %x_ub[%x_off] : !pto.ptr<f16, ub> -> !pto.vreg<128xf16>
// MI: vdintlv 解交织 x,拆出 even/odd 两个半宽寄存器。
// VMI 对应: channel_split。

%x_even, %x_odd = pto.vdintlv %x, %x : !pto.vreg<128xf16>, !pto.vreg<128xf16> -> !pto.vreg<128xf16>, !pto.vreg<128xf16>

// Explicit even/odd RoPE equations for the interleaved layout.
// ========================================================================
// MI: 成对 RoPE 计算——8条算术指令,每条绑 mask<b16>。
// 4条 vmul + 2条 vsub + 2条 vadd,mask 参数反复出现。
// VMI 对应: 6条 vmi.mulf/subf/addf(无 mask 参数)。
// ========================================================================

%x_even_cos = pto.vmul %x_even, %cos_even, %mask16_pair : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask<b16> -> !pto.vreg<128xf16>
%x_odd_sin = pto.vmul %x_odd, %sin_even, %mask16_pair : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask<b16> -> !pto.vreg<128xf16>
%y_even = pto.vsub %x_even_cos, %x_odd_sin, %mask16_pair : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask<b16> -> !pto.vreg<128xf16>

%x_odd_cos = pto.vmul %x_odd, %cos_odd, %mask16_pair : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask<b16> -> !pto.vreg<128xf16>
%x_even_sin = pto.vmul %x_even, %sin_odd, %mask16_pair : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask<b16> -> !pto.vreg<128xf16>
%y_odd = pto.vadd %x_odd_cos, %x_even_sin, %mask16_pair : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask<b16> -> !pto.vreg<128xf16>

// MI: vintlv 硬件交织——将独立的 even/odd 寄存器合并回
// 交织格式 [e0,o0, e1,o1, ...]。产生两个输出(低半/高半),
// 需要分别 store 或只 store 低半(取决于数据量)。
// VMI 对应: channel_merge(产出单个 vreg<128>,逻辑更简洁)。

%y_pack, %y_pack_hi = pto.vintlv %y_even, %y_odd : !pto.vreg<128xf16>, !pto.vreg<128xf16> -> !pto.vreg<128xf16>, !pto.vreg<128xf16>
// MI: vsts 写回,绑定 mask<b16> full mask(64宽)。
// vintlv 产出的 y_pack_hi 在当前数据量(64元素)下不需要 store。
// VMI 对应: vmi.masked_store(单个 128xf16 向量一次写回)。

pto.vsts %y_pack, %y_ub[%y_off], %mask16_full : !pto.vreg<128xf16>, !pto.ptr<f16, ub>, !pto.mask<b16>
}
}
}
}

@learning-chip learning-chip Jul 2, 2026

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This interleave branch has a different implementation logic than my intrinsics code. The math is equivalent, but the cycles might differ.. The version below mirrors the original intrinsics more closely:

        } else {
          // === INTERLEAVE mode (GPT-J layout) — v2(code here) vs rope_f16.mi.pto ===
          //
          // Same interleaved RoPE semantics; two equivalent spellings of complex rotation:
          //
          // rope_f16.mi.pto — Cartesian expansion form (even/odd deinterleaved compute):
          //   Treat (x_even, x_odd) as (Re, Im) per 2D block; expand (a+ib)(c+id):
          //     y_even = x_even*cos_even - x_odd*sin_even
          //     y_odd  = x_odd*cos_odd  + x_even*sin_odd
          //   vdintlv cos, sin, AND x → 8 arith ops → vintlv → store
          //   Fixed pge_b16 masks; no blk loop; no vbr.
          //
          // v2(code here) — complex-multiply form (rope_cce_compute.h ComputeF16 INTERLEAVE):
          //   y = x*cos + (i*x)*sin,  (i*x) = [-x1,x0,-x3,x2,...] in interleaved layout
          //   cos/sin dense load (no vdintlv on tables)
          //   Build i*x on x only: vdintlv(x) → vmul(odd,-1) → vintlv
          //   3 arith ops → single vsts; blk loop (64 f16/block); plt_b16; vbr(-1)
          %neg_one = pto.vbr %neg_one_f16 : f16 -> !pto.vreg<128xf16>

          scf.for %s = %c0 to %s_count step %c1 {
            %x_s_off = arith.muli %s, %x_s_step : index
            %cs_off = arith.muli %s, %cs_s_step : index
            %y_s_off = arith.muli %s, %y_s_step : index

            scf.for %blk = %c0 to %d_blocks step %c1 {
              // rope_f16.mi.pto: no blk loop — one cos/sin load at cs_off per s.
              %off = arith.muli %blk, %c64 : index
              %off_i32 = arith.index_cast %off : index to i32
              %remaining_i32 = arith.subi %c64_i32, %off_i32 : i32
              %lt_blk = arith.cmpi slt, %remaining_i32, %c64_i32 : i32
              %cnt_i32 = arith.select %lt_blk, %remaining_i32, %c64_i32 : i32
              %pair_cnt_plus = arith.addi %cnt_i32, %c1_i32 : i32
              %pair_cnt_i32 = arith.divui %pair_cnt_plus, %c2_i32 : i32

              %mask16, %mask16_next = pto.plt_b16 %cnt_i32 : i32 -> !pto.mask<b16>, i32
              %mask_pair, %mask_pair_next = pto.plt_b16 %pair_cnt_i32 : i32 -> !pto.mask<b16>, i32

              %cs_elem_off = arith.addi %cs_off, %off : index
              // rope_f16.mi.pto deinterleaves cos/sin here; v2 keeps tables dense.
              %cos = pto.vlds %cos_ub[%cs_elem_off] : !pto.ptr<f16, ub> -> !pto.vreg<128xf16>
              %sin = pto.vlds %sin_ub[%cs_elem_off] : !pto.ptr<f16, ub> -> !pto.vreg<128xf16>

              scf.for %n = %c0 to %n_count step %c1 {
                %x_n_off = arith.muli %n, %x_n_step : index
                %y_n_off = arith.muli %n, %y_n_step : index
                %x_off = arith.addi %x_s_off, %x_n_off : index
                %y_off = arith.addi %y_s_off, %y_n_off : index
                %x_elem_off = arith.addi %x_off, %off : index
                %y_elem_off = arith.addi %y_off, %off : index

                %x = pto.vlds %x_ub[%x_elem_off] : !pto.ptr<f16, ub> -> !pto.vreg<128xf16>

                // Build (i*x) for complex-multiply form — v1 builds i*x implicitly
                // via separate y_even/y_odd updates on deinterleaved streams.
                %x_even, %x_odd = pto.vdintlv %x, %x : !pto.vreg<128xf16>, !pto.vreg<128xf16> -> !pto.vreg<128xf16>, !pto.vreg<128xf16>
                %x_neg_odd = pto.vmul %x_odd, %neg_one, %mask_pair : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask<b16> -> !pto.vreg<128xf16>
                %x_rot, %x_rot_hi = pto.vintlv %x_neg_odd, %x_even : !pto.vreg<128xf16>, !pto.vreg<128xf16> -> !pto.vreg<128xf16>, !pto.vreg<128xf16>

                // complex-multiply: y = x*cos + (i*x)*sin  (3 ops)
                // vs Cartesian expansion: separate y_even/y_odd updates (8 ops)
                %x_cos = pto.vmul %x, %cos, %mask16 : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask<b16> -> !pto.vreg<128xf16>
                %rot_sin = pto.vmul %x_rot, %sin, %mask16 : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask<b16> -> !pto.vreg<128xf16>
                %y = pto.vadd %x_cos, %rot_sin, %mask16 : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask<b16> -> !pto.vreg<128xf16>

                pto.vsts %y, %y_ub[%y_elem_off], %mask16 : !pto.vreg<128xf16>, !pto.ptr<f16, ub>, !pto.mask<b16>
              }
            }
          }
        }

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants