Skip to content

Commit 68b206b

Browse files
committed
sam implementation without using CPU only ops
1 parent 88032f4 commit 68b206b

File tree

1 file changed

+103
-6
lines changed

1 file changed

+103
-6
lines changed

tools/mtmd/clip.cpp

Lines changed: 103 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -734,8 +734,8 @@ struct clip_graph {
734734

735735
struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, 1.0f / sqrtf(enc_d_heads));
736736

737-
struct ggml_tensor * rw = ggml_get_rel_pos(ctx0, layer.rel_pos_w, W, W);
738-
struct ggml_tensor * rh = ggml_get_rel_pos(ctx0, layer.rel_pos_h, H, H);
737+
struct ggml_tensor * rw = get_rel_pos(ctx0, layer.rel_pos_w, W, W);
738+
struct ggml_tensor * rh = get_rel_pos(ctx0, layer.rel_pos_h, H, H);
739739

740740
struct ggml_tensor * q_r = ggml_reshape_4d(ctx0, Qcur, enc_d_heads, W, H, B * enc_n_heads);
741741

@@ -745,7 +745,7 @@ struct clip_graph {
745745
2, 1, 3));
746746
struct ggml_tensor * rel_h = ggml_mul_mat(ctx0, rh, q_r);
747747

748-
struct ggml_tensor * attn = ggml_add_rel_pos_inplace(ctx0, KQ_scaled, rel_w, rel_h);
748+
struct ggml_tensor * attn = add_rel_pos_inplace(ctx0, KQ_scaled, rel_w, rel_h, W);
749749

750750
struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, attn);
751751

