@@ -9580,16 +9580,11 @@ static bool ggml_compute_forward_mul_mat_use_blas(
95809580}
95819581#endif
95829582
9583- // off1 = offset in i11 and i1
9584- // cne1 = ne11 and ne1
9585- // in a normal matrix multiplication, off1 = 0 and cne1 = ne1
9586- // during GGML_TASK_INIT, the full src1 is converted regardless of off1 and cne1
95879583static void ggml_compute_forward_mul_mat(
95889584        const struct ggml_compute_params * params,
95899585        const struct ggml_tensor * src0,
95909586        const struct ggml_tensor * src1,
9591-               struct ggml_tensor * dst,
9592-               int64_t off1, int64_t cne1) {
9587+               struct ggml_tensor * dst) {
95939588    int64_t t0 = ggml_perf_time_us();
95949589    UNUSED(t0);
95959590
@@ -9657,9 +9652,9 @@ static void ggml_compute_forward_mul_mat(
96579652                const int64_t i03 = i13/r3;
96589653                const int64_t i02 = i12/r2;
96599654
9660-                 const void  * x = (char *)            src0->data +              i02*nb02 + i03*nb03;
9661-                 const float * y = (float *) ((char *) src1->data + off1*nb11 +  i12*nb12 + i13*nb13);
9662-                       float * d = (float *) ((char *)  dst->data + off1*nb1  +  i12*nb2  + i13*nb3);
9655+                 const void  * x = (char *)            src0->data + i02*nb02 + i03*nb03;
9656+                 const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13);
9657+                       float * d = (float *) ((char *)  dst->data + i12*nb2  + i13*nb3);
96639658
96649659                if (type != GGML_TYPE_F32) {
96659660                            float * const wdata    = params->wdata;
@@ -9676,7 +9671,7 @@ static void ggml_compute_forward_mul_mat(
96769671                }
96779672
96789673                cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
9679-                          cne1 , ne01, ne10,
9674+                           ne1 , ne01, ne10,
96809675                         1.0f,    y, ne10,
96819676                                  x, ne00,
96829677                         0.0f,    d, ne01);
@@ -9717,8 +9712,8 @@ static void ggml_compute_forward_mul_mat(
97179712    const void * wdata    = (src1->type == vec_dot_type) ? src1->data : params->wdata;
97189713    const size_t row_size = ggml_row_size(vec_dot_type, ne10);
97199714
9720-     const int64_t nr0 = ne01;            // src0 rows
9721-     const int64_t nr1 = cne1 *ne12*ne13; // src1 rows
9715+     const int64_t nr0 = ne01;          // src0 rows
9716+     const int64_t nr1 = ne1 *ne12*ne13; // src1 rows
97229717
97239718    //printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1);
97249719
@@ -9760,9 +9755,9 @@ static void ggml_compute_forward_mul_mat(
97609755    for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
97619756        for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
97629757            for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
9763-                 const int64_t i13 = (ir1/(ne12*cne1 ));
9764-                 const int64_t i12 = (ir1 - i13*ne12*cne1)/cne1 ;
9765-                 const int64_t i11 = (ir1 - i13*ne12*cne1  - i12*cne1) + off1 ;
9758+                 const int64_t i13 = (ir1/(ne12*ne1 ));
9759+                 const int64_t i12 = (ir1 - i13*ne12*ne1)/ne1 ;
9760+                 const int64_t i11 = (ir1 - i13*ne12*ne1  - i12*ne1) ;
97669761
97679762                // broadcast src0 into src1
97689763                const int64_t i03 = i13/r3;
@@ -9802,28 +9797,191 @@ static void ggml_compute_forward_mul_mat(
98029797
98039798static void ggml_compute_forward_mul_mat_id(
98049799        const struct ggml_compute_params * params,
9805-         const struct ggml_tensor * src0 ,
9800+         const struct ggml_tensor * ids ,
98069801        const struct ggml_tensor * src1,
98079802              struct ggml_tensor * dst) {
98089803
9809-     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
9810-         // during GGML_TASK_INIT the entire src1 is converted to vec_dot_type
9811-         ggml_compute_forward_mul_mat(params, dst->src[2], src1, dst, 0, dst->ne[1]);
9812-         return;
9813-     }
9804+     const struct ggml_tensor * src0 = dst->src[2]; // only for GGML_TENSOR_BINARY_OP_LOCALS
9805+ 
9806+     GGML_TENSOR_BINARY_OP_LOCALS
9807+ 
9808+     const int ith = params->ith;
9809+     const int nth = params->nth;
9810+ 
9811+     const enum ggml_type type = src0->type;
9812+ 
9813+     const bool src1_cont = ggml_is_contiguous(src1);
9814+ 
9815+     ggml_vec_dot_t    const vec_dot               = type_traits[type].vec_dot;
9816+     enum ggml_type    const vec_dot_type          = type_traits[type].vec_dot_type;
9817+     ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
9818+ 
9819+     GGML_ASSERT(ne0 == ne01);
9820+     GGML_ASSERT(ne1 == ne11);
9821+     GGML_ASSERT(ne2 == ne12);
9822+     GGML_ASSERT(ne3 == ne13);
9823+ 
9824+     // we don't support permuted src0 or src1
9825+     GGML_ASSERT(nb00 == ggml_type_size(type));
9826+     GGML_ASSERT(nb10 == ggml_type_size(src1->type));
9827+ 
9828+     // dst cannot be transposed or permuted
9829+     GGML_ASSERT(nb0 == sizeof(float));
9830+     GGML_ASSERT(nb0 <= nb1);
9831+     GGML_ASSERT(nb1 <= nb2);
9832+     GGML_ASSERT(nb2 <= nb3);
98149833
9815-     const struct ggml_tensor * ids = src0;
9834+     // broadcast factors
9835+     const int64_t r2 = ne12/ne02;
9836+     const int64_t r3 = ne13/ne03;
9837+ 
9838+     // row groups
98169839    const int id   = ggml_get_op_params_i32(dst, 0);
98179840    const int n_as = ggml_get_op_params_i32(dst, 1);
98189841
9819-     for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
9820-         const int32_t row_id = *(const int32_t *) ((const char *) ids->data + i01*ids->nb[1] + id*ids->nb[0]);
9842+     char * wdata_src1_end = (src1->type == vec_dot_type) ?
9843+             (char *) params->wdata :
9844+             (char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t));
9845+ 
9846+     int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
9847+     int64_t * matrix_rows       = matrix_row_counts + n_as;     // [n_as][ne11]
9848+ 
9849+     #define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne11 + (i1)]
98219850
9822-         GGML_ASSERT(row_id >= 0 && row_id < n_as);
9851+    if (params->type == GGML_TASK_INIT) {
9852+         char * wdata = params->wdata;
9853+         if (src1->type != vec_dot_type) {
9854+             const size_t row_size = ggml_row_size(vec_dot_type, ne10);
98239855
9824-         const struct ggml_tensor * src0_row = dst->src[row_id + 2];
9825-         ggml_compute_forward_mul_mat(params, src0_row, src1, dst, i01, 1);
9856+             assert(params->wsize >= ne11*ne12*ne13*row_size);
9857+             assert(src1->type == GGML_TYPE_F32);
9858+ 
9859+             for (int64_t i13 = 0; i13 < ne13; ++i13) {
9860+                 for (int64_t i12 = 0; i12 < ne12; ++i12) {
9861+                     for (int64_t i11 = 0; i11 < ne11; ++i11) {
9862+                         from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
9863+                         wdata += row_size;
9864+                     }
9865+                 }
9866+             }
9867+         }
9868+ 
9869+         // initialize matrix_row_counts
9870+         GGML_ASSERT(wdata == wdata_src1_end);
9871+         memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
9872+ 
9873+         // group rows by src0 matrix
9874+         for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
9875+             const int32_t row_id = *(const int32_t *) ((const char *) ids->data + i01*ids->nb[1] + id*ids->nb[0]);
9876+ 
9877+             GGML_ASSERT(row_id >= 0 && row_id < n_as);
9878+             MMID_MATRIX_ROW(row_id, matrix_row_counts[row_id]) = i01;
9879+             matrix_row_counts[row_id] += 1;
9880+         }
9881+ 
9882+         return;
98269883    }
9884+ 
9885+     if (params->type == GGML_TASK_FINALIZE) {
9886+         return;
9887+     }
9888+ 
9889+     // compute each matrix multiplication in sequence
9890+     for (int cur_a = 0; cur_a < n_as; ++cur_a) {
9891+         const int64_t cne1 = matrix_row_counts[cur_a];
9892+ 
9893+         if (cne1 == 0) {
9894+             continue;
9895+         }
9896+ 
9897+         const struct ggml_tensor * src0_cur = dst->src[cur_a + 2];
9898+ 
9899+         const void * wdata    = (src1->type == vec_dot_type) ? src1->data : params->wdata;
9900+         const size_t row_size = ggml_row_size(vec_dot_type, ne10);
9901+ 
9902+         const int64_t nr0 = ne01;           // src0 rows
9903+         const int64_t nr1 = cne1*ne12*ne13; // src1 rows
9904+ 
9905+         //printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1);
9906+ 
9907+         // distribute the thread work across the inner or outer loop based on which one is larger
9908+ 
9909+         const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
9910+         const int64_t nth1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
9911+ 
9912+         const int64_t ith0 = ith % nth0;
9913+         const int64_t ith1 = ith / nth0;
9914+ 
9915+         const int64_t dr0 = (nr0 + nth0 - 1)/nth0;
9916+         const int64_t dr1 = (nr1 + nth1 - 1)/nth1;
9917+ 
9918+         const int64_t ir010 = dr0*ith0;
9919+         const int64_t ir011 = MIN(ir010 + dr0, nr0);
9920+ 
9921+         const int64_t ir110 = dr1*ith1;
9922+         const int64_t ir111 = MIN(ir110 + dr1, nr1);
9923+ 
9924+         //printf("ir010 = %6lld, ir011 = %6lld, ir110 = %6lld, ir111 = %6lld\n", ir010, ir011, ir110, ir111);
9925+ 
9926+         // threads with no work simply yield (not sure if it helps)
9927+         if (ir010 >= ir011 || ir110 >= ir111) {
9928+             sched_yield();
9929+             continue;
9930+         }
9931+ 
9932+         assert(ne12 % ne02 == 0);
9933+         assert(ne13 % ne03 == 0);
9934+ 
9935+         // block-tiling attempt
9936+         const int64_t blck_0 = 16;
9937+         const int64_t blck_1 = 16;
9938+ 
9939+         // attempt to reduce false-sharing (does not seem to make a difference)
9940+         float tmp[16];
9941+ 
9942+         for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
9943+             for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
9944+                 for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
9945+                     const int64_t  i13 = (ir1/(ne12*cne1)); // Note: currently, src1 is always a matrix
9946+                     const int64_t  i12 = (ir1 - i13*ne12*cne1)/cne1;
9947+                     const int64_t _i11 = (ir1 - i13*ne12*cne1 - i12*cne1);
9948+                     const int64_t  i11 = MMID_MATRIX_ROW(cur_a, _i11);
9949+ 
9950+                     // broadcast src0 into src1
9951+                     const int64_t i03 = i13/r3;
9952+                     const int64_t i02 = i12/r2;
9953+ 
9954+                     const int64_t i1 = i11;
9955+                     const int64_t i2 = i12;
9956+                     const int64_t i3 = i13;
9957+ 
9958+                     const char * src0_row = (const char *) src0_cur->data + (0 + i02*nb02 + i03*nb03);
9959+ 
9960+                     // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
9961+                     //       if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
9962+                     //       the original src1 data pointer, so we should index using the indices directly
9963+                     // TODO: this is a bit of a hack, we should probably have a better way to handle this
9964+                     const char * src1_col = (const char *) wdata +
9965+                         (src1_cont || src1->type != vec_dot_type
9966+                         ? (i11      + i12*ne11 + i13*ne12*ne11)*row_size
9967+                         : (i11*nb11 + i12*nb12 + i13*nb13));
9968+ 
9969+                     float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));
9970+ 
9971+                     //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
9972+                     //    vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
9973+                     //}
9974+ 
9975+                     for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
9976+                         vec_dot(ne00, &tmp[ir0 - iir0], src0_row + ir0*nb01, src1_col);
9977+                     }
9978+                     memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
9979+                 }
9980+             }
9981+         }
9982+     }
9983+ 
9984+     #undef MMID_MATRIX_ROW
98279985}
98289986
98299987// ggml_compute_forward_out_prod
@@ -14191,7 +14349,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
1419114349            } break;
1419214350        case GGML_OP_MUL_MAT:
1419314351            {
14194-                 ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor, 0, tensor->ne[1] );
14352+                 ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor);
1419514353            } break;
1419614354        case GGML_OP_MUL_MAT_ID:
1419714355            {
@@ -15991,7 +16149,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
1599116149            } break;
1599216150        case GGML_OP_MUL_MAT_ID:
1599316151            {
15994-                 // FIXME: blas
1599516152                n_tasks = n_threads;
1599616153            } break;
1599716154        case GGML_OP_OUT_PROD:
@@ -16325,20 +16482,16 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
1632516482                } break;
1632616483            case GGML_OP_MUL_MAT_ID:
1632716484                {
16328-                     const struct ggml_tensor * a = node->src[2];
16329-                     const struct ggml_tensor * b = node->src[1];
16330-                     const enum ggml_type vec_dot_type = type_traits[a->type].vec_dot_type;
16331- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
16332-                     if (ggml_compute_forward_mul_mat_use_blas(a, b, node)) {
16333-                         if (a->type != GGML_TYPE_F32) {
16334-                             // here we need memory just for single 2D matrix from src0
16335-                             cur = ggml_type_size(GGML_TYPE_F32)*(a->ne[0]*a->ne[1]);
16336-                         }
16337-                     } else
16338- #endif
16339-                     if (b->type != vec_dot_type) {
16340-                         cur = ggml_row_size(vec_dot_type, ggml_nelements(b));
16485+                     const struct ggml_tensor * src0 = node->src[2];
16486+                     const struct ggml_tensor * src1 = node->src[1];
16487+                     const enum ggml_type vec_dot_type = type_traits[src0->type].vec_dot_type;
16488+                     if (src1->type != vec_dot_type) {
16489+                         cur = ggml_row_size(vec_dot_type, ggml_nelements(src1));
1634116490                    }
16491+                     const int n_as = ggml_get_op_params_i32(node, 1);
16492+                     cur = GGML_PAD(cur, sizeof(int64_t));        // align
16493+                     cur += n_as * sizeof(int64_t);               // matrix_row_counts
16494+                     cur += n_as * src1->ne[1] * sizeof(int64_t); // matrix_rows
1634216495                } break;
1634316496            case GGML_OP_OUT_PROD:
1634416497                {
0 commit comments