Skip to content

Commit b0c47dd

Browse files
committed
POC:Avoid PackTranspose TinyBLAS_PPC in MMA kernel
Signed-off-by: Shalini Salomi Bodapati <[email protected]>
1 parent 0d92267 commit b0c47dd

File tree

6 files changed

+191
-37
lines changed

6 files changed

+191
-37
lines changed

convert_hf_to_gguf.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1963,6 +1963,7 @@ def permute(weights: Tensor, n_head: int, n_head_kv: int | None):
19631963
_experts: list[dict[str, Tensor]] | None = None
19641964

19651965
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
1966+
print(f"[GGUF-CONVERT] modifying tensor {name}")
19661967
n_head = self.hparams["num_attention_heads"]
19671968
n_kv_head = self.hparams.get("num_key_value_heads")
19681969
is_vision_tensor = "vision_tower" in name \
@@ -1985,6 +1986,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
19851986
if name.endswith(("k_proj.weight", "k_proj.bias")):
19861987
data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head)
19871988

1989+
#if name.endswith(("attn.k_proj.weight", "attn.o_proj.weight", "attn.v_proj.weight","attn.q_proj.weight","up_proj.weight", "gate_proj.weight", "down_proj.weight")):
1990+
if name.endswith(( "attn.o_proj.weight", "up_proj.weight", "gate_proj.weight")):
1991+
print(f"[GGUF-CONVERT] Transposing {name}")
1992+
data_torch = data_torch.T.contiguous()
1993+
1994+
19881995
# process the experts separately
19891996
if name.find("block_sparse_moe.experts") != -1:
19901997
n_experts = self.hparams["num_local_experts"]
@@ -2018,8 +2025,10 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
20182025
return tensors
20192026
else:
20202027
return []
2021-
2022-
return [(self.map_tensor_name(name), data_torch)]
2028+
mapped_name = self.map_tensor_name(name)
2029+
print(f"[GGUF-CONVERT] Mapping: {name} --> {mapped_name}")
2030+
print(f"[GGUF-CONVERT] Final shape for {mapped_name}: {data_torch.shape}")
2031+
return [(mapped_name, data_torch)]
20232032

