@@ -689,6 +689,7 @@ struct vk_device_struct {
689689 vk_pipeline pipeline_gelu_quick[2];
690690 vk_pipeline pipeline_silu[2];
691691 vk_pipeline pipeline_relu[2];
692+ vk_pipeline pipeline_xielu[2];
692693 vk_pipeline pipeline_neg[2];
693694 vk_pipeline pipeline_tanh[2];
694695 vk_pipeline pipeline_sigmoid[2];
@@ -990,6 +991,8 @@ struct vk_op_push_constants {
990991 uint32_t KY;
991992 float param1;
992993 float param2;
994+ float param3;
995+ float param4;
993996};
994997
995998struct vk_op_glu_push_constants {
@@ -3973,6 +3976,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
39733976 CREATE_UNARY(gelu_quick)
39743977 CREATE_UNARY(silu)
39753978 CREATE_UNARY(relu)
3979+ CREATE_UNARY(xielu)
39763980 CREATE_UNARY(neg)
39773981 CREATE_UNARY(tanh)
39783982 CREATE_UNARY(sigmoid)
@@ -8549,6 +8553,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
85498553 return ctx->device->pipeline_gelu_quick[dst->type == GGML_TYPE_F16];
85508554 case GGML_UNARY_OP_RELU:
85518555 return ctx->device->pipeline_relu[dst->type == GGML_TYPE_F16];
8556+ case GGML_UNARY_OP_XIELU:
8557+ return ctx->device->pipeline_xielu[dst->type == GGML_TYPE_F16];
85528558 case GGML_UNARY_OP_NEG:
85538559 return ctx->device->pipeline_neg[dst->type == GGML_TYPE_F16];
85548560 case GGML_UNARY_OP_TANH:
@@ -9695,14 +9701,14 @@ static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& su
96959701
96969702 ggml_vk_op_f32_opt_step_adamw(
96979703 ctx, subctx, dst,
9698- { (uint32_t)n, 0, 0.0f, 0.0f }
9704+ { (uint32_t)n, 0, 0.0f, 0.0f, 0.0f, 0.0f }
96999705 );
97009706}
97019707
97029708static void ggml_vk_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
97039709 const size_t n = ggml_nelements(dst->src[0]);
97049710
9705- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f });
9711+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f, 0.0f, 0.0f });
97069712}
97079713
97089714static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -9788,6 +9794,7 @@ static void ggml_vk_arange(ggml_backend_vk_context * ctx, vk_context& subctx, gg
97889794 1,
97899795 ggml_get_op_params_f32(dst, 0),
97909796 ggml_get_op_params_f32(dst, 2),
9797+ 0.0f, 0.0f,
97919798 };
97929799
97939800 vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, dst, GGML_OP_ARANGE);
@@ -9809,6 +9816,7 @@ static void ggml_vk_fill(ggml_backend_vk_context * ctx, vk_context& subctx, ggml
98099816 1,
98109817 ggml_get_op_params_f32(dst, 0),
98119818 0.0f,
9819+ 0.0f, 0.0f,
98129820 };
98139821
98149822 vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, dst, GGML_OP_FILL);
@@ -9924,13 +9932,13 @@ static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx,
99249932}
99259933
99269934static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
9927- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f });
9935+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f, 0.0f, 0.0f });
99289936}
99299937
99309938static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
99319939 float * op_params = (float *)dst->op_params;
99329940
9933- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f });
9941+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f });
99349942}
99359943
99369944static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
@@ -9941,7 +9949,7 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
99419949 const float eps = float_op_params[1];
99429950 const uint32_t group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
99439951
9944- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f });
9952+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f, 0.0f, 0.0f });
99459953}
99469954
99479955static uint32_t ggml_vk_rms_num_partials(ggml_backend_vk_context * ctx, const ggml_tensor *node) {
@@ -10110,16 +10118,26 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx,
1011010118
1011110119static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1011210120 float * op_params = (float *)dst->op_params;
10113- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f });
10121+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f });
1011410122}
1011510123
1011610124static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
1011710125 float * op_params = (float *)dst->op_params;
10118- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f });
10126+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f });
1011910127}
1012010128
1012110129static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10122- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f });
10130+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f, 0.0f, 0.0f });
10131+ }
10132+
10133+ static void ggml_vk_xielu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10134+ float * op_params = (float *)dst->op_params;
10135+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_UNARY,
10136+ {
10137+ (uint32_t)ggml_nelements(src0), 0,
10138+ op_params[1], op_params[2], op_params[3], op_params[4]
10139+ }
10140+ );
1012310141}
1012410142
1012510143static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -10244,7 +10262,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
1024410262
1024510263static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1024610264 float * op_params = (float *)dst->op_params;
10247- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1] });
10265+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1], 0.0f, 0.0f });
1024810266}
1024910267
1025010268static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx) {
@@ -10541,11 +10559,11 @@ static void ggml_vk_cumsum(ggml_backend_vk_context * ctx, vk_context& subctx, co
1054110559}
1054210560
1054310561static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10544- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f });
10562+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f, 0.0f, 0.0f });
1054510563}
1054610564
1054710565static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
10548- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f });
10566+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f, 0.0f, 0.0f });
1054910567}
1055010568
1055110569static void ggml_vk_solve_tri(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -10804,7 +10822,7 @@ static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx
1080410822
1080510823static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
1080610824 const float * op_params = (const float *)dst->op_params;
10807- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f });
10825+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f, 0.0f, 0.0f });
1080810826}
1080910827
1081010828#ifdef GGML_VULKAN_RUN_TESTS
@@ -12050,6 +12068,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1205012068 case GGML_UNARY_OP_TRUNC:
1205112069 ggml_vk_unary(ctx, compute_ctx, src0, node);
1205212070 break;
12071+ case GGML_UNARY_OP_XIELU:
12072+ ggml_vk_xielu(ctx, compute_ctx, src0, node);
12073+ break;
1205312074 default:
1205412075 return false;
1205512076 }
@@ -13843,6 +13864,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1384313864 case GGML_UNARY_OP_GELU_QUICK:
1384413865 case GGML_UNARY_OP_SILU:
1384513866 case GGML_UNARY_OP_RELU:
13867+ case GGML_UNARY_OP_XIELU:
1384613868 case GGML_UNARY_OP_NEG:
1384713869 case GGML_UNARY_OP_TANH:
1384813870 case GGML_UNARY_OP_SIGMOID:
@@ -14748,7 +14770,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1474814770 } else if (tensor->op == GGML_OP_LOG) {
1474914771 tensor_clone = ggml_log(ggml_ctx, src_clone[0]);
1475014772 } else if (tensor->op == GGML_OP_TRI) {
14751- tensor_clone = ggml_tri(ggml_ctx, src_clone[0], ggml_get_op_params_i32(tensor, 0));
14773+ tensor_clone = ggml_tri(ggml_ctx, src_clone[0], (ggml_tri_type) ggml_get_op_params_i32(tensor, 0));
1475214774 } else if (tensor->op == GGML_OP_DIAG) {
1475314775 tensor_clone = ggml_diag(ggml_ctx, src_clone[0]);
1475414776 } else if (tensor->op == GGML_OP_CLAMP) {
@@ -14836,6 +14858,13 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1483614858 case GGML_UNARY_OP_RELU:
1483714859 tensor_clone = ggml_relu(ggml_ctx, src_clone[0]);
1483814860 break;
14861+ case GGML_UNARY_OP_XIELU:
14862+ tensor_clone = ggml_xielu(ggml_ctx, src_clone[0], 0, 0, 0, 0);
14863+ ggml_set_op_params_f32(tensor_clone, 1, ggml_get_op_params_f32(tensor, 1));
14864+ ggml_set_op_params_f32(tensor_clone, 2, ggml_get_op_params_f32(tensor, 2));
14865+ ggml_set_op_params_f32(tensor_clone, 3, ggml_get_op_params_f32(tensor, 3));
14866+ ggml_set_op_params_f32(tensor_clone, 4, ggml_get_op_params_f32(tensor, 4));
14867+ break;
1483914868 case GGML_UNARY_OP_NEG:
1484014869 tensor_clone = ggml_neg(ggml_ctx, src_clone[0]);
1484114870 break;
0 commit comments