Skip to content

Commit 1fa4551

Browse files
authored
vulkan: support larger argsort (ggml-org#17313)
* vulkan: support larger argsort This is an extension of the original bitonic sorting shader that puts the temporary values in global memory and when more than 1024 threads are needed it runs multiple workgroups and synchronizes through a pipelinebarrier. To improve the memory access pattern, a copy of the float value is kept with the index value. I've applied this same change to the original shared memory version of the shader, which is still used when ncols <= 1024. * Reduce the number of shader variants. Use smaller workgroups when doing a single pass, for a modest perf boost * reduce loop overhead * run multiple cols per invocation, to reduce barrier overhead
1 parent 2eba631 commit 1fa4551

File tree

5 files changed

+257
-48
lines changed

5 files changed

+257
-48
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 118 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -406,8 +406,8 @@ enum shader_reduction_mode {
406406
SHADER_REDUCTION_MODE_COUNT,
407407
};
408408

409+
// argsort pipelines for up to 1<<10 invocations per workgroup
409410
static constexpr uint32_t num_argsort_pipelines = 11;
410-
static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1);
411411
static constexpr uint32_t num_topk_moe_pipelines = 10;
412412

413413
static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
@@ -526,6 +526,7 @@ struct vk_device_struct {
526526
bool multi_add;
527527
bool shader_int64;
528528
bool buffer_device_address;
529+
bool vulkan_memory_model;
529530

530531
bool add_rms_fusion;
531532
uint32_t partials_binding_alignment;
@@ -539,6 +540,9 @@ struct vk_device_struct {
539540
uint32_t subgroup_max_size;
540541
bool subgroup_require_full_support;
541542

543+
// floor(log2(maxComputeWorkGroupInvocations))
544+
uint32_t max_workgroup_size_log2 {};
545+
542546
bool coopmat_support;
543547
bool coopmat_acc_f32_support {};
544548
bool coopmat_acc_f16_support {};
@@ -684,6 +688,7 @@ struct vk_device_struct {
684688
vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
685689
vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
686690
vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
691+
vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines];
687692
vk_pipeline pipeline_sum_rows_f32;
688693
vk_pipeline pipeline_argmax_f32;
689694
vk_pipeline pipeline_count_equal_i32;
@@ -1174,8 +1179,14 @@ struct vk_op_soft_max_push_constants {
11741179

11751180
struct vk_op_argsort_push_constants {
11761181
uint32_t ncols;
1182+
uint32_t ncols_padded;
1183+
uint32_t ncols_padded_log2;
11771184
uint32_t nrows;
1178-
int32_t order;
1185+
uint32_t order;
1186+
uint32_t outer_start;
1187+
uint32_t outer_end;
1188+
uint32_t inner_start;
1189+
uint32_t inner_end;
11791190
};
11801191

11811192
struct vk_op_im2col_push_constants {
@@ -3895,7 +3906,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
38953906
}
38963907

38973908
for (uint32_t i = 0; i < num_argsort_pipelines; ++i) {
3898-
ggml_vk_create_pipeline2(device, device->pipeline_argsort_f32[i], "argsort_f32_"+std::to_string(i), argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1u<<i, 1, 1}, {1u<<i, i}, 1, true);
3909+
uint32_t BLOCK_SIZE = 1u << std::min(i, device->max_workgroup_size_log2);
3910+
if (i <= device->max_workgroup_size_log2 &&
3911+
2 * sizeof(int) * BLOCK_SIZE <= device->properties.limits.maxComputeSharedMemorySize) {
3912+
const uint32_t NCOLS_PADDED_LOG2 = i;
3913+
ggml_vk_create_pipeline2(device, device->pipeline_argsort_f32[i], "argsort_f32_"+std::to_string(i), argsort_f32_len, argsort_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, NCOLS_PADDED_LOG2}, 1, true);
3914+
}
3915+
const uint32_t WG_UNROLL_FACTOR = BLOCK_SIZE > 1 ? 2 : 1;
3916+
BLOCK_SIZE /= WG_UNROLL_FACTOR;
3917+
ggml_vk_create_pipeline2(device, device->pipeline_argsort_large_f32[i], "argsort_large_f32_"+std::to_string(i), argsort_large_f32_len, argsort_large_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE * WG_UNROLL_FACTOR, 1, 1}, {BLOCK_SIZE, WG_UNROLL_FACTOR}, 1, true);
38993918
}
39003919

