Skip to content

Commit 1a3b914

Browse files
authored
Merge pull request #35 from zoq/metal_backend
Metal backend
2 parents a5810ed + e8a84f6 commit 1a3b914

File tree

15 files changed

+933
-19
lines changed

15 files changed

+933
-19
lines changed

examples/training/finetune-lora.cpp

Lines changed: 77 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "common.h"
33
#include "log.h"
44
#include "llama.h"
5+
#include "ggml-backend.h"
56

67
#include <cstring>
78
#include <vector>
@@ -55,6 +56,72 @@ static uint32_t parse_lora_modules(const std::string& modules_str) {
5556
return target_modules;
5657
}
5758

59+
static bool training_supports_out_prod_f16(const common_params & params) {
60+
std::vector<ggml_backend_dev_t> devices;
61+
62+
if (!params.devices.empty()) {
63+
devices.assign(params.devices.begin(), params.devices.end());
64+
} else {
65+
ggml_backend_dev_t gpu = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU);
66+
if (gpu) {
67+
devices.push_back(gpu);
68+
}
69+
}
70+
71+
if (devices.empty()) {
72+
return true;
73+
}
74+
75+
constexpr int64_t ne0 = 4;
76+
constexpr int64_t ne1 = 3;
77+
constexpr int64_t k = 2;
78+
79+
struct ggml_tensor src0 = {};
80+
struct ggml_tensor src1 = {};
81+
struct ggml_tensor dst = {};
82+
83+
src0.type = GGML_TYPE_F16;
84+
src1.type = GGML_TYPE_F32;
85+
dst.type = GGML_TYPE_F32;
86+
87+
src0.ne[0] = ne0; src0.ne[1] = k; src0.ne[2] = 1; src0.ne[3] = 1;
88+
src1.ne[0] = ne1; src1.ne[1] = k; src1.ne[2] = 1; src1.ne[3] = 1;
89+
dst.ne [0] = ne0; dst.ne [1] = ne1; dst.ne [2] = 1; dst.ne [3] = 1;
90+
91+
src0.nb[0] = sizeof(ggml_fp16_t);
92+
src0.nb[1] = src0.nb[0] * ne0;
93+
src0.nb[2] = src0.nb[1] * k;
94+
src0.nb[3] = src0.nb[2] * 1;
95+
96+
src1.nb[0] = sizeof(float);
97+
src1.nb[1] = src1.nb[0] * ne1;
98+
src1.nb[2] = src1.nb[1] * k;
99+
src1.nb[3] = src1.nb[2] * 1;
100+
101+
dst.nb[0] = sizeof(float);
102+
dst.nb[1] = dst.nb[0] * ne0;
103+
dst.nb[2] = dst.nb[1] * ne1;
104+
dst.nb[3] = dst.nb[2] * 1;
105+
106+
dst.op = GGML_OP_OUT_PROD;
107+
dst.src[0] = &src0;
108+
dst.src[1] = &src1;
109+
110+
for (ggml_backend_dev_t dev : devices) {
111+
if (dev == nullptr) {
112+
continue;
113+
}
114+
if (ggml_backend_dev_type(dev) != GGML_BACKEND_DEVICE_TYPE_GPU) {
115+
continue;
116+
}
117+
if (!ggml_backend_dev_supports_op(dev, &dst)) {
118+
return false;
119+
}
120+
}
121+
122+
return true;
123+
}
124+
58125
static void print_lora_usage() {
59126
printf("\n----- LoRA Fine-tuning Parameters -----\n");
60127
printf(" --lora-rank N LoRA rank (default: 8, range: 1-512)\n");
@@ -380,13 +447,16 @@ int main(int argc, char ** argv) {
380447
LOG_INF("%s: force disabling memory mapping because it would result in-read-only pointers to the weights\n", __func__);
381448
params.use_mmap = false;
382449
}
383-
if (params.cache_type_k != GGML_TYPE_F32) {
384-
LOG_INF("%s: force changing k cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__);
385-
params.cache_type_k = GGML_TYPE_F32;
386-
}
387-
if (params.cache_type_v != GGML_TYPE_F32) {
388-
LOG_INF("%s: force changing v cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__);
389-
params.cache_type_v = GGML_TYPE_F32;
450+
const bool supports_out_prod_f16 = training_supports_out_prod_f16(params);
451+
if (!supports_out_prod_f16) {
452+
if (params.cache_type_k != GGML_TYPE_F32) {
453+
LOG_INF("%s: force changing k cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__);
454+
params.cache_type_k = GGML_TYPE_F32;
455+
}
456+
if (params.cache_type_v != GGML_TYPE_F32) {
457+
LOG_INF("%s: force changing v cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__);
458+
params.cache_type_v = GGML_TYPE_F32;
459+
}
390460
}
391461

392462
common_init();

examples/training/finetune.cpp

Lines changed: 77 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "common.h"
33
#include "log.h"
44
#include "llama.h"
5+
#include "ggml-backend.h"
56

67
#include <cmath>
78
#include <cstdio>
@@ -13,6 +14,72 @@
1314
#pragma warning(disable: 4244 4267) // possible loss of data
1415
#endif
1516

17+
static bool training_supports_out_prod_f16(const common_params & params) {
18+
std::vector<ggml_backend_dev_t> devices;
19+
20+
if (!params.devices.empty()) {
21+
devices.assign(params.devices.begin(), params.devices.end());
22+
} else {
23+
ggml_backend_dev_t gpu = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU);
24+
if (gpu) {
25+
devices.push_back(gpu);
26+
}
27+
}
28+
29+
if (devices.empty()) {
30+
return true;
31+
}
32+
33+
constexpr int64_t ne0 = 4;
34+
constexpr int64_t ne1 = 3;
35+
constexpr int64_t k = 2;
36+
37+
struct ggml_tensor src0 = {};
38+
struct ggml_tensor src1 = {};
39+
struct ggml_tensor dst = {};
40+
41+
src0.type = GGML_TYPE_F16;
42+
src1.type = GGML_TYPE_F32;
43+
dst.type = GGML_TYPE_F32;
44+
45+
src0.ne[0] = ne0; src0.ne[1] = k; src0.ne[2] = 1; src0.ne[3] = 1;
46+
src1.ne[0] = ne1; src1.ne[1] = k; src1.ne[2] = 1; src1.ne[3] = 1;
47+
dst.ne [0] = ne0; dst.ne [1] = ne1; dst.ne [2] = 1; dst.ne [3] = 1;
48+
49+
src0.nb[0] = sizeof(ggml_fp16_t);
50+
src0.nb[1] = src0.nb[0] * ne0;
51+
src0.nb[2] = src0.nb[1] * k;
52+
src0.nb[3] = src0.nb[2] * 1;
53+
54+
src1.nb[0] = sizeof(float);
55+
src1.nb[1] = src1.nb[0] * ne1;
56+
src1.nb[2] = src1.nb[1] * k;
57+
src1.nb[3] = src1.nb[2] * 1;
58+
59+
dst.nb[0] = sizeof(float);
60+
dst.nb[1] = dst.nb[0] * ne0;
61+
dst.nb[2] = dst.nb[1] * ne1;
62+
dst.nb[3] = dst.nb[2] * 1;
63+
64+
dst.op = GGML_OP_OUT_PROD;
65+
dst.src[0] = &src0;
66+
dst.src[1] = &src1;
67+
68+
for (ggml_backend_dev_t dev : devices) {
69+
if (dev == nullptr) {
70+
continue;
71+
}
72+
if (ggml_backend_dev_type(dev) != GGML_BACKEND_DEVICE_TYPE_GPU) {
73+
continue;
74+
}
75+
if (!ggml_backend_dev_supports_op(dev, &dst)) {
76+
return false;
77+
}
78+
}
79+
80+
return true;
81+
}
82+
1683
int main(int argc, char ** argv) {
1784
common_params params;
1885
params.escape = false;
@@ -26,13 +93,16 @@ int main(int argc, char ** argv) {
2693
__func__);
2794
params.use_mmap = false;
2895
}
29-
if (params.cache_type_k != GGML_TYPE_F32) {
30-
LOG_INF("%s: force changing k cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__);
31-
params.cache_type_k = GGML_TYPE_F32;
32-
}
33-
if (params.cache_type_v != GGML_TYPE_F32) {
34-
LOG_INF("%s: force changing v cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__);
35-
params.cache_type_v = GGML_TYPE_F32;
96+
const bool supports_out_prod_f16 = training_supports_out_prod_f16(params);
97+
if (!supports_out_prod_f16) {
98+
if (params.cache_type_k != GGML_TYPE_F32) {
99+
LOG_INF("%s: force changing k cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__);
100+
params.cache_type_k = GGML_TYPE_F32;
101+
}
102+
if (params.cache_type_v != GGML_TYPE_F32) {
103+
LOG_INF("%s: force changing v cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__);
104+
params.cache_type_v = GGML_TYPE_F32;
105+
}
36106
}
37107

38108
common_init();

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,33 @@ typedef struct {
184184
uint64_t nb3;
185185
} ggml_metal_kargs_cpy;
186186

187+
typedef struct {
188+
int32_t ne00;
189+
int32_t ne01;
190+
int32_t ne02;
191+
int32_t ne03;
192+
uint64_t nb00;
193+
uint64_t nb01;
194+
uint64_t nb02;
195+
uint64_t nb03;
196+
int32_t ne10;
197+
int32_t ne11;
198+
int32_t ne12;
199+
int32_t ne13;
200+
uint64_t nb10;
201+
uint64_t nb11;
202+
uint64_t nb12;
203+
uint64_t nb13;
204+
int32_t ne0;
205+
int32_t ne1;
206+
int32_t ne2;
207+
int32_t ne3;
208+
uint64_t nb0;
209+
uint64_t nb1;
210+
uint64_t nb2;
211+
uint64_t nb3;
212+
} ggml_metal_kargs_out_prod;
213+
187214
typedef struct {
188215
int64_t ne10;
189216
int64_t ne11;
@@ -439,6 +466,21 @@ typedef struct {
439466
uint64_t nbf3[3];
440467
} ggml_metal_kargs_rms_norm;
441468

469+
typedef struct {
470+
int32_t ne00;
471+
int32_t ne00_4;
472+
uint64_t nb01;
473+
uint64_t nb02;
474+
uint64_t nb03;
475+
uint64_t nb11;
476+
uint64_t nb12;
477+
uint64_t nb13;
478+
uint64_t nb1;
479+
uint64_t nb2;
480+
uint64_t nb3;
481+
float eps;
482+
} ggml_metal_kargs_rms_norm_back;
483+
442484
typedef struct {
443485
int32_t ne00;
444486
int32_t ne00_4;

0 commit comments

Comments
 (0)