@@ -250,6 +250,8 @@ - (void) dealloc {
250250 GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
251251 GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
252252 GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4,
253+ GGML_METAL_KERNEL_TYPE_SOFT_MAX_BACK,
254+ GGML_METAL_KERNEL_TYPE_SOFT_MAX_BACK_4,
253255 GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
254256 GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
255257 GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
@@ -1183,6 +1185,8 @@ @implementation GGMLMetalClass
11831185 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction);
11841186 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction);
11851187 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, has_simdgroup_reduction);
1188+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SOFT_MAX_BACK, soft_max_back, has_simdgroup_reduction);
1189+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SOFT_MAX_BACK_4, soft_max_back_4, has_simdgroup_reduction);
11861190 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true );
11871191 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true );
11881192 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true );
@@ -1935,6 +1939,27 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
19351939 case GGML_OP_SOFT_MAX:
19361940 case GGML_OP_GROUP_NORM:
19371941 return has_simdgroup_reduction && ggml_is_contiguous_rows (op->src [0 ]);
1942+ case GGML_OP_SOFT_MAX_BACK:
1943+ if (!has_simdgroup_reduction ||
1944+ op->type != GGML_TYPE_F32 ||
1945+ op->src [0 ] == NULL || op->src [1 ] == NULL ||
1946+ op->src [0 ]->type != GGML_TYPE_F32 ||
1947+ op->src [1 ]->type != GGML_TYPE_F32 ||
1948+ !ggml_is_contiguous_1 (op->src [0 ]) ||
1949+ !ggml_is_contiguous_1 (op->src [1 ]) ||
1950+ !ggml_is_contiguous_1 (op) ||
1951+ !ggml_are_same_shape (op, op->src [0 ]) ||
1952+ !ggml_are_same_shape (op, op->src [1 ])) {
1953+ return false ;
1954+ }
1955+
1956+ float max_bias = 0 .0f ;
1957+ memcpy (&max_bias, ((const float *) op->op_params ) + 1 , sizeof (float ));
1958+ if (max_bias != 0 .0f ) {
1959+ return false ;
1960+ }
1961+
1962+ return true ;
19381963 case GGML_OP_RMS_NORM:
19391964 case GGML_OP_L2_NORM:
19401965 return has_simdgroup_reduction && (op->ne [0 ] % 4 == 0 && ggml_is_contiguous_1 (op->src [0 ]));
@@ -1955,6 +1980,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
19551980 case GGML_OP_NORM:
19561981 return has_simdgroup_reduction && (op->ne [0 ] % 4 == 0 && ggml_is_contiguous_1 (op->src [0 ]));
19571982 case GGML_OP_ROPE:
1983+ case GGML_OP_ROPE_BACK:
19581984 return true ;
19591985 case GGML_OP_IM2COL:
19601986 return ggml_is_contiguous (op->src [1 ]) && op->src [1 ]->type == GGML_TYPE_F32 && (op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32);
@@ -3295,6 +3321,76 @@ static int ggml_metal_encode_node(
32953321
32963322 [encoder setThreadgroupMemoryLength: 32 *sizeof (float ) atIndex: 0 ];
32973323
3324+ [encoder dispatchThreadgroups: MTLSizeMake (ne01, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
3325+ } break ;
3326+ case GGML_OP_SOFT_MAX_BACK:
3327+ {
3328+ GGML_ASSERT (src0 != NULL );
3329+ GGML_ASSERT (src1 != NULL );
3330+ GGML_ASSERT (dstt == GGML_TYPE_F32);
3331+ GGML_ASSERT (src0t == GGML_TYPE_F32);
3332+ GGML_ASSERT (src1t == GGML_TYPE_F32);
3333+ GGML_ASSERT (ggml_are_same_shape (dst, src0));
3334+ GGML_ASSERT (ggml_are_same_shape (dst, src1));
3335+ GGML_ASSERT (ggml_is_contiguous_1 (src0));
3336+ GGML_ASSERT (ggml_is_contiguous_1 (src1));
3337+ GGML_ASSERT (ggml_is_contiguous_1 (dst));
3338+
3339+ float scale = 1 .0f ;
3340+ float max_bias = 0 .0f ;
3341+
3342+ memcpy (&scale, ((const int32_t *) dst->op_params ) + 0 , sizeof (scale));
3343+ memcpy (&max_bias, ((const int32_t *) dst->op_params ) + 1 , sizeof (max_bias));
3344+
3345+ GGML_ASSERT (max_bias == 0 .0f );
3346+
3347+ const bool use_vec4 = (ne00 % 4 ) == 0 ;
3348+
3349+ id <MTLComputePipelineState > pipeline = use_vec4 ?
3350+ ctx->kernels [GGML_METAL_KERNEL_TYPE_SOFT_MAX_BACK_4].pipeline :
3351+ ctx->kernels [GGML_METAL_KERNEL_TYPE_SOFT_MAX_BACK ].pipeline ;
3352+
3353+ int nth = 32 ; // SIMD width
3354+
3355+ if (use_vec4) {
3356+ const int ne00_4 = ne00/4 ;
3357+ while (nth < ne00_4 && nth < (int ) pipeline.maxTotalThreadsPerThreadgroup ) {
3358+ nth *= 2 ;
3359+ }
3360+ nth = MIN (nth, ne00_4);
3361+ } else {
3362+ while (nth < ne00 && nth < (int ) pipeline.maxTotalThreadsPerThreadgroup ) {
3363+ nth *= 2 ;
3364+ }
3365+ nth = MIN (nth, ne00);
3366+ }
3367+
3368+ nth = MAX (1 , nth);
3369+ nth = MIN (nth, (int ) pipeline.maxTotalThreadsPerThreadgroup );
3370+
3371+ ggml_metal_kargs_soft_max_back args = {
3372+ /* .ne00 =*/ ne00,
3373+ /* .ne00_4 =*/ ne00/4 ,
3374+ /* .nb01 =*/ nb01,
3375+ /* .nb02 =*/ nb02,
3376+ /* .nb03 =*/ nb03,
3377+ /* .nb11 =*/ nb11,
3378+ /* .nb12 =*/ nb12,
3379+ /* .nb13 =*/ nb13,
3380+ /* .nb1 =*/ nb1,
3381+ /* .nb2 =*/ nb2,
3382+ /* .nb3 =*/ nb3,
3383+ /* .scale =*/ scale,
3384+ };
3385+
3386+ [encoder setComputePipelineState: pipeline];
3387+ [encoder setBytes: &args length: sizeof (args) atIndex: 0 ];
3388+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
3389+ [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 2 ];
3390+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 3 ];
3391+
3392+ [encoder setThreadgroupMemoryLength: 32 *sizeof (float ) atIndex: 0 ];
3393+
32983394 [encoder dispatchThreadgroups: MTLSizeMake (ne01, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
32993395 } break ;
33003396 case GGML_OP_DIAG_MASK_INF:
@@ -4854,7 +4950,9 @@ static int ggml_metal_encode_node(
48544950 [encoder dispatchThreadgroups: MTLSizeMake (nrows, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
48554951 } break ;
48564952 case GGML_OP_ROPE:
4953+ case GGML_OP_ROPE_BACK:
48574954 {
4955+ const bool is_backward = dst->op == GGML_OP_ROPE_BACK;
48584956
48594957 // make sure we have one or more position id(ne10) per token(ne02)
48604958 GGML_ASSERT (ne10 % ne02 == 0 );
@@ -4892,6 +4990,8 @@ static int ggml_metal_encode_node(
48924990 const int sect_2 = ((const int32_t *) dst->op_params )[13 ];
48934991 const int sect_3 = ((const int32_t *) dst->op_params )[14 ];
48944992
4993+ const float sin_sign = is_backward ? -1 .0f : 1 .0f ;
4994+
48954995 id <MTLComputePipelineState > pipeline = nil ;
48964996
48974997 if (is_neox) {
@@ -4952,6 +5052,7 @@ static int ggml_metal_encode_node(
49525052 /* sect_1 =*/ sect_1,
49535053 /* sect_2 =*/ sect_2,
49545054 /* sect_3 =*/ sect_3,
5055+ /* sin_sign =*/ sin_sign,
49555056 };
49565057
49575058 [encoder setComputePipelineState: pipeline];
0 commit comments