39013920
ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
@@ -4296,6 +4315,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
42964315

42974316
device->integer_dot_product = device->integer_dot_product && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated;
42984317

4318+
device->max_workgroup_size_log2 = uint32_t(log2f(float(device->properties.limits.maxComputeWorkGroupInvocations)));
4319+
42994320
std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();
43004321

43014322
// Try to find a non-graphics compute queue and transfer-focused queues
@@ -4435,6 +4456,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
44354456

44364457
device->shader_int64 = device_features2.features.shaderInt64;
44374458
device->buffer_device_address = vk12_features.bufferDeviceAddress;
4459+
device->vulkan_memory_model = vk12_features.vulkanMemoryModel;
44384460

44394461
if (device->subgroup_size_control) {
44404462
device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize;
@@ -8359,19 +8381,6 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
83598381
}
83608382
return nullptr;
83618383
}
8362-
case GGML_OP_ARGSORT:
8363-
if (ctx->num_additional_fused_ops) {
8364-
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
8365-
GGML_ASSERT(idx < num_topk_moe_pipelines);
8366-
topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
8367-
return ctx->device->pipeline_topk_moe[idx][mode];
8368-
}
8369-
8370-
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
8371-
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
8372-
return ctx->device->pipeline_argsort_f32[idx];
8373-
}
8374-
return nullptr;
83758384
case GGML_OP_SUM:
83768385
case GGML_OP_SUM_ROWS:
83778386
case GGML_OP_MEAN:
@@ -8763,8 +8772,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
87638772
elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
87648773
break;
87658774
case GGML_OP_ARGSORT:
8766-
elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 };
8767-
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
8775+
GGML_ASSERT(0);
87688776
break;
87698777
case GGML_OP_IM2COL:
87708778
{
@@ -9891,16 +9899,89 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
98919899
}
98929900

98939901
static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
9894-
int32_t * op_params = (int32_t *)dst->op_params;
9902+
const uint32_t * op_params = (const uint32_t *)dst->op_params;
98959903

98969904
uint32_t ncols = src0->ne[0];
98979905
uint32_t nrows = ggml_nrows(src0);
98989906

