Skip to content

Commit effe669

Browse files
committed
mtmd: minor changed
1 parent 7b8d735 commit effe669

File tree

1 file changed

+21
-15
lines changed

1 file changed

+21
-15
lines changed

tools/mtmd/clip.cpp

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -739,13 +739,14 @@ struct clip_graph {
739739

740740
struct ggml_tensor * q_r = ggml_reshape_4d(ctx0, Qcur, enc_d_heads, W, H, B * enc_n_heads);
741741

742-
struct ggml_tensor * rel_w = ggml_cont(
743-
ctx0,
744-
ggml_permute(ctx0, ggml_mul_mat(ctx0, rw, ggml_cont(ctx0, ggml_permute(ctx0, q_r, 0, 2, 1, 3))), 0,
745-
2, 1, 3));
742+
struct ggml_tensor * rel_w = ggml_cont(ctx0,ggml_permute(ctx0,
743+
ggml_mul_mat(ctx0,
744+
rw,
745+
ggml_cont(ctx0, ggml_permute(ctx0, q_r, 0, 2, 1, 3))),
746+
0, 2, 1, 3));
746747
struct ggml_tensor * rel_h = ggml_mul_mat(ctx0, rh, q_r);
747748

748-
struct ggml_tensor * attn = add_rel_pos_inplace(ctx0, KQ_scaled, rel_w, rel_h, W);
749+
struct ggml_tensor * attn = add_rel_pos_inplace(ctx0, KQ_scaled, rel_w, rel_h);
749750

750751
struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, attn);
751752

@@ -2466,32 +2467,37 @@ struct clip_graph {
24662467
return inpL;
24672468
}
24682469

2469-
// attn: [k_h*k_w, q_h*q_w]
2470+
// attn: [q_h*q_w, k_h*k_w]
24702471
// rel_h: [q_h, q_w, k_h]
24712472
// rel_w: [q_h, q_w, k_w]
24722473

24732474
static ggml_tensor * add_rel_pos_inplace(
24742475
ggml_context * ctx,
24752476
ggml_tensor * attn,
24762477
ggml_tensor * rel_w,
2477-
ggml_tensor * rel_h,
2478-
int q_size
2478+
ggml_tensor * rel_h
24792479
) {
2480+
const int k_w = rel_w->ne[0];
2481+
const int k_h = rel_h->ne[0];
2482+
const int q_w = rel_h->ne[1];
2483+
const int q_h = rel_h->ne[2];
2484+
2485+
GGML_ASSERT(q_w == rel_w->ne[1]);
2486+
GGML_ASSERT(q_h == rel_w->ne[2]);
2487+
GGML_ASSERT(attn->ne[0] == k_h*k_w);
2488+
GGML_ASSERT(attn->ne[1] == q_h*q_w);
24802489

2481-
ggml_tensor *attn_4d =
2482-
ggml_reshape_4d(ctx, attn, q_size,q_size, attn->ne[1], attn->ne[2]);
2490+
ggml_tensor *attn_4d = ggml_reshape_4d(ctx, attn, k_w, k_h, attn->ne[1], attn->ne[2]);
24832491

2484-
ggml_tensor *rel_h_4d =
2485-
ggml_reshape_4d(ctx, rel_h, 1, q_size, attn->ne[1], attn->ne[2]);
2492+
ggml_tensor *rel_h_4d = ggml_reshape_4d(ctx, rel_h, 1, k_h, attn->ne[1], attn->ne[2]);
24862493

24872494
ggml_tensor *rel_h_rep = ggml_repeat(ctx, rel_h_4d, attn_4d); // now same shape as attn_5d
24882495

2489-
ggml_tensor *rel_w_4d =
2490-
ggml_reshape_4d(ctx, rel_w, q_size, 1, attn->ne[1], attn->ne[2]);
2496+
ggml_tensor *rel_w_4d = ggml_reshape_4d(ctx, rel_w, k_w, 1, attn->ne[1], attn->ne[2]);
24912497

24922498
ggml_tensor *rel_w_rep = ggml_repeat(ctx, rel_w_4d, attn_4d); // now same shape as attn_5d
24932499

2494-
ggml_tensor * result = ggml_add(ctx, attn_4d, ggml_add(ctx, rel_h_rep, rel_w_rep));
2500+
ggml_tensor * result = ggml_add_inplace(ctx, attn_4d, ggml_add_inplace(ctx, rel_h_rep, rel_w_rep));
24952501
result = ggml_reshape_3d(ctx, result, attn->ne[0], attn->ne[1], attn->ne[2]);
24962502

24972503

0 commit comments

Comments
 (0)