|
21 | 21 | #define GGML_METAL_HAS_RESIDENCY_SETS 1 |
22 | 22 | #endif |
23 | 23 |
|
24 | | -// overload of MTLGPUFamilyMetal3 (not available in some environments) |
| 24 | +// overload of MTLGPUFamilyMetalX (not available in some environments) |
25 | 25 | static const NSInteger MTLGPUFamilyMetal3_GGML = 5001; |
| 26 | +static const NSInteger MTLGPUFamilyMetal4_GGML = 5002; |
26 | 27 |
|
27 | 28 | // virtual address for GPU memory allocations |
28 | 29 | static atomic_uintptr_t g_addr_device = 0x000000400ULL; |
@@ -261,6 +262,10 @@ ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) { |
261 | 262 | [prep setObject:@"1" forKey:@"GGML_METAL_HAS_BF16"]; |
262 | 263 | } |
263 | 264 |
|
| 265 | + if (ggml_metal_device_get_props(dev)->has_tensor) { |
| 266 | + [prep setObject:@"1" forKey:@"GGML_METAL_HAS_TENSOR"]; |
| 267 | + } |
| 268 | + |
264 | 269 | #if GGML_METAL_EMBED_LIBRARY |
265 | 270 | [prep setObject:@"1" forKey:@"GGML_METAL_EMBED_LIBRARY"]; |
266 | 271 | #endif |
@@ -298,6 +303,72 @@ ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) { |
298 | 303 | return res; |
299 | 304 | } |
300 | 305 |
|
| 306 | +ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev, const char * source, bool verbose) { |
| 307 | + if (source == NULL) { |
| 308 | + GGML_LOG_ERROR("%s: source is NULL\n", __func__); |
| 309 | + return NULL; |
| 310 | + } |
| 311 | + |
| 312 | + id<MTLDevice> device = ggml_metal_device_get_obj(dev); |
| 313 | + id<MTLLibrary> library = nil; |
| 314 | + NSError * error = nil; |
| 315 | + |
| 316 | + const int64_t t_start = ggml_time_us(); |
| 317 | + |
| 318 | + NSString * src = [[NSString alloc] initWithBytes:source |
| 319 | + length:strlen(source) |
| 320 | + encoding:NSUTF8StringEncoding]; |
| 321 | + if (!src) { |
| 322 | + GGML_LOG_ERROR("%s: failed to create NSString from source\n", __func__); |
| 323 | + return NULL; |
| 324 | + } |
| 325 | + |
| 326 | + @autoreleasepool { |
| 327 | + NSMutableDictionary * prep = [NSMutableDictionary dictionary]; |
| 328 | + |
| 329 | + MTLCompileOptions * options = [MTLCompileOptions new]; |
| 330 | + options.preprocessorMacros = prep; |
| 331 | + |
| 332 | + library = [device newLibraryWithSource:src options:options error:&error]; |
| 333 | + if (error) { |
| 334 | + if (verbose) { |
| 335 | + GGML_LOG_ERROR("%s: error compiling source: %s\n", __func__, [[error description] UTF8String]); |
| 336 | + } else { |
| 337 | + GGML_LOG_ERROR("%s: error compiling source\n", __func__); |
| 338 | + } |
| 339 | + library = nil; |
| 340 | + } |
| 341 | + |
| 342 | + [options release]; |
| 343 | + } |
| 344 | + |
| 345 | + [src release]; |
| 346 | + |
| 347 | + if (!library) { |
| 348 | + if (verbose) { |
| 349 | + GGML_LOG_ERROR("%s: failed to create Metal library from source\n", __func__); |
| 350 | + } |
| 351 | + |
| 352 | + return NULL; |
| 353 | + } |
| 354 | + |
| 355 | + if (verbose) { |
| 356 | + GGML_LOG_INFO("%s: compiled in %.3f sec\n", __func__, (ggml_time_us() - t_start) / 1e6); |
| 357 | + } |
| 358 | + |
| 359 | + ggml_metal_library_t res = calloc(1, sizeof(struct ggml_metal_library)); |
| 360 | + if (!res) { |
| 361 | + GGML_LOG_ERROR("%s: calloc failed\n", __func__); |
| 362 | + return NULL; |
| 363 | + } |
| 364 | + |
| 365 | + res->obj = library; |
| 366 | + res->device = device; |
| 367 | + res->pipelines = ggml_metal_pipelines_init(); |
| 368 | + |
| 369 | + return res; |
| 370 | +} |
| 371 | + |
301 | 372 | void ggml_metal_library_free(ggml_metal_library_t lib) { |
302 | 373 | if (!lib) { |
303 | 374 | return; |
@@ -345,23 +416,31 @@ ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t l |
345 | 416 | if (!mtl_function) { |
346 | 417 | ggml_critical_section_end(); |
347 | 418 |
|
348 | | - GGML_LOG_ERROR("%s: error: failed to compile pipeline: base = '%s', name = '%s'\n", __func__, base, name); |
| 419 | + GGML_LOG_ERROR("%s: failed to compile pipeline: base = '%s', name = '%s'\n", __func__, base, name); |
349 | 420 | if (error) { |
350 | | - GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); |
| 421 | + GGML_LOG_ERROR("%s: %s\n", __func__, [[error description] UTF8String]); |
351 | 422 | } |
352 | 423 |
|
353 | 424 | return nil; |
354 | 425 | } |
355 | 426 |
|
356 | 427 | res->obj = [lib->device newComputePipelineStateWithFunction:mtl_function error:&error]; |
357 | 428 |
|
358 | | - ggml_metal_pipelines_add(lib->pipelines, name, res); |
359 | | - |
360 | 429 | [mtl_function release]; |
361 | 430 |
|
362 | 431 | GGML_LOG_DEBUG("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, name, (void *) res->obj, |
363 | 432 | (int) res->obj.maxTotalThreadsPerThreadgroup, |
364 | 433 | (int) res->obj.threadExecutionWidth); |
| 434 | + |
| 435 | + if (res->obj.maxTotalThreadsPerThreadgroup == 0 || res->obj.threadExecutionWidth == 0) { |
| 436 | + ggml_critical_section_end(); |
| 437 | + |
| 438 | + GGML_LOG_ERROR("%s: incompatible pipeline %s\n", __func__, name); |
| 439 | + |
| 440 | + return nil; |
| 441 | + } |
| 442 | + |
| 443 | + ggml_metal_pipelines_add(lib->pipelines, name, res); |
365 | 444 | } |
366 | 445 |
|
367 | 446 | ggml_critical_section_end(); |
@@ -469,14 +548,133 @@ ggml_metal_device_t ggml_metal_device_init(void) { |
469 | 548 |
|
470 | 549 | dev->props.has_bfloat = [dev->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML]; |
471 | 550 | dev->props.has_bfloat |= [dev->mtl_device supportsFamily:MTLGPUFamilyApple6]; |
| 551 | + if (getenv("GGML_METAL_BF16_DISABLE") != NULL) { |
| 552 | + dev->props.has_bfloat = false; |
| 553 | + } |
| 554 | + |
| 555 | + dev->props.has_tensor = [dev->mtl_device supportsFamily:MTLGPUFamilyMetal4_GGML]; |
| 556 | + if (getenv("GGML_METAL_TENSOR_DISABLE") != NULL) { |
| 557 | + dev->props.has_tensor = false; |
| 558 | + } |
| 559 | + |
| 560 | + // note: disable the tensor API by default for old chips because with the current implementation it is not useful |
| 561 | + // - M2 Ultra: ~5% slower |
| 562 | + // - M4, M4 Max: no significant difference |
| 563 | + // |
| 564 | + // TODO: try to update the tensor API kernels to at least match the simdgroup performance |
| 565 | + if (getenv("GGML_METAL_TENSOR_ENABLE") == NULL && |
| 566 | + ![[dev->mtl_device name] containsString:@"M5"] && |
| 567 | + ![[dev->mtl_device name] containsString:@"M6"]) { |
| 568 | + GGML_LOG_WARN("%s: tensor API disabled for pre-M5 device\n", __func__); |
| 569 | + dev->props.has_tensor = false; |
| 570 | + } |
| 571 | + |
| 572 | + // double-check that the tensor API compiles |
| 573 | + if (dev->props.has_tensor) { |
| 574 | + const char * src_tensor_f16 = "\n" |
| 575 | + "#include <metal_stdlib> \n" |
| 576 | + "#include <metal_tensor> \n" |
| 577 | + "#include <MetalPerformancePrimitives/MetalPerformancePrimitives.h> \n" |
| 578 | + " \n" |
| 579 | + "using namespace metal; \n" |
| 580 | + "using namespace mpp::tensor_ops; \n" |
| 581 | + " \n" |
| 582 | + "kernel void dummy_kernel( \n" |
| 583 | + " tensor<device half, dextents<int32_t, 2>> A [[buffer(0)]], \n" |
| 584 | + " tensor<device half, dextents<int32_t, 2>> B [[buffer(1)]], \n" |
| 585 | + " device float * C [[buffer(2)]], \n" |
| 586 | + " uint2 tgid [[threadgroup_position_in_grid]]) \n" |
| 587 | + "{ \n" |
| 588 | + " auto tA = A.slice(0, (int)tgid.y); \n" |
| 589 | + " auto tB = B.slice((int)tgid.x, 0); \n" |
| 590 | + " \n" |
| 591 | + " matmul2d< \n" |
| 592 | + " matmul2d_descriptor(8, 8, dynamic_extent), \n" |
| 593 | + " execution_simdgroups<4>> mm; \n" |
| 594 | + " \n" |
| 595 | + " auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>(); \n" |
| 596 | + " \n" |
| 597 | + " auto sA = tA.slice(0, 0); \n" |
| 598 | + " auto sB = tB.slice(0, 0); \n" |
| 599 | + " mm.run(sB, sA, cT); \n" |
| 600 | + " \n" |
| 601 | + " auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(4, 4)); \n" |
| 602 | + " \n" |
| 603 | + " cT.store(tC); \n" |
| 604 | + "}"; |
| 605 | + |
| 606 | + GGML_LOG_INFO("%s: testing tensor API for f16 support\n", __func__); |
| 607 | + ggml_metal_library_t lib = ggml_metal_library_init_from_source(dev, src_tensor_f16, false); |
| 608 | + if (lib == NULL) { |
| 609 | + GGML_LOG_WARN("%s: - the tensor API is not supported in this environment - disabling\n", __func__); |
| 610 | + dev->props.has_tensor = false; |
| 611 | + } else { |
| 612 | + ggml_metal_pipeline_t ppl = ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil); |
| 613 | + if (!ppl) { |
| 614 | + GGML_LOG_WARN("%s: - the tensor API is not supported in this environment - disabling\n", __func__); |
| 615 | + dev->props.has_tensor = false; |
| 616 | + } |
| 617 | + |
| 618 | + ggml_metal_library_free(lib); |
| 619 | + } |
| 620 | + } |
| 621 | + |
| 622 | + // try to compile a dummy kernel to determine if the tensor API is supported for bfloat |
| 623 | + if (dev->props.has_tensor && dev->props.has_bfloat) { |
| 624 | + const char * src_tensor_bf16 = "\n" |
| 625 | + "#include <metal_stdlib> \n" |
| 626 | + "#include <metal_tensor> \n" |
| 627 | + "#include <MetalPerformancePrimitives/MetalPerformancePrimitives.h> \n" |
| 628 | + " \n" |
| 629 | + "using namespace metal; \n" |
| 630 | + "using namespace mpp::tensor_ops; \n" |
| 631 | + " \n" |
| 632 | + "kernel void dummy_kernel( \n" |
| 633 | + " tensor<device bfloat, dextents<int32_t, 2>> A [[buffer(0)]], \n" |
| 634 | + " tensor<device bfloat, dextents<int32_t, 2>> B [[buffer(1)]], \n" |
| 635 | + " device float * C [[buffer(2)]], \n" |
| 636 | + " uint2 tgid [[threadgroup_position_in_grid]]) \n" |
| 637 | + "{ \n" |
| 638 | + " auto tA = A.slice(0, (int)tgid.y); \n" |
| 639 | + " auto tB = B.slice((int)tgid.x, 0); \n" |
| 640 | + " \n" |
| 641 | + " matmul2d< \n" |
| 642 | + " matmul2d_descriptor(8, 8, dynamic_extent), \n" |
| 643 | + " execution_simdgroups<4>> mm; \n" |
| 644 | + " \n" |
| 645 | + " auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>(); \n" |
| 646 | + " \n" |
| 647 | + " auto sA = tA.slice(0, 0); \n" |
| 648 | + " auto sB = tB.slice(0, 0); \n" |
| 649 | + " mm.run(sB, sA, cT); \n" |
| 650 | + " \n" |
| 651 | + " auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(4, 4)); \n" |
| 652 | + " \n" |
| 653 | + " cT.store(tC); \n" |
| 654 | + "}"; |
| 655 | + |
| 656 | + GGML_LOG_INFO("%s: testing tensor API for bfloat support\n", __func__); |
| 657 | + ggml_metal_library_t lib = ggml_metal_library_init_from_source(dev, src_tensor_bf16, false); |
| 658 | + if (lib == NULL) { |
| 659 | + GGML_LOG_WARN("%s: - the tensor API does not support bfloat - disabling bfloat support\n", __func__); |
| 660 | + dev->props.has_bfloat = false; |
| 661 | + } else { |
| 662 | + ggml_metal_pipeline_t ppl = ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil); |
| 663 | + if (!ppl) { |
| 664 | + GGML_LOG_WARN("%s: - the tensor API does not support bfloat - disabling bfloat support\n", __func__); |
| 665 | + dev->props.has_bfloat = false; |
| 666 | + } |
| 667 | + |
| 668 | + ggml_metal_library_free(lib); |
| 669 | + } |
| 670 | + } |
472 | 671 |
|
473 | 672 | dev->props.use_residency_sets = true; |
474 | 673 | #if defined(GGML_METAL_HAS_RESIDENCY_SETS) |
475 | 674 | dev->props.use_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == nil; |
476 | 675 | #endif |
477 | 676 |
|
478 | 677 | dev->props.use_shared_buffers = dev->props.has_unified_memory; |
479 | | - |
480 | 678 | if (getenv("GGML_METAL_SHARED_BUFFERS_DISABLE") != NULL) { |
481 | 679 | dev->props.use_shared_buffers = false; |
482 | 680 | } |
@@ -529,6 +727,7 @@ ggml_metal_device_t ggml_metal_device_init(void) { |
529 | 727 | GGML_LOG_INFO("%s: simdgroup matrix mul. = %s\n", __func__, dev->props.has_simdgroup_mm ? "true" : "false"); |
530 | 728 | GGML_LOG_INFO("%s: has unified memory = %s\n", __func__, dev->props.has_unified_memory ? "true" : "false"); |
531 | 729 | GGML_LOG_INFO("%s: has bfloat = %s\n", __func__, dev->props.has_bfloat ? "true" : "false"); |
| 730 | + GGML_LOG_INFO("%s: has tensor = %s\n", __func__, dev->props.has_tensor ? "true" : "false"); |
532 | 731 | GGML_LOG_INFO("%s: use residency sets = %s\n", __func__, dev->props.use_residency_sets ? "true" : "false"); |
533 | 732 | GGML_LOG_INFO("%s: use shared buffers = %s\n", __func__, dev->props.use_shared_buffers ? "true" : "false"); |
534 | 733 |
|
|
0 commit comments