@@ -1090,36 +1090,6 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin(
10901090 return res;
10911091}
10921092
1093- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rms_norm (ggml_metal_library_t lib, const ggml_tensor * op, int32_t n_fuse) {
1094- assert (op->op == GGML_OP_RMS_NORM);
1095-
1096- GGML_ASSERT (op->src [0 ]->ne [0 ] % 4 == 0 );
1097- GGML_ASSERT (ggml_is_contiguous_rows (op->src [0 ]));
1098-
1099- char base[256 ];
1100- char name[256 ];
1101-
1102- switch (n_fuse) {
1103- case 1 : snprintf (base, 256 , " kernel_rms_norm_f32" ); break ;
1104- case 2 : snprintf (base, 256 , " kernel_rms_norm_mul_f32" ); break ;
1105- case 3 : snprintf (base, 256 , " kernel_rms_norm_mul_add_f32" ); break ;
1106- default : GGML_ABORT (" fatal error" );
1107- }
1108-
1109- snprintf (name, 256 , " %s" , base);
1110-
1111- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline (lib, name);
1112- if (res) {
1113- return res;
1114- }
1115-
1116- res = ggml_metal_library_compile_pipeline (lib, base, name, nullptr );
1117-
1118- ggml_metal_pipeline_set_smem (res, 32 *sizeof (float ));
1119-
1120- return res;
1121- }
1122-
11231093ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const ggml_tensor * op) {
11241094 assert (op->op == GGML_OP_L2_NORM);
11251095
@@ -1167,16 +1137,37 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm(ggml_metal_libr
11671137 return res;
11681138}
11691139
1170- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const ggml_tensor * op) {
1171- assert (op->op == GGML_OP_NORM);
1140+ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const ggml_tensor * op, int n_fuse ) {
1141+ assert (op->op == GGML_OP_NORM || op-> op == GGML_OP_RMS_NORM );
11721142
1173- GGML_ASSERT (op->src [0 ]->ne [0 ] % 4 == 0 );
1174- GGML_ASSERT (ggml_is_contiguous_1 (op->src [0 ]));
1143+ GGML_ASSERT (ggml_is_contiguous_rows (op->src [0 ]));
11751144
11761145 char base[256 ];
11771146 char name[256 ];
11781147
1179- snprintf (base, 256 , " kernel_norm_f32" );
1148+ const char * suffix = " " ;
1149+ if (op->ne [0 ] % 4 == 0 ) {
1150+ suffix = " _4" ;
1151+ }
1152+
1153+ switch (op->op ) {
1154+ case GGML_OP_NORM:
1155+ switch (n_fuse) {
1156+ case 1 : snprintf (base, 256 , " kernel_norm_f32%s" , suffix); break ;
1157+ case 2 : snprintf (base, 256 , " kernel_norm_mul_f32%s" , suffix); break ;
1158+ case 3 : snprintf (base, 256 , " kernel_norm_mul_add_f32%s" , suffix); break ;
1159+ default : GGML_ABORT (" fatal error" );
1160+ } break ;
1161+ case GGML_OP_RMS_NORM:
1162+ switch (n_fuse) {
1163+ case 1 : snprintf (base, 256 , " kernel_rms_norm_f32%s" , suffix); break ;
1164+ case 2 : snprintf (base, 256 , " kernel_rms_norm_mul_f32%s" , suffix); break ;
1165+ case 3 : snprintf (base, 256 , " kernel_rms_norm_mul_add_f32%s" , suffix); break ;
1166+ default : GGML_ABORT (" fatal error" );
1167+ } break ;
1168+ default : GGML_ABORT (" fatal error" );
1169+ }
1170+
11801171 snprintf (name, 256 , " %s" , base);
11811172
11821173 ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline (lib, name);
0 commit comments