20242033
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
20252034
if rope_scaling := self.find_hparam(["rope_scaling"], optional=True):

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1234,11 +1234,45 @@ void ggml_compute_forward_mul_mat(
12341234
const int64_t r3 = ne13 / ne03;
12351235

12361236
const bool src1_cont = ggml_is_contiguous(src1);
1237-
1237+
bool is_transposed = false;
12381238
if (src1_cont) {
1239-
for (int64_t i13 = 0; i13 < ne13; i13++)
1240-
for (int64_t i12 = 0; i12 < ne12; i12++)
1239+
const char * name = src0->name;
1240+
const char * name1 = src1->name;
1241+
1242+
if (name &&
1243+
strstr(name, "attn_output.weight") ||
1244+
strstr(name, "ffn_up.weight") ||
1245+
strstr(name, "ffn_gate.weight")) {
1246+
printf("[llamafile_sgemm] src0 %s was transposed during HF->GGUF conversion\n", name);
1247+
is_transposed = true;
1248+
//is_transposed = false;
1249+
}
1250+
if (name1 &&
1251+
strstr(name1, "attn_output.weight") ||
1252+
strstr(name1, "ffn_up.weight") ||
1253+
strstr(name1, "ffn_gate.weight")) {
1254+
printf("[llamafile_sgemm] src1 %s was transposed during HF->GGUF conversion\n", name1);
1255+
}
1256+
printf("\n==> llamafile_sgemm call: %s * %s\n", src0->name, src1->name);
1257+
printf("A shape: [%lld x %lld] B shape: [%lld x %lld]\n", src0->ne[1], src0->ne[0], src1->ne[1], src1->ne[0]);
1258+
1259+
for (int64_t i13 = 0; i13 < ne13; i13++) {
1260+
for (int64_t i12 = 0; i12 < ne12; i12++) {
1261+
if (is_transposed) {
12411262
if (!llamafile_sgemm(params,
1263+
ne00/ggml_blck_size(src0->type), ne11, ne01,///ggml_blck_size(src0->type),
1264+
(const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
1265+
ne01,
1266+
(const char *)src1->data + i12*nb12 + i13*nb13,
1267+
nb11/ggml_type_size(src1->type),
1268+
(char *)dst->data + i12*nb2 + i13*nb3,
1269+
ne00/ggml_blck_size(src0->type),
1270+
src0->type,
1271+
src1->type,
1272+
dst->type, is_transposed))
1273+
goto UseGgmlGemm1;
1274+
} else {
1275+
if (!llamafile_sgemm(params,
12421276
ne01, ne11, ne00/ggml_blck_size(src0->type),
12431277
(const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
12441278
nb01/ggml_type_size(src0->type),
@@ -1248,8 +1282,11 @@ void ggml_compute_forward_mul_mat(
12481282
nb1/ggml_type_size(dst->type),
12491283
src0->type,
12501284
src1->type,
1251-
dst->type))
1285+
dst->type, false))
12521286
goto UseGgmlGemm1;
1287+
}
1288+
}
1289+
}
12531290
return;
12541291
}
12551292
UseGgmlGemm1:;
@@ -1304,8 +1341,9 @@ UseGgmlGemm1:;
13041341
const void* wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
13051342
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
13061343

1307-
for (int64_t i13 = 0; i13 < ne13; i13++)
1308-
for (int64_t i12 = 0; i12 < ne12; i12++)
1344+
for (int64_t i13 = 0; i13 < ne13; i13++) {
1345+
for (int64_t i12 = 0; i12 < ne12; i12++) {
1346+
//printf("calling from 2nd site here \n");
13091347
if (!llamafile_sgemm(params,
13101348
ne01, ne11, ne00/ggml_blck_size(src0->type),
13111349
(const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
@@ -1316,8 +1354,10 @@ UseGgmlGemm1:;
13161354
nb1/ggml_type_size(dst->type),
13171355
src0->type,
13181356
vec_dot_type,
1319-
dst->type))
1357+
dst->type, false))
13201358
goto UseGgmlGemm2;
1359+
}
1360+
}
13211361
return;
13221362
}
13231363
UseGgmlGemm2:;

ggml/src/ggml-cpu/llamafile/sgemm.cpp

Lines changed: 112 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2695,11 +2695,13 @@ class tinyBLAS_PPC {
26952695
const TA *A, int64_t lda,
26962696
const TB *B, int64_t ldb,
26972697
TC *C, int64_t ldc,
2698-
int ith, int nth)
2699-
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
2698+
int ith, int nth, int64_t m_orig, bool is_transposed)
2699+
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth), is_transposed(is_transposed){
2700+
m_orig = 0;
27002701
}
27012702

27022703
void matmul(int64_t m, int64_t n) {
2704+
m_orig = m;
27032705
mnpack(0, m, 0, n);
27042706
}
27052707

