@@ -3385,11 +3385,11 @@ struct test_mul_mat : public test_case {
33853385 const std::array<int64_t , 2 > bs; // dims 3 and 4
33863386 const std::array<int64_t , 2 > nr; // repeat in dims 3 and 4
33873387 const std::array<int64_t , 4 > per; // permutation of dimensions
3388- const bool v ; // whether a and b are non-contiguous views
3388+ const int64_t k_v ; // size of k in memory, resulting in a non-contiguous view for k_v > k, no view for k_v == 0
33893389 const uint32_t o; // number of outputs
33903390
33913391 std::string vars () override {
3392- return VARS_TO_STR10 (type_a, type_b, m, n, k, bs, nr, per, v , o);
3392+ return VARS_TO_STR10 (type_a, type_b, m, n, k, bs, nr, per, k_v , o);
33933393 }
33943394
33953395 double max_nmse_err () override {
@@ -3410,8 +3410,8 @@ struct test_mul_mat : public test_case {
34103410 std::array<int64_t , 2 > bs = {10 , 10 },
34113411 std::array<int64_t , 2 > nr = {2 , 2 },
34123412 std::array<int64_t , 4 > per = {0 , 1 , 2 , 3 },
3413- bool v = false , uint32_t o = 1 )
3414- : type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr), per(per), v(v ), o(o) {}
3413+ int64_t k_v = 0 , uint32_t o = 1 )
3414+ : type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr), per(per), k_v(k_v ), o(o) {}
34153415
34163416 ggml_tensor * build_graph (ggml_context * ctx) override {
34173417 // C^T = A * B^T: (k, m) * (k, n) => (m, n)
@@ -3421,7 +3421,7 @@ struct test_mul_mat : public test_case {
34213421 const int npermuted = (per[0 ] != 0 ) + (per[1 ] != 1 ) + (per[2 ] != 2 ) + (per[3 ] != 3 );
34223422 if (npermuted > 0 ) {
34233423 GGML_ASSERT (npermuted == 2 );
3424- GGML_ASSERT (!v ); // not handled
3424+ GGML_ASSERT (k_v == 0 ); // not handled
34253425 GGML_ASSERT (!ggml_is_quantized (type_a) || per[0 ] == 0 );
34263426 GGML_ASSERT (!ggml_is_quantized (type_b) || per[0 ] == 0 );
34273427
@@ -3445,29 +3445,21 @@ struct test_mul_mat : public test_case {
34453445 ggml_set_name (a, " a_permuted" );
34463446 ggml_set_name (b, " b_permuted" );
34473447 } else {
3448- if (v) {
3449- a = ggml_new_tensor_4d (ctx, type_a, k* 2 , m, bs[0 ], bs[1 ]);
3450- b = ggml_new_tensor_4d (ctx, type_b, k* 2 , n, bs[0 ]*nr[0 ], bs[1 ]*nr[1 ]);
3448+ const int64_t k_physical = k_v == 0 ? k : k_v;
3449+ a = ggml_new_tensor_4d (ctx, type_a, k_physical , m, bs[0 ], bs[1 ]);
3450+ b = ggml_new_tensor_4d (ctx, type_b, k_physical , n, bs[0 ]*nr[0 ], bs[1 ]*nr[1 ]);
34513451
3452- if (!ggml_is_quantized (type_a)) {
3453- if (bs[1 ] == 1 && nr[1 ] == 1 ) {
3454- ggml_set_param (a);
3455- }
3456- ggml_set_param (b);
3452+ if (!ggml_is_quantized (type_a)) {
3453+ if (bs[1 ] == 1 && nr[1 ] == 1 ) {
3454+ ggml_set_param (a);
34573455 }
3456+ ggml_set_param (b);
3457+ }
34583458
3459+ if (k_v != 0 ) {
3460+ GGML_ASSERT (k_v > k);
34593461 a = ggml_view_4d (ctx, a, k, m, bs[0 ], bs[1 ], a->nb [1 ], a->nb [2 ], a->nb [3 ], 0 );
34603462 b = ggml_view_4d (ctx, b, k, n, bs[0 ]*nr[0 ], bs[1 ]*nr[1 ], b->nb [1 ], b->nb [2 ], b->nb [3 ], 0 );
3461- } else {
3462- a = ggml_new_tensor_4d (ctx, type_a, k, m, bs[0 ], bs[1 ]);
3463- b = ggml_new_tensor_4d (ctx, type_b, k, n, bs[0 ]*nr[0 ], bs[1 ]*nr[1 ]);
3464-
3465- if (!ggml_is_quantized (type_a)) {
3466- if (bs[1 ] == 1 && nr[1 ] == 1 ) {
3467- ggml_set_param (a);
3468- }
3469- ggml_set_param (b);
3470- }
34713463 }
34723464 ggml_set_name (a, " a" );
34733465 ggml_set_name (b, " b" );
@@ -6901,7 +6893,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
69016893 test_cases.emplace_back (new test_mul_mat (GGML_TYPE_F16, GGML_TYPE_F32, 128 , 45 , 64 , { 8 , 1 }, {4 , 1 }));
69026894 test_cases.emplace_back (new test_mul_mat (GGML_TYPE_F16, GGML_TYPE_F32, 1056 , 1 , 193 , {1 , 1 }, {4 , 1 }, {0 , 2 , 1 , 3 }));
69036895 test_cases.emplace_back (new test_mul_mat (GGML_TYPE_F16, GGML_TYPE_F32, 1056 , 1 , 67 , {1 , 1 }, {4 , 1 }, {0 , 2 , 1 , 3 }));
6904- test_cases.emplace_back (new test_mul_mat (GGML_TYPE_F32, GGML_TYPE_F32, 16 , 32 , 32 , { 1 , 1 }, {1 , 1 }, {0 , 1 , 2 , 3 }, true , 3 ));
6896+ test_cases.emplace_back (new test_mul_mat (GGML_TYPE_F32, GGML_TYPE_F32, 16 , 32 , 32 , { 1 , 1 }, {1 , 1 }, {0 , 1 , 2 , 3 }, 64 , 3 ));
69056897 test_cases.emplace_back (new test_mul_mat (GGML_TYPE_F32, GGML_TYPE_F32, 64 , 77 , 77 , {12 ,1 }, {1 ,1 }));
69066898
69076899#if 0
@@ -6927,7 +6919,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
69276919 for (uint32_t k = 0 ; k < 2 ; ++k) {
69286920 for (ggml_type type: {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_F32}) {
69296921 test_cases.emplace_back (new test_mul_mat (type, GGML_TYPE_F32, 1056 + m, 1 , 128 + k, {bs, bs2}, {nr, 1 }, {0 , 2 , 1 , 3 }));
6930- test_cases.emplace_back (new test_mul_mat (type, GGML_TYPE_F32, 128 + m, 1 , 1056 + k, {bs, bs2}, {nr, 1 }, {0 , 1 , 2 , 3 }, true ));
6922+ test_cases.emplace_back (new test_mul_mat (type, GGML_TYPE_F32, 128 + m, 1 , 1056 + k, {bs, bs2}, {nr, 1 }, {0 , 1 , 2 , 3 }, 2 * 1056 + k ));
69316923 }
69326924 }
69336925 }
@@ -7432,7 +7424,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
74327424 test_cases.emplace_back (new test_pad_reflect_1d (GGML_TYPE_F32, {3000 , 384 , 4 , 1 }));
74337425
74347426 test_cases.emplace_back (new test_mul_mat (GGML_TYPE_F16, GGML_TYPE_F32, 16416 , 1 , 128 , {8 , 1 }, {4 , 1 }, {0 , 2 , 1 , 3 }));
7435- test_cases.emplace_back (new test_mul_mat (GGML_TYPE_F16, GGML_TYPE_F32, 128 , 1 , 16416 , {8 , 1 }, {4 , 1 }, {0 , 1 , 2 , 3 }, true ));
7427+ test_cases.emplace_back (new test_mul_mat (GGML_TYPE_F16, GGML_TYPE_F32, 128 , 1 , 16416 , {8 , 1 }, {4 , 1 }, {0 , 1 , 2 , 3 }, 2 * 16416 ));
74367428
74377429 for (int bs : {1 , 2 , 3 , 4 , 5 , 8 , 512 }) {
74387430 for (ggml_type type_a : all_types) {
0 commit comments