Skip to content

Commit 065908c

Browse files
authored
metal : fix fusion across different encoders (ggml-org#14849)
* metal : fix fusion across different encoders ggml-ci * cont : add assertion ggml-ci
1 parent 4ec6291 commit 065908c

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1955,6 +1955,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
19551955
static int ggml_metal_encode_node(
19561956
ggml_backend_t backend,
19571957
int idx,
1958+
int idx_end,
19581959
id<MTLComputeCommandEncoder> encoder,
19591960
struct ggml_metal_mem_pool * mem_pool) {
19601961
struct ggml_backend_metal_context * ctx = backend->context;
@@ -2181,7 +2182,9 @@ static int ggml_metal_encode_node(
21812182
size_t offs_fuse;
21822183
id<MTLBuffer> id_fuse;
21832184

2184-
for (n_fuse = 0; n_fuse <= 6; ++n_fuse) {
2185+
// note: in metal, we sometimes encode the graph in parallel so we have to avoid fusing nodes
2186+
// across splits. idx_end indicates the last node in the current split
2187+
for (n_fuse = 0; n_fuse <= 6 && idx + n_fuse + 1 < idx_end; ++n_fuse) {
21852188
if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) {
21862189
break;
21872190
}
@@ -4288,7 +4291,7 @@ static int ggml_metal_encode_node(
42884291
ops[1] = GGML_OP_MUL;
42894292
ops[2] = GGML_OP_ADD;
42904293

4291-
for (n_fuse = 0; n_fuse <= 1; ++n_fuse) {
4294+
for (n_fuse = 0; n_fuse <= 1 && idx + n_fuse + 1 < idx_end; ++n_fuse) {
42924295
if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) {
42934296
break;
42944297
}
@@ -6271,7 +6274,11 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
62716274
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
62726275
}
62736276

6274-
const int res = ggml_metal_encode_node(backend, idx, encoder, mem_pool);
6277+
const int res = ggml_metal_encode_node(backend, idx, node_end, encoder, mem_pool);
6278+
if (idx + res > node_end) {
6279+
GGML_ABORT("fusion error: nodes spanning multiple encoders have been fused. this indicates a bug in the fusion logic %s",
6280+
"https://github.com/ggml-org/llama.cpp/pull/14849");
6281+
}
62756282

62766283
if (should_capture) {
62776284
[encoder popDebugGroup];

0 commit comments

Comments
 (0)