@@ -13820,12 +13820,11 @@ static bool ggml_vk_can_fuse_rope_set_rows(ggml_backend_vk_context * ctx, const
1382013820 return true;
1382113821}
1382213822
13823- // Check whether the tensors overlap in memory but are not equal.
13824- // Fusions can potenitally overwrite src tensors in ways that are not prevented
13825- // by ggml-alloc. If the fusion is entirely elementwise, then it's OK for them
13826- // to overlap if they are exactly equal.
13827- // XXX TODO this check is probably missing from several fusion optimizations.
13828- static bool ggml_vk_tensors_overlap_but_not_equal(const ggml_tensor * a, const ggml_tensor * b) {
13823+ // Check whether the tensors overlap in memory.
13824+ // Fusions can potentially overwrite src tensors in ways that are not prevented
13825+ // by ggml-alloc. If the fusion src is being applied in a way that's elementwise
13826+ // with the destination, then it's OK for them to overlap if they are exactly equal.
13827+ static bool ggml_vk_tensors_overlap(const ggml_tensor * a, const ggml_tensor * b, bool elementwise) {
1382913828 ggml_backend_vk_buffer_context * a_buf_ctx = (ggml_backend_vk_buffer_context *)a->buffer->context;
1383013829 vk_buffer a_buf = a_buf_ctx->dev_buffer;
1383113830 ggml_backend_vk_buffer_context * b_buf_ctx = (ggml_backend_vk_buffer_context *)b->buffer->context;
@@ -13836,7 +13835,7 @@ static bool ggml_vk_tensors_overlap_but_not_equal(const ggml_tensor * a, const g
1383613835 auto b_base = vk_tensor_offset(b) + b->view_offs;
1383713836 auto b_size = ggml_nbytes(b);
1383813837
13839- if (a_base == b_base && a_size == b_size) {
13838+ if (elementwise && a_base == b_base && a_size == b_size) {
1384013839 return false;
1384113840 }
1384213841
@@ -13874,13 +13873,6 @@ static bool ggml_vk_can_fuse_rms_norm_mul_rope(ggml_backend_vk_context * ctx, co
1387413873 return false;
1387513874 }
1387613875
13877- // must not overwrite srcs in a way that's not elementwise
13878- ggml_tensor *other_src = mul->src[0] == rms ? mul->src[1] : mul->src[0];
13879- if (ggml_vk_tensors_overlap_but_not_equal(rms->src[0], rope) ||
13880- ggml_vk_tensors_overlap_but_not_equal(other_src, rope)) {
13881- return false;
13882- }
13883-
1388413876 // conditions for pipeline creation
1388513877 if (!(ctx->device->float_controls_rte_fp16 &&
1388613878 sizeof(vk_op_rms_norm_mul_rope_push_constants) <= ctx->device->properties.limits.maxPushConstantsSize)) {
@@ -13942,6 +13934,18 @@ static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const stru
1394213934 return num_adds;
1394313935}
1394413936
13937+ static int32_t find_first_set(uint32_t x) {
13938+ int32_t ret = 0;
13939+ if (!x) {
13940+ return -1;
13941+ }
13942+ while (!(x & 1)) {
13943+ x >>= 1;
13944+ ret++;
13945+ }
13946+ return ret;
13947+ }
13948+
1394513949static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
1394613950 VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
1394713951 ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
@@ -14040,6 +14044,12 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
1404014044 total_mul_mat_bytes += bytes;
1404114045 }
1404214046
14047+ // op_srcs_fused_elementwise indicates whether an op's srcs all contribute to
14048+ // the fused result in an elementwise-way. This affects whether the memory for
14049+ // the src is allowed to overlap the memory for the destination.
14050+ // The array is sized to handle the largest fusion (asserted later).
14051+ bool op_srcs_fused_elementwise[12];
14052+
1404314053 ctx->fused_topk_moe_mode = TOPK_MOE_COUNT;
1404414054 ctx->fused_topk_moe_scale = false;
1404514055 const char *fusion_string {};
@@ -14048,39 +14058,68 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
1404814058 if (num_adds) {
1404914059 ctx->num_additional_fused_ops = num_adds - 1;
1405014060 fusion_string = "MULTI_ADD";
14061+ std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, true);
1405114062 } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_ADD })) {
1405214063 ctx->num_additional_fused_ops = 2;
1405314064 fusion_string = "MUL_MAT_ADD_ADD";
14065+ op_srcs_fused_elementwise[0] = false;
14066+ op_srcs_fused_elementwise[1] = true;
14067+ op_srcs_fused_elementwise[2] = true;
1405414068 } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD })) {
1405514069 ctx->num_additional_fused_ops = 1;
1405614070 fusion_string = "MUL_MAT_ADD";
14071+ op_srcs_fused_elementwise[0] = false;
14072+ op_srcs_fused_elementwise[1] = true;
1405714073 } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL })) {
1405814074 ctx->num_additional_fused_ops = 2;
1405914075 fusion_string = "MUL_MAT_ID_ADD_ID_MUL";
14076+ op_srcs_fused_elementwise[0] = false;
14077+ op_srcs_fused_elementwise[1] = true;
14078+ op_srcs_fused_elementwise[2] = true;
1406014079 } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID })) {
1406114080 ctx->num_additional_fused_ops = 1;
1406214081 fusion_string = "MUL_MAT_ID_ADD_ID";
14082+ op_srcs_fused_elementwise[0] = false;
14083+ op_srcs_fused_elementwise[1] = true;
1406314084 } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_MUL })) {
1406414085 ctx->num_additional_fused_ops = 1;
1406514086 fusion_string = "MUL_MAT_ID_MUL";
14087+ op_srcs_fused_elementwise[0] = false;
14088+ op_srcs_fused_elementwise[1] = true;
1406614089 } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 4 }) &&
1406714090 ggml_check_edges(cgraph, i, rms_norm_mul_rope_view_set_rows_edges) &&
1406814091 ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i) &&
1406914092 ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i + 2)) {
1407014093 ctx->num_additional_fused_ops = 4;
1407114094 fusion_string = "RMS_NORM_MUL_ROPE_VIEW_SET_ROWS";
14095+ op_srcs_fused_elementwise[0] = false;
14096+ op_srcs_fused_elementwise[1] = false;
14097+ op_srcs_fused_elementwise[2] = false;
14098+ op_srcs_fused_elementwise[3] = false;
14099+ op_srcs_fused_elementwise[4] = false;
1407214100 } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE })&&
1407314101 ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i)) {
1407414102 ctx->num_additional_fused_ops = 2;
1407514103 fusion_string = "RMS_NORM_MUL_ROPE";
14104+ // rope is approximately elementwise - whole rows are done by a single workgroup and it's row-wise
14105+ op_srcs_fused_elementwise[0] = false;
14106+ op_srcs_fused_elementwise[1] = true;
14107+ op_srcs_fused_elementwise[2] = true;
1407614108 } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
1407714109 ctx->num_additional_fused_ops = 1;
1407814110 fusion_string = "RMS_NORM_MUL";
14111+ // rms_norm is not elementwise, but whole rows must be consumed and the scale factor computed before
14112+ // they are overwritten, and one workgroup per row. So close enough.
14113+ op_srcs_fused_elementwise[0] = true;
14114+ op_srcs_fused_elementwise[1] = true;
1407914115 } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 2 }) &&
1408014116 ggml_check_edges(cgraph, i, rope_view_set_rows_edges) &&
1408114117 ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i)) {
1408214118 ctx->num_additional_fused_ops = 2;
1408314119 fusion_string = "ROPE_VIEW_SET_ROWS";
14120+ op_srcs_fused_elementwise[0] = false;
14121+ op_srcs_fused_elementwise[1] = false;
14122+ op_srcs_fused_elementwise[2] = false;
1408414123 } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) &&
1408514124 ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
1408614125 ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
@@ -14089,6 +14128,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
1408914128 ctx->fused_ops_write_mask |= 1 << 3;
1409014129 ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX_NORM;
1409114130 fusion_string = "TOPK_MOE_EARLY_SOFTMAX_NORM";
14131+ std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, false);
1409214132 } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_sigmoid_norm_bias, { i + 4, i + 10 }) &&
1409314133 ggml_check_edges(cgraph, i, topk_moe_sigmoid_norm_bias_edges) &&
1409414134 ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_SIGMOID_NORM_BIAS)) {
@@ -14097,6 +14137,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
1409714137 ctx->fused_ops_write_mask |= 1 << 4;
1409814138 ctx->fused_topk_moe_mode = TOPK_MOE_SIGMOID_NORM_BIAS;
1409914139 fusion_string = "TOPK_MOE_SIGMOID_NORM_BIAS";
14140+ std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, false);
1410014141 } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
1410114142 ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
1410214143 ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
@@ -14105,6 +14146,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
1410514146 ctx->fused_ops_write_mask |= 1 << 3;
1410614147 ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX;
1410714148 fusion_string = "TOPK_MOE_EARLY_SOFTMAX";
14149+ std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, false);
1410814150 } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
1410914151 ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
1411014152 ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
@@ -14113,18 +14155,81 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
1411314155 ctx->fused_ops_write_mask |= 1 << 1;
1411414156 ctx->fused_topk_moe_mode = TOPK_MOE_LATE_SOFTMAX;
1411514157 fusion_string = "TOPK_MOE_LATE_SOFTMAX";
14158+ std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, false);
1411614159 }
1411714160 if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
1411814161 // Look for an additional scale op to fuse - occurs in deepseek2 and nemotron3 nano.
1411914162 if (ggml_can_fuse_subgraph(cgraph, i + ctx->num_additional_fused_ops - 1, { GGML_OP_DIV, GGML_OP_RESHAPE, GGML_OP_SCALE }, { i + ctx->num_additional_fused_ops + 1 }) ||
1412014163 ggml_can_fuse_subgraph(cgraph, i + ctx->num_additional_fused_ops, { GGML_OP_GET_ROWS, GGML_OP_SCALE }, { i + ctx->num_additional_fused_ops + 1 })) {
1412114164 ctx->fused_topk_moe_scale = true;
1412214165 ctx->num_additional_fused_ops++;
14166+ op_srcs_fused_elementwise[ctx->num_additional_fused_ops] = false;
1412314167 }
1412414168 }
1412514169 }
14170+ GGML_ASSERT(ctx->num_additional_fused_ops < (int)(sizeof(op_srcs_fused_elementwise) / sizeof(op_srcs_fused_elementwise[0])));
1412614171 ctx->fused_ops_write_mask |= 1 << ctx->num_additional_fused_ops;
1412714172
14173+ // Check whether fusion would overwrite src operands while they're still in use.
14174+ // If so, disable fusion.
14175+ if (ctx->num_additional_fused_ops) {
14176+ // There are up to two output nodes - topk_moe has two.
14177+ uint32_t bits = ctx->fused_ops_write_mask & ~(1 << ctx->num_additional_fused_ops);
14178+ ggml_tensor *output_nodes[2] {};
14179+ output_nodes[0] = cgraph->nodes[i + ctx->num_additional_fused_ops];
14180+ if (bits) {
14181+ int output_idx = find_first_set(bits);
14182+ GGML_ASSERT(bits == (1u << output_idx));
14183+ output_nodes[1] = cgraph->nodes[i + output_idx];
14184+ }
14185+
14186+ bool need_disable = false;
14187+
14188+ // topk_moe often overwrites the source, but for a given row all the src values are
14189+ // loaded before anything is stored. If there's only one row, this is safe, so treat
14190+ // this as a special case.
14191+ bool is_topk_moe_single_row = ctx->fused_topk_moe_mode != TOPK_MOE_COUNT &&
14192+ ggml_nrows(cgraph->nodes[i]->src[0]) == 1;
14193+
14194+ if (!is_topk_moe_single_row) {
14195+ for (int j = 0; j < 2; ++j) {
14196+ ggml_tensor *dst = output_nodes[j];
14197+ if (!dst) {
14198+ continue;
14199+ }
14200+ // Loop over all srcs of all nodes in the fusion. If the src overlaps
14201+ // the destination and the src is not an intermediate node that's being
14202+ // elided, then disable fusion.
14203+ for (int k = 0; k <= ctx->num_additional_fused_ops; ++k) {
14204+ for (uint32_t s = 0; s < GGML_MAX_SRC; ++s) {
14205+ ggml_tensor *src = cgraph->nodes[i + k]->src[s];
14206+ if (!src || src->op == GGML_OP_NONE) {
14207+ continue;
14208+ }
14209+ if (ggml_vk_tensors_overlap(src, dst, op_srcs_fused_elementwise[k])) {
14210+ bool found = false;
14211+ for (int n = 0; n < k; ++n) {
14212+ if (cgraph->nodes[i + n] == src) {
14213+ found = true;
14214+ break;
14215+ }
14216+ }
14217+ if (!found) {
14218+ need_disable = true;
14219+ }
14220+ }
14221+ }
14222+ }
14223+ }
14224+ }
14225+ if (need_disable) {
14226+ ctx->num_additional_fused_ops = 0;
14227+ ctx->fused_ops_write_mask = 1;
14228+ ctx->fused_topk_moe_mode = TOPK_MOE_COUNT;
14229+ ctx->fused_topk_moe_scale = false;
14230+ }
14231+ }
14232+
1412814233 // Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)
1412914234 bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
1413014235 bool submit = (submitted_nodes >= nodes_per_submit) ||
0 commit comments