Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 77 additions & 7 deletions examples/training/finetune-lora.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "common.h"
#include "log.h"
#include "llama.h"
#include "ggml-backend.h"

#include <cmath>
#include <cstdio>
Expand Down Expand Up @@ -54,6 +55,72 @@ static uint32_t parse_lora_modules(const std::string& modules_str) {
return target_modules;
}

static bool training_supports_out_prod_f16(const common_params & params) {
std::vector<ggml_backend_dev_t> devices;

if (!params.devices.empty()) {
devices.assign(params.devices.begin(), params.devices.end());
} else {
ggml_backend_dev_t gpu = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU);
if (gpu) {
devices.push_back(gpu);
}
}

if (devices.empty()) {
return true;
}

constexpr int64_t ne0 = 4;
constexpr int64_t ne1 = 3;
constexpr int64_t k = 2;

struct ggml_tensor src0 = {};
struct ggml_tensor src1 = {};
struct ggml_tensor dst = {};

src0.type = GGML_TYPE_F16;
src1.type = GGML_TYPE_F32;
dst.type = GGML_TYPE_F32;

src0.ne[0] = ne0; src0.ne[1] = k; src0.ne[2] = 1; src0.ne[3] = 1;
src1.ne[0] = ne1; src1.ne[1] = k; src1.ne[2] = 1; src1.ne[3] = 1;
dst.ne [0] = ne0; dst.ne [1] = ne1; dst.ne [2] = 1; dst.ne [3] = 1;

src0.nb[0] = sizeof(ggml_fp16_t);
src0.nb[1] = src0.nb[0] * ne0;
src0.nb[2] = src0.nb[1] * k;
src0.nb[3] = src0.nb[2] * 1;

src1.nb[0] = sizeof(float);
src1.nb[1] = src1.nb[0] * ne1;
src1.nb[2] = src1.nb[1] * k;
src1.nb[3] = src1.nb[2] * 1;

dst.nb[0] = sizeof(float);
dst.nb[1] = dst.nb[0] * ne0;
dst.nb[2] = dst.nb[1] * ne1;
dst.nb[3] = dst.nb[2] * 1;

dst.op = GGML_OP_OUT_PROD;
dst.src[0] = &src0;
dst.src[1] = &src1;

for (ggml_backend_dev_t dev : devices) {
if (dev == nullptr) {
continue;
}
if (ggml_backend_dev_type(dev) != GGML_BACKEND_DEVICE_TYPE_GPU) {
continue;
}
if (!ggml_backend_dev_supports_op(dev, &dst)) {
return false;
}
}

return true;
}

