@@ -4498,6 +4498,107 @@ static void ggml_compute_forward_out_prod_f32(
4498
4498
}
4499
4499
}
4500
4500
4501
+ static void ggml_compute_forward_out_prod_f16_f32 (
4502
+ const ggml_compute_params * params,
4503
+ ggml_tensor * dst) {
4504
+
4505
+ const ggml_tensor * src0 = dst->src [0 ];
4506
+ const ggml_tensor * src1 = dst->src [1 ];
4507
+
4508
+ GGML_TENSOR_BINARY_OP_LOCALS
4509
+
4510
+ GGML_ASSERT (dst->type == GGML_TYPE_F32);
4511
+ GGML_ASSERT (src0->type == GGML_TYPE_F16);
4512
+ GGML_ASSERT (src1->type == GGML_TYPE_F32);
4513
+
4514
+ const int ith = params->ith ;
4515
+ const int nth = params->nth ;
4516
+
4517
+ GGML_ASSERT (ne0 == ne00);
4518
+ GGML_ASSERT (ne1 == ne10);
4519
+ GGML_ASSERT (ne2 == ne12);
4520
+ GGML_ASSERT (ne3 == ne13);
4521
+
4522
+ GGML_ASSERT (ne2 % ne02 == 0 );
4523
+ GGML_ASSERT (ne3 % ne03 == 0 );
4524
+
4525
+ // we don't support permuted src0 or src1
4526
+ GGML_ASSERT (nb00 == sizeof (ggml_fp16_t ));
4527
+
4528
+ // dst cannot be transposed or permuted
4529
+ GGML_ASSERT (nb0 == sizeof (float ));
4530
+ // GGML_ASSERT(nb0 <= nb1);
4531
+ // GGML_ASSERT(nb1 <= nb2);
4532
+ // GGML_ASSERT(nb2 <= nb3);
4533
+
4534
+ // nb01 >= nb00 - src0 is not transposed
4535
+ // compute by src0 rows
4536
+
4537
+ if (ith == 0 ) {
4538
+ ggml_vec_set_f32 (ne0*ne1*ne2*ne3, (float *)dst->data , 0 );
4539
+ }
4540
+ ggml_barrier (params->threadpool );
4541
+
4542
+ // dst[:,:,:,:] = 0
4543
+ // for i2,i3:
4544
+ // for i1:
4545
+ // for i01:
4546
+ // for i0:
4547
+ // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
4548
+
4549
+ // parallelize by last three dimensions
4550
+
4551
+ // total rows in dst
4552
+ const int64_t nr = ne1*ne2*ne3;
4553
+
4554
+ // rows per thread
4555
+ const int64_t dr = (nr + nth - 1 )/nth;
4556
+
4557
+ // row range for this thread
4558
+ const int64_t ir0 = dr*ith;
4559
+ const int64_t ir1 = MIN (ir0 + dr, nr);
4560
+
4561
+ // block-tiling attempt
4562
+ const int64_t blck_0 = MAX (GGML_VEC_MAD_UNROLL, 32 );
4563
+ const int64_t blck_1 = 16 ;
4564
+
4565
+ // dps == dst per src0, used for group query attention
4566
+ const int64_t dps2 = ne2 / ne02;
4567
+ const int64_t dps3 = ne3 / ne03;
4568
+
4569
+ for (int64_t bir = ir0; bir < ir1; bir += blck_1) {
4570
+ const int64_t bir1 = MIN (bir + blck_1, ir1);
4571
+ for (int64_t bi01 = 0 ; bi01 < ne01; bi01 += blck_0) {
4572
+ const int64_t bne01 = MIN (bi01 + blck_0, ne01);
4573
+ for (int64_t ir = bir; ir < bir1; ++ir) {
4574
+ // dst indices
4575
+ const int64_t i3 = ir/(ne2*ne1);
4576
+ const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
4577
+ const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
4578
+
4579
+ const int64_t i02 = i2 / dps2;
4580
+ const int64_t i03 = i3 / dps3;
4581
+
4582
+ // const int64_t i10 = i1;
4583
+ const int64_t i12 = i2;
4584
+ const int64_t i13 = i3;
4585
+
4586
+ for (int64_t i01 = bi01; i01 < bne01; ++i01) {
4587
+ const int64_t i11 = i01;
4588
+
4589
+ ggml_fp16_t * s0 = (ggml_fp16_t *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
4590
+ float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
4591
+ float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
4592
+
4593
+ for (int i = 0 ; i < ne0; ++i) {
4594
+ d[i] += GGML_CPU_FP16_TO_FP32 (s0[i])*(*s1);
4595
+ }
4596
+ }
4597
+ }
4598
+ }
4599
+ }
4600
+ }
4601
+
4501
4602
static void ggml_compute_forward_out_prod_q_f32 (
4502
4603
const ggml_compute_params * params,
4503
4604
ggml_tensor * dst) {
@@ -4620,9 +4721,8 @@ void ggml_compute_forward_out_prod(
4620
4721
} break ;
4621
4722
case GGML_TYPE_F16:
4622
4723
{
4623
- GGML_ABORT (" fatal error" ); // todo
4624
- // ggml_compute_forward_out_prod_f16_f32(params, dst);
4625
- }
4724
+ ggml_compute_forward_out_prod_f16_f32 (params, dst);
4725
+ } break ;
4626
4726
case GGML_TYPE_F32:
4627
4727
{
4628
4728
ggml_compute_forward_out_prod_f32 (params, dst);
0 commit comments