Skip to content

Commit a037a76

Browse files
jeffbolznvtinglou
authored andcommitted
vulkan: Implement "fast divide" (mul+shift) for unary ops like copy (ggml-org#10642)
1 parent 953665e commit a037a76

File tree

3 files changed

+66
-8
lines changed

3 files changed

+66
-8
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,45 @@ struct vk_op_unary_push_constants {
353353
uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
354354
uint32_t d_offset;
355355
float param1; float param2;
356+
uint32_t ne0_012mp; uint32_t ne0_012L;
357+
uint32_t ne0_01mp; uint32_t ne0_01L;
358+
uint32_t ne0_0mp; uint32_t ne0_0L;
359+
uint32_t ne1_012mp; uint32_t ne1_012L;
360+
uint32_t ne1_01mp; uint32_t ne1_01L;
361+
uint32_t ne1_0mp; uint32_t ne1_0L;
356362
};
363+
static_assert(sizeof(vk_op_unary_push_constants) <= 128, "sizeof(vk_op_unary_push_constants) must be <= 128");
364+
365+
// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
366+
// Precompute mp (m' in the paper) and L such that division
367+
// can be computed using a multiply (high 32b of 64b result)
368+
// and a shift:
369+
//
370+
// n/d = (mulhi(n, mp) + n) >> L;
371+
void init_fastdiv_values(uint32_t d, uint32_t &mp, uint32_t &L)
372+
{
373+
// compute L = ceil(log2(d));
374+
L = 0;
375+
while (L < 32 && (uint32_t{1} << L) < d) {
376+
L++;
377+
}
378+
379+
mp = (uint32_t)((uint64_t{1} << 32) * ((uint64_t{1} << L) - d) / d + 1);
380+
}
381+
382+
template <typename T> void init_pushconst_fastdiv(T &p) {
383+
static_assert(!std::is_const<T>::value, "unexpected type");
384+
}
385+
386+
template <> void init_pushconst_fastdiv(vk_op_unary_push_constants &p) {
387+
// Compute magic values to divide by these six numbers.
388+
init_fastdiv_values(p.ne02*p.ne01*p.ne00, p.ne0_012mp, p.ne0_012L);
389+
init_fastdiv_values(p.ne01*p.ne00, p.ne0_01mp, p.ne0_01L);
390+
init_fastdiv_values(p.ne00, p.ne0_0mp, p.ne0_0L);
391+
init_fastdiv_values(p.ne12*p.ne11*p.ne10, p.ne1_012mp, p.ne1_012L);
392+
init_fastdiv_values(p.ne11*p.ne10, p.ne1_01mp, p.ne1_01L);
393+
init_fastdiv_values(p.ne10, p.ne1_0mp, p.ne1_0L);
394+
}
357395

358396
struct vk_op_binary_push_constants {
359397
uint32_t ne;
@@ -2914,13 +2952,14 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context&
29142952
elements = { ne, 1, 1 };
29152953
}
29162954

2917-
const vk_op_unary_push_constants pc = {
2955+
vk_op_unary_push_constants pc = {
29182956
(uint32_t)ne,
29192957
(uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], (uint32_t)tensor->nb[0] / tensor_type_size, (uint32_t)tensor->nb[1] / tensor_type_size, (uint32_t)tensor->nb[2] / tensor_type_size, (uint32_t)tensor->nb[3] / tensor_type_size,
29202958
(uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], 1 , (uint32_t)tensor->ne[0] , (uint32_t)(tensor->ne[0] * tensor->ne[1]) , (uint32_t)(tensor->ne[0] * tensor->ne[1] * tensor->ne[2]),
29212959
0,
29222960
0.0f, 0.0f,
29232961
};
2962+
init_pushconst_fastdiv(pc);
29242963
ggml_vk_sync_buffers(subctx);
29252964
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(vk_op_unary_push_constants), &pc, elements);
29262965
}
@@ -4125,7 +4164,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
41254164
}
41264165

41274166
template<typename PC>
4128-
static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op, const PC&& pc, bool dryrun = false) {
4167+
static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op, PC&& pc, bool dryrun = false) {
41294168
VK_LOG_DEBUG("ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
41304169
if (src1 != nullptr) {
41314170
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
@@ -4165,6 +4204,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
41654204
const uint64_t ned3 = dst->ne[3];
41664205
const uint64_t ned = ned0 * ned1;
41674206

4207+
init_pushconst_fastdiv(pc);
4208+
41684209
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, dst, op);
41694210

41704211
if (pipeline == nullptr) {

ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,13 @@ layout (push_constant) uniform parameter
88
uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13;
99
uint d_offset;
1010
float param1; float param2;
11+
12+
uint ne0_012mp; uint ne0_012L;
13+
uint ne0_01mp; uint ne0_01L;
14+
uint ne0_0mp; uint ne0_0L;
15+
uint ne1_012mp; uint ne1_012L;
16+
uint ne1_01mp; uint ne1_01L;
17+
uint ne1_0mp; uint ne1_0L;
1118
} p;
1219

1320
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
@@ -17,22 +24,30 @@ uint get_idx() {
1724
return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
1825
}
1926

27+
// see init_fastdiv_values in ggml-vulkan.cpp
28+
uint fastdiv(uint n, uint mp, uint L) {
29+
uint msbs, lsbs;
30+
// msbs = mulhi(n, mp)
31+
umulExtended(n, mp, msbs, lsbs);
32+
return (msbs + n) >> L;
33+
}
34+
2035
uint src0_idx(uint idx) {
21-
const uint i03 = idx / (p.ne02*p.ne01*p.ne00);
36+
const uint i03 = fastdiv(idx, p.ne0_012mp, p.ne0_012L);
2237
const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
23-
const uint i02 = (idx - i03_offset) / (p.ne01*p.ne00);
38+
const uint i02 = fastdiv(idx - i03_offset, p.ne0_01mp, p.ne0_01L);
2439
const uint i02_offset = i02*p.ne01*p.ne00;
25-
const uint i01 = (idx - i03_offset - i02_offset) / p.ne00;
40+
const uint i01 = fastdiv(idx - i03_offset - i02_offset, p.ne0_0mp, p.ne0_0L);
2641
const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;
2742
return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00;
2843
}
2944

3045
uint dst_idx(uint idx) {
31-
const uint i13 = idx / (p.ne12*p.ne11*p.ne10);
46+
const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L);
3247
const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10;
33-
const uint i12 = (idx - i13_offset) / (p.ne11*p.ne10);
48+
const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L);
3449
const uint i12_offset = i12*p.ne11*p.ne10;
35-
const uint i11 = (idx - i13_offset - i12_offset) / p.ne10;
50+
const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L);
3651
const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;
3752
return i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10;
3853
}

tests/test-backend-ops.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3862,6 +3862,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
38623862
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 512, 1, 1}));
38633863

38643864
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F16, {512, 3072, 1, 1}));
3865+
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {8192, 512, 2, 1}, {0, 2, 1, 3}));
3866+
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {0, 2, 1, 3}));
38653867

38663868
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, 1.0f, 0.0f));
38673869
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, 1.0f, 0.0f));

0 commit comments

Comments
 (0)