Skip to content

Commit e38eb44

Browse files
committed
Add SOFT_MAX_BACK metal kernel.
Signed-off-by: Marcus Edel <[email protected]>
1 parent e8a84f6 commit e38eb44

File tree

3 files changed

+234
-8
lines changed

3 files changed

+234
-8
lines changed

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ typedef struct {
256256
int32_t sect_1;
257257
int32_t sect_2;
258258
int32_t sect_3;
259+
float sin_sign;
259260
} ggml_metal_kargs_rope;
260261

261262
typedef struct {
@@ -589,6 +590,21 @@ typedef struct {
589590
int32_t n_head_log2;
590591
} ggml_metal_kargs_soft_max;
591592

593+
typedef struct {
594+
int32_t ne00;
595+
int32_t ne00_4;
596+
uint64_t nb01;
597+
uint64_t nb02;
598+
uint64_t nb03;
599+
uint64_t nb11;
600+
uint64_t nb12;
601+
uint64_t nb13;
602+
uint64_t nb1;
603+
uint64_t nb2;
604+
uint64_t nb3;
605+
float scale;
606+
} ggml_metal_kargs_soft_max_back;
607+
592608
typedef struct {
593609
int64_t ne00;
594610
int64_t ne01;

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,8 @@ - (void) dealloc {
250250
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
251251
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
252252
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4,
253+
GGML_METAL_KERNEL_TYPE_SOFT_MAX_BACK,
254+
GGML_METAL_KERNEL_TYPE_SOFT_MAX_BACK_4,
253255
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
254256
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
255257
GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
@@ -1183,6 +1185,8 @@ @implementation GGMLMetalClass
11831185
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction);
11841186
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction);
11851187
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, has_simdgroup_reduction);
1188+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_BACK, soft_max_back, has_simdgroup_reduction);
1189+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_BACK_4, soft_max_back_4, has_simdgroup_reduction);
11861190
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
11871191
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
11881192
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
@@ -1935,6 +1939,27 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
19351939
case GGML_OP_SOFT_MAX:
19361940
case GGML_OP_GROUP_NORM:
19371941
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
1942+
case GGML_OP_SOFT_MAX_BACK:
1943+
if (!has_simdgroup_reduction ||
1944+
op->type != GGML_TYPE_F32 ||
1945+
op->src[0] == NULL || op->src[1] == NULL ||
1946+
op->src[0]->type != GGML_TYPE_F32 ||
1947+
op->src[1]->type != GGML_TYPE_F32 ||
1948+
!ggml_is_contiguous_1(op->src[0]) ||
1949+
!ggml_is_contiguous_1(op->src[1]) ||
1950+
!ggml_is_contiguous_1(op) ||
1951+
!ggml_are_same_shape(op, op->src[0]) ||
1952+
!ggml_are_same_shape(op, op->src[1])) {
1953+
return false;
1954+
}
1955+
1956+
float max_bias = 0.0f;
1957+
memcpy(&max_bias, ((const float *) op->op_params) + 1, sizeof(float));
1958+
if (max_bias != 0.0f) {
1959+
return false;
1960+
}
1961+
1962+
return true;
19381963
case GGML_OP_RMS_NORM:
19391964
case GGML_OP_L2_NORM:
19401965
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
@@ -1955,6 +1980,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
19551980
case GGML_OP_NORM:
19561981
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
19571982
case GGML_OP_ROPE:
1983+
case GGML_OP_ROPE_BACK:
19581984
return true;
19591985
case GGML_OP_IM2COL:
19601986
return ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_F32 && (op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32);
@@ -3295,6 +3321,76 @@ static int ggml_metal_encode_node(
32953321

32963322
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
32973323

3324+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
3325+
} break;
3326+
case GGML_OP_SOFT_MAX_BACK:
3327+
{
3328+
GGML_ASSERT(src0 != NULL);
3329+
GGML_ASSERT(src1 != NULL);
3330+
GGML_ASSERT(dstt == GGML_TYPE_F32);
3331+
GGML_ASSERT(src0t == GGML_TYPE_F32);
3332+
GGML_ASSERT(src1t == GGML_TYPE_F32);
3333+
GGML_ASSERT(ggml_are_same_shape(dst, src0));
3334+
GGML_ASSERT(ggml_are_same_shape(dst, src1));
3335+
GGML_ASSERT(ggml_is_contiguous_1(src0));
3336+
GGML_ASSERT(ggml_is_contiguous_1(src1));
3337+
GGML_ASSERT(ggml_is_contiguous_1(dst));
3338+
3339+
float scale = 1.0f;
3340+
float max_bias = 0.0f;
3341+
3342+
memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale));
3343+
memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias));
3344+
3345+
GGML_ASSERT(max_bias == 0.0f);
3346+
3347+
const bool use_vec4 = (ne00 % 4) == 0;
3348+
3349+
id<MTLComputePipelineState> pipeline = use_vec4 ?
3350+
ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_BACK_4].pipeline :
3351+
ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_BACK ].pipeline;
3352+
3353+
int nth = 32; // SIMD width
3354+
3355+
if (use_vec4) {
3356+
const int ne00_4 = ne00/4;
3357+
while (nth < ne00_4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
3358+
nth *= 2;
3359+
}
3360+
nth = MIN(nth, ne00_4);
3361+
} else {
3362+
while (nth < ne00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
3363+
nth *= 2;
3364+
}
3365+
nth = MIN(nth, ne00);
3366+
}
3367+
3368+
nth = MAX(1, nth);
3369+
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
3370+
3371+
ggml_metal_kargs_soft_max_back args = {
3372+
/*.ne00 =*/ ne00,
3373+
/*.ne00_4 =*/ ne00/4,
3374+
/*.nb01 =*/ nb01,
3375+
/*.nb02 =*/ nb02,
3376+
/*.nb03 =*/ nb03,
3377+
/*.nb11 =*/ nb11,
3378+
/*.nb12 =*/ nb12,
3379+
/*.nb13 =*/ nb13,
3380+
/*.nb1 =*/ nb1,
3381+
/*.nb2 =*/ nb2,
3382+
/*.nb3 =*/ nb3,
3383+
/*.scale =*/ scale,
3384+
};
3385+
3386+
[encoder setComputePipelineState:pipeline];
3387+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
3388+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
3389+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
3390+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
3391+
3392+
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
3393+
32983394
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
32993395
} break;
33003396
case GGML_OP_DIAG_MASK_INF:
@@ -4854,7 +4950,9 @@ static int ggml_metal_encode_node(
48544950
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
48554951
} break;
48564952
case GGML_OP_ROPE:
4953+
case GGML_OP_ROPE_BACK:
48574954
{
4955+
const bool is_backward = dst->op == GGML_OP_ROPE_BACK;
48584956

48594957
// make sure we have one or more position id(ne10) per token(ne02)
48604958
GGML_ASSERT(ne10 % ne02 == 0);
@@ -4892,6 +4990,8 @@ static int ggml_metal_encode_node(
48924990
const int sect_2 = ((const int32_t *) dst->op_params)[13];
48934991
const int sect_3 = ((const int32_t *) dst->op_params)[14];
48944992

4993+
const float sin_sign = is_backward ? -1.0f : 1.0f;
4994+
48954995
id<MTLComputePipelineState> pipeline = nil;
48964996

48974997
if (is_neox) {
@@ -4952,6 +5052,7 @@ static int ggml_metal_encode_node(
49525052
/* sect_1 =*/ sect_1,
49535053
/* sect_2 =*/ sect_2,
49545054
/* sect_3 =*/ sect_3,
5055+
/* sin_sign =*/ sin_sign,
49555056
};
49565057

49575058
[encoder setComputePipelineState:pipeline];

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 117 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2049,6 +2049,107 @@ template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kerne
20492049
template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<half4>;
20502050
template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<float4>;
20512051

2052+
[[host_name("kernel_soft_max_back")]]
2053+
kernel void kernel_soft_max_back(
2054+
constant ggml_metal_kargs_soft_max_back & args,
2055+
device const char * src0,
2056+
device const char * src1,
2057+
device char * dst,
2058+
threadgroup float * shmem_f32 [[threadgroup(0)]],
2059+
uint3 tgpig[[threadgroup_position_in_grid]],
2060+
ushort3 tpitg[[thread_position_in_threadgroup]],
2061+
ushort sgitg[[simdgroup_index_in_threadgroup]],
2062+
ushort tiisg[[thread_index_in_simdgroup]],
2063+
ushort3 ntg[[threads_per_threadgroup]]) {
2064+
if (sgitg == 0) {
2065+
shmem_f32[tiisg] = 0.0f;
2066+
}
2067+
2068+
const int i01 = tgpig.x;
2069+
const int i02 = tgpig.y;
2070+
const int i03 = tgpig.z;
2071+
2072+
device const float * dy = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
2073+
device const float * y = (device const float *) (src1 + i03*args.nb13 + i02*args.nb12 + i01*args.nb11);
2074+
device float * dx = (device float *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
2075+
2076+
float sum = 0.0f;
2077+
for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
2078+
sum += dy[i00] * y[i00];
2079+
}
2080+
2081+
sum = simd_sum(sum);
2082+
2083+
threadgroup_barrier(mem_flags::mem_threadgroup);
2084+
2085+
if (tiisg == 0) {
2086+
shmem_f32[sgitg] = sum;
2087+
}
2088+
2089+
threadgroup_barrier(mem_flags::mem_threadgroup);
2090+
2091+
sum = shmem_f32[tiisg];
2092+
sum = simd_sum(sum);
2093+
2094+
const float scale = args.scale;
2095+
2096+
for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
2097+
dx[i00] = (dy[i00] - sum) * y[i00] * scale;
2098+
}
2099+
}
2100+
2101+
[[host_name("kernel_soft_max_back_4")]]
2102+
kernel void kernel_soft_max_back_4(
2103+
constant ggml_metal_kargs_soft_max_back & args,
2104+
device const char * src0,
2105+
device const char * src1,
2106+
device char * dst,
2107+
threadgroup float * shmem_f32 [[threadgroup(0)]],
2108+
uint3 tgpig[[threadgroup_position_in_grid]],
2109+
ushort3 tpitg[[thread_position_in_threadgroup]],
2110+
ushort sgitg[[simdgroup_index_in_threadgroup]],
2111+
ushort tiisg[[thread_index_in_simdgroup]],
2112+
ushort3 ntg[[threads_per_threadgroup]]) {
2113+
if (sgitg == 0) {
2114+
shmem_f32[tiisg] = 0.0f;
2115+
}
2116+
2117+
const int i01 = tgpig.x;
2118+
const int i02 = tgpig.y;
2119+
const int i03 = tgpig.z;
2120+
2121+
device const float4 * dy = (device const float4 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
2122+
device const float4 * y = (device const float4 *) (src1 + i03*args.nb13 + i02*args.nb12 + i01*args.nb11);
2123+
device float4 * dx = (device float4 *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
2124+
2125+
float sum = 0.0f;
2126+
for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) {
2127+
sum += dot(dy[i00], y[i00]);
2128+
}
2129+
2130+
sum = simd_sum(sum);
2131+
2132+
threadgroup_barrier(mem_flags::mem_threadgroup);
2133+
2134+
if (tiisg == 0) {
2135+
shmem_f32[sgitg] = sum;
2136+
}
2137+
2138+
threadgroup_barrier(mem_flags::mem_threadgroup);
2139+
2140+
sum = shmem_f32[tiisg];
2141+
sum = simd_sum(sum);
2142+
2143+
const float scale = args.scale;
2144+
const float4 sum4 = float4(sum);
2145+
2146+
for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) {
2147+
const float4 dy4 = dy[i00];
2148+
const float4 y4 = y[i00];
2149+
dx[i00] = (dy4 - sum4) * y4 * scale;
2150+
}
2151+
}
2152+
20522153
kernel void kernel_diag_mask_inf(
20532154
device const float * src0,
20542155
device float * dst,
@@ -3908,8 +4009,10 @@ kernel void kernel_rope_norm(
39084009
const float x0 = src[0];
39094010
const float x1 = src[1];
39104011

3911-
dst_data[0] = x0*cos_theta - x1*sin_theta;
3912-
dst_data[1] = x0*sin_theta + x1*cos_theta;
4012+
const float sin_theta_mod = sin_theta * args.sin_sign;
4013+
4014+
dst_data[0] = x0*cos_theta - x1*sin_theta_mod;
4015+
dst_data[1] = x0*sin_theta_mod + x1*cos_theta;
39134016
} else {
39144017
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
39154018
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
@@ -3961,8 +4064,10 @@ kernel void kernel_rope_neox(
39614064
const float x0 = src[0];
39624065
const float x1 = src[args.n_dims/2];
39634066

3964-
dst_data[0] = x0*cos_theta - x1*sin_theta;
3965-
dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;
4067+
const float sin_theta_mod = sin_theta * args.sin_sign;
4068+
4069+
dst_data[0] = x0*cos_theta - x1*sin_theta_mod;
4070+
dst_data[args.n_dims/2] = x0*sin_theta_mod + x1*cos_theta;
39664071
} else {
39674072
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
39684073
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
@@ -4032,8 +4137,10 @@ kernel void kernel_rope_multi(
40324137
const float x0 = src[0];
40334138
const float x1 = src[args.n_dims/2];
40344139

4035-
dst_data[0] = x0*cos_theta - x1*sin_theta;
4036-
dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;
4140+
const float sin_theta_mod = sin_theta * args.sin_sign;
4141+
4142+
dst_data[0] = x0*cos_theta - x1*sin_theta_mod;
4143+
dst_data[args.n_dims/2] = x0*sin_theta_mod + x1*cos_theta;
40374144
} else {
40384145
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
40394146
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
@@ -4099,8 +4206,10 @@ kernel void kernel_rope_vision(
40994206
const float x0 = src[0];
41004207
const float x1 = src[args.n_dims]; // different from kernel_rope_multi
41014208

4102-
dst_data[0] = x0*cos_theta - x1*sin_theta;
4103-
dst_data[args.n_dims] = x0*sin_theta + x1*cos_theta; // different from kernel_rope_multi
4209+
const float sin_theta_mod = sin_theta * args.sin_sign;
4210+
4211+
dst_data[0] = x0*cos_theta - x1*sin_theta_mod;
4212+
dst_data[args.n_dims] = x0*sin_theta_mod + x1*cos_theta; // different from kernel_rope_multi
41044213
} else {
41054214
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
41064215
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);

0 commit comments

Comments
 (0)