@@ -353,7 +353,45 @@ struct vk_op_unary_push_constants {
353353 uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
354354 uint32_t d_offset;
355355 float param1; float param2;
356+ uint32_t ne0_012mp; uint32_t ne0_012L;
357+ uint32_t ne0_01mp; uint32_t ne0_01L;
358+ uint32_t ne0_0mp; uint32_t ne0_0L;
359+ uint32_t ne1_012mp; uint32_t ne1_012L;
360+ uint32_t ne1_01mp; uint32_t ne1_01L;
361+ uint32_t ne1_0mp; uint32_t ne1_0L;
356362};
363+ static_assert (sizeof (vk_op_unary_push_constants) <= 128 , " sizeof(vk_op_unary_push_constants) must be <= 128" );
364+
365+ // See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
366+ // Precompute mp (m' in the paper) and L such that division
367+ // can be computed using a multiply (high 32b of 64b result)
368+ // and a shift:
369+ //
370+ // n/d = (mulhi(n, mp) + n) >> L;
371+ void init_fastdiv_values (uint32_t d, uint32_t &mp, uint32_t &L)
372+ {
373+ // compute L = ceil(log2(d));
374+ L = 0 ;
375+ while (L < 32 && (uint32_t {1 } << L) < d) {
376+ L++;
377+ }
378+
379+ mp = (uint32_t )((uint64_t {1 } << 32 ) * ((uint64_t {1 } << L) - d) / d + 1 );
380+ }
381+
382+ template <typename T> void init_pushconst_fastdiv (T &p) {
383+ static_assert (!std::is_const<T>::value, " unexpected type" );
384+ }
385+
386+ template <> void init_pushconst_fastdiv (vk_op_unary_push_constants &p) {
387+ // Compute magic values to divide by these six numbers.
388+ init_fastdiv_values (p.ne02 *p.ne01 *p.ne00 , p.ne0_012mp , p.ne0_012L );
389+ init_fastdiv_values (p.ne01 *p.ne00 , p.ne0_01mp , p.ne0_01L );
390+ init_fastdiv_values (p.ne00 , p.ne0_0mp , p.ne0_0L );
391+ init_fastdiv_values (p.ne12 *p.ne11 *p.ne10 , p.ne1_012mp , p.ne1_012L );
392+ init_fastdiv_values (p.ne11 *p.ne10 , p.ne1_01mp , p.ne1_01L );
393+ init_fastdiv_values (p.ne10 , p.ne1_0mp , p.ne1_0L );
394+ }
357395
358396struct vk_op_binary_push_constants {
359397 uint32_t ne;
@@ -2914,13 +2952,14 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context&
29142952 elements = { ne, 1 , 1 };
29152953 }
29162954
2917- const vk_op_unary_push_constants pc = {
2955+ vk_op_unary_push_constants pc = {
29182956 (uint32_t )ne,
29192957 (uint32_t )tensor->ne [0 ], (uint32_t )tensor->ne [1 ], (uint32_t )tensor->ne [2 ], (uint32_t )tensor->ne [3 ], (uint32_t )tensor->nb [0 ] / tensor_type_size, (uint32_t )tensor->nb [1 ] / tensor_type_size, (uint32_t )tensor->nb [2 ] / tensor_type_size, (uint32_t )tensor->nb [3 ] / tensor_type_size,
29202958 (uint32_t )tensor->ne [0 ], (uint32_t )tensor->ne [1 ], (uint32_t )tensor->ne [2 ], (uint32_t )tensor->ne [3 ], 1 , (uint32_t )tensor->ne [0 ] , (uint32_t )(tensor->ne [0 ] * tensor->ne [1 ]) , (uint32_t )(tensor->ne [0 ] * tensor->ne [1 ] * tensor->ne [2 ]),
29212959 0 ,
29222960 0 .0f , 0 .0f ,
29232961 };
2962+ init_pushconst_fastdiv (pc);
29242963 ggml_vk_sync_buffers (subctx);
29252964 ggml_vk_dispatch_pipeline (ctx, subctx, pipeline, { in, out }, sizeof (vk_op_unary_push_constants), &pc, elements);
29262965}
@@ -4125,7 +4164,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
41254164}
41264165
41274166template <typename PC>
4128- static void ggml_vk_op_f32 (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op, const PC&& pc, bool dryrun = false ) {
4167+ static void ggml_vk_op_f32 (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op, PC&& pc, bool dryrun = false ) {
41294168 VK_LOG_DEBUG (" ggml_vk_op_f32((" << src0 << " , name=" << src0->name << " , type=" << src0->type << " , ne0=" << src0->ne [0 ] << " , ne1=" << src0->ne [1 ] << " , ne2=" << src0->ne [2 ] << " , ne3=" << src0->ne [3 ] << " , nb0=" << src0->nb [0 ] << " , nb1=" << src0->nb [1 ] << " , nb2=" << src0->nb [2 ] << " , nb3=" << src0->nb [3 ];
41304169 if (src1 != nullptr ) {
41314170 std::cerr << " ), (" << src1 << " , name=" << src1->name << " , type=" << src1->type << " , ne0=" << src1->ne [0 ] << " , ne1=" << src1->ne [1 ] << " , ne2=" << src1->ne [2 ] << " , ne3=" << src1->ne [3 ] << " , nb0=" << src1->nb [0 ] << " , nb1=" << src1->nb [1 ] << " , nb2=" << src1->nb [2 ] << " , nb3=" << src1->nb [3 ];
@@ -4165,6 +4204,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
41654204 const uint64_t ned3 = dst->ne [3 ];
41664205 const uint64_t ned = ned0 * ned1;
41674206
4207+ init_pushconst_fastdiv (pc);
4208+
41684209 vk_pipeline pipeline = ggml_vk_op_get_pipeline (ctx, src0, src1, src2, dst, op);
41694210
41704211 if (pipeline == nullptr ) {
0 commit comments