Skip to content
Merged

Glm45 #336

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions common/chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1646,7 +1646,7 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
"|<function name=\"([^\"]+)\">" // match 5 (function name again)
);

if (auto res = builder.try_find_regex(open_regex)) {
while (auto res = builder.try_find_regex(open_regex)) {
const auto & block_start = res->groups[1];
std::string block_end = block_start.empty() ? "" : "```";

Expand All @@ -1668,7 +1668,6 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
builder.consume_literal(block_end);
builder.consume_spaces();
}
builder.add_content(builder.consume_rest());
} else {
throw common_chat_msg_partial_exception("failed to parse tool call");
}
Expand All @@ -1693,11 +1692,10 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
builder.consume_spaces();
}
}
builder.add_content(builder.consume_rest());
}
} else {
builder.add_content(builder.consume_rest());
}

builder.add_content(builder.consume_rest());
}

static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
Expand Down
3 changes: 3 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
if chkhsh == "2085e1638f6c377a0aa4ead21b27bb4cb941bf800df86ed391011769c1758dfb":
# ref: https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B
res = "exaone4"
if chkhsh == "d4540891389ea895b53b399da6ac824becc30f2fba0e9ddbb98f92e55ca0e97c":
# ref: https://huggingface.co/Qwen/Qwen3-Embedding-8B
res = "qwen2"

if res is None:
logger.warning("\n")
Expand Down
327 changes: 233 additions & 94 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp

Large diffs are not rendered by default.

88 changes: 53 additions & 35 deletions ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
#version 450

#extension GL_EXT_control_flow_attributes : enable

#ifdef USE_COLLECTIVES
# extension GL_KHR_shader_subgroup_shuffle : enable
#endif

#include "types.comp"

// Make spec constant
#define SHMEM_PAD 0

// shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j
layout(binding = 0) readonly buffer A {
A_TYPE knl_data[];
Expand Down Expand Up @@ -56,6 +55,12 @@ layout(push_constant) uniform parameter {
uint32_t nb1;
uint32_t nb2;
uint32_t nb3;

// fastdiv helper values
uint32_t KWmp; uint32_t KWL;
uint32_t KWKHmp; uint32_t KWKHL;
uint32_t OWmp; uint32_t OWL;
uint32_t OWOHmp; uint32_t OWOHL;
}

p;
Expand All @@ -68,6 +73,7 @@ layout(constant_id = 3) const uint BS_NPQ = 128;
// Thread-tile sizes
layout(constant_id = 4) const uint TS_K = 8;
layout(constant_id = 5) const uint use_collectives = 1;
layout(constant_id = 6) const uint SHMEM_PAD = 4;

uint32_t tid = gl_LocalInvocationID.x;
const uint32_t WG_SIZE = gl_WorkGroupSize.x;
Expand Down Expand Up @@ -131,6 +137,14 @@ uint32_t Br = tid / BS_NPQ;
uint32_t Bc = tid % BS_NPQ;
const uint32_t BrpWg = WG_SIZE / BS_NPQ;

// see init_fastdiv_values in ggml-vulkan.cpp
uint fastdiv(uint n, uint mp, uint L) {
uint msbs, lsbs;
// msbs = mulhi(n, mp)
umulExtended(n, mp, msbs, lsbs);
return (msbs + n) >> L;
}

void main() {
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
Expand All @@ -151,9 +165,9 @@ void main() {
uint32_t cached_KW_idx;
if (use_collectives == 1) {
cached_CRS_idx = B_idx_CRS * BS_CRS + gl_SubgroupInvocationID;
cached_Cin_idx = cached_CRS_idx / (p.KW * p.KH);
cached_Cin_idx = fastdiv(cached_CRS_idx, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
uint32_t cached_CRS_remainder = (cached_CRS_idx - cached_Cin_idx * p.KW * p.KH);
cached_KH_idx = cached_CRS_remainder / p.KW;
cached_KH_idx = fastdiv(cached_CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
cached_KW_idx = cached_CRS_remainder - cached_KH_idx * p.KW;

CRS_idx_a = subgroupShuffle(cached_CRS_idx, Ac);
Expand All @@ -162,16 +176,16 @@ void main() {
KW_idx_a = subgroupShuffle(cached_KW_idx, Ac);
} else {
CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A)
Cin_idx_a = CRS_idx_a / (p.KW * p.KH);
Cin_idx_a = fastdiv(CRS_idx_a, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
uint32_t CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH;
KH_idx_a = CRS_remainder / p.KW;
KH_idx_a = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
KW_idx_a = CRS_remainder - KH_idx_a * p.KW;
}
#else
CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A)
Cin_idx_a = CRS_idx_a / (p.KW * p.KH);
Cin_idx_a = fastdiv(CRS_idx_a, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH); / (p.KW * p.KH);
CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH;
KH_idx_a = CRS_remainder / p.KW;
KH_idx_a = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
KW_idx_a = CRS_remainder - KH_idx_a * p.KW;
#endif

Expand All @@ -188,13 +202,13 @@ void main() {
Ash[B_ly * Ash_stride + B_lx] = val;
}
/* Load input to B_block: (BS_CRS x BS_NPQ) */
for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) {
UNROLL for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) {
uint32_t B_ly = r_offset + Br; /* Row index of B block */
uint32_t B_lx = Bc;
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + B_lx; /* Global NPQ index (column index of B) */
uint32_t N_idx = NPQ_idx / (p.OH * p.OW);
uint32_t N_idx = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW;
uint32_t NPQ_remainder = NPQ_idx - N_idx * p.OH * p.OW;
uint32_t OH_idx = NPQ_remainder / p.OW;
uint32_t OH_idx = fastdiv(NPQ_remainder, p.OWmp, p.OWL); // divide by p.OW;
uint32_t OW_idx = NPQ_remainder - OH_idx * p.OW;

uint32_t CRS_idx_b;
Expand All @@ -209,16 +223,16 @@ void main() {
KW_idx_b = subgroupShuffle(cached_KW_idx, r_offset + Br);
} else {
CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */
Cin_idx_b = CRS_idx_b / (p.KW * p.KH);
Cin_idx_b = fastdiv(CRS_idx_b, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH;
KH_idx_b = CRS_remainder / p.KW;
KH_idx_b = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
KW_idx_b = CRS_remainder - KH_idx_b * p.KW;
}
#else
CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */
Cin_idx_b = CRS_idx_b / (p.KW * p.KH);
Cin_idx_b = fastdiv(CRS_idx_b, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH;
KH_idx_b = CRS_remainder / p.KW;
KH_idx_b = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
KW_idx_b = CRS_remainder - KH_idx_b * p.KW;
#endif

Expand All @@ -233,32 +247,36 @@ void main() {
Bsh[B_ly * Bsh_stride + B_lx] = val;
}
barrier();
for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) {
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx];
}
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
regB[T_lx] = Bsh[CRS_lidx * Bsh_stride + T_x * TS_NPQ + T_lx];
}
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
if (T_y * TS_K < K) {
UNROLL for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) {
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx];
}
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]);
regB[T_lx] = Bsh[CRS_lidx * Bsh_stride + T_x * TS_NPQ + T_lx];
}
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]);
}
}
}
}
barrier();
}
/* Save C* */
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly;
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx;
uint32_t N_idx = NPQ_idx / (p.OH * p.OW);
uint32_t OH_idx = (NPQ_idx - N_idx * p.OH * p.OW) / p.OW;
uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW;
uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3;
if (K_idx < K && NPQ_idx < NPQ) {
dst_data[dst_idx] = regC[T_ly][T_lx];
if (T_y * TS_K < K) {
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly;
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx;
uint32_t N_idx = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW;
uint32_t OH_idx = fastdiv(NPQ_idx - N_idx * p.OH * p.OW, p.OWmp, p.OWL); // divide by p.OW;
uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW;
uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3;
if (K_idx < K && NPQ_idx < NPQ) {
dst_data[dst_idx] = regC[T_ly][T_lx];
}
}
}
}
Expand Down
18 changes: 11 additions & 7 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ layout (push_constant) uniform parameter
uint ne12;
uint b_offset;
uint d_offset;
uint nb03;
uint nb13;
uint nb23;
} p;

