diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 89539bd761..ab6bbdb2d9 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -250,6 +250,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal case GGML_UNARY_OP_CEIL: op_num = OP_UNARY_NUM_CEIL; break; case GGML_UNARY_OP_ROUND: op_num = OP_UNARY_NUM_ROUND; break; case GGML_UNARY_OP_TRUNC: op_num = OP_UNARY_NUM_TRUNC; break; + case GGML_UNARY_OP_XIELU: op_num = OP_UNARY_NUM_XIELU; break; default: GGML_ABORT("fatal error"); } break; default: GGML_ABORT("fatal error"); @@ -1461,10 +1462,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(ggml_metal_l int op_num = -1; switch (op->op) { - case GGML_OP_ADD: op_num = 0; break; - case GGML_OP_SUB: op_num = 1; break; - case GGML_OP_MUL: op_num = 2; break; - case GGML_OP_DIV: op_num = 3; break; + case GGML_OP_ADD: + case GGML_OP_ADD1: op_num = 0; break; + case GGML_OP_SUB: op_num = 1; break; + case GGML_OP_MUL: op_num = 2; break; + case GGML_OP_DIV: op_num = 3; break; default: GGML_ABORT("fatal error"); }; diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 17d51b11b6..d5e08b3a25 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1043,6 +1043,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_UNARY_OP_CEIL: case GGML_UNARY_OP_ROUND: case GGML_UNARY_OP_TRUNC: + case GGML_UNARY_OP_XIELU: return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16); default: return false; @@ -1067,6 +1068,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_CONCAT: return true; case GGML_OP_ADD: + case GGML_OP_ADD1: case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index eb2253e029..32b61ead49 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -124,6 +124,7 @@ #define OP_UNARY_NUM_CEIL 118 #define OP_UNARY_NUM_ROUND 119 #define OP_UNARY_NUM_TRUNC 120 +#define OP_UNARY_NUM_XIELU 121 #define OP_SUM_ROWS_NUM_SUM_ROWS 10 #define OP_SUM_ROWS_NUM_MEAN 11 diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 3cda21be43..7df66a7d91 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -268,6 +268,7 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { n_fuse = ggml_metal_op_concat(ctx, idx); } break; case GGML_OP_ADD: + case GGML_OP_ADD1: case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: @@ -787,6 +788,14 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) { args.max = ggml_get_op_params_f32(op, 1); } + if (op->op == GGML_OP_UNARY && ggml_get_unary_op(op) == GGML_UNARY_OP_XIELU) { + // reuse: slope = alpha_n, scale = alpha_p, bias = beta, val = eps + args.slope = ggml_get_op_params_f32(op, 1); + args.scale = ggml_get_op_params_f32(op, 2); + args.bias = ggml_get_op_params_f32(op, 3); + args.val = ggml_get_op_params_f32(op, 4); + } + auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op); if (pipeline.c4) { diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 2074211594..2c74be552d 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1110,6 +1110,14 @@ kernel void kernel_unary_impl( if (FC_OP == OP_UNARY_NUM_TRUNC) { dst_ptr[i0] = (T) trunc(x); } + + if (FC_OP == OP_UNARY_NUM_XIELU) { + // slope = alpha_n, scale = alpha_p, bias = beta, val = eps + const TC pos = args.scale * x * x + args.bias * x; + const TC min_x_eps = fmin(x, (TC) args.val); + const TC neg = (exp(min_x_eps) - 1 - x) * args.slope + args.bias * x; + dst_ptr[i0] = (T) select(neg, pos, x > (TC) 0); + } } #undef FC_OP