Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions ggml/src/ggml-metal/ggml-metal-device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal
case GGML_UNARY_OP_CEIL: op_num = OP_UNARY_NUM_CEIL; break;
case GGML_UNARY_OP_ROUND: op_num = OP_UNARY_NUM_ROUND; break;
case GGML_UNARY_OP_TRUNC: op_num = OP_UNARY_NUM_TRUNC; break;
case GGML_UNARY_OP_XIELU: op_num = OP_UNARY_NUM_XIELU; break;
default: GGML_ABORT("fatal error");
} break;
default: GGML_ABORT("fatal error");
Expand Down Expand Up @@ -1461,10 +1462,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(ggml_metal_l
int op_num = -1;

switch (op->op) {
case GGML_OP_ADD: op_num = 0; break;
case GGML_OP_SUB: op_num = 1; break;
case GGML_OP_MUL: op_num = 2; break;
case GGML_OP_DIV: op_num = 3; break;
case GGML_OP_ADD:
case GGML_OP_ADD1: op_num = 0; break;
case GGML_OP_SUB: op_num = 1; break;
case GGML_OP_MUL: op_num = 2; break;
case GGML_OP_DIV: op_num = 3; break;
default: GGML_ABORT("fatal error");
};

Expand Down
2 changes: 2 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-device.m
Original file line number Diff line number Diff line change
Expand Up @@ -1043,6 +1043,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_UNARY_OP_CEIL:
case GGML_UNARY_OP_ROUND:
case GGML_UNARY_OP_TRUNC:
case GGML_UNARY_OP_XIELU:
return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);
default:
return false;
Expand All @@ -1067,6 +1068,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_CONCAT:
return true;
case GGML_OP_ADD:
case GGML_OP_ADD1:
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-metal/ggml-metal-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@
#define OP_UNARY_NUM_CEIL 118
#define OP_UNARY_NUM_ROUND 119
#define OP_UNARY_NUM_TRUNC 120
#define OP_UNARY_NUM_XIELU 121

#define OP_SUM_ROWS_NUM_SUM_ROWS 10
#define OP_SUM_ROWS_NUM_MEAN 11
Expand Down
9 changes: 9 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
n_fuse = ggml_metal_op_concat(ctx, idx);
} break;
case GGML_OP_ADD:
case GGML_OP_ADD1:
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
Expand Down Expand Up @@ -787,6 +788,14 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
args.max = ggml_get_op_params_f32(op, 1);
}

if (op->op == GGML_OP_UNARY && ggml_get_unary_op(op) == GGML_UNARY_OP_XIELU) {
// reuse: slope = alpha_n, scale = alpha_p, bias = beta, val = eps
args.slope = ggml_get_op_params_f32(op, 1);
args.scale = ggml_get_op_params_f32(op, 2);
args.bias = ggml_get_op_params_f32(op, 3);
args.val = ggml_get_op_params_f32(op, 4);
}

auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);

if (pipeline.c4) {
Expand Down
8 changes: 8 additions & 0 deletions ggml/src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -1110,6 +1110,14 @@ kernel void kernel_unary_impl(
if (FC_OP == OP_UNARY_NUM_TRUNC) {
dst_ptr[i0] = (T) trunc(x);
}

if (FC_OP == OP_UNARY_NUM_XIELU) {
// slope = alpha_n, scale = alpha_p, bias = beta, val = eps
const TC pos = args.scale * x * x + args.bias * x;
const TC min_x_eps = fmin(x, (TC) args.val);
const TC neg = (exp(min_x_eps) - 1 - x) * args.slope + args.bias * x;
dst_ptr[i0] = (T) select(neg, pos, x > (TC) 0);
}
}

#undef FC_OP
Expand Down
Loading