@@ -406,8 +406,8 @@ enum shader_reduction_mode {
406406 SHADER_REDUCTION_MODE_COUNT,
407407};
408408
409+ // argsort pipelines for up to 1<<10 invocations per workgroup
409410static constexpr uint32_t num_argsort_pipelines = 11;
410- static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1);
411411static constexpr uint32_t num_topk_moe_pipelines = 10;
412412
413413static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
@@ -526,6 +526,7 @@ struct vk_device_struct {
526526 bool multi_add;
527527 bool shader_int64;
528528 bool buffer_device_address;
529+ bool vulkan_memory_model;
529530
530531 bool add_rms_fusion;
531532 uint32_t partials_binding_alignment;
@@ -539,6 +540,9 @@ struct vk_device_struct {
539540 uint32_t subgroup_max_size;
540541 bool subgroup_require_full_support;
541542
543+ // floor(log2(maxComputeWorkGroupInvocations))
544+ uint32_t max_workgroup_size_log2 {};
545+
542546 bool coopmat_support;
543547 bool coopmat_acc_f32_support {};
544548 bool coopmat_acc_f16_support {};
@@ -684,6 +688,7 @@ struct vk_device_struct {
684688 vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
685689 vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
686690 vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
691+ vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines];
687692 vk_pipeline pipeline_sum_rows_f32;
688693 vk_pipeline pipeline_argmax_f32;
689694 vk_pipeline pipeline_count_equal_i32;
@@ -1174,8 +1179,14 @@ struct vk_op_soft_max_push_constants {
11741179
11751180struct vk_op_argsort_push_constants {
11761181 uint32_t ncols;
1182+ uint32_t ncols_padded;
1183+ uint32_t ncols_padded_log2;
11771184 uint32_t nrows;
1178- int32_t order;
1185+ uint32_t order;
1186+ uint32_t outer_start;
1187+ uint32_t outer_end;
1188+ uint32_t inner_start;
1189+ uint32_t inner_end;
11791190};
11801191
11811192struct vk_op_im2col_push_constants {
@@ -3895,7 +3906,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
38953906 }
38963907
38973908 for (uint32_t i = 0; i < num_argsort_pipelines; ++i) {
3898- ggml_vk_create_pipeline2(device, device->pipeline_argsort_f32[i], "argsort_f32_"+std::to_string(i), argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1u<<i, 1, 1}, {1u<<i, i}, 1, true);
3909+ uint32_t BLOCK_SIZE = 1u << std::min(i, device->max_workgroup_size_log2);
3910+ if (i <= device->max_workgroup_size_log2 &&
3911+ 2 * sizeof(int) * BLOCK_SIZE <= device->properties.limits.maxComputeSharedMemorySize) {
3912+ const uint32_t NCOLS_PADDED_LOG2 = i;
3913+ ggml_vk_create_pipeline2(device, device->pipeline_argsort_f32[i], "argsort_f32_"+std::to_string(i), argsort_f32_len, argsort_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, NCOLS_PADDED_LOG2}, 1, true);
3914+ }
3915+ const uint32_t WG_UNROLL_FACTOR = BLOCK_SIZE > 1 ? 2 : 1;
3916+ BLOCK_SIZE /= WG_UNROLL_FACTOR;
3917+ ggml_vk_create_pipeline2(device, device->pipeline_argsort_large_f32[i], "argsort_large_f32_"+std::to_string(i), argsort_large_f32_len, argsort_large_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE * WG_UNROLL_FACTOR, 1, 1}, {BLOCK_SIZE, WG_UNROLL_FACTOR}, 1, true);
38993918 }
39003919
39013920 ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
@@ -4296,6 +4315,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
42964315
42974316 device->integer_dot_product = device->integer_dot_product && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated;
42984317
4318+ device->max_workgroup_size_log2 = uint32_t(log2f(float(device->properties.limits.maxComputeWorkGroupInvocations)));
4319+
42994320 std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();
43004321
43014322 // Try to find a non-graphics compute queue and transfer-focused queues
@@ -4435,6 +4456,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
44354456
44364457 device->shader_int64 = device_features2.features.shaderInt64;
44374458 device->buffer_device_address = vk12_features.bufferDeviceAddress;
4459+ device->vulkan_memory_model = vk12_features.vulkanMemoryModel;
44384460
44394461 if (device->subgroup_size_control) {
44404462 device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize;
@@ -8359,19 +8381,6 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
83598381 }
83608382 return nullptr;
83618383 }
8362- case GGML_OP_ARGSORT:
8363- if (ctx->num_additional_fused_ops) {
8364- uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
8365- GGML_ASSERT(idx < num_topk_moe_pipelines);
8366- topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
8367- return ctx->device->pipeline_topk_moe[idx][mode];
8368- }
8369-
8370- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
8371- uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
8372- return ctx->device->pipeline_argsort_f32[idx];
8373- }
8374- return nullptr;
83758384 case GGML_OP_SUM:
83768385 case GGML_OP_SUM_ROWS:
83778386 case GGML_OP_MEAN:
@@ -8763,8 +8772,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
87638772 elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
87648773 break;
87658774 case GGML_OP_ARGSORT:
8766- elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 };
8767- elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
8775+ GGML_ASSERT(0);
87688776 break;
87698777 case GGML_OP_IM2COL:
87708778 {
@@ -9891,16 +9899,89 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
98919899}
98929900
98939901static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
9894- int32_t * op_params = (int32_t *)dst->op_params;
9902+ const uint32_t * op_params = (const uint32_t *)dst->op_params;
98959903
98969904 uint32_t ncols = src0->ne[0];
98979905 uint32_t nrows = ggml_nrows(src0);
98989906
9899- ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGSORT, {
9900- ncols,
9901- nrows,
9902- op_params[0],
9903- });
9907+ uint32_t ncols_pad_log2 = (uint32_t)ceilf(log2f(float(ncols)));
9908+ uint32_t ncolsp2 = 1 << ncols_pad_log2;
9909+
9910+ vk_op_argsort_push_constants pc { ncols, ncolsp2, ncols_pad_log2, nrows, op_params[0], 0, 0, 0, 0, };
9911+
9912+ // Pick the largest workgroup size <= ncolsp2
9913+ uint32_t pipeline_idx = std::min(ncols_pad_log2, num_argsort_pipelines - 1);
9914+
9915+ // Use the "small" argsort shader if the whole sort can be done by a single workgroup.
9916+ bool use_small = ncols_pad_log2 <= ctx->device->max_workgroup_size_log2 &&
9917+ ctx->device->pipeline_argsort_f32[pipeline_idx] != nullptr;
9918+
9919+ vk_pipeline pipeline = use_small ? ctx->device->pipeline_argsort_f32[pipeline_idx]
9920+ : ctx->device->pipeline_argsort_large_f32[pipeline_idx];
9921+
9922+ vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer(ctx, src0);
9923+ vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
9924+ vk_subbuffer subbuf1 = dst_buf;
9925+
9926+ // Reserve space for ivec2 per element, with rows padded to a power of two
9927+ if (!use_small) {
9928+ const size_t x_sz = size_t{ncolsp2} * nrows * 2 * sizeof(int);
9929+
9930+ if (ctx->prealloc_size_x < x_sz) {
9931+ ctx->prealloc_size_x = x_sz;
9932+ ggml_vk_preallocate_buffers(ctx, subctx);
9933+ }
9934+ if (ctx->prealloc_x_need_sync) {
9935+ ggml_vk_sync_buffers(ctx, subctx);
9936+ }
9937+ subbuf1 = { ctx->prealloc_x, 0, ctx->prealloc_x->size };
9938+ }
9939+
9940+ std::array<uint32_t, 3> elements;
9941+
9942+ elements[0] = ncolsp2;
9943+ elements[1] = std::min((uint32_t)ggml_nrows(src0), ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
9944+ elements[2] = 1;
9945+
9946+ // First dispatch initializes tmp_idx and does the first N passes where
9947+ // there is only communication between threads in the same workgroup.
9948+ {
9949+ vk_op_argsort_push_constants pc2 = pc;
9950+ pc2.outer_start = 0;
9951+ pc2.outer_end = std::min(ncols_pad_log2, ctx->device->max_workgroup_size_log2);
9952+ pc2.inner_start = 0;
9953+ pc2.inner_end = 100;
9954+ ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
9955+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, dst_buf }, pc2, elements);
9956+ }
9957+ if (!use_small) {
9958+ ggml_vk_sync_buffers(ctx, subctx);
9959+ // Loop over outer/inner passes, synchronizing between each pass.
9960+ for (uint32_t outer = ctx->device->max_workgroup_size_log2; outer < ncols_pad_log2; ++outer) {
9961+ for (uint32_t inner = 0; inner < outer + 1; ++inner) {
9962+ vk_op_argsort_push_constants pc2 = pc;
9963+ pc2.outer_start = outer;
9964+ pc2.outer_end = outer + 1;
9965+ pc2.inner_start = inner;
9966+ pc2.inner_end = inner + 1;
9967+ // When the inner idx is large enough, there's only communication
9968+ // within a workgroup. So the remaining inner iterations can all
9969+ // run in the same dispatch.
9970+ if (outer - inner < pipeline_idx) {
9971+ pc2.inner_end = 100;
9972+ inner = outer;
9973+ pipeline = ctx->device->pipeline_argsort_large_f32[pipeline_idx];
9974+ } else {
9975+ // Smaller workgroup empirically seems to perform better
9976+ pipeline = ctx->device->pipeline_argsort_large_f32[pipeline_idx - 2];
9977+ }
9978+ ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
9979+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, dst_buf }, pc2, elements);
9980+ ggml_vk_sync_buffers(ctx, subctx);
9981+ }
9982+ }
9983+ ctx->prealloc_x_need_sync = true;
9984+ }
99049985}
99059986
99069987static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
@@ -13721,7 +13802,19 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1372113802 case GGML_OP_LOG:
1372213803 return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16;
1372313804 case GGML_OP_ARGSORT:
13724- return op->ne[0] <= max_argsort_cols;
13805+ {
13806+ if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {
13807+ return false;
13808+ }
13809+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
13810+ auto device = ggml_vk_get_device(ctx->device);
13811+ // pipeline_argsort_large_f32 requires vulkan memory model.
13812+ if (device->vulkan_memory_model) {
13813+ return true;
13814+ } else {
13815+ return op->ne[0] <= (1 << device->max_workgroup_size_log2);
13816+ }
13817+ }
1372513818 case GGML_OP_UPSCALE:
1372613819 case GGML_OP_ACC:
1372713820 case GGML_OP_CONCAT:
0 commit comments