@@ -2957,7 +2959,13 @@ class tinyBLAS_PPC {
29572959
acc_t acc_0;
29582960
__builtin_mma_xxsetaccz(&acc_0);
29592961
for (int l = 0; l < k; l+=4) {
2960-
packTranspose<vector float>(A+(ii*lda)+l, lda, 4, 4, (TA*)vec_A);
2962+
if (is_transposed) {
2963+
for (int x = 0; x< 4; x++) {
2964+
vec_A[x] = (vec_t)vec_xl(0, (float*)A+ (l+x)*m_orig+ii);
2965+
}
2966+
} else {
2967+
packTranspose<vector float>(A+(ii*lda)+l, lda, 4, 4, (TA*)vec_A);
2968+
}
29612969
packTranspose<vector float>(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B);
29622970
__builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
29632971
__builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
@@ -2973,7 +2981,13 @@ class tinyBLAS_PPC {
29732981
__builtin_mma_xxsetaccz(&acc_0);
29742982
__builtin_mma_xxsetaccz(&acc_1);
29752983
for (int64_t l = 0; l < k; l+=4) {
2976-
packTranspose<vector float>(A+(ii*lda)+l, lda, 4, 4, (TA*)vec_A);
2984+
if (is_transposed) {
2985+
for (int x =0; x< 4; x++) {
2986+
vec_A[x] = (vec_t) vec_xl(0, (float*)A+(l+x)*m_orig+ii);
2987+
}
2988+
}else{
2989+
packTranspose<vector float>(A+(ii*lda)+l, lda, 4, 4, (TA*)vec_A);
2990+
}
29772991
packTranspose<vector float>(B+(jj*ldb)+l, ldb, 8, 4, (TA*)vec_B);
29782992
__builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]);
29792993
__builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]);
@@ -2994,7 +3008,14 @@ class tinyBLAS_PPC {
29943008
__builtin_mma_xxsetaccz(&acc_0);
29953009
__builtin_mma_xxsetaccz(&acc_1);
29963010
for (int64_t l = 0; l < k; l+=4) {
3011+
if (is_transposed) {
3012+
for (int x = 0; x <4; x++) {
3013+
vec_A[2*x] = (vec_t)vec_xl(0, (float*)A+(l+x)*m_orig+ii);
3014+
vec_A[2*x+1] = (vec_t)vec_xl(0, (float*)A+(l+x)*m_orig+ii+4);
3015+
}
3016+
} else {
29973017
packTranspose<vector float>(A+(ii*lda)+l, lda, 8, 4, (TA*)vec_A);
3018+
}
29983019
packTranspose<vector float>(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B);
29993020
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]);
30003021
__builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]);
@@ -3017,7 +3038,14 @@ class tinyBLAS_PPC {
30173038
__builtin_mma_xxsetaccz(&acc_2);
30183039
__builtin_mma_xxsetaccz(&acc_3);
30193040
for (int l = 0; l < k; l+=8) {
3041+
if (is_transposed) {
3042+
for (int x = 0; x <8; x++) {
3043+
vec_A[2*x] = (vec_t)vec_xl(0, (float*)A+(l+x)*m_orig+ii);
3044+
vec_A[2*x+1] = (vec_t)vec_xl(0, (float*)A+(l+x)*m_orig+ii+4);
3045+
}
3046+
} else {
30203047
packTranspose<vector float>(A+(ii*lda)+l, lda, 8, 8, (TA*)vec_A);
3048+
}
30213049
packTranspose<vector float>(B+(jj*ldb)+l, ldb, 8, 8, (TA*)vec_B);
30223050
for(int x = 0; x < 16; x+=2) {
30233051
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]);
@@ -3205,24 +3233,31 @@ class tinyBLAS_PPC {
32053233
* broadcasted, instead of using packing routine to prepack the
32063234
* matrix elements.
32073235
*/
3208-
if (RM == 1) {
3209-
TA* a = const_cast<TA*>(A+(ii)*lda+l);
3236+
if (is_transposed) {
3237+
for (int x = 0; x< 4; x++) {
3238+
vec_A[x] = (vec_t)vec_xl(0, (float*)A+(l+x)*m_orig+ii);
3239+
}
32103240
packTranspose<vector float>(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B);
3211-
vec_A[0] = (vec_t)vec_xl(0,a);
3212-
vec_A[1] = (vec_t)vec_splats(*((TA*)&vec_A+1));
3213-
vec_A[2] = (vec_t)vec_splats(*((TA*)&vec_A+2));
3214-
vec_A[3] = (vec_t)vec_splats(*((TA*)&vec_A+3));
3215-
} else if (RN == 1) {
3216-
packTranspose<vector float>(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A);
3217-
TB* b = const_cast<TB*>(B+(jj)*ldb+l);
3218-
vec_B[0] = (vec_t)vec_xl(0,b);
3219-
vec_B[1] = (vec_t)vec_splats(*((TB*)&vec_B+1));
3220-
vec_B[2] = (vec_t)vec_splats(*((TB*)&vec_B+2));
3221-
vec_B[3] = (vec_t)vec_splats(*((TB*)&vec_B+3));
3222-
} else {
3223-
packTranspose<vector float>(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A);
3224-
packTranspose<vector float>(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B);
3225-
}
3241+
} else {
3242+
if (RM == 1) {
3243+
TA* a = const_cast<TA*>(A+(ii)*lda+l);
3244+
packTranspose<vector float>(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B);
3245+
vec_A[0] = (vec_t)vec_xl(0,a);
3246+
vec_A[1] = (vec_t)vec_splats(*((TA*)&vec_A+1));
3247+
vec_A[2] = (vec_t)vec_splats(*((TA*)&vec_A+2));
3248+
vec_A[3] = (vec_t)vec_splats(*((TA*)&vec_A+3));
3249+
} else if (RN == 1) {
3250+
packTranspose<vector float>(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A);
3251+
TB* b = const_cast<TB*>(B+(jj)*ldb+l);
3252+
vec_B[0] = (vec_t)vec_xl(0,b);
3253+
vec_B[1] = (vec_t)vec_splats(*((TB*)&vec_B+1));
3254+
vec_B[2] = (vec_t)vec_splats(*((TB*)&vec_B+2));
3255+
vec_B[3] = (vec_t)vec_splats(*((TB*)&vec_B+3));
3256+
} else {
3257+
packTranspose<vector float>(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A);
3258+
packTranspose<vector float>(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B);
3259+
}
3260+
}
32263261
__builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
32273262
__builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
32283263
__builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
@@ -3274,6 +3309,8 @@ class tinyBLAS_PPC {
32743309
const int64_t ldc;
32753310
const int ith;
32763311
const int nth;
3312+
int64_t m_orig;
3313+
bool is_transposed;
32773314
};
32783315
#endif
32793316
} // namespace
@@ -3310,13 +3347,16 @@ class tinyBLAS_PPC {
33103347
*/
33113348
bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
33123349
const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
3313-
int64_t ldc, int Atype, int Btype, int Ctype) {
3314-
3350+
int64_t ldc, int Atype, int Btype, int Ctype, bool is_transposed) {
3351+
printf("m=%ld n=%ld k=%ld lda=%ld ldb=%ld ldc=%ld\n", m, n, k, lda, ldb, ldc);
33153352
assert(m >= 0);
33163353
assert(n >= 0);
33173354
assert(k >= 0);
3318-
assert(lda >= k);
3319-
assert(ldb >= k);
3355+
/* if (is_transposed)
3356+
assert(lda >= m);
3357+
else*/
3358+
//assert(lda >= k);
3359+
//assert(ldb >= k);
33203360
assert(ldc >= m);
33213361
assert(params->nth > 0);
33223362
assert(params->ith < params->nth);
@@ -3366,12 +3406,58 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
33663406
#elif defined(__MMA__)
33673407
if (k % 8)
33683408
return false;
3409+
//if (is_transposed)
3410+
//printf("A was transposed during GGUF; m = %d n = %d k = %d\n", m, n, k);
3411+
float * Ap = (float*)A;
3412+
float * Bp = (float*)B;
3413+
float * Cp = (float*)C;
3414+
printf("Matrix AT in column major\n");
3415+
for (int r = 0; r < k; r ++) {
3416+
printf("| ");
3417+
for (int c = 0; c< m; c++) {
3418+
printf("%.2f ", Ap[c*k + r]);
3419+
}
3420+
printf(" |\n");
3421+
}
3422+
printf("A memory layout n");
3423+
for (int i = 0; i < (m*k); i++){
3424+
printf("%.2f ", *(Ap++));
3425+
}
3426+
printf("\n");
3427+
printf("B in column major\n");
3428+
for (int r = 0; r < k; r ++) {
3429+
printf("| ");
3430+
for (int c = 0; c< n; c++) {
3431+
printf("%.2f ", Bp[c*k + r]);
3432+
}
3433+
printf(" |\n");
3434+
}
3435+
3436+
printf("B memory layout n");
3437+
for (int i = 0; i < (n*k); i++){
3438+
printf("%.2f ", *(Bp++));
3439+
}
3440+
printf("\n");
33693441
tinyBLAS_PPC<float, float, float> tb{
33703442
k, (const float *)A, lda,
33713443
(const float *)B, ldb,
33723444
(float *)C, ldc,
3373-
params->ith, params->nth};
3445+
params->ith, params->nth, m, is_transposed};
33743446
tb.matmul(m, n);
3447+
printf("C Matrix\n");
3448+
for (int r = 0; r < m; r ++) {
3449+
printf("| ");
3450+
for (int c = 0; c< n; c++) {
3451+
printf("%.2f ", Cp[c*m + r]);
3452+
}
3453+
printf(" |\n");
3454+
}
3455+
3456+
for (int i = 0; i < (m*n); i++){
3457+
printf("%.2f ", *(Cp++));
3458+
}
3459+
printf("\n");
3460+
//printf("completd llamafile_Sgemm\n");
33753461
return true;
33763462
#else
33773463
return false;

ggml/src/ggml-cpu/llamafile/sgemm.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ extern "C" {
1212

1313
bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t, int64_t, int64_t,
1414
const void *, int64_t, const void *, int64_t, void *, int64_t,
15-
int, int, int);
15+
int, int, int, bool is_transposed);
1616

