Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion docker/rocm.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -210,11 +210,14 @@ RUN pip uninstall -y aiter
# block switching to commits that predate that rule (e.g. the current default
# AITER_COMMIT_DEFAULT). The working tree was just produced by a fresh
# `git clone` above, so there are no real user changes to preserve.
COPY scripts/ci/amd/patch_aiter_gluon_pa_mqa_logits.py /tmp/patch_aiter_gluon_pa_mqa_logits.py

RUN git clone ${AITER_REPO} \
&& cd aiter \
&& git checkout -f ${AITER_COMMIT} \
&& git submodule update --init --recursive \
&& pip install -r requirements.txt
&& pip install -r requirements.txt \
&& python3 /tmp/patch_aiter_gluon_pa_mqa_logits.py /sgl-workspace/aiter

RUN cd aiter \
&& echo "[AITER] GPU_ARCH=${GPU_ARCH}" \
Expand Down
5 changes: 5 additions & 0 deletions docs_new/cookbook/autoregressive/GLM/GLM-5.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ sglang serve \
The following ROCm command is an additional option for AMD GPUs and does not replace the NVIDIA instructions above.

```shell Command
SGLANG_ROCM_FUSED_DECODE_MLA=0 \
ROCM_QUICK_REDUCE_QUANTIZATION=INT4 \
SAFETENSORS_FAST_GPU=1 \
sglang serve \
--model-path zai-org/GLM-5 \
--tp 8 \
Expand All @@ -133,6 +136,8 @@ sglang serve \
--port 30000
```

GLM-5 uses DSA (not MLA): set `SGLANG_ROCM_FUSED_DECODE_MLA=0` on ROCm. `ROCM_QUICK_REDUCE_QUANTIZATION=INT4` and `SAFETENSORS_FAST_GPU=1` speed weight load on large checkpoints.

### 4.2 Basic Usage

For basic API usage and request examples, please refer to:
Expand Down
6 changes: 5 additions & 1 deletion docs_new/src/snippets/autoregressive/glm-5-deployment.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,11 @@ export const GLM5Deployment = () => {
const tpValue = hwConfig.tp;
const memFraction = hwConfig.mem;

let cmd = 'sglang serve \\\n';
let cmd = '';
if (isAMD) {
cmd += 'SGLANG_ROCM_FUSED_DECODE_MLA=0 ROCM_QUICK_REDUCE_QUANTIZATION=INT4 SAFETENSORS_FAST_GPU=1 ';
}
cmd += 'sglang serve \\\n';
cmd += ` --model-path ${modelName}`;
cmd += ` \\\n --tp ${tpValue}`;

Expand Down
8 changes: 8 additions & 0 deletions scripts/ci/amd/amd_ci_install_dependency.sh
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,14 @@ if [[ "${NEED_REBUILD}" == "true" ]]; then
echo "[CI-AITER-CHECK] === AITER REBUILD COMPLETE ==="
fi

# Apply gluon hotpatch for pre-installed AITER in CI images (idempotent; no-op on fixed aiter).
if docker exec ci_sglang test -d /sgl-workspace/aiter 2>/dev/null; then
echo "[CI-AITER-CHECK] Applying gluon pa_mqa_logits instr_shape hotpatch (idempotent)..."
docker exec ci_sglang python3 \
/sglang-checkout/scripts/ci/amd/patch_aiter_gluon_pa_mqa_logits.py \
/sgl-workspace/aiter || true
fi

echo "[CI-AITER-CHECK] === AITER VERSION CHECK END ==="


Expand Down
96 changes: 96 additions & 0 deletions scripts/ci/amd/patch_aiter_gluon_pa_mqa_logits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
"""Idempotently patch aiter gluon pa_mqa_logits for Triton >= 3.5 MFMA instr_shape.

The base _gluon_deepgemm_fp8_paged_mqa_logits variant hardcoded 2D instr_shape
[16, 16] on older aiter commits. Triton 3.5+ requires 3D [16, 16, 32] when
_Use_2d_instr_shape_mfma_layout is false (GLM-5 NSA / deepgemm path).

Upstream fix: ROCm/aiter a1bdcec (#2575). This hotpatch remains for ROCm images
that still ship an older vendored aiter until images are rebuilt.

Usage:
python3 patch_aiter_gluon_pa_mqa_logits.py [AITER_REPO_ROOT]
Default AITER_REPO_ROOT: /sgl-workspace/aiter
"""

from __future__ import annotations

import os
import sys

_SENTINEL = "[PATCHED] 3D instr_shape for base gluon variant"

_OLD = """\
mfma_layout: gl.constexpr = gl.amd.AMDMFMALayout(
version=CDNA_VERSION,
instr_shape=[16, 16],
transposed=False,
warps_per_cta=[1, NumWarps],
)
mfma_layout_a: gl.constexpr = gl.DotOperandLayout(
operand_index=0, parent=mfma_layout, k_width=16
)
mfma_layout_b: gl.constexpr = gl.DotOperandLayout(
operand_index=1, parent=mfma_layout, k_width=16
)"""

_NEW = """\
# [PATCHED] 3D instr_shape for base gluon variant
if _Use_2d_instr_shape_mfma_layout:
mfma_layout: gl.constexpr = gl.amd.AMDMFMALayout(
version=CDNA_VERSION,
instr_shape=[16, 16],
transposed=False,
warps_per_cta=[1, NumWarps],
)
else:
mfma_layout: gl.constexpr = gl.amd.AMDMFMALayout(
version=CDNA_VERSION,
instr_shape=[16, 16, 32],
transposed=False,
warps_per_cta=[1, NumWarps],
)
mfma_layout_a: gl.constexpr = gl.DotOperandLayout(
operand_index=0, parent=mfma_layout, k_width=16
)
mfma_layout_b: gl.constexpr = gl.DotOperandLayout(
operand_index=1, parent=mfma_layout, k_width=16
)"""


def patch_gluon_pa_mqa_logits(aiter_root: str) -> bool:
target = os.path.join(
aiter_root, "aiter", "ops", "triton", "gluon", "pa_mqa_logits.py"
)
if not os.path.isfile(target):
print(f"[aiter-hotpatch] {target} not found, skipping")
return False

src = open(target, encoding="utf-8").read()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

It is recommended to use a with statement when opening files to ensure that file descriptors are closed properly and promptly, rather than relying on garbage collection.

Suggested change
src = open(target, encoding="utf-8").read()
with open(target, encoding="utf-8") as f:
src = f.read()

if _SENTINEL in src:
print("[aiter-hotpatch] gluon pa_mqa_logits 3D instr_shape already applied")
return False

if _OLD not in src:
print(
"[aiter-hotpatch] WARN: gluon pa_mqa_logits pattern not found "
"(aiter may already include ROCm/aiter#2575)"
)
return False

new_src = src.replace(_OLD, _NEW, 1)
with open(target, "w", encoding="utf-8") as f:
f.write(new_src)
print("[aiter-hotpatch] Patched gluon pa_mqa_logits 3D instr_shape (base variant)")
return True


def main() -> int:
aiter_root = sys.argv[1] if len(sys.argv) > 1 else "/sgl-workspace/aiter"
patch_gluon_pa_mqa_logits(aiter_root)
return 0


if __name__ == "__main__":
sys.exit(main())
Loading