Skip to content

Commit 5b180c3

Browse files
authored
metal : initial Metal4 tensor API support (ggml-org#16634)
* metal : rework mat-mat multiplication * metal : initial Metal4 support * cont * metal : detect tensor support * cont : better ifdefs * metal : support tensors in mul_mm_id * metal : add env for disabling tensor API * tests : restore * metal : remove unused constants * metal : fix check for bfloat tensor support * cont : handle API incompatibilities * cont : handle even more incompatibilities * metal : use tensor API only on M5 and later
1 parent b7f9010 commit 5b180c3

File tree

4 files changed

+606
-136
lines changed

4 files changed

+606
-136
lines changed

ggml/src/ggml-metal/ggml-metal-context.m

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
// additional, inference-time compiled pipelines
3636
ggml_metal_pipelines_t pipelines_ext;
3737

38-
bool use_bfloat;
3938
bool use_fusion;
4039
bool use_concurrency;
4140
bool use_graph_optimize;
@@ -121,11 +120,10 @@ ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) {
121120
}
122121
}
123122

124-
const struct ggml_metal_device_props * props_dev = ggml_metal_device_get_props(dev);
123+
//const struct ggml_metal_device_props * props_dev = ggml_metal_device_get_props(dev);
125124

126125
res->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
127126

128-
res->use_bfloat = props_dev->has_bfloat;
129127
res->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil;
130128
res->use_concurrency = getenv("GGML_METAL_CONCURRENCY_DISABLE") == nil;
131129

@@ -147,7 +145,6 @@ ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) {
147145

148146
memset(res->fuse_cnt, 0, sizeof(res->fuse_cnt));
149147

150-
GGML_LOG_INFO("%s: use bfloat = %s\n", __func__, res->use_bfloat ? "true" : "false");
151148
GGML_LOG_INFO("%s: use fusion = %s\n", __func__, res->use_fusion ? "true" : "false");
152149
GGML_LOG_INFO("%s: use concurrency = %s\n", __func__, res->use_concurrency ? "true" : "false");
153150
GGML_LOG_INFO("%s: use graph optimize = %s\n", __func__, res->use_graph_optimize ? "true" : "false");

ggml/src/ggml-metal/ggml-metal-device.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,9 @@ void ggml_metal_encoder_end_encoding(ggml_metal_encoder_t encoder);
9595

9696
typedef struct ggml_metal_library * ggml_metal_library_t;
9797

98-
ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev);
98+
ggml_metal_library_t ggml_metal_library_init (ggml_metal_device_t dev);
99+
ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev, const char * source, bool verbose);
100+
99101
void ggml_metal_library_free(ggml_metal_library_t lib);
100102

101103
ggml_metal_pipeline_t ggml_metal_library_get_pipeline (ggml_metal_library_t lib, const char * name);
@@ -193,6 +195,7 @@ struct ggml_metal_device_props {
193195
bool has_simdgroup_mm;
194196
bool has_unified_memory;
195197
bool has_bfloat;
198+
bool has_tensor;
196199
bool use_residency_sets;
197200
bool use_shared_buffers;
198201

ggml/src/ggml-metal/ggml-metal-device.m

Lines changed: 205 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@
2121
#define GGML_METAL_HAS_RESIDENCY_SETS 1
2222
#endif
2323

24-
// overload of MTLGPUFamilyMetal3 (not available in some environments)
24+
// overload of MTLGPUFamilyMetalX (not available in some environments)
2525
static const NSInteger MTLGPUFamilyMetal3_GGML = 5001;
26+
static const NSInteger MTLGPUFamilyMetal4_GGML = 5002;
2627

2728
// virtual address for GPU memory allocations
2829
static atomic_uintptr_t g_addr_device = 0x000000400ULL;
@@ -261,6 +262,10 @@ ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) {
261262
[prep setObject:@"1" forKey:@"GGML_METAL_HAS_BF16"];
262263
}
263264