static void print_lora_usage() {
printf("\nLoRA Fine-tuning Parameters:\n");
printf(" --lora-rank N LoRA rank (default: 8, range: 1-512)\n");
Expand Down Expand Up @@ -124,13 +191,16 @@ int main(int argc, char ** argv) {
LOG_INF("%s: force disabling memory mapping because it would result in-read-only pointers to the weights\n", __func__);
params.use_mmap = false;
}
if (params.cache_type_k != GGML_TYPE_F32) {
LOG_INF("%s: force changing k cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__);
params.cache_type_k = GGML_TYPE_F32;
}
if (params.cache_type_v != GGML_TYPE_F32) {
LOG_INF("%s: force changing v cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__);
params.cache_type_v = GGML_TYPE_F32;
const bool supports_out_prod_f16 = training_supports_out_prod_f16(params);
if (!supports_out_prod_f16) {
if (params.cache_type_k != GGML_TYPE_F32) {
LOG_INF("%s: force changing k cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__);
params.cache_type_k = GGML_TYPE_F32;
}
if (params.cache_type_v != GGML_TYPE_F32) {
LOG_INF("%s: force changing v cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__);
params.cache_type_v = GGML_TYPE_F32;
}
}

common_init();
Expand Down
84 changes: 77 additions & 7 deletions examples/training/finetune.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "common.h"
#include "log.h"
#include "llama.h"
#include "ggml-backend.h"

#include <cmath>
#include <cstdio>
Expand All @@ -13,6 +14,72 @@
#pragma warning(disable: 4244 4267) // possible loss of data
#endif

static bool training_supports_out_prod_f16(const common_params & params) {
std::vector<ggml_backend_dev_t> devices;

if (!params.devices.empty()) {
devices.assign(params.devices.begin(), params.devices.end());
} else {
ggml_backend_dev_t gpu = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU);
if (gpu) {
devices.push_back(gpu);
}
}

if (devices.empty()) {
return true;
}

constexpr int64_t ne0 = 4;
constexpr int64_t ne1 = 3;
constexpr int64_t k = 2;

struct ggml_tensor src0 = {};
struct ggml_tensor src1 = {};
struct ggml_tensor dst = {};

src0.type = GGML_TYPE_F16;
src1.type = GGML_TYPE_F32;
dst.type = GGML_TYPE_F32;

src0.ne[0] = ne0; src0.ne[1] = k; src0.ne[2] = 1; src0.ne[3] = 1;
src1.ne[0] = ne1; src1.ne[1] = k; src1.ne[2] = 1; src1.ne[3] = 1;
dst.ne [0] = ne0; dst.ne [1] = ne1; dst.ne [2] = 1; dst.ne [3] = 1;

src0.nb[0] = sizeof(ggml_fp16_t);
src0.nb[1] = src0.nb[0] * ne0;
src0.nb[2] = src0.nb[1] * k;
src0.nb[3] = src0.nb[2] * 1;

src1.nb[0] = sizeof(float);
src1.nb[1] = src1.nb[0] * ne1;
src1.nb[2] = src1.nb[1] * k;
src1.nb[3] = src1.nb[2] * 1;

dst.nb[0] = sizeof(float);
dst.nb[1] = dst.nb[0] * ne0;
dst.nb[2] = dst.nb[1] * ne1;
dst.nb[3] = dst.nb[2] * 1;

dst.op = GGML_OP_OUT_PROD;
dst.src[0] = &src0;
dst.src[1] = &src1;

for (ggml_backend_dev_t dev : devices) {
if (dev == nullptr) {
continue;
}
if (ggml_backend_dev_type(dev) != GGML_BACKEND_DEVICE_TYPE_GPU) {
continue;
}
if (!ggml_backend_dev_supports_op(dev, &dst)) {
return false;
}
}

return true;
}

int main(int argc, char ** argv) {
common_params params;
params.escape = false;
Expand All @@ -26,13 +93,16 @@ int main(int argc, char ** argv) {
__func__);
params.use_mmap = false;
}
if (params.cache_type_k != GGML_TYPE_F32) {
LOG_INF("%s: force changing k cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__);
params.cache_type_k = GGML_TYPE_F32;
}
if (params.cache_type_v != GGML_TYPE_F32) {
LOG_INF("%s: force changing v cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__);
params.cache_type_v = GGML_TYPE_F32;
const bool supports_out_prod_f16 = training_supports_out_prod_f16(params);
if (!supports_out_prod_f16) {
if (params.cache_type_k != GGML_TYPE_F32) {
LOG_INF("%s: force changing k cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__);
params.cache_type_k = GGML_TYPE_F32;
}
if (params.cache_type_v != GGML_TYPE_F32) {
LOG_INF("%s: force changing v cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__);
params.cache_type_v = GGML_TYPE_F32;
}
}

common_init();
Expand Down
42 changes: 42 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,33 @@ typedef struct {
uint64_t nb3;
} ggml_metal_kargs_cpy;

typedef struct {
int32_t ne00;
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne10;
int32_t ne11;
int32_t ne12;
int32_t ne13;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
int32_t ne0;
int32_t ne1;
int32_t ne2;
int32_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
} ggml_metal_kargs_out_prod;

typedef struct {
int64_t ne10;
int64_t ne11;
Expand Down Expand Up @@ -439,6 +466,21 @@ typedef struct {
uint64_t nbf3[3];
} ggml_metal_kargs_rms_norm;

typedef struct {
int32_t ne00;
int32_t ne00_4;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
float eps;
} ggml_metal_kargs_rms_norm_back;

typedef struct {
int32_t ne00;
int32_t ne00_4;
Expand Down
Loading
Loading