Skip to content

[Feature] PTODSL: 补齐 SIMT RMSNorm 所需的 4-lane 向量切片访存与实例化 allreduce helper #441

Description

@Zhendong404

背景

TileLang_RMSNorm_to_PTO_Design.md3.3 向量访问对接方案3.4 Reduce 对接方案 中,RMSNorm 的 SIMT body 对 PTODSL 提出了两类明确需求:

  • 保留 8 x float4 粒度的连续 4-lane 访存结构,并允许后续按标量消费 lane-local scratch。
  • 能按 AscendAllReduce<Reducer, threads, scale, thread_offset> 的实例参数,选择/生成可复用的 reduce helper。

设计文档链接:

现状梳理

当前 PTODSL 已经具备一部分基础:

  • @pto.simt / with pto.simt(): 结构化 SIMT surface。
  • pto.get_tid_x/y/z、标量 load/storepipe_barrier/mem_bar 等基础能力。
  • 有面向 SIMD 的 tile-slice sugar,例如 pto.vlds(tile[row, col:]) / pto.vsts(vec, tile[row, col:], mask),但它们表达的是“整条 vector-width slice”。

但对照 RMSNorm 设计文档,当前 PTODSL 还缺两块关键能力:

  1. 缺少能稳定表达 tx * 4 + lane 这类 4-lane 连续访存关系的 surface / lowering 约定。
  2. 缺少可实例化复用的 SIMT collective/reduce helper 组织方式,尤其是 allreduce(sum/max/min, threads/scale/offset) 这类模式。

文档里的目标形态

1. 4-lane 连续访存 + lane-local scratch

设计文档里 RMSNorm 的关键模式是:

  • 第一层保留 8 x float4 的 load/store 粒度。
  • 中间保留 x_frag[32] 这样的 lane-local scratch。
  • 第二层继续按标量 x_frag[i] 消费。
  • 后续交给 SROA/LLVM 把本地数组拆成 extractelement

文档明确建议:

  • 如果以后接 PTODSL,可以在 PTODSL 中用类似 Python slice 的语法描述 4-lane 连续访存。
  • lowering 时把这类结构化切片翻成 LLVM Dialect 的向量 load/store。

2. 实例化 allreduce helper

文档对 tl::AscendAllReduce<Reducer, threads, scale, thread_offset>::run(...) 的建议是:

  • collective 语义尽量沉淀到 PTODSL。
  • codegen 阶段按实例参数选择/导入具体 helper。
  • helper 内部可基于 redux_* + syncthreads + scratch load/store 组织实现。

当前 RMSNorm 直接需求的最小实例就是:

  • reducer = sum
  • dtype = f32
  • threads = 128
  • scale = 1
  • thread_offset = 0

建议补齐的 PTODSL 能力

A. 结构化 4-lane 向量访存 surface

建议支持一种能保持“连续 4 元素窗口”语义的 PTODSL 写法,用于 SIMT helper 内的 lane-local scratch 装载/写回。形式不一定要和设计文档完全一致,但至少要满足:

  • 作者态可以表达 base + tx * 4 起始的连续 4 元素访存。
  • lowering 后能稳定落到 LLVM Dialect 的 <4 x f32> load/store,而不是一开始就被摊平成 4 次标量 load/store。
  • 允许后续按标量读取 scratch 中的元素,给 SROA/LLVM 留出拆成 extractelement 的空间。

这里不强求必须暴露成通用 public API,也可以先做成 PTODSL 内部 helper / lowering 约定,但希望后续能被 TileLang->PTO 路径复用。

B. lane-local local scratch 表达

RMSNorm 还需要一个能承载 x_frag[32] / sum_sq[1] 这类临时值的 PTODSL 表达方式。建议明确:

  • 是否提供 PTODSL 层的 local scratch surface。
  • 或者由 PTODSL lowering 内部直接生成对应的 LLVM alloca / slot 组织。
  • 无论采用哪种方式,都需要保证“先以 <4 x f32> 为单位写入,再按标量消费”的优化链路可行。

C. collective/reduce helper 机制

建议在 PTODSL 侧引入一套可实例化、可缓存的 reduce helper 机制,至少覆盖:

  • reducer kind: sum / max / min
  • dtype: 先满足 f32
  • compile-time params: threads / scale / thread_offset

期望行为:

  • 同一实例在 module 内只生成一次 helper。
  • codegen 调用点按实例参数选择 helper,而不是在 PTO 后端再做模板推导。
  • helper body 可表达两级规约、scratch 写回、syncthreads、广播读回等结构。

验收建议

如果按 RMSNorm 当前需求收敛,建议至少满足下面几项:

  1. 能用 PTODSL/其 lowering 路径表达 128-thread SIMT body 中的 8 x float4 连续访存,并在 IR 中保住 <4 x f32> load/store。
  2. 能保住 x_frag[32] 这类 lane-local scratch,再由后续 LLVM pass 拆成标量消费链。
  3. 能提供或内部生成一个等价于 AscendAllReduce<SumOp, 128, 1, 0>::run 的 helper,并在调用点复用。
  4. 最好补一个最小 RMSNorm 或等价 micro case 回归,覆盖:
    • 4-lane load/store
    • lane-local scratch 标量消费
    • sum allreduce + sync + broadcast

备注

这个 issue 先聚焦 PTODSL 侧的需求抽象,不要求一次把 TileLang->PTO 全链路都打通;但希望 PTODSL 的 surface / lowering 设计能和上述 RMSNorm 路径对齐,避免后面再为 TileLang 单独造一套旁路逻辑。

Metadata

Metadata

Assignees

No one assigned

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions