Skip to content

Commit fd05c51

Browse files
authored
vulkan: fix im2col overflowing maxworkgroupcount (#18180)
1 parent b365c3f commit fd05c51

File tree

3 files changed

+26
-5
lines changed

3 files changed

+26
-5
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1261,6 +1261,7 @@ struct vk_op_im2col_push_constants {
12611261
int32_t s0; int32_t s1;
12621262
int32_t p0; int32_t p1;
12631263
int32_t d0; int32_t d1;
1264+
uint32_t batch_IC;
12641265
};
12651266

12661267
struct vk_op_im2col_3d_push_constants {
@@ -5902,6 +5903,9 @@ static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context&
59025903
std::cerr << "(" << buffer.buffer << ", " << buffer.offset << ", " << buffer.range << "), ";
59035904
}
59045905
std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))");
5906+
GGML_ASSERT(wg0 <= ctx->device->properties.limits.maxComputeWorkGroupCount[0] &&
5907+
wg1 <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&
5908+
wg2 <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
59055909
GGML_ASSERT(ctx->descriptor_set_idx < ctx->descriptor_sets.size());
59065910
GGML_ASSERT(descriptor_buffer_infos.size() <= MAX_PARAMETER_COUNT);
59075911
GGML_ASSERT(pipeline->parameter_count == descriptor_buffer_infos.size());
@@ -9090,6 +9094,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
90909094
const uint32_t batch = src1->ne[is_2D ? 3 : 2];
90919095

90929096
elements = { OW * KW * KH, OH, batch * IC };
9097+
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
9098+
elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
90939099
} break;
90949100
case GGML_OP_IM2COL_3D:
90959101
{
@@ -10605,6 +10611,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
1060510611
const uint32_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
1060610612

1060710613
const uint32_t pelements = OW * KW * KH;
10614+
const uint32_t batch = src1->ne[is_2D ? 3 : 2];
1060810615

1060910616
const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
1061010617
const vk_buffer d_buf = d_buf_ctx->dev_buffer;
@@ -10617,7 +10624,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
1061710624
IC, IW, IH, OW, OH, KW, KH,
1061810625
pelements,
1061910626
IC * KH * KW,
10620-
s0, s1, p0, p1, d0, d1,
10627+
s0, s1, p0, p1, d0, d1, batch * IC
1062110628
});
1062210629
}
1062310630

ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ layout (push_constant) uniform parameter
1919
int s0; int s1;
2020
int p0; int p1;
2121
int d0; int d1;
22+
uint batch_IC;
2223
} p;
2324

2425
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
@@ -34,12 +35,12 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
3435
layout (buffer_reference) buffer D_ptr {D_TYPE d;};
3536
#endif
3637

37-
void main() {
38+
void im2col(const uint y, const uint z) {
3839
const uint gidx = gl_GlobalInvocationID.x;
3940

40-
const uint oh = gl_GlobalInvocationID.y;
41-
const uint batch = gl_GlobalInvocationID.z / p.IC;
42-
const uint ic = gl_GlobalInvocationID.z % p.IC;
41+
const uint oh = y;
42+
const uint batch = z / p.IC;
43+
const uint ic = z % p.IC;
4344

4445
const uint src_base = ic * p.offset_delta + batch * p.batch_offset;
4546
const BDA_OFFSET_T dst_base = ((BDA_OFFSET_T(batch) * p.OH + oh) * p.OW) * p.CHW + BDA_OFFSET_T(ic) * (p.KW * p.KH);
@@ -101,3 +102,15 @@ void main() {
101102
#endif
102103
}
103104
}
105+
106+
void main() {
107+
uint y = gl_GlobalInvocationID.y;
108+
while (y < p.OH) {
109+
uint z = gl_GlobalInvocationID.z;
110+
while (z < p.batch_IC) {
111+
im2col(y, z);
112+
z += gl_NumWorkGroups.z;
113+
}
114+
y += gl_NumWorkGroups.y;
115+
}
116+
}

tests/test-backend-ops.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6930,6 +6930,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
69306930
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 2560}, {3, 3, 1, 2560}, 1, 1, 1, 1, 1, 1, true));
69316931
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 2560}, {3, 3, 2, 2560}, 1, 1, 1, 1, 1, 1, true));
69326932
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {5, 5, 1, 32}, {3, 4, 1, 32}, 1, 1, 0, 0, 1, 1, true));
6933+
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {2, 2, 1536, 729}, {2, 2, 1536, 4096}, 1, 1, 0, 0, 1, 1, true));
69336934

69346935
// im2col 3D
69356936
test_cases.emplace_back(new test_im2col_3d(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32));

0 commit comments

Comments
 (0)