@@ -1955,6 +1955,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
19551955static int ggml_metal_encode_node (
19561956 ggml_backend_t backend,
19571957 int idx,
1958+ int idx_end,
19581959 id <MTLComputeCommandEncoder > encoder,
19591960 struct ggml_metal_mem_pool * mem_pool) {
19601961 struct ggml_backend_metal_context * ctx = backend->context ;
@@ -2181,7 +2182,9 @@ static int ggml_metal_encode_node(
21812182 size_t offs_fuse;
21822183 id <MTLBuffer > id_fuse;
21832184
2184- for (n_fuse = 0 ; n_fuse <= 6 ; ++n_fuse) {
2185+ // note: in metal, we sometimes encode the graph in parallel so we have to avoid fusing nodes
2186+ // across splits. idx_end indicates the last node in the current split
2187+ for (n_fuse = 0 ; n_fuse <= 6 && idx + n_fuse + 1 < idx_end; ++n_fuse) {
21852188 if (!ggml_can_fuse (gf, idx + n_fuse, ops + n_fuse, 2 )) {
21862189 break ;
21872190 }
@@ -4288,7 +4291,7 @@ static int ggml_metal_encode_node(
42884291 ops[1 ] = GGML_OP_MUL;
42894292 ops[2 ] = GGML_OP_ADD;
42904293
4291- for (n_fuse = 0 ; n_fuse <= 1 ; ++n_fuse) {
4294+ for (n_fuse = 0 ; n_fuse <= 1 && idx + n_fuse + 1 < idx_end ; ++n_fuse) {
42924295 if (!ggml_can_fuse (gf, idx + n_fuse, ops + n_fuse, 2 )) {
42934296 break ;
42944297 }
@@ -6271,7 +6274,11 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
62716274 [encoder pushDebugGroup: [NSString stringWithCString: ggml_op_desc (ggml_graph_node (ctx->gf, idx)) encoding: NSUTF8StringEncoding]];
62726275 }
62736276
6274- const int res = ggml_metal_encode_node (backend, idx, encoder, mem_pool);
6277+ const int res = ggml_metal_encode_node (backend, idx, node_end, encoder, mem_pool);
6278+ if (idx + res > node_end) {
6279+ GGML_ABORT (" fusion error: nodes spanning multiple encoders have been fused. this indicates a bug in the fusion logic %s " ,
6280+ " https://github.com/ggml-org/llama.cpp/pull/14849" );
6281+ }
62756282
62766283 if (should_capture) {
62776284 [encoder popDebugGroup ];
0 commit comments