@@ -739,13 +739,14 @@ struct clip_graph {
739739
740740 struct ggml_tensor * q_r = ggml_reshape_4d (ctx0, Qcur, enc_d_heads, W, H, B * enc_n_heads);
741741
742- struct ggml_tensor * rel_w = ggml_cont (
743- ctx0,
744- ggml_permute (ctx0, ggml_mul_mat (ctx0, rw, ggml_cont (ctx0, ggml_permute (ctx0, q_r, 0 , 2 , 1 , 3 ))), 0 ,
745- 2 , 1 , 3 ));
742+ struct ggml_tensor * rel_w = ggml_cont (ctx0,ggml_permute (ctx0,
743+ ggml_mul_mat (ctx0,
744+ rw,
745+ ggml_cont (ctx0, ggml_permute (ctx0, q_r, 0 , 2 , 1 , 3 ))),
746+ 0 , 2 , 1 , 3 ));
746747 struct ggml_tensor * rel_h = ggml_mul_mat (ctx0, rh, q_r);
747748
748- struct ggml_tensor * attn = add_rel_pos_inplace (ctx0, KQ_scaled, rel_w, rel_h, W );
749+ struct ggml_tensor * attn = add_rel_pos_inplace (ctx0, KQ_scaled, rel_w, rel_h);
749750
750751 struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace (ctx0, attn);
751752
@@ -2466,32 +2467,37 @@ struct clip_graph {
24662467 return inpL;
24672468 }
24682469
2469- // attn: [k_h*k_w, q_h*q_w ]
2470+ // attn: [q_h*q_w, k_h*k_w ]
24702471 // rel_h: [q_h, q_w, k_h]
24712472 // rel_w: [q_h, q_w, k_w]
24722473
24732474 static ggml_tensor * add_rel_pos_inplace (
24742475 ggml_context * ctx,
24752476 ggml_tensor * attn,
24762477 ggml_tensor * rel_w,
2477- ggml_tensor * rel_h,
2478- int q_size
2478+ ggml_tensor * rel_h
24792479 ) {
2480+ const int k_w = rel_w->ne [0 ];
2481+ const int k_h = rel_h->ne [0 ];
2482+ const int q_w = rel_h->ne [1 ];
2483+ const int q_h = rel_h->ne [2 ];
2484+
2485+ GGML_ASSERT (q_w == rel_w->ne [1 ]);
2486+ GGML_ASSERT (q_h == rel_w->ne [2 ]);
2487+ GGML_ASSERT (attn->ne [0 ] == k_h*k_w);
2488+ GGML_ASSERT (attn->ne [1 ] == q_h*q_w);
24802489
2481- ggml_tensor *attn_4d =
2482- ggml_reshape_4d (ctx, attn, q_size,q_size, attn->ne [1 ], attn->ne [2 ]);
2490+ ggml_tensor *attn_4d = ggml_reshape_4d (ctx, attn, k_w, k_h, attn->ne [1 ], attn->ne [2 ]);
24832491
2484- ggml_tensor *rel_h_4d =
2485- ggml_reshape_4d (ctx, rel_h, 1 , q_size, attn->ne [1 ], attn->ne [2 ]);
2492+ ggml_tensor *rel_h_4d = ggml_reshape_4d (ctx, rel_h, 1 , k_h, attn->ne [1 ], attn->ne [2 ]);
24862493
24872494 ggml_tensor *rel_h_rep = ggml_repeat (ctx, rel_h_4d, attn_4d); // now same shape as attn_5d
24882495
2489- ggml_tensor *rel_w_4d =
2490- ggml_reshape_4d (ctx, rel_w, q_size, 1 , attn->ne [1 ], attn->ne [2 ]);
2496+ ggml_tensor *rel_w_4d = ggml_reshape_4d (ctx, rel_w, k_w, 1 , attn->ne [1 ], attn->ne [2 ]);
24912497
24922498 ggml_tensor *rel_w_rep = ggml_repeat (ctx, rel_w_4d, attn_4d); // now same shape as attn_5d
24932499
2494- ggml_tensor * result = ggml_add (ctx, attn_4d, ggml_add (ctx, rel_h_rep, rel_w_rep));
2500+ ggml_tensor * result = ggml_add_inplace (ctx, attn_4d, ggml_add_inplace (ctx, rel_h_rep, rel_w_rep));
24952501 result = ggml_reshape_3d (ctx, result, attn->ne [0 ], attn->ne [1 ], attn->ne [2 ]);
24962502
24972503
0 commit comments