9899-
ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGSORT, {
9900-
ncols,
9901-
nrows,
9902-
op_params[0],
9903-
});
9907+
uint32_t ncols_pad_log2 = (uint32_t)ceilf(log2f(float(ncols)));
9908+
uint32_t ncolsp2 = 1 << ncols_pad_log2;
9909+
9910+
vk_op_argsort_push_constants pc { ncols, ncolsp2, ncols_pad_log2, nrows, op_params[0], 0, 0, 0, 0, };
9911+
9912+
// Pick the largest workgroup size <= ncolsp2
9913+
uint32_t pipeline_idx = std::min(ncols_pad_log2, num_argsort_pipelines - 1);
9914+
9915+
// Use the "small" argsort shader if the whole sort can be done by a single workgroup.
9916+
bool use_small = ncols_pad_log2 <= ctx->device->max_workgroup_size_log2 &&
9917+
ctx->device->pipeline_argsort_f32[pipeline_idx] != nullptr;
9918+
9919+
vk_pipeline pipeline = use_small ? ctx->device->pipeline_argsort_f32[pipeline_idx]
9920+
: ctx->device->pipeline_argsort_large_f32[pipeline_idx];
9921+
9922+
vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer(ctx, src0);
9923+
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
9924+
vk_subbuffer subbuf1 = dst_buf;
9925+
9926+
// Reserve space for ivec2 per element, with rows padded to a power of two
9927+
if (!use_small) {
9928+
const size_t x_sz = size_t{ncolsp2} * nrows * 2 * sizeof(int);
9929+
9930+
if (ctx->prealloc_size_x < x_sz) {
9931+
ctx->prealloc_size_x = x_sz;
9932+
ggml_vk_preallocate_buffers(ctx, subctx);
9933+
}
9934+
if (ctx->prealloc_x_need_sync) {
9935+
ggml_vk_sync_buffers(ctx, subctx);
9936+
}
9937+
subbuf1 = { ctx->prealloc_x, 0, ctx->prealloc_x->size };
9938+
}
9939+
9940+
std::array<uint32_t, 3> elements;
9941+
9942+
elements[0] = ncolsp2;
9943+
elements[1] = std::min((uint32_t)ggml_nrows(src0), ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
9944+
elements[2] = 1;
9945+
9946+
// First dispatch initializes tmp_idx and does the first N passes where
9947+
// there is only communication between threads in the same workgroup.
9948+
{
9949+
vk_op_argsort_push_constants pc2 = pc;
9950+
pc2.outer_start = 0;
9951+
pc2.outer_end = std::min(ncols_pad_log2, ctx->device->max_workgroup_size_log2);
9952+
pc2.inner_start = 0;
9953+
pc2.inner_end = 100;
9954+
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
9955+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, dst_buf }, pc2, elements);
9956+
}
9957+
if (!use_small) {
9958+
ggml_vk_sync_buffers(ctx, subctx);
9959+
// Loop over outer/inner passes, synchronizing between each pass.
9960+
for (uint32_t outer = ctx->device->max_workgroup_size_log2; outer < ncols_pad_log2; ++outer) {
9961+
for (uint32_t inner = 0; inner < outer + 1; ++inner) {
9962+
vk_op_argsort_push_constants pc2 = pc;
9963+
pc2.outer_start = outer;
9964+
pc2.outer_end = outer + 1;
9965+
pc2.inner_start = inner;
9966+
pc2.inner_end = inner + 1;
9967+
// When the inner idx is large enough, there's only communication
9968+
// within a workgroup. So the remaining inner iterations can all
9969+
// run in the same dispatch.
9970+
if (outer - inner < pipeline_idx) {
9971+
pc2.inner_end = 100;
9972+
inner = outer;
9973+
pipeline = ctx->device->pipeline_argsort_large_f32[pipeline_idx];
9974+
} else {
9975+
// Smaller workgroup empirically seems to perform better
9976+
pipeline = ctx->device->pipeline_argsort_large_f32[pipeline_idx - 2];
9977+
}
9978+
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
9979+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, dst_buf }, pc2, elements);
9980+
ggml_vk_sync_buffers(ctx, subctx);
9981+
}
9982+
}
9983+
ctx->prealloc_x_need_sync = true;
9984+
}
99049985
}
99059986

99069987
static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
@@ -13721,7 +13802,19 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1372113802
case GGML_OP_LOG:
1372213803
return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16;
1372313804
case GGML_OP_ARGSORT:
13724-
return op->ne[0] <= max_argsort_cols;
13805+
{
13806+
if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {
13807+
return false;
13808+
}
13809+
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
13810+
auto device = ggml_vk_get_device(ctx->device);
13811+
// pipeline_argsort_large_f32 requires vulkan memory model.
13812+
if (device->vulkan_memory_model) {
13813+
return true;
13814+
} else {
13815+
return op->ne[0] <= (1 << device->max_workgroup_size_log2);
13816+
}
13817+
}
1372513818
case GGML_OP_UPSCALE:
1372613819
case GGML_OP_ACC:
1372713820
case GGML_OP_CONCAT:

ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,27 @@
44
#include "types.glsl"
55

66
layout(constant_id = 0) const int BLOCK_SIZE = 1024;
7-
layout(constant_id = 1) const int BLOCK_SIZE_LOG2 = 10;
7+
layout(constant_id = 1) const int NCOLS_PADDED_LOG2 = 10;
88
#define ASC 0
99

