|
2 | 2 | #include "common.h" |
3 | 3 | #include "log.h" |
4 | 4 | #include "llama.h" |
| 5 | +#include "ggml-backend.h" |
5 | 6 |
|
6 | 7 | #include <cmath> |
7 | 8 | #include <cstdio> |
@@ -54,6 +55,72 @@ static uint32_t parse_lora_modules(const std::string& modules_str) { |
54 | 55 | return target_modules; |
55 | 56 | } |
56 | 57 |
|
| 58 | +static bool training_supports_out_prod_f16(const common_params & params) { |
| 59 | + std::vector<ggml_backend_dev_t> devices; |
| 60 | + |
| 61 | + if (!params.devices.empty()) { |
| 62 | + devices.assign(params.devices.begin(), params.devices.end()); |
| 63 | + } else { |
| 64 | + ggml_backend_dev_t gpu = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU); |
| 65 | + if (gpu) { |
| 66 | + devices.push_back(gpu); |
| 67 | + } |
| 68 | + } |
| 69 | + |
| 70 | + if (devices.empty()) { |
| 71 | + return true; |
| 72 | + } |
| 73 | + |
| 74 | + constexpr int64_t ne0 = 4; |
| 75 | + constexpr int64_t ne1 = 3; |
| 76 | + constexpr int64_t k = 2; |
| 77 | + |
| 78 | + struct ggml_tensor src0 = {}; |
| 79 | + struct ggml_tensor src1 = {}; |
| 80 | + struct ggml_tensor dst = {}; |
| 81 | + |
| 82 | + src0.type = GGML_TYPE_F16; |
| 83 | + src1.type = GGML_TYPE_F32; |
| 84 | + dst.type = GGML_TYPE_F32; |
| 85 | + |
| 86 | + src0.ne[0] = ne0; src0.ne[1] = k; src0.ne[2] = 1; src0.ne[3] = 1; |
| 87 | + src1.ne[0] = ne1; src1.ne[1] = k; src1.ne[2] = 1; src1.ne[3] = 1; |
| 88 | + dst.ne [0] = ne0; dst.ne [1] = ne1; dst.ne [2] = 1; dst.ne [3] = 1; |
| 89 | + |
| 90 | + src0.nb[0] = sizeof(ggml_fp16_t); |
| 91 | + src0.nb[1] = src0.nb[0] * ne0; |
| 92 | + src0.nb[2] = src0.nb[1] * k; |
| 93 | + src0.nb[3] = src0.nb[2] * 1; |
| 94 | + |
| 95 | + src1.nb[0] = sizeof(float); |
| 96 | + src1.nb[1] = src1.nb[0] * ne1; |
| 97 | + src1.nb[2] = src1.nb[1] * k; |
| 98 | + src1.nb[3] = src1.nb[2] * 1; |
| 99 | + |
| 100 | + dst.nb[0] = sizeof(float); |
| 101 | + dst.nb[1] = dst.nb[0] * ne0; |
| 102 | + dst.nb[2] = dst.nb[1] * ne1; |
| 103 | + dst.nb[3] = dst.nb[2] * 1; |
| 104 | + |
| 105 | + dst.op = GGML_OP_OUT_PROD; |
| 106 | + dst.src[0] = &src0; |
| 107 | + dst.src[1] = &src1; |
| 108 | + |
| 109 | + for (ggml_backend_dev_t dev : devices) { |
| 110 | + if (dev == nullptr) { |
| 111 | + continue; |
| 112 | + } |
| 113 | + if (ggml_backend_dev_type(dev) != GGML_BACKEND_DEVICE_TYPE_GPU) { |
| 114 | + continue; |
| 115 | + } |
| 116 | + if (!ggml_backend_dev_supports_op(dev, &dst)) { |
| 117 | + return false; |
| 118 | + } |
| 119 | + } |
| 120 | + |
| 121 | + return true; |
| 122 | +} |
| 123 | + |
57 | 124 | static void print_lora_usage() { |
58 | 125 | printf("\nLoRA Fine-tuning Parameters:\n"); |
59 | 126 | printf(" --lora-rank N LoRA rank (default: 8, range: 1-512)\n"); |
@@ -124,13 +191,16 @@ int main(int argc, char ** argv) { |
124 | 191 | LOG_INF("%s: force disabling memory mapping because it would result in-read-only pointers to the weights\n", __func__); |
125 | 192 | params.use_mmap = false; |
126 | 193 | } |
127 | | - if (params.cache_type_k != GGML_TYPE_F32) { |
128 | | - LOG_INF("%s: force changing k cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__); |
129 | | - params.cache_type_k = GGML_TYPE_F32; |
130 | | - } |
131 | | - if (params.cache_type_v != GGML_TYPE_F32) { |
132 | | - LOG_INF("%s: force changing v cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__); |
133 | | - params.cache_type_v = GGML_TYPE_F32; |
| 194 | + const bool supports_out_prod_f16 = training_supports_out_prod_f16(params); |
| 195 | + if (!supports_out_prod_f16) { |
| 196 | + if (params.cache_type_k != GGML_TYPE_F32) { |
| 197 | + LOG_INF("%s: force changing k cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__); |
| 198 | + params.cache_type_k = GGML_TYPE_F32; |
| 199 | + } |
| 200 | + if (params.cache_type_v != GGML_TYPE_F32) { |
| 201 | + LOG_INF("%s: force changing v cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__); |
| 202 | + params.cache_type_v = GGML_TYPE_F32; |
| 203 | + } |
134 | 204 | } |
135 | 205 |
|
136 | 206 | common_init(); |
|
0 commit comments