Skip to content

Commit fb0e501

Browse files
Italo Nicolamakaveli10
authored andcommitted
CPU: add support for fp16_fp32 OUT_PROD op
1 parent e9f5d88 commit fb0e501

File tree

2 files changed

+104
-4
lines changed

2 files changed

+104
-4
lines changed

ggml/src/ggml-cpu/ggml-cpu.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,7 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
442442
case GGML_OP_GET_ROWS_BACK:
443443
return src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16;
444444
case GGML_OP_OUT_PROD:
445-
return (src0->type == GGML_TYPE_F32 || (ggml_is_quantized(src0->type) && src0->ne[2] == src1->ne[2] && src0->ne[3] == src1->ne[3])) &&
445+
return (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || (ggml_is_quantized(src0->type) && src0->ne[2] == src1->ne[2] && src0->ne[3] == src1->ne[3])) &&
446446
src1->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
447447
default:
448448
return true;

ggml/src/ggml-cpu/ops.cpp

Lines changed: 103 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4498,6 +4498,107 @@ static void ggml_compute_forward_out_prod_f32(
44984498
}
44994499
}
45004500

4501+
static void ggml_compute_forward_out_prod_f16_f32(
4502+
const ggml_compute_params * params,
4503+
ggml_tensor * dst) {
4504+
4505+
const ggml_tensor * src0 = dst->src[0];
4506+
const ggml_tensor * src1 = dst->src[1];
4507+
4508+
GGML_TENSOR_BINARY_OP_LOCALS
4509+
4510+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
4511+
GGML_ASSERT(src0->type == GGML_TYPE_F16);
4512+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
4513+
4514+
const int ith = params->ith;
4515+
const int nth = params->nth;
4516+
4517+
GGML_ASSERT(ne0 == ne00);
4518+
GGML_ASSERT(ne1 == ne10);
4519+
GGML_ASSERT(ne2 == ne12);
4520+
GGML_ASSERT(ne3 == ne13);
4521+
4522+
GGML_ASSERT(ne2 % ne02 == 0);
4523+
GGML_ASSERT(ne3 % ne03 == 0);
4524+
4525+
// we don't support permuted src0 or src1
4526+
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
4527+
4528+
// dst cannot be transposed or permuted
4529+
GGML_ASSERT(nb0 == sizeof(float));
4530+
// GGML_ASSERT(nb0 <= nb1);
4531+
// GGML_ASSERT(nb1 <= nb2);
4532+
// GGML_ASSERT(nb2 <= nb3);
4533+
4534+
// nb01 >= nb00 - src0 is not transposed
4535+
// compute by src0 rows
4536+
4537+
if (ith == 0) {
4538+
ggml_vec_set_f32(ne0*ne1*ne2*ne3, (float *)dst->data, 0);
4539+
}
4540+
ggml_barrier(params->threadpool);
4541+
4542+
// dst[:,:,:,:] = 0
4543+
// for i2,i3:
4544+
// for i1:
4545+
// for i01:
4546+
// for i0:
4547+
// dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
4548+
4549+
// parallelize by last three dimensions
4550+
4551+
// total rows in dst
4552+
const int64_t nr = ne1*ne2*ne3;
4553+
4554+
// rows per thread
4555+
const int64_t dr = (nr + nth - 1)/nth;
4556+
4557+
// row range for this thread
4558+
const int64_t ir0 = dr*ith;
4559+
const int64_t ir1 = MIN(ir0 + dr, nr);
4560+
4561+
// block-tiling attempt
4562+
const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32);
4563+
const int64_t blck_1 = 16;
4564+
4565+
// dps == dst per src0, used for group query attention
4566+
const int64_t dps2 = ne2 / ne02;
4567+
const int64_t dps3 = ne3 / ne03;
4568+
4569+
for (int64_t bir = ir0; bir < ir1; bir += blck_1) {
4570+
const int64_t bir1 = MIN(bir + blck_1, ir1);
4571+
for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) {
4572+
const int64_t bne01 = MIN(bi01 + blck_0, ne01);
4573+
for (int64_t ir = bir; ir < bir1; ++ir) {
4574+
// dst indices
4575+
const int64_t i3 = ir/(ne2*ne1);
4576+
const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
4577+
const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
4578+
4579+
const int64_t i02 = i2 / dps2;
4580+
const int64_t i03 = i3 / dps3;
4581+
4582+
//const int64_t i10 = i1;
4583+
const int64_t i12 = i2;
4584+
const int64_t i13 = i3;
4585+
4586+
for (int64_t i01 = bi01; i01 < bne01; ++i01) {
4587+
const int64_t i11 = i01;
4588+
4589+
ggml_fp16_t * s0 = (ggml_fp16_t *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
4590+
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
4591+
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
4592+
4593+
for (int i = 0; i < ne0; ++i) {
4594+
d[i] += GGML_CPU_FP16_TO_FP32(s0[i])*(*s1);
4595+
}
4596+
}
4597+
}
4598+
}
4599+
}
4600+
}
4601+
45014602
static void ggml_compute_forward_out_prod_q_f32(
45024603
const ggml_compute_params * params,
45034604
ggml_tensor * dst) {
@@ -4620,9 +4721,8 @@ void ggml_compute_forward_out_prod(
46204721
} break;
46214722
case GGML_TYPE_F16:
46224723
{
4623-
GGML_ABORT("fatal error"); // todo
4624-
// ggml_compute_forward_out_prod_f16_f32(params, dst);
4625-
}
4724+
ggml_compute_forward_out_prod_f16_f32(params, dst);
4725+
} break;
46264726
case GGML_TYPE_F32:
46274727
{
46284728
ggml_compute_forward_out_prod_f32(params, dst);

0 commit comments

Comments
 (0)