diff --git a/examples/training/finetune-lora.cpp b/examples/training/finetune-lora.cpp index c12e119613b..296c236f276 100644 --- a/examples/training/finetune-lora.cpp +++ b/examples/training/finetune-lora.cpp @@ -2,6 +2,7 @@ #include "common.h" #include "log.h" #include "llama.h" +#include "ggml-backend.h" #include #include @@ -54,6 +55,72 @@ static uint32_t parse_lora_modules(const std::string& modules_str) { return target_modules; } +static bool training_supports_out_prod_f16(const common_params & params) { + std::vector devices; + + if (!params.devices.empty()) { + devices.assign(params.devices.begin(), params.devices.end()); + } else { + ggml_backend_dev_t gpu = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU); + if (gpu) { + devices.push_back(gpu); + } + } + + if (devices.empty()) { + return true; + } + + constexpr int64_t ne0 = 4; + constexpr int64_t ne1 = 3; + constexpr int64_t k = 2; + + struct ggml_tensor src0 = {}; + struct ggml_tensor src1 = {}; + struct ggml_tensor dst = {}; + + src0.type = GGML_TYPE_F16; + src1.type = GGML_TYPE_F32; + dst.type = GGML_TYPE_F32; + + src0.ne[0] = ne0; src0.ne[1] = k; src0.ne[2] = 1; src0.ne[3] = 1; + src1.ne[0] = ne1; src1.ne[1] = k; src1.ne[2] = 1; src1.ne[3] = 1; + dst.ne [0] = ne0; dst.ne [1] = ne1; dst.ne [2] = 1; dst.ne [3] = 1; + + src0.nb[0] = sizeof(ggml_fp16_t); + src0.nb[1] = src0.nb[0] * ne0; + src0.nb[2] = src0.nb[1] * k; + src0.nb[3] = src0.nb[2] * 1; + + src1.nb[0] = sizeof(float); + src1.nb[1] = src1.nb[0] * ne1; + src1.nb[2] = src1.nb[1] * k; + src1.nb[3] = src1.nb[2] * 1; + + dst.nb[0] = sizeof(float); + dst.nb[1] = dst.nb[0] * ne0; + dst.nb[2] = dst.nb[1] * ne1; + dst.nb[3] = dst.nb[2] * 1; + + dst.op = GGML_OP_OUT_PROD; + dst.src[0] = &src0; + dst.src[1] = &src1; + + for (ggml_backend_dev_t dev : devices) { + if (dev == nullptr) { + continue; + } + if (ggml_backend_dev_type(dev) != GGML_BACKEND_DEVICE_TYPE_GPU) { + continue; + } + if (!ggml_backend_dev_supports_op(dev, &dst)) { + return false; + } + } + + return true; +} + static void print_lora_usage() { printf("\nLoRA Fine-tuning Parameters:\n"); printf(" --lora-rank N LoRA rank (default: 8, range: 1-512)\n"); @@ -124,13 +191,16 @@ int main(int argc, char ** argv) { LOG_INF("%s: force disabling memory mapping because it would result in-read-only pointers to the weights\n", __func__); params.use_mmap = false; } - if (params.cache_type_k != GGML_TYPE_F32) { - LOG_INF("%s: force changing k cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__); - params.cache_type_k = GGML_TYPE_F32; - } - if (params.cache_type_v != GGML_TYPE_F32) { - LOG_INF("%s: force changing v cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__); - params.cache_type_v = GGML_TYPE_F32; + const bool supports_out_prod_f16 = training_supports_out_prod_f16(params); + if (!supports_out_prod_f16) { + if (params.cache_type_k != GGML_TYPE_F32) { + LOG_INF("%s: force changing k cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__); + params.cache_type_k = GGML_TYPE_F32; + } + if (params.cache_type_v != GGML_TYPE_F32) { + LOG_INF("%s: force changing v cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__); + params.cache_type_v = GGML_TYPE_F32; + } } common_init(); diff --git a/examples/training/finetune.cpp b/examples/training/finetune.cpp index 561e61f8a21..c826f154f49 100644 --- a/examples/training/finetune.cpp +++ b/examples/training/finetune.cpp @@ -2,6 +2,7 @@ #include "common.h" #include "log.h" #include "llama.h" +#include "ggml-backend.h" #include #include @@ -13,6 +14,72 @@ #pragma warning(disable: 4244 4267) // possible loss of data #endif +static bool training_supports_out_prod_f16(const common_params & params) { + std::vector devices; + + if (!params.devices.empty()) { + devices.assign(params.devices.begin(), params.devices.end()); + } else { + ggml_backend_dev_t gpu = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU); + if (gpu) { + devices.push_back(gpu); + } + } + + if (devices.empty()) { + return true; + } + + constexpr int64_t ne0 = 4; + constexpr int64_t ne1 = 3; + constexpr int64_t k = 2; + + struct ggml_tensor src0 = {}; + struct ggml_tensor src1 = {}; + struct ggml_tensor dst = {}; + + src0.type = GGML_TYPE_F16; + src1.type = GGML_TYPE_F32; + dst.type = GGML_TYPE_F32; + + src0.ne[0] = ne0; src0.ne[1] = k; src0.ne[2] = 1; src0.ne[3] = 1; + src1.ne[0] = ne1; src1.ne[1] = k; src1.ne[2] = 1; src1.ne[3] = 1; + dst.ne [0] = ne0; dst.ne [1] = ne1; dst.ne [2] = 1; dst.ne [3] = 1; + + src0.nb[0] = sizeof(ggml_fp16_t); + src0.nb[1] = src0.nb[0] * ne0; + src0.nb[2] = src0.nb[1] * k; + src0.nb[3] = src0.nb[2] * 1; + + src1.nb[0] = sizeof(float); + src1.nb[1] = src1.nb[0] * ne1; + src1.nb[2] = src1.nb[1] * k; + src1.nb[3] = src1.nb[2] * 1; + + dst.nb[0] = sizeof(float); + dst.nb[1] = dst.nb[0] * ne0; + dst.nb[2] = dst.nb[1] * ne1; + dst.nb[3] = dst.nb[2] * 1; + + dst.op = GGML_OP_OUT_PROD; + dst.src[0] = &src0; + dst.src[1] = &src1; + + for (ggml_backend_dev_t dev : devices) { + if (dev == nullptr) { + continue; + } + if (ggml_backend_dev_type(dev) != GGML_BACKEND_DEVICE_TYPE_GPU) { + continue; + } + if (!ggml_backend_dev_supports_op(dev, &dst)) { + return false; + } + } + + return true; +} + int main(int argc, char ** argv) { common_params params; params.escape = false; @@ -26,13 +93,16 @@ int main(int argc, char ** argv) { __func__); params.use_mmap = false; } - if (params.cache_type_k != GGML_TYPE_F32) { - LOG_INF("%s: force changing k cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__); - params.cache_type_k = GGML_TYPE_F32; - } - if (params.cache_type_v != GGML_TYPE_F32) { - LOG_INF("%s: force changing v cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__); - params.cache_type_v = GGML_TYPE_F32; + const bool supports_out_prod_f16 = training_supports_out_prod_f16(params); + if (!supports_out_prod_f16) { + if (params.cache_type_k != GGML_TYPE_F32) { + LOG_INF("%s: force changing k cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__); + params.cache_type_k = GGML_TYPE_F32; + } + if (params.cache_type_v != GGML_TYPE_F32) { + LOG_INF("%s: force changing v cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__); + params.cache_type_v = GGML_TYPE_F32; + } } common_init(); diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 651943fa923..310b28954de 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -184,6 +184,33 @@ typedef struct { uint64_t nb3; } ggml_metal_kargs_cpy; +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne10; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_out_prod; + typedef struct { int64_t ne10; int64_t ne11; @@ -439,6 +466,21 @@ typedef struct { uint64_t nbf3[3]; } ggml_metal_kargs_rms_norm; +typedef struct { + int32_t ne00; + int32_t ne00_4; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + float eps; +} ggml_metal_kargs_rms_norm_back; + typedef struct { int32_t ne00; int32_t ne00_4; diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index e76fb712631..2f7039f1fb1 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -215,6 +215,14 @@ - (void) dealloc { GGML_METAL_KERNEL_TYPE_REPEAT_F16, GGML_METAL_KERNEL_TYPE_REPEAT_I32, GGML_METAL_KERNEL_TYPE_REPEAT_I16, + GGML_METAL_KERNEL_TYPE_OUT_PROD_F32, + GGML_METAL_KERNEL_TYPE_OUT_PROD_F16_F32, + GGML_METAL_KERNEL_TYPE_OUT_PROD_F32_F16, + GGML_METAL_KERNEL_TYPE_OUT_PROD_F16_F16, + GGML_METAL_KERNEL_TYPE_OUT_PROD_Q8_0_F32, + GGML_METAL_KERNEL_TYPE_OUT_PROD_Q8_0_F16, + GGML_METAL_KERNEL_TYPE_OUT_PROD_Q4_0_F32, + GGML_METAL_KERNEL_TYPE_OUT_PROD_Q4_0_F16, GGML_METAL_KERNEL_TYPE_SCALE, GGML_METAL_KERNEL_TYPE_SCALE_4, GGML_METAL_KERNEL_TYPE_CLAMP, @@ -229,6 +237,8 @@ - (void) dealloc { GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, GGML_METAL_KERNEL_TYPE_SILU, GGML_METAL_KERNEL_TYPE_SILU_4, + GGML_METAL_KERNEL_TYPE_SILU_BACK, + GGML_METAL_KERNEL_TYPE_SILU_BACK_4, GGML_METAL_KERNEL_TYPE_ELU, GGML_METAL_KERNEL_TYPE_ABS, GGML_METAL_KERNEL_TYPE_SGN, @@ -278,6 +288,7 @@ - (void) dealloc { GGML_METAL_KERNEL_TYPE_RMS_NORM, GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL, GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD, + GGML_METAL_KERNEL_TYPE_RMS_NORM_BACK, GGML_METAL_KERNEL_TYPE_L2_NORM, GGML_METAL_KERNEL_TYPE_GROUP_NORM, GGML_METAL_KERNEL_TYPE_NORM, @@ -1137,6 +1148,14 @@ @implementation GGMLMetalClass GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_OUT_PROD_F32, out_prod_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_OUT_PROD_F16_F32, out_prod_f16_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_OUT_PROD_F32_F16, out_prod_f32_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_OUT_PROD_F16_F16, out_prod_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_OUT_PROD_Q8_0_F32, out_prod_q8_0_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_OUT_PROD_Q8_0_F16, out_prod_q8_0_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_OUT_PROD_Q4_0_F32, out_prod_q4_0_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_OUT_PROD_Q4_0_F16, out_prod_q4_0_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true); @@ -1151,6 +1170,8 @@ @implementation GGMLMetalClass GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_BACK, silu_back, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_BACK_4, silu_back_4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ELU, elu, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ABS, abs, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SGN, sgn, true); @@ -1200,6 +1221,7 @@ @implementation GGMLMetalClass GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL, rms_norm_mul, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD, rms_norm_mul_add, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM_BACK, rms_norm_back, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); @@ -1853,6 +1875,37 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex case GGML_OP_DIV: case GGML_OP_ADD_ID: return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_OUT_PROD: + if (op->type != GGML_TYPE_F32) { + return false; + } + + { + const enum ggml_type src0_type = op->src[0]->type; + const enum ggml_type src1_type = op->src[1]->type; + + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { + return true; + } + + if (src0_type == GGML_TYPE_F32 && (src1_type == GGML_TYPE_F32 || src1_type == GGML_TYPE_F16)) { + return true; + } + + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { + return true; + } + + if (src0_type == GGML_TYPE_Q8_0 && (src1_type == GGML_TYPE_F32 || src1_type == GGML_TYPE_F16)) { + return true; + } + + if (src0_type == GGML_TYPE_Q4_0 && (src1_type == GGML_TYPE_F32 || src1_type == GGML_TYPE_F16)) { + return true; + } + } + + return false; case GGML_OP_ACC: case GGML_OP_REPEAT: case GGML_OP_SCALE: @@ -1860,6 +1913,16 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex return true; case GGML_OP_CLAMP: return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_SILU_BACK: + return op->type == GGML_TYPE_F32 && + op->src[0] != NULL && op->src[1] != NULL && + op->src[0]->type == GGML_TYPE_F32 && + op->src[1]->type == GGML_TYPE_F32 && + ggml_is_contiguous_1(op->src[0]) && + ggml_is_contiguous_1(op->src[1]) && + ggml_is_contiguous_1(op) && + ggml_are_same_shape(op, op->src[0]) && + ggml_are_same_shape(op, op->src[1]); case GGML_OP_SQR: case GGML_OP_SQRT: case GGML_OP_SIN: @@ -1875,6 +1938,18 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex case GGML_OP_RMS_NORM: case GGML_OP_L2_NORM: return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0])); + case GGML_OP_RMS_NORM_BACK: + return has_simdgroup_reduction && + op->type == GGML_TYPE_F32 && + op->src[0] != NULL && op->src[1] != NULL && + op->src[0]->type == GGML_TYPE_F32 && + op->src[1]->type == GGML_TYPE_F32 && + op->ne[0] % 4 == 0 && + ggml_is_contiguous_1(op->src[0]) && + ggml_is_contiguous_1(op->src[1]) && + ggml_is_contiguous_1(op) && + ggml_are_same_shape(op, op->src[0]) && + ggml_are_same_shape(op, op->src[1]); case GGML_OP_ARGMAX: return true; case GGML_OP_NORM: @@ -2365,6 +2440,80 @@ static int ggml_metal_encode_node( [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } } break; + case GGML_OP_OUT_PROD: + { + GGML_ASSERT(dstt == GGML_TYPE_F32); + GGML_ASSERT(src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q4_0); + GGML_ASSERT(src1t == GGML_TYPE_F32 || src1t == GGML_TYPE_F16); + + id pipeline = nil; + + if (src0t == GGML_TYPE_Q8_0) { + if (src1t == GGML_TYPE_F32) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_OUT_PROD_Q8_0_F32].pipeline; + } else if (src1t == GGML_TYPE_F16) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_OUT_PROD_Q8_0_F16].pipeline; + } else { + GGML_ABORT("fatal error"); + } + } else if (src0t == GGML_TYPE_Q4_0) { + if (src1t == GGML_TYPE_F32) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_OUT_PROD_Q4_0_F32].pipeline; + } else if (src1t == GGML_TYPE_F16) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_OUT_PROD_Q4_0_F16].pipeline; + } else { + GGML_ABORT("fatal error"); + } + } else if (src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F32) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_OUT_PROD_F16_F32].pipeline; + } else if (src0t == GGML_TYPE_F32 && src1t == GGML_TYPE_F32) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_OUT_PROD_F32].pipeline; + } else if (src0t == GGML_TYPE_F32 && src1t == GGML_TYPE_F16) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_OUT_PROD_F32_F16].pipeline; + } else if (src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F16) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_OUT_PROD_F16_F16].pipeline; + } else { + GGML_ABORT("fatal error"); + } + + ggml_metal_kargs_out_prod args = { + (int32_t) ne00, + (int32_t) ne01, + (int32_t) ne02, + (int32_t) ne03, + nb00, + nb01, + nb02, + nb03, + (int32_t) ne10, + (int32_t) ne11, + (int32_t) ne12, + (int32_t) ne13, + nb10, + nb11, + nb12, + nb13, + (int32_t) ne0, + (int32_t) ne1, + (int32_t) ne2, + (int32_t) ne3, + nb0, + nb1, + nb2, + nb3, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + + const int threads = ne0 < 1 ? 1 : (int) ne0; + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, threads); + + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; case GGML_OP_ADD_ID: { GGML_ASSERT(src0t == GGML_TYPE_F32); @@ -2575,6 +2724,37 @@ static int ggml_metal_encode_node( const int64_t n = ggml_nelements(dst); + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_SILU_BACK: + { + GGML_ASSERT(src0 != NULL); + GGML_ASSERT(src1 != NULL); + GGML_ASSERT(src0t == GGML_TYPE_F32); + GGML_ASSERT(src1t == GGML_TYPE_F32); + GGML_ASSERT(dstt == GGML_TYPE_F32); + GGML_ASSERT(ggml_are_same_shape(dst, src0)); + GGML_ASSERT(ggml_are_same_shape(dst, src1)); + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); + + int64_t n = ggml_nelements(dst); + + id pipeline = nil; + + if (n % 4 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU_BACK_4].pipeline; + n /= 4; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU_BACK].pipeline; + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; case GGML_OP_UNARY: @@ -4508,6 +4688,59 @@ static int ggml_metal_encode_node( [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_RMS_NORM_BACK: + { + GGML_ASSERT(src0 != NULL); + GGML_ASSERT(src1 != NULL); + GGML_ASSERT(ne00 % 4 == 0); + GGML_ASSERT(dstt == GGML_TYPE_F32); + GGML_ASSERT(src0t == GGML_TYPE_F32); + GGML_ASSERT(src1t == GGML_TYPE_F32); + GGML_ASSERT(ggml_are_same_shape(dst, src0)); + GGML_ASSERT(ggml_are_same_shape(dst, src1)); + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + ggml_metal_kargs_rms_norm_back args = { + /*.ne00 =*/ ne00, + /*.ne00_4 =*/ ne00/4, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.eps =*/ eps, + }; + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM_BACK].pipeline; + + int nth = 32; // SIMD width + + while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) { + nth *= 2; + } + + nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup); + nth = MIN(nth, ne00/4); + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + + [encoder setThreadgroupMemoryLength:2*32*sizeof(float) atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; case GGML_OP_L2_NORM: diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 77be3c5c9d8..3f37f59f8bb 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1087,6 +1087,160 @@ template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat; template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat; +template +kernel void kernel_out_prod_impl( + constant ggml_metal_kargs_out_prod & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i3 = tgpig.z; + const int i2 = tgpig.y; + const int i1 = tgpig.x; + + const int dps2 = args.ne02 > 0 ? args.ne2 / args.ne02 : 1; + const int dps3 = args.ne03 > 0 ? args.ne3 / args.ne03 : 1; + + const int i02 = args.ne02 > 0 ? i2 / dps2 : 0; + const int i03 = args.ne03 > 0 ? i3 / dps3 : 0; + + device const char * src0_base = src0 + i02*args.nb02 + i03*args.nb03; + device const char * src1_base = src1 + i1*args.nb10 + i2*args.nb12 + i3*args.nb13; + device char * dst_base = dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3; + + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + float acc = 0.0f; + + for (int i01 = 0; i01 < args.ne01; ++i01) { + device const char * src0_row = src0_base + i01*args.nb01; + const float v0 = (float) *((device const src0_t *)(src0_row + i0*args.nb00)); + const float v1 = (float) *((device const src1_t *)(src1_base + i01*args.nb11)); + + acc += v0 * v1; + } + + *((device float *)(dst_base + i0*args.nb0)) = acc; + } +} + +typedef decltype(kernel_out_prod_impl) kernel_out_prod_f32_t; +typedef decltype(kernel_out_prod_impl) kernel_out_prod_f16_f32_t; +typedef decltype(kernel_out_prod_impl) kernel_out_prod_f32_f16_t; +typedef decltype(kernel_out_prod_impl) kernel_out_prod_f16_t; + +template [[host_name("kernel_out_prod_f32")]] kernel kernel_out_prod_f32_t kernel_out_prod_impl; +template [[host_name("kernel_out_prod_f16_f32")]] kernel kernel_out_prod_f16_f32_t kernel_out_prod_impl; +template [[host_name("kernel_out_prod_f32_f16")]] kernel kernel_out_prod_f32_f16_t kernel_out_prod_impl; +template [[host_name("kernel_out_prod_f16")]] kernel kernel_out_prod_f16_t kernel_out_prod_impl; + +template +kernel void kernel_out_prod_q8_0_impl( + constant ggml_metal_kargs_out_prod & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i3 = tgpig.z; + const int i2 = tgpig.y; + const int i1 = tgpig.x; + + const int dps2 = args.ne02 > 0 ? args.ne2 / args.ne02 : 1; + const int dps3 = args.ne03 > 0 ? args.ne3 / args.ne03 : 1; + + const int i02 = args.ne02 > 0 ? i2 / dps2 : 0; + const int i03 = args.ne03 > 0 ? i3 / dps3 : 0; + + device const char * src0_base = src0 + i02*args.nb02 + i03*args.nb03; + device const char * src1_base = src1 + i1*args.nb10 + i2*args.nb12 + i3*args.nb13; + device char * dst_base = dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3; + + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int ib = i0 / QK8_0; + const int ix = i0 % QK8_0; + + float acc = 0.0f; + + for (int i01 = 0; i01 < args.ne01; ++i01) { + device const char * src0_row_char = src0_base + i01*args.nb01; + device const block_q8_0 * src0_row = (device const block_q8_0 *) src0_row_char; + const block_q8_0 blk = src0_row[ib]; + + const float v0 = (float) blk.d * (float) blk.qs[ix]; + + device const src1_t * src1_row = (device const src1_t *)(src1_base + i01*args.nb11); + const float v1 = (float) src1_row[0]; + + acc += v0 * v1; + } + + *((device float *)(dst_base + i0*args.nb0)) = acc; + } +} + +typedef decltype(kernel_out_prod_q8_0_impl) kernel_out_prod_q8_0_f32_t; +typedef decltype(kernel_out_prod_q8_0_impl) kernel_out_prod_q8_0_f16_t; + +template [[host_name("kernel_out_prod_q8_0_f32")]] kernel kernel_out_prod_q8_0_f32_t kernel_out_prod_q8_0_impl; +template [[host_name("kernel_out_prod_q8_0_f16")]] kernel kernel_out_prod_q8_0_f16_t kernel_out_prod_q8_0_impl; + +template +kernel void kernel_out_prod_q4_0_impl( + constant ggml_metal_kargs_out_prod & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i3 = tgpig.z; + const int i2 = tgpig.y; + const int i1 = tgpig.x; + + const int dps2 = args.ne02 > 0 ? args.ne2 / args.ne02 : 1; + const int dps3 = args.ne03 > 0 ? args.ne3 / args.ne03 : 1; + + const int i02 = args.ne02 > 0 ? i2 / dps2 : 0; + const int i03 = args.ne03 > 0 ? i3 / dps3 : 0; + + device const char * src0_base = src0 + i02*args.nb02 + i03*args.nb03; + device const char * src1_base = src1 + i1*args.nb10 + i2*args.nb12 + i3*args.nb13; + device char * dst_base = dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3; + + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int ib = i0 / QK4_0; + const int ix = i0 % QK4_0; + + float acc = 0.0f; + + for (int i01 = 0; i01 < args.ne01; ++i01) { + device const char * src0_row_char = src0_base + i01*args.nb01; + device const block_q4_0 * src0_row = (device const block_q4_0 *) src0_row_char; + const block_q4_0 blk = src0_row[ib]; + + const uint8_t q = blk.qs[ix / 2]; + const int nibble = (ix & 1) ? (q >> 4) : (q & 0x0F); + const float v0 = ((float) blk.d) * ((float) nibble - 8.0f); + + device const src1_t * src1_row = (device const src1_t *)(src1_base + i01*args.nb11); + const float v1 = (float) src1_row[0]; + + acc += v0 * v1; + } + + *((device float *)(dst_base + i0*args.nb0)) = acc; + } +} + +typedef decltype(kernel_out_prod_q4_0_impl) kernel_out_prod_q4_0_f32_t; +typedef decltype(kernel_out_prod_q4_0_impl) kernel_out_prod_q4_0_f16_t; + +template [[host_name("kernel_out_prod_q4_0_f32")]] kernel kernel_out_prod_q4_0_f32_t kernel_out_prod_q4_0_impl; +template [[host_name("kernel_out_prod_q4_0_f16")]] kernel kernel_out_prod_q4_0_f16_t kernel_out_prod_q4_0_impl; + // assumption: src1 is a row // broadcast src1 into src0 template @@ -1374,6 +1528,28 @@ kernel void kernel_silu_4( dst[tpig] = x / (1.0f + exp(-x)); } +kernel void kernel_silu_back( + device const float * grad, + device const float * src1, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + const float dy = grad[tpig]; + const float x = src1[tpig]; + const float s = 1.0f/(1.0f + exp(-x)); + dst[tpig] = dy*s*(1.0f + x*(1.0f - s)); +} + +kernel void kernel_silu_back_4( + device const float4 * grad, + device const float4 * src1, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + const float4 dy = grad[tpig]; + const float4 x = src1[tpig]; + const float4 s = 1.0f/(1.0f + exp(-x)); + dst[tpig] = dy*s*(1.0f + x*(1.0f - s)); +} + kernel void kernel_elu( device const float * src0, device float * dst, @@ -2547,6 +2723,73 @@ template [[host_name("kernel_rms_norm")]] kernel kernel_rms_norm_fuse_t template [[host_name("kernel_rms_norm_mul")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<2>; template [[host_name("kernel_rms_norm_mul_add")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<3>; +kernel void kernel_rms_norm_back( + constant ggml_metal_kargs_rms_norm_back & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup float * shmem_f32 [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + threadgroup float * shmem_xx = shmem_f32; + threadgroup float * shmem_xdz = shmem_f32 + 32; + + if (sgitg == 0) { + shmem_xx[tiisg] = 0.0f; + shmem_xdz[tiisg] = 0.0f; + } + + const int i01 = tgpig.x; + const int i02 = tgpig.y; + const int i03 = tgpig.z; + + device const float4 * dz = (device const float4 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); + device const float4 * x = (device const float4 *) (src1 + i03*args.nb13 + i02*args.nb12 + i01*args.nb11); + + float sum_xx = 0.0f; + float sum_xdz = 0.0f; + + for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) { + const float4 x4 = x[i00]; + const float4 dz4 = dz[i00]; + sum_xx += dot(x4, x4); + sum_xdz += dot(x4, dz4); + } + + sum_xx = simd_sum(sum_xx); + sum_xdz = simd_sum(sum_xdz); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + shmem_xx[sgitg] = sum_xx; + shmem_xdz[sgitg] = sum_xdz; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + sum_xx = shmem_xx[tiisg]; + sum_xdz = shmem_xdz[tiisg]; + + sum_xx = simd_sum(sum_xx); + sum_xdz = simd_sum(sum_xdz); + + const float mean_eps = sum_xx/args.ne00 + args.eps; + const float rrms = rsqrt(mean_eps); + const float sum_eps = sum_xx + args.eps*args.ne00; + const float scale = -sum_xdz/sum_eps; + + device float4 * y = (device float4 *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1); + + for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) { + float4 dx = dz[i00] + x[i00]*scale; + y[i00] = dx*rrms; + } +} + kernel void kernel_l2_norm( constant ggml_metal_kargs_l2_norm & args, device const char * src0, diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 540254f31da..a4eb0142919 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -216,7 +216,7 @@ class vk_memory_logger; class vk_perf_logger; static void ggml_vk_destroy_buffer(vk_buffer& buf); -static constexpr uint32_t mul_mat_vec_max_cols = 8; +static constexpr uint32_t mul_mat_vec_max_cols = 16; static constexpr uint32_t p021_max_gqa_ratio = 8; enum vk_device_architecture { @@ -557,6 +557,7 @@ struct vk_device_struct { vk_pipeline pipeline_out_prod_f16_f32; vk_pipeline pipeline_out_prod_q4_0; vk_pipeline pipeline_out_prod_q8_0; + vk_pipeline pipeline_out_prod_tq2_0; vk_pipeline pipeline_argmax_f32; vk_pipeline pipeline_count_equal_i32; vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16; @@ -2798,6 +2799,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_TQ2_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_TQ2_0], matmul_tq2_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); @@ -3063,6 +3065,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32", arr_dmmv_q5_0_f32_f32_len[reduc], arr_dmmv_q5_0_f32_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f32_f32", arr_dmmv_q5_1_f32_f32_len[reduc], arr_dmmv_q5_1_f32_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f32_f32", arr_dmmv_q8_0_f32_f32_len[reduc], arr_dmmv_q8_0_f32_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_TQ2_0][i], ("mul_mat_vec_tq2_0_f32_f32_"+std::to_string(i+1)).c_str(), mul_mat_vec_tq2_0_f32_f32_len, mul_mat_vec_tq2_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f32_f32", arr_dmmv_q2_k_f32_f32_len[reduc16], arr_dmmv_q2_k_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f32_f32", arr_dmmv_q3_k_f32_f32_len[reduc16], arr_dmmv_q3_k_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32", arr_dmmv_q4_k_f32_f32_len[reduc16], arr_dmmv_q4_k_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); @@ -3149,6 +3152,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_0], "dequant_q5_0", dequant_q5_0_len, dequant_q5_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_1], "dequant_q5_1", dequant_q5_1_len, dequant_q5_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q8_0], "dequant_q8_0", dequant_q8_0_len, dequant_q8_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_TQ2_0], "dequant_tq2_0", dequant_tq2_0_len, dequant_tq2_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q2_K], "dequant_q2_k", dequant_q2_k_len, dequant_q2_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q3_K], "dequant_q3_k", dequant_q3_k_len, dequant_q3_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_K], "dequant_q4_k", dequant_q4_k_len, dequant_q4_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); @@ -3429,6 +3433,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_out_prod_f32, "out_prod_f32", out_prod_f32_len, out_prod_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, { 0 }, 1); ggml_vk_create_pipeline(device, device->pipeline_out_prod_q4_0, "out_prod_q4_0", out_prod_q4_0_len, out_prod_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, { 0 }, 1, true); ggml_vk_create_pipeline(device, device->pipeline_out_prod_q8_0, "out_prod_q8_0", out_prod_q8_0_len, out_prod_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, { 0 }, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_out_prod_tq2_0, "out_prod_tq2_0", out_prod_tq2_0_len, out_prod_tq2_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, { 0 }, 1, true); // TODO: should we have device->subgroup_size here or 0? ggml_vk_create_pipeline(device, device->pipeline_out_prod_f16_f32, "out_prod_f16_f32", out_prod_f16_f32_len, out_prod_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, { 0 }, 1); @@ -4668,6 +4673,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: + case GGML_TYPE_TQ2_0: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -4739,6 +4745,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: + case GGML_TYPE_TQ2_0: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -4796,6 +4803,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: + case GGML_TYPE_TQ2_0: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -4875,6 +4883,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: + case GGML_TYPE_TQ2_0: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -4921,6 +4930,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: + case GGML_TYPE_TQ2_0: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -7797,6 +7807,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const if (src0->type == GGML_TYPE_F32) return ctx->device->pipeline_out_prod_f32; if (src0->type == GGML_TYPE_Q4_0) return ctx->device->pipeline_out_prod_q4_0; if (src0->type == GGML_TYPE_Q8_0) return ctx->device->pipeline_out_prod_q8_0; + if (src0->type == GGML_TYPE_TQ2_0) return ctx->device->pipeline_out_prod_tq2_0; } if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_out_prod_f16_f32; @@ -12308,6 +12319,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: + case GGML_TYPE_TQ2_0: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -12502,8 +12514,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm } if ( - src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_I32 || - src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_F32 + (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_I32) || + (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_F32) ) { return true; } @@ -12547,6 +12559,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_F16: case GGML_TYPE_Q4_0: case GGML_TYPE_Q8_0: + case GGML_TYPE_TQ2_0: return true; default: return false; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp index d3127fbd986..16f56e25135 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp @@ -434,6 +434,30 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) { } #endif +#if defined(DATA_A_TQ2_0) +// TQ2_0 ternary dequantization: {0,1,2} -> {-1,0,+1} via (q-1) mapping +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a[a_offset + ib].qs[iqs]); + const uint c0 = (vui >> 0) & 3; + const uint c1 = (vui >> 2) & 3; + const float q0 = float(c0) - 1.0f; + const float q1 = float(c1) - 1.0f; + return vec2(q0, q1); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a[a_offset + ib].qs[iqs]); + const uint c0 = (vui >> 0) & 3; + const uint c1 = (vui >> 2) & 3; + const uint c2 = (vui >> 4) & 3; + const uint c3 = (vui >> 6) & 3; + const float q0 = float(c0) - 1.0f; + const float q1 = float(c1) - 1.0f; + const float q2 = float(c2) - 1.0f; + const float q3 = float(c3) - 1.0f; + return vec4(q0, q1, q2, q3); +} +#endif + #if defined(DATA_A_MXFP4) vec2 dequantize(uint ib, uint iqs, uint a_offset) { const uint vui = uint(data_a[a_offset + ib].qs[iqs]); @@ -461,7 +485,7 @@ vec2 get_dm(uint ib, uint a_offset) { } #endif -#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) +#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_TQ2_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) vec2 get_dm(uint ib, uint a_offset) { return vec2(float(data_a[a_offset + ib].d), 0); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp index 706540fd851..e0c36d30f07 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp @@ -654,6 +654,25 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor } #endif +#if defined(DATA_A_TQ2_0) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufTQ2_0 { + block_tq2_0 block; +}; + +float16_t dequantFuncTQ2_0(const in decodeBufTQ2_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + const uint iqs = idx / 4; + const uint iqs_offset = idx % 4; + const uint vui = uint(bl.block.qs[iqs]); + const uint c = (vui >> (2 * iqs_offset)) & 3; + const float q = float(c) - 1.0f; + float16_t ret = d * float16_t(q); + return ret; +} +#endif + #if defined(DATA_A_MXFP4) layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufMXFP4 { block_mxfp4 block; @@ -715,6 +734,8 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords #define dequantFuncA dequantFuncIQ4_XS #elif defined(DATA_A_IQ4_NL) #define dequantFuncA dequantFuncIQ4_NL +#elif defined(DATA_A_TQ2_0) +#define dequantFuncA dequantFuncTQ2_0 #elif defined(DATA_A_MXFP4) #define dequantFuncA dequantFuncMXFP4 #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_tq2_0.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_tq2_0.comp new file mode 100644 index 00000000000..f2fafcb3d49 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_tq2_0.comp @@ -0,0 +1,36 @@ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require + +#include "types.comp" + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +layout (push_constant) uniform parameter { + uint ne; +} p; + +layout (local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint i = gl_GlobalInvocationID.x * 4; + + if (i >= p.ne) { + return; + } + + const uint ib = i / QUANT_K; // block index + const uint iqs = (i % QUANT_K) / 4; // quant index within block (byte index) + const uint bit_pos_base = (i % 4) * 2; // bit position within byte + + const float d = float(data_a[ib].d); + + for (uint j = 0; j < 4 && (i + j) < p.ne; ++j) { + const uint local_iqs = ((i + j) % QUANT_K) / 4; // byte index for this element + const uint bit_pos = ((i + j) % 4) * 2; // bit position for this element + const uint vui = uint(data_a[ib].qs[local_iqs]); + const uint q = (vui >> bit_pos) & 3; + data_b[i + j] = D_TYPE(d * (float(q) - 1.0f)); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_tq2_0.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_tq2_0.comp new file mode 100644 index 00000000000..e49f8f3139b --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_tq2_0.comp @@ -0,0 +1,66 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + const uint tid = gl_LocalInvocationID.x; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = tid; i < num_blocks_per_row; i += gl_WorkGroupSize.x) { + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row + n) * num_blocks_per_row; + const float d = float(data_a[ib0 + i].d); + + [[unroll]] for (uint j = 0; j < 64; j += 32) { + [[unroll]] for (uint l = 0; l < 4; ++l) { + [[unroll]] for (uint k = 0; k < 32; ++k) { + // Extract quantized value: ((x[i].qs[j + k] >> (l*2)) & 3) - 1 + const uint q_byte = uint(data_a[ib0 + i].qs[j + k]); + const uint shift = l * 2; + const uint q = (q_byte >> shift) & 3; + const FLOAT_TYPE dequant_val = FLOAT_TYPE(d * (float(q) - 1.0f)); // CPU kernel: (q-1)*d + + // y-data access pattern: y[i].qs[j*4 + l*32 + k] + const uint b_idx = i * QUANT_K + j * 4 + l * 32 + k; + if (b_idx < p.ncols) { + [[unroll]] for (uint jcol = 0; jcol < NUM_COLS; ++jcol) { + temp[jcol][n] += dequant_val * FLOAT_TYPE(data_b[jcol * p.batch_stride_b + b_offset + b_idx]); + } + } + } + } + } + } + } + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index f6a7761ffa0..4b0ab7b96d7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -450,6 +450,22 @@ void main() { buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); buf_a[buf_idx + 2] = FLOAT_TYPE(v.z); buf_a[buf_idx + 3] = FLOAT_TYPE(v.w); +#elif defined(DATA_A_TQ2_0) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 128; // 2 values per idx (like Q2_K) + const uint iqs = idx % 128; // 0..127 + const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // Q2_K indexing pattern + const uint qsshift = ((iqs % 64) / 16) * 2; // Q2_K shift: 0,2,4,6 + + const float d = float(data_a[ib].d); + + const uvec2 qs = uvec2(data_a[ib].qs[qsi], data_a[ib].qs[qsi + 1]); + const vec2 v = d * (vec2((qs >> qsshift) & 3) - 1.0f); // (q-1)*d + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); #elif defined(DATA_A_Q2_K) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/out_prod_tq2_0.comp b/ggml/src/ggml-vulkan/vulkan-shaders/out_prod_tq2_0.comp new file mode 100644 index 00000000000..e877a6cd36c --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/out_prod_tq2_0.comp @@ -0,0 +1,58 @@ +#version 450 + +#include "types.comp" +#include "generic_binary_head.comp" +#include "dequant_funcs.comp" + +const uint num_threads = 256; +layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; + +void get_dst_indices(uint idx, out uint i20, out uint i21, out uint i22, out uint i23) { + i23 = fastdiv(idx, (p.ne22*p.ne21*p.ne20)); + const uint i23_offset = i23 * p.ne22*p.ne21*p.ne20; + i22 = fastdiv((idx - i23_offset), (p.ne21*p.ne20)); + const uint i22_offset = i22*p.ne21*p.ne20; + i21 = (idx - i23_offset - i22_offset) / p.ne20; + i20 = idx - i23_offset - i22_offset - i21*p.ne20; +} + +void main() { + // num_threads * num_iter must equal 512 to match the wg_denoms and get_idx + const uint num_iter = 2; + + const uint broadcast2 = uint(p.param2); + const uint broadcast3 = p.param3; + + uint idx = get_idx(); + + [[unroll]] for (uint it = 0; it < num_iter; ++it) { + if (idx < p.ne) { + uint i0, i1, i2, i3; + get_dst_indices(idx, i0, i1, i2, i3); + + float acc = 0.0f; + + for (uint k = 0; k < p.ne01; k += 1) { + const uint a_block_base = get_aoffset() + (i3 / broadcast3) * p.nb03 + (i2 / broadcast2) * p.nb02 + k * p.nb01; + const uint ib = a_block_base + (i0 / QUANT_K); + const uint r = (i0 % QUANT_K); + const uint iqs = (r % 32u) + 32u * (r / 128u); + const uint sub = (r % 128u) / 32u; + + const vec4 v = dequantize4(ib, iqs, 0); + const vec2 dm = get_dm(ib, 0); + + float qv = (sub == 0u) ? v.x : (sub == 1u) ? v.y : (sub == 2u) ? v.z : v.w; + const float a_val = qv * dm.x + dm.y; + + const uint b_idx = src1_idx(i1, k, i2, i3); + const float b = data_b[get_boffset() + b_idx]; + acc += a_val * b; + } + + uint d_idx = dst_idx(i0, i1, i2, i3); + data_d[get_doffset() + d_idx] = acc; + } + idx += num_threads; + } +} \ No newline at end of file diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.comp b/ggml/src/ggml-vulkan/vulkan-shaders/types.comp index c2acc803f68..ab96f0485b4 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.comp @@ -1355,6 +1355,22 @@ struct block_iq4_nl_packed16 #define A_TYPE_PACKED16 block_iq4_nl_packed16 #endif +// TQ2_0 +#define QUANT_K_TQ2_0 256 +#define QUANT_R_TQ2_0 4 + +struct block_tq2_0 +{ + uint8_t qs[QUANT_K_TQ2_0/QUANT_R_TQ2_0]; // 256/4 = 64 bytes + float16_t d; +}; + +#if defined(DATA_A_TQ2_0) +#define QUANT_K QUANT_K_TQ2_0 +#define QUANT_R QUANT_R_TQ2_0 +#define A_TYPE block_tq2_0 +#endif + #define QUANT_K_MXFP4 32 #define QUANT_R_MXFP4 2 diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index eb9a4476dcf..61ebc6c61c6 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -50,6 +50,7 @@ const std::vector type_names = { "q5_0", "q5_1", "q8_0", + "tq2_0", "q2_k", "q3_k", "q4_k", @@ -504,6 +505,9 @@ void process_shaders() { // mul mat vec std::string data_a_key = "DATA_A_" + to_uppercase(tname); std::string shader = (string_ends_with(tname, "_k") || string_starts_with(tname, "iq1_") || string_starts_with(tname, "iq2_") || string_starts_with(tname, "iq3_")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp"; + if (tname == "tq2_0") { + shader = "mul_mat_vec_tq2_0.comp"; + } string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}})); @@ -714,6 +718,7 @@ void process_shaders() { string_to_spv("out_prod_f16_f32", "out_prod.comp", merge_maps(base_dict, {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("out_prod_q4_0", "out_prod_q4_0.comp", merge_maps(base_dict, {{"DATA_A_Q4_0", "1"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("out_prod_q8_0", "out_prod_q8_0.comp", merge_maps(base_dict, {{"DATA_A_Q8_0", "1"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("out_prod_tq2_0", "out_prod_tq2_0.comp", merge_maps(base_dict, {{"DATA_A_TQ2_0", "1"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}}); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index adf91ab6f9e..4852c4d148d 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -5555,7 +5555,8 @@ static const ggml_type all_types[] = { GGML_TYPE_Q2_K, GGML_TYPE_Q3_K, GGML_TYPE_Q4_K, GGML_TYPE_Q5_K, GGML_TYPE_Q6_K, - // GGML_TYPE_TQ1_0, GGML_TYPE_TQ2_0, // TODO: implement for all backends + // GGML_TYPE_TQ1_0, + GGML_TYPE_TQ2_0, // TODO: implement for all backends GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S, GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M, GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,