@@ -837,9 +837,9 @@ struct clip_graph {
837837
ggml_tensor * global_features_2 = build_dp_ocr_clip(inp_raw, global_features_1);
838838

839839
// torch global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
840-
global_features_1 = ggml_permute(ctx0, global_features_1,2,1,0,3);
841-
global_features_1 = ggml_cont(ctx0, global_features_1);
840+
global_features_1 = ggml_cont(ctx0,ggml_permute(ctx0, global_features_1,2,1,0,3));
842841
global_features_1 = ggml_reshape_2d(ctx0, global_features_1, n_embd, n_patches);
842+
843843
// remove CLS token
844844
global_features_2 = ggml_view_2d(ctx0, global_features_2,
845845
n_embd, n_patches,
@@ -850,6 +850,7 @@ struct clip_graph {
850850
global_features = ggml_cont(ctx0, global_features);
851851
global_features = ggml_mul_mat(ctx0, model.fc_w, global_features);
852852
global_features = ggml_add(ctx0, global_features, model.fc_b);
853+
853854
global_features = build_global_local_features(ctx0,global_features);
854855
ggml_build_forward_expand(gf, global_features);
855856
return gf;
@@ -869,7 +870,6 @@ struct clip_graph {
869870
t = ggml_cont(ctx0, ggml_permute(ctx0, t, 2, 1, 0, 3)); // (h, w, n_dim)
870871
ggml_tensor * nl = ggml_cont(ctx0,ggml_permute(ctx0, model.image_newline, 2, 1, 0, 3));
871872
nl = ggml_repeat_4d(ctx0, nl, 64, 1, 1280, 1); // n_pos rows
872-
nl = ggml_cont(ctx0, nl);
873873

874874

875875
// 2) image_newline: [n_dim] -> [1, 1, n_dim] -> repeat to [h, 1, n_dim]
@@ -2466,6 +2466,103 @@ struct clip_graph {
24662466
return inpL;
24672467
}
24682468

2469+
// attn: [k_h*k_w, q_h*q_w]
2470+
// rel_h: [q_h, q_w, k_h]
2471+
// rel_w: [q_h, q_w, k_w]
2472+
2473+
static ggml_tensor * add_rel_pos_inplace(
2474+
ggml_context * ctx,
2475+
ggml_tensor * attn,
2476+
ggml_tensor * rel_w,
2477+
ggml_tensor * rel_h,
2478+
int q_size
2479+
) {
2480+
2481+
ggml_tensor *attn_4d =
2482+
ggml_reshape_4d(ctx, attn, q_size,q_size, attn->ne[1], attn->ne[2]);
2483+
2484+
ggml_tensor *rel_h_4d =
2485+
ggml_reshape_4d(ctx, rel_h, 1, q_size, attn->ne[1], attn->ne[2]);
2486+
2487+
ggml_tensor *rel_h_rep = ggml_repeat(ctx, rel_h_4d, attn_4d); // now same shape as attn_5d
2488+
2489+
ggml_tensor *rel_w_4d =
2490+
ggml_reshape_4d(ctx, rel_w, q_size, 1, attn->ne[1], attn->ne[2]);
2491+
2492+
ggml_tensor *rel_w_rep = ggml_repeat(ctx, rel_w_4d, attn_4d); // now same shape as attn_5d
2493+
2494+
ggml_tensor * result = ggml_add(ctx, attn_4d, ggml_add(ctx, rel_h_rep, rel_w_rep));
2495+
result = ggml_reshape_3d(ctx, result, attn->ne[0], attn->ne[1], attn->ne[2]);
2496+
2497+
2498+
return result;
2499+
}
2500+
2501+
2502+
static ggml_tensor * get_rel_pos(
2503+
ggml_context * ctx,
2504+
ggml_tensor * rel_pos, // [L, C]
2505+
int q_size,
2506+
int k_size
2507+
) {
2508+
2509+
const auto dtype = rel_pos->type;
2510+
2511+
const int64_t L = rel_pos->ne[0]; // length
2512+
const int64_t C = rel_pos->ne[1]; // channels
2513+
2514+
// -------------------------------------------------
2515+
// 1) q_idx ← arange(0..q_size-1) [q_size]
2516+
// 2) k_idx ← arange(0..k_size-1) [k_size]
2517+
// -------------------------------------------------
2518+
2519+
2520+
ggml_tensor * q_coord = ggml_cast(ctx,
2521+
ggml_arange(ctx, 0.0f, static_cast<float>(q_size), 1.0f),
2522+
GGML_TYPE_F32); // [q_size]
2523+
ggml_tensor * k_coord = ggml_cast(ctx,
2524+
ggml_arange(ctx, 0.0f, static_cast<float>(k_size), 1.0f),
2525+
GGML_TYPE_F32); // [k_size]
2526+
2527+
ggml_tensor * rel = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, q_size, k_size);
2528+
q_coord = ggml_cont(ctx,ggml_repeat(ctx, q_coord, rel)); // [q_size, k_size]
2529+
2530+
// broadcast reshape:
2531+
k_coord = ggml_reshape_2d(ctx, k_coord, 1, k_size); // [1, k_size]
2532+
k_coord = ggml_cont(ctx,ggml_repeat(ctx, k_coord, rel)); // [q_size, k_size]
2533+
2534+
// -------------------------------------------------
2535+
// relative_coords = q - k + (k_size - 1) // SAME as PyTorch when no scaling
2536+
// -------------------------------------------------
2537+
rel = ggml_sub(ctx, k_coord, q_coord); // [q_size, k_size]
2538+
2539+
rel = ggml_scale_bias(ctx, rel, 1.0f, static_cast<float>(k_size) - 1.0f); // [q_size, k_size]
2540+
2541+
// -------------------------------------------------
2542+
// clamp to [0, L-1] and cast to int32 (for ggml_get_rows)
2543+
// -------------------------------------------------
2544+
2545+
ggml_tensor * rel_clamped = ggml_clamp(ctx, rel, 0, static_cast<float>(L - 1));
2546+
2547+
ggml_tensor * idx_2d = ggml_cast(ctx, rel_clamped, GGML_TYPE_I32); // [q_size, k_size]
2548+
2549+
// flatten to 1D for ggml_get_rows
2550+
const int64_t qk = static_cast<int64_t>(q_size) * static_cast<int64_t>(k_size);
2551+
ggml_tensor * idx_flat = ggml_reshape_1d(ctx, idx_2d, qk); // [qk]
2552+
2553+
// -------------------------------------------------
2554+
// Gather from rel_pos → [qk, C]
2555+
// -------------------------------------------------
2556+
ggml_tensor * gathered = ggml_get_rows(ctx, rel_pos, idx_flat); // [qk, C]
2557+
2558+
// reshape to final output → [q_size, k_size, C]
2559+
ggml_tensor * out = ggml_reshape_3d(ctx, gathered,rel_pos->ne[0],
2560+
q_size,
2561+
k_size);
2562+
2563+
return out; // [q_size, k_size, C]
2564+
}
2565+
24692566
static ggml_tensor* window_partition(ggml_context* ctx, ggml_tensor* x, int window) {
24702567
auto [c, w, h, b] = x->ne;
24712568
// same as

0 commit comments

Comments
 (0)