shared FLOAT_TYPE tmp[BLOCK_SIZE];
Expand All @@ -34,14 +37,15 @@ void main() {
const uint tid = gl_LocalInvocationID.x;
const uint row_x = gl_GlobalInvocationID.y;
const uint channel = gl_GlobalInvocationID.z;
const uint i3 = gl_WorkGroupID.x;
const uint channel_x = channel / p.channel_x_divisor;
const uint channel_y = channel % p.ne12;

const uint nrows_y = p.ncols_x;
const uint nrows_dst = p.nrows_x;
const uint row_dst = row_x;

const uint idst = channel*nrows_dst + row_dst;
const uint idst = i3*p.nb23 + channel*nrows_dst + row_dst;

FLOAT_TYPE temp = 0.0f;

Expand All @@ -58,8 +62,8 @@ void main() {

const uint row_y = col_x;

const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
const uint iy = channel_y*p.channel_stride_y + row_y;
const uint ix = i3*p.nb03 + channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
const uint iy = i3*p.nb13 + channel_y*p.channel_stride_y + row_y;

const vec4 av4 = vec4(data_a_v4[ix / 4]);
const vec4 bv4 = vec4(data_b_v4[iy / 4]);
Expand All @@ -74,8 +78,8 @@ void main() {

const uint row_y = col_x;

const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
const uint iy = channel_y*p.channel_stride_y + row_y;
const uint ix = i3*p.nb03 + channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
const uint iy = i3*p.nb13 + channel_y*p.channel_stride_y + row_y;

const vec4 av4 = vec4(data_a_v4[ix / 4]);
const vec4 bv4 = vec4(data_b_v4[iy / 4]);
Expand All @@ -91,8 +95,8 @@ void main() {

const uint row_y = col_x;

const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
const uint iy = channel_y*p.channel_stride_y + row_y;
const uint ix = i3*p.nb03 + channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
const uint iy = i3*p.nb13 + channel_y*p.channel_stride_y + row_y;

const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);

Expand Down
7 changes: 5 additions & 2 deletions ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -655,8 +655,11 @@ void process_shaders() {

string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));

string_to_spv("conv2d_f32", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}});
string_to_spv("conv2d_f16_f32", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}});
string_to_spv("conv2d_f32_unroll", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}});
string_to_spv("conv2d_f16_f32_unroll", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}});

string_to_spv("conv2d_f32", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", ""}});
string_to_spv("conv2d_f16_f32", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", ""}});

string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
Expand Down
25 changes: 25 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ class MODEL_ARCH(IntEnum):
DEEPSEEK2 = auto()
CHATGLM = auto()
GLM4 = auto()
GLM4_MOE = auto()
BITNET = auto()
T5 = auto()
T5ENCODER = auto()
Expand Down Expand Up @@ -678,6 +679,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.DEEPSEEK2: "deepseek2",
MODEL_ARCH.CHATGLM: "chatglm",
MODEL_ARCH.GLM4: "glm4",
MODEL_ARCH.GLM4_MOE: "glm4_moe",
MODEL_ARCH.BITNET: "bitnet",
MODEL_ARCH.T5: "t5",
MODEL_ARCH.T5ENCODER: "t5encoder",
Expand Down Expand Up @@ -2124,6 +2126,29 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.ATTN_POST_NORM,
MODEL_TENSOR.FFN_POST_NORM,
],
MODEL_ARCH.GLM4_MOE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_K_NORM, # not always present
MODEL_TENSOR.ATTN_Q_NORM, # not always present
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_GATE_SHEXP,
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
MODEL_TENSOR.FFN_EXP_PROBS_B, # AKA "e_score_correction_bias" in transformers
],
MODEL_ARCH.BITNET: [
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
Expand Down
Loading
Loading