@@ -481,6 +481,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
481481 GGML_METAL_KERNEL_TYPE_SQRT,
482482 GGML_METAL_KERNEL_TYPE_SIN,
483483 GGML_METAL_KERNEL_TYPE_COS,
484+ GGML_METAL_KERNEL_TYPE_NEG,
484485 GGML_METAL_KERNEL_TYPE_SUM_ROWS,
485486 GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
486487 GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
@@ -1159,6 +1160,7 @@ @implementation GGMLMetalClass
11591160 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true );
11601161 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SIN, sin, true );
11611162 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_COS, cos, true );
1163+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_NEG, neg, true );
11621164 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true );
11631165 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true );
11641166 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true );
@@ -1320,6 +1322,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
13201322 case GGML_UNARY_OP_GELU_QUICK:
13211323 case GGML_UNARY_OP_SILU:
13221324 case GGML_UNARY_OP_ELU:
1325+ case GGML_UNARY_OP_NEG:
13231326 return ggml_is_contiguous (op->src [0 ]) && op->src [0 ]->type == GGML_TYPE_F32;
13241327 default :
13251328 return false ;
@@ -2010,6 +2013,18 @@ static void ggml_metal_encode_node(
20102013
20112014 [encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
20122015 } break ;
2016+ case GGML_UNARY_OP_NEG:
2017+ {
2018+ id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_NEG].pipeline ;
2019+
2020+ [encoder setComputePipelineState: pipeline];
2021+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
2022+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
2023+
2024+ const int64_t n = ggml_nelements (dst);
2025+
2026+ [encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
2027+ } break ;
20132028 default :
20142029 {
20152030 GGML_LOG_WARN (" %s : node %3d , op = %8s not implemented\n " , __func__, idx, ggml_op_name (dst->op ));
0 commit comments