@@ -1261,6 +1261,7 @@ struct vk_op_im2col_push_constants {
12611261 int32_t s0; int32_t s1;
12621262 int32_t p0; int32_t p1;
12631263 int32_t d0; int32_t d1;
1264+ uint32_t batch_IC;
12641265};
12651266
12661267struct vk_op_im2col_3d_push_constants {
@@ -5902,6 +5903,9 @@ static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context&
59025903 std::cerr << "(" << buffer.buffer << ", " << buffer.offset << ", " << buffer.range << "), ";
59035904 }
59045905 std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))");
5906+ GGML_ASSERT(wg0 <= ctx->device->properties.limits.maxComputeWorkGroupCount[0] &&
5907+ wg1 <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&
5908+ wg2 <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
59055909 GGML_ASSERT(ctx->descriptor_set_idx < ctx->descriptor_sets.size());
59065910 GGML_ASSERT(descriptor_buffer_infos.size() <= MAX_PARAMETER_COUNT);
59075911 GGML_ASSERT(pipeline->parameter_count == descriptor_buffer_infos.size());
@@ -9090,6 +9094,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
90909094 const uint32_t batch = src1->ne[is_2D ? 3 : 2];
90919095
90929096 elements = { OW * KW * KH, OH, batch * IC };
9097+ elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
9098+ elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
90939099 } break;
90949100 case GGML_OP_IM2COL_3D:
90959101 {
@@ -10605,6 +10611,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
1060510611 const uint32_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
1060610612
1060710613 const uint32_t pelements = OW * KW * KH;
10614+ const uint32_t batch = src1->ne[is_2D ? 3 : 2];
1060810615
1060910616 const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
1061010617 const vk_buffer d_buf = d_buf_ctx->dev_buffer;
@@ -10617,7 +10624,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
1061710624 IC, IW, IH, OW, OH, KW, KH,
1061810625 pelements,
1061910626 IC * KH * KW,
10620- s0, s1, p0, p1, d0, d1,
10627+ s0, s1, p0, p1, d0, d1, batch * IC
1062110628 });
1062210629}
1062310630
0 commit comments