265+
if (ggml_metal_device_get_props(dev)->has_tensor) {
266+
[prep setObject:@"1" forKey:@"GGML_METAL_HAS_TENSOR"];
267+
}
268+
264269
#if GGML_METAL_EMBED_LIBRARY
265270
[prep setObject:@"1" forKey:@"GGML_METAL_EMBED_LIBRARY"];
266271
#endif
@@ -298,6 +303,72 @@ ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) {
298303
return res;
299304
}
300305

306+
ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev, const char * source, bool verbose) {
307+
if (source == NULL) {
308+
GGML_LOG_ERROR("%s: source is NULL\n", __func__);
309+
return NULL;
310+
}
311+
312+
id<MTLDevice> device = ggml_metal_device_get_obj(dev);
313+
id<MTLLibrary> library = nil;
314+
NSError * error = nil;
315+
316+
const int64_t t_start = ggml_time_us();
317+
318+
NSString * src = [[NSString alloc] initWithBytes:source
319+
length:strlen(source)
320+
encoding:NSUTF8StringEncoding];
321+
if (!src) {
322+
GGML_LOG_ERROR("%s: failed to create NSString from source\n", __func__);
323+
return NULL;
324+
}
325+
326+
@autoreleasepool {
327+
NSMutableDictionary * prep = [NSMutableDictionary dictionary];
328+
329+
MTLCompileOptions * options = [MTLCompileOptions new];
330+
options.preprocessorMacros = prep;
331+
332+
library = [device newLibraryWithSource:src options:options error:&error];
333+
if (error) {
334+
if (verbose) {
335+
GGML_LOG_ERROR("%s: error compiling source: %s\n", __func__, [[error description] UTF8String]);
336+
} else {
337+
GGML_LOG_ERROR("%s: error compiling source\n", __func__);
338+
}
339+
library = nil;
340+
}
341+
342+
[options release];
343+
}
344+
345+
[src release];
346+
347+
if (!library) {
348+
if (verbose) {
349+
GGML_LOG_ERROR("%s: failed to create Metal library from source\n", __func__);
350+
}
351+
352+
return NULL;
353+
}
354+
355+
if (verbose) {
356+
GGML_LOG_INFO("%s: compiled in %.3f sec\n", __func__, (ggml_time_us() - t_start) / 1e6);
357+
}
358+
359+
ggml_metal_library_t res = calloc(1, sizeof(struct ggml_metal_library));
360+
if (!res) {
361+
GGML_LOG_ERROR("%s: calloc failed\n", __func__);
362+
return NULL;
363+
}
364+
365+
res->obj = library;
366+
res->device = device;
367+
res->pipelines = ggml_metal_pipelines_init();
368+
369+
return res;
370+
}
371+
301372
void ggml_metal_library_free(ggml_metal_library_t lib) {
302373
if (!lib) {
303374
return;
@@ -345,23 +416,31 @@ ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t l
345416
if (!mtl_function) {
346417
ggml_critical_section_end();
347418

348-
GGML_LOG_ERROR("%s: error: failed to compile pipeline: base = '%s', name = '%s'\n", __func__, base, name);
419+
GGML_LOG_ERROR("%s: failed to compile pipeline: base = '%s', name = '%s'\n", __func__, base, name);
349420
if (error) {
350-
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
421+
GGML_LOG_ERROR("%s: %s\n", __func__, [[error description] UTF8String]);
351422
}
352423

353424
return nil;
354425
}
355426

356427
res->obj = [lib->device newComputePipelineStateWithFunction:mtl_function error:&error];
357428

358-
ggml_metal_pipelines_add(lib->pipelines, name, res);
359-
360429
[mtl_function release];
361430

362431
GGML_LOG_DEBUG("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, name, (void *) res->obj,
363432
(int) res->obj.maxTotalThreadsPerThreadgroup,
364433
(int) res->obj.threadExecutionWidth);
434+
435+
if (res->obj.maxTotalThreadsPerThreadgroup == 0 || res->obj.threadExecutionWidth == 0) {
436+
ggml_critical_section_end();
437+
438+
GGML_LOG_ERROR("%s: incompatible pipeline %s\n", __func__, name);
439+
440+
return nil;
441+
}
442+
443+
ggml_metal_pipelines_add(lib->pipelines, name, res);
365444
}
366445

367446
ggml_critical_section_end();
@@ -469,14 +548,133 @@ ggml_metal_device_t ggml_metal_device_init(void) {
469548

470549
dev->props.has_bfloat = [dev->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
471550
dev->props.has_bfloat |= [dev->mtl_device supportsFamily:MTLGPUFamilyApple6];
551+
if (getenv("GGML_METAL_BF16_DISABLE") != NULL) {
552+
dev->props.has_bfloat = false;
553+
}
554+
555+
dev->props.has_tensor = [dev->mtl_device supportsFamily:MTLGPUFamilyMetal4_GGML];
556+
if (getenv("GGML_METAL_TENSOR_DISABLE") != NULL) {
557+
dev->props.has_tensor = false;
558+
}
559+
560+
// note: disable the tensor API by default for old chips because with the current implementation it is not useful
561+
// - M2 Ultra: ~5% slower
562+
// - M4, M4 Max: no significant difference
563+
//
564+
// TODO: try to update the tensor API kernels to at least match the simdgroup performance
565+
if (getenv("GGML_METAL_TENSOR_ENABLE") == NULL &&
566+
![[dev->mtl_device name] containsString:@"M5"] &&
567+
![[dev->mtl_device name] containsString:@"M6"]) {
568+
GGML_LOG_WARN("%s: tensor API disabled for pre-M5 device\n", __func__);
569+
dev->props.has_tensor = false;
570+
}
571+
572+
// double-check that the tensor API compiles
573+
if (dev->props.has_tensor) {
574+
const char * src_tensor_f16 = "\n"
575+
"#include <metal_stdlib> \n"
576+
"#include <metal_tensor> \n"
577+
"#include <MetalPerformancePrimitives/MetalPerformancePrimitives.h> \n"
578+
" \n"
579+
"using namespace metal; \n"
580+
"using namespace mpp::tensor_ops; \n"
581+
" \n"
582+
"kernel void dummy_kernel( \n"
583+
" tensor<device half, dextents<int32_t, 2>> A [[buffer(0)]], \n"
584+
" tensor<device half, dextents<int32_t, 2>> B [[buffer(1)]], \n"
585+
" device float * C [[buffer(2)]], \n"
586+
" uint2 tgid [[threadgroup_position_in_grid]]) \n"
587+
"{ \n"
588+
" auto tA = A.slice(0, (int)tgid.y); \n"
589+
" auto tB = B.slice((int)tgid.x, 0); \n"
590+
" \n"
591+
" matmul2d< \n"
592+
" matmul2d_descriptor(8, 8, dynamic_extent), \n"
593+
" execution_simdgroups<4>> mm; \n"
594+
" \n"
595+
" auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>(); \n"
596+
" \n"
597+
" auto sA = tA.slice(0, 0); \n"
598+
" auto sB = tB.slice(0, 0); \n"
599+
" mm.run(sB, sA, cT); \n"
600+
" \n"
601+
" auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(4, 4)); \n"
602+
" \n"
603+
" cT.store(tC); \n"
604+
"}";
605+
606+
GGML_LOG_INFO("%s: testing tensor API for f16 support\n", __func__);
607+
ggml_metal_library_t lib = ggml_metal_library_init_from_source(dev, src_tensor_f16, false);
608+
if (lib == NULL) {
609+
GGML_LOG_WARN("%s: - the tensor API is not supported in this environment - disabling\n", __func__);
610+
dev->props.has_tensor = false;
611+
} else {
612+
ggml_metal_pipeline_t ppl = ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil);
613+
if (!ppl) {
614+
GGML_LOG_WARN("%s: - the tensor API is not supported in this environment - disabling\n", __func__);
615+
dev->props.has_tensor = false;
616+
}
617+
618+
ggml_metal_library_free(lib);
619+
}
620+
}
621+
622+
// try to compile a dummy kernel to determine if the tensor API is supported for bfloat
623+
if (dev->props.has_tensor && dev->props.has_bfloat) {
624+
const char * src_tensor_bf16 = "\n"
625+
"#include <metal_stdlib> \n"
626+
"#include <metal_tensor> \n"
627+
"#include <MetalPerformancePrimitives/MetalPerformancePrimitives.h> \n"
628+
" \n"
629+
"using namespace metal; \n"
630+
"using namespace mpp::tensor_ops; \n"
631+
" \n"
632+
"kernel void dummy_kernel( \n"
633+
" tensor<device bfloat, dextents<int32_t, 2>> A [[buffer(0)]], \n"
634+
" tensor<device bfloat, dextents<int32_t, 2>> B [[buffer(1)]], \n"
635+
" device float * C [[buffer(2)]], \n"
636+
" uint2 tgid [[threadgroup_position_in_grid]]) \n"
637+
"{ \n"
638+
" auto tA = A.slice(0, (int)tgid.y); \n"
639+
" auto tB = B.slice((int)tgid.x, 0); \n"
640+
" \n"
641+
" matmul2d< \n"
642+
" matmul2d_descriptor(8, 8, dynamic_extent), \n"
643+
" execution_simdgroups<4>> mm; \n"
644+
" \n"
645+
" auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>(); \n"
646+
" \n"
647+
" auto sA = tA.slice(0, 0); \n"
648+
" auto sB = tB.slice(0, 0); \n"
649+
" mm.run(sB, sA, cT); \n"
650+
" \n"
651+
" auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(4, 4)); \n"
652+
" \n"
653+
" cT.store(tC); \n"
654+
"}";
655+
656+
GGML_LOG_INFO("%s: testing tensor API for bfloat support\n", __func__);
657+
ggml_metal_library_t lib = ggml_metal_library_init_from_source(dev, src_tensor_bf16, false);
658+
if (lib == NULL) {
659+
GGML_LOG_WARN("%s: - the tensor API does not support bfloat - disabling bfloat support\n", __func__);
660+
dev->props.has_bfloat = false;
661+
} else {
662+
ggml_metal_pipeline_t ppl = ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil);
663+
if (!ppl) {
664+
GGML_LOG_WARN("%s: - the tensor API does not support bfloat - disabling bfloat support\n", __func__);
665+
dev->props.has_bfloat = false;
666+
}
667+
668+
ggml_metal_library_free(lib);
669+
}
670+
}
472671

473672
dev->props.use_residency_sets = true;
474673
#if defined(GGML_METAL_HAS_RESIDENCY_SETS)
475674
dev->props.use_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == nil;
476675
#endif
477676

478677
dev->props.use_shared_buffers = dev->props.has_unified_memory;
479-
480678
if (getenv("GGML_METAL_SHARED_BUFFERS_DISABLE") != NULL) {
481679
dev->props.use_shared_buffers = false;
482680
}
@@ -529,6 +727,7 @@ ggml_metal_device_t ggml_metal_device_init(void) {
529727
GGML_LOG_INFO("%s: simdgroup matrix mul. = %s\n", __func__, dev->props.has_simdgroup_mm ? "true" : "false");
530728
GGML_LOG_INFO("%s: has unified memory = %s\n", __func__, dev->props.has_unified_memory ? "true" : "false");
531729
GGML_LOG_INFO("%s: has bfloat = %s\n", __func__, dev->props.has_bfloat ? "true" : "false");
730+
GGML_LOG_INFO("%s: has tensor = %s\n", __func__, dev->props.has_tensor ? "true" : "false");
532731
GGML_LOG_INFO("%s: use residency sets = %s\n", __func__, dev->props.use_residency_sets ? "true" : "false");
533732
GGML_LOG_INFO("%s: use shared buffers = %s\n", __func__, dev->props.use_shared_buffers ? "true" : "false");
534733

0 commit comments

Comments
 (0)