Skip to content

Commit cb64222

Browse files
authored
vulkan: support GGML_UNARY_OP_XIELU (#18062)
1 parent 6eb7081 commit cb64222

File tree

4 files changed

+81
-13
lines changed

4 files changed

+81
-13
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -689,6 +689,7 @@ struct vk_device_struct {
689689
vk_pipeline pipeline_gelu_quick[2];
690690
vk_pipeline pipeline_silu[2];
691691
vk_pipeline pipeline_relu[2];
692+
vk_pipeline pipeline_xielu[2];
692693
vk_pipeline pipeline_neg[2];
693694
vk_pipeline pipeline_tanh[2];
694695
vk_pipeline pipeline_sigmoid[2];
@@ -990,6 +991,8 @@ struct vk_op_push_constants {
990991
uint32_t KY;
991992
float param1;
992993
float param2;
994+
float param3;
995+
float param4;
993996
};
994997

995998
struct vk_op_glu_push_constants {
@@ -3973,6 +3976,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
39733976
CREATE_UNARY(gelu_quick)
39743977
CREATE_UNARY(silu)
39753978
CREATE_UNARY(relu)
3979+
CREATE_UNARY(xielu)
39763980
CREATE_UNARY(neg)
39773981
CREATE_UNARY(tanh)
39783982
CREATE_UNARY(sigmoid)
@@ -8549,6 +8553,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
85498553
return ctx->device->pipeline_gelu_quick[dst->type == GGML_TYPE_F16];
85508554
case GGML_UNARY_OP_RELU:
85518555
return ctx->device->pipeline_relu[dst->type == GGML_TYPE_F16];
8556+
case GGML_UNARY_OP_XIELU:
8557+
return ctx->device->pipeline_xielu[dst->type == GGML_TYPE_F16];
85528558
case GGML_UNARY_OP_NEG:
85538559
return ctx->device->pipeline_neg[dst->type == GGML_TYPE_F16];
85548560
case GGML_UNARY_OP_TANH:
@@ -9695,14 +9701,14 @@ static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& su
96959701

96969702
ggml_vk_op_f32_opt_step_adamw(
96979703
ctx, subctx, dst,
9698-
{ (uint32_t)n, 0, 0.0f, 0.0f }
9704+
{ (uint32_t)n, 0, 0.0f, 0.0f, 0.0f, 0.0f }
96999705
);
97009706
}
97019707

97029708
static void ggml_vk_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
97039709
const size_t n = ggml_nelements(dst->src[0]);
97049710

9705-
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f });
9711+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f, 0.0f, 0.0f });
97069712
}
97079713

97089714
static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -9788,6 +9794,7 @@ static void ggml_vk_arange(ggml_backend_vk_context * ctx, vk_context& subctx, gg
97889794
1,
97899795
ggml_get_op_params_f32(dst, 0),
97909796
ggml_get_op_params_f32(dst, 2),
9797+
0.0f, 0.0f,
97919798
};
97929799

97939800
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, dst, GGML_OP_ARANGE);
@@ -9809,6 +9816,7 @@ static void ggml_vk_fill(ggml_backend_vk_context * ctx, vk_context& subctx, ggml
98099816
1,
98109817
ggml_get_op_params_f32(dst, 0),
98119818
0.0f,
9819+
0.0f, 0.0f,
98129820
};
98139821

98149822
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, dst, GGML_OP_FILL);
@@ -9924,13 +9932,13 @@ static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx,
99249932
}
99259933

99269934
static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
9927-
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f });
9935+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f, 0.0f, 0.0f });
99289936
}
99299937

99309938
static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
99319939
float * op_params = (float *)dst->op_params;
99329940

9933-
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f });
9941+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f });
99349942
}
99359943

99369944
static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
@@ -9941,7 +9949,7 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
99419949
const float eps = float_op_params[1];
99429950
const uint32_t group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
99439951

9944-
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f });
9952+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f, 0.0f, 0.0f });
99459953
}
99469954

99479955
static uint32_t ggml_vk_rms_num_partials(ggml_backend_vk_context * ctx, const ggml_tensor *node) {
@@ -10110,16 +10118,26 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx,
1011010118

1011110119
static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1011210120
float * op_params = (float *)dst->op_params;
10113-
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f });
10121+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f });
1011410122
}
1011510123

1011610124
static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
1011710125
float * op_params = (float *)dst->op_params;
10118-
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f });
10126+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f });
1011910127
}
1012010128

1012110129
static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10122-
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f });
10130+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f, 0.0f, 0.0f });
10131+
}
10132+
10133+
static void ggml_vk_xielu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10134+
float * op_params = (float *)dst->op_params;
10135+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_UNARY,
10136+
{
10137+
(uint32_t)ggml_nelements(src0), 0,
10138+
op_params[1], op_params[2], op_params[3], op_params[4]
10139+
}
10140+
);
1012310141
}
1012410142

1012510143
static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -10244,7 +10262,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
1024410262

1024510263
static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1024610264
float * op_params = (float *)dst->op_params;
10247-
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1] });
10265+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1], 0.0f, 0.0f });
1024810266
}
1024910267