1717
#ifdef __cplusplus
1818
}

ggml/src/ggml.c

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1904,6 +1904,7 @@ static struct ggml_tensor * ggml_add_impl(
19041904
struct ggml_tensor * a,
19051905
struct ggml_tensor * b,
19061906
bool inplace) {
1907+
//printf("%s %s\n", a->name, b->name);
19071908
GGML_ASSERT(ggml_can_repeat(b, a));
19081909

19091910
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
@@ -2972,7 +2973,7 @@ struct ggml_tensor * ggml_mul_mat(
29722973
struct ggml_context * ctx,
29732974
struct ggml_tensor * a,
29742975
struct ggml_tensor * b) {
2975-
GGML_ASSERT(ggml_can_mul_mat(a, b));
2976+
//GGML_ASSERT(ggml_can_mul_mat(a, b));
29762977
GGML_ASSERT(!ggml_is_transposed(a));
29772978

29782979
const int64_t ne[4] = { a->ne[1], b->ne[1], b->ne[2], b->ne[3] };
@@ -3365,6 +3366,10 @@ struct ggml_tensor * ggml_reshape_3d(
33653366
int64_t ne1,
33663367
int64_t ne2) {
33673368
GGML_ASSERT(ggml_is_contiguous(a));
3369+
//printf("%s\n", a->name);
3370+
//printf("a->ne[] = [%lld %lld %lld %lld]\n", a->ne[0], a->ne[1], a->ne[2], a->ne[3]);
3371+
3372+
//printf("ggml_nelements=%d ne0=%d ne1=%d ne2=%d\n", ggml_nelements(a), ne0, ne1, ne2);
33683373
GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2);
33693374

33703375
const int64_t ne[3] = { ne0, ne1, ne2 };

src/llama-model-loader.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -774,6 +774,20 @@ const struct ggml_tensor * llama_model_loader::check_tensor_dims(const std::stri
774774
break;
775775
}
776776
}
777+
// if direct match fails, try transposed match (only for 2D tensors)
778+
if (!is_ok && ne.size() == 2) {
779+
bool is_transposed_ok = (cur->ne[0] == ne[1] && cur->ne[1] == ne[0]);
780+
for (size_t i = 2; i < GGML_MAX_DIMS; ++i) {
781+
if (cur->ne[i] != 1) {
782+
is_transposed_ok = false;
783+
break;
784+
}
785+
}
786+
if (is_transposed_ok) {
787+
is_ok = true;
788+
}
789+
}
790+
777791
if (!is_ok) {
778792
throw std::runtime_error(
779793
format("%s: tensor '%s' has wrong shape; expected %s, got %s",

0 commit comments

Comments
 (0)