1010
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
1111

1212
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
13-
layout (binding = 1) buffer D {int data_d[];};
13+
layout (binding = 2) writeonly buffer D {int data_d[];};
1414

1515
layout (push_constant) uniform parameter {
1616
uint ncols;
17+
uint ncols_padded;
18+
uint ncols_padded_log2;
1719
uint nrows;
1820
uint order;
21+
uint outer_start;
22+
uint outer_end;
23+
uint inner_start;
24+
uint inner_end;
1925
} p;
2026

21-
shared int dst_row[BLOCK_SIZE];
22-
shared A_TYPE a_sh[BLOCK_SIZE];
23-
24-
void swap(uint idx0, uint idx1) {
25-
int tmp = dst_row[idx0];
26-
dst_row[idx0] = dst_row[idx1];
27-
dst_row[idx1] = tmp;
28-
}
27+
shared ivec2 dst_row[BLOCK_SIZE];
2928

3029
void argsort(bool needs_bounds_check, const uint row) {
3130
// bitonic sort
@@ -34,11 +33,10 @@ void argsort(bool needs_bounds_check, const uint row) {
3433
const uint row_offset = row * p.ncols;
3534

3635
// initialize indices
37-
dst_row[col] = col;
38-
a_sh[col] = data_a[row_offset + col];
36+
dst_row[col] = ivec2(col, floatBitsToInt(data_a[row_offset + col]));
3937
barrier();
4038

41-
uint num_outer_loop_iters = BLOCK_SIZE_LOG2;
39+
uint num_outer_loop_iters = NCOLS_PADDED_LOG2;
4240
[[unroll]] for (uint k = 2, outer_idx = 0; outer_idx < num_outer_loop_iters; k *= 2, outer_idx++) {
4341
uint num_inner_loop_iters = outer_idx + 1;
4442
[[unroll]] for (uint j = k / 2, inner_idx = 0; inner_idx < num_inner_loop_iters; j /= 2, inner_idx++) {
@@ -47,14 +45,15 @@ void argsort(bool needs_bounds_check, const uint row) {
4745
int idx_0 = (col & k) == 0 ? col : ixj;
4846
int idx_1 = (col & k) == 0 ? ixj : col;
4947

50-
int sh_idx_0 = dst_row[idx_0];
51-
int sh_idx_1 = dst_row[idx_1];
52-
bool idx_0_oob = needs_bounds_check ? sh_idx_0 >= p.ncols : false;
53-
bool idx_1_oob = needs_bounds_check ? sh_idx_1 >= p.ncols : false;
48+
ivec2 sh_idx_0 = dst_row[idx_0];
49+
ivec2 sh_idx_1 = dst_row[idx_1];
50+
bool idx_0_oob = needs_bounds_check ? sh_idx_0.x >= p.ncols : false;
51+
bool idx_1_oob = needs_bounds_check ? sh_idx_1.x >= p.ncols : false;
5452

5553
if ((idx_0_oob ||
56-
(!idx_1_oob && a_sh[sh_idx_0] > a_sh[sh_idx_1])) && (ixj > col)) {
57-
swap(idx_0, idx_1);
54+
(!idx_1_oob && intBitsToFloat(sh_idx_0.y) > intBitsToFloat(sh_idx_1.y))) && (ixj > col)) {
55+
dst_row[idx_0] = sh_idx_1;
56+
dst_row[idx_1] = sh_idx_0;
5857
}
5958

6059
barrier();
@@ -63,9 +62,9 @@ void argsort(bool needs_bounds_check, const uint row) {
6362

6463
if (col < p.ncols) {
6564
if (p.order == ASC) {
66-
data_d[row_offset + col] = dst_row[col];
65+
data_d[row_offset + col] = dst_row[col].x;
6766
} else {
68-
data_d[row_offset + p.ncols - col - 1] = dst_row[col];
67+
data_d[row_offset + p.ncols - col - 1] = dst_row[col].x;
6968
}
7069
}
7170
}

0 commit comments

Comments
 (0)