1025010268
static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx) {
@@ -10541,11 +10559,11 @@ static void ggml_vk_cumsum(ggml_backend_vk_context * ctx, vk_context& subctx, co
1054110559
}
1054210560

1054310561
static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10544-
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f });
10562+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f, 0.0f, 0.0f });
1054510563
}
1054610564

1054710565
static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
10548-
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f });
10566+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f, 0.0f, 0.0f });
1054910567
}
1055010568

1055110569
static void ggml_vk_solve_tri(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -10804,7 +10822,7 @@ static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx
1080410822

1080510823
static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
1080610824
const float * op_params = (const float *)dst->op_params;
10807-
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f });
10825+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f, 0.0f, 0.0f });
1080810826
}
1080910827

1081010828
#ifdef GGML_VULKAN_RUN_TESTS
@@ -12050,6 +12068,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1205012068
case GGML_UNARY_OP_TRUNC:
1205112069
ggml_vk_unary(ctx, compute_ctx, src0, node);
1205212070
break;
12071+
case GGML_UNARY_OP_XIELU:
12072+
ggml_vk_xielu(ctx, compute_ctx, src0, node);
12073+
break;
1205312074
default:
1205412075
return false;
1205512076
}
@@ -13843,6 +13864,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1384313864
case GGML_UNARY_OP_GELU_QUICK:
1384413865
case GGML_UNARY_OP_SILU:
1384513866
case GGML_UNARY_OP_RELU:
13867+
case GGML_UNARY_OP_XIELU:
1384613868
case GGML_UNARY_OP_NEG:
1384713869
case GGML_UNARY_OP_TANH:
1384813870
case GGML_UNARY_OP_SIGMOID:
@@ -14748,7 +14770,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1474814770
} else if (tensor->op == GGML_OP_LOG) {
1474914771
tensor_clone = ggml_log(ggml_ctx, src_clone[0]);
1475014772
} else if (tensor->op == GGML_OP_TRI) {
14751-
tensor_clone = ggml_tri(ggml_ctx, src_clone[0], ggml_get_op_params_i32(tensor, 0));
14773+
tensor_clone = ggml_tri(ggml_ctx, src_clone[0], (ggml_tri_type)ggml_get_op_params_i32(tensor, 0));
1475214774
} else if (tensor->op == GGML_OP_DIAG) {
1475314775
tensor_clone = ggml_diag(ggml_ctx, src_clone[0]);
1475414776
} else if (tensor->op == GGML_OP_CLAMP) {
@@ -14836,6 +14858,13 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1483614858
case GGML_UNARY_OP_RELU:
1483714859
tensor_clone = ggml_relu(ggml_ctx, src_clone[0]);
1483814860
break;
14861+
case GGML_UNARY_OP_XIELU:
14862+
tensor_clone = ggml_xielu(ggml_ctx, src_clone[0], 0, 0, 0, 0);
14863+
ggml_set_op_params_f32(tensor_clone, 1, ggml_get_op_params_f32(tensor, 1));
14864+
ggml_set_op_params_f32(tensor_clone, 2, ggml_get_op_params_f32(tensor, 2));
14865+
ggml_set_op_params_f32(tensor_clone, 3, ggml_get_op_params_f32(tensor, 3));
14866+
ggml_set_op_params_f32(tensor_clone, 4, ggml_get_op_params_f32(tensor, 4));
14867+
break;
1483914868
case GGML_UNARY_OP_NEG:
1484014869
tensor_clone = ggml_neg(ggml_ctx, src_clone[0]);
1484114870
break;

ggml/src/ggml-vulkan/vulkan-shaders/generic_head.glsl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,6 @@ layout (push_constant) uniform parameter
66
uint KY;
77
float param1;
88
float param2;
9+
float param3;
10+
float param4;
911
} p;

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -853,6 +853,8 @@ void process_shaders() {
853853
string_to_spv("hardswish_f32", "hardswish.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
854854
string_to_spv("abs_f16", "abs.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
855855
string_to_spv("abs_f32", "abs.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
856+
string_to_spv("xielu_f16", "xielu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
857+
string_to_spv("xielu_f32", "xielu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
856858

857859
string_to_spv("tri_f16", "tri.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
858860
string_to_spv("tri_f32", "tri.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#version 450
2+
3+
#include "generic_head.glsl"
4+
#include "types.glsl"
5+
6+
#extension GL_EXT_control_flow_attributes : enable
7+
8+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
9+
10+
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
11+
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
12+
13+
void main() {
14+
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
15+
16+
if (i >= p.KX) {
17+
return;
18+
}
19+
20+
float x = float(data_a[i]);
21+
22+
float alpha_n = p.param1;
23+
float alpha_p = p.param2;
24+
float beta = p.param3;
25+
float eps = p.param4;
26+
27+
if (x > 0.0f) {
28+
x = alpha_p * x * x + beta * x;
29+
} else {
30+
const float min_x_eps = min(x, eps);
31+
x = (exp(min_x_eps) - 1 - x) * alpha_n + beta * x;
32+
}
33+
34+
data_d[i] = D_TYPE(x);
35+
}

0 commit comments

Comments
 (0)