diff --git a/examples/training/CMakeLists.txt b/examples/training/CMakeLists.txt index 64afe6ddc647a..08d7ab2479055 100644 --- a/examples/training/CMakeLists.txt +++ b/examples/training/CMakeLists.txt @@ -3,3 +3,9 @@ add_executable(${TARGET} finetune.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) target_compile_features(${TARGET} PRIVATE cxx_std_11) + +set(TARGET llama-finetune-lora) +add_executable(${TARGET} finetune-lora.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_11) \ No newline at end of file diff --git a/examples/training/README.md b/examples/training/README.md index df425279266e4..ed255a0e1af3d 100644 --- a/examples/training/README.md +++ b/examples/training/README.md @@ -1,5 +1,6 @@ # llama.cpp/examples/training +## finetune This directory contains examples related to language model training using llama.cpp/GGML. So far finetuning is technically functional (for FP32 models and limited hardware setups) but the code is very much WIP. Finetuning of Stories 260K and LLaMA 3.2 1b seems to work with 24 GB of memory. @@ -15,3 +16,67 @@ export model_name=llama_3.2-1b && export quantization=f32 ``` The perplexity value of the finetuned model should be lower after training on the test set for 2 epochs. + + +## finetune-lora + +LoRA (Low-Rank Adaptation) fine-tuning for efficient model training. This approach trains only a small set of additional parameters while keeping +the base model frozen, making it memory-efficient. + +### Basic Usage + +```sh +# Create new LoRA adapter with default settings (rank=8, alpha=16, attention modules) +./build/bin/llama-finetune-lora -m model.gguf -f dataset.txt -ngl 999 -c 512 -b 512 -ub 512 + +# Custom LoRA parameters(creates new lora adapter and trains it from scratch) +./build/bin/llama-finetune-lora -m model.gguf -f dataset.txt -ngl 999 -c 512 -b 512 -ub 512 \ + --lora-rank 16 --lora-alpha 32 --lora-modules "attn_q,attn_k,attn_v,attn_o" + +# Fine-tune existing LoRA adapter +./build/bin/llama-finetune-lora -m base_model.gguf -f dataset.txt --lora existing_adapter.gguf \ + --output-adapter improved_adapter.gguf -ngl 999 -c 512 -b 512 -ub 512 +``` + + +### Parameters + +#### LoRA Configuration +- `--lora-rank N` - LoRA rank (default: 8) + - Lower rank = smaller adapter, less capacity + - Higher rank = larger adapter, more capacity +- `--lora-alpha N` - LoRA alpha scaling factor (default: 16.0) + - Controls adaptation strength + - Common rule: alpha = 2 × rank +- `--lora-modules MODULES` - Target modules as comma-separated list + - Available: `attn_q`, `attn_k`, `attn_v`, `attn_o`, `ffn_gate`, `ffn_up`, `ffn_down`, `embed`, `output`, `all` + - Default: `attn_q,attn_k,attn_v,attn_o` (attention modules) +- `--output-adapter PATH` - Output adapter filename (default: auto-generated) + +#### Standard Parameters +- `-m MODEL` - Base model file (.gguf) +- `-f FILE` - Training dataset +- `-ngl N` - GPU layers (use 999 for full GPU training) +- `-c N` - Context length (512 recommended for mobile) + + +### Using Trained Adapters + +After training, you'll get a small adapter file. Use it with the original base model: + +```sh +./build/bin/llama-cli -m base_model.gguf --lora trained_adapter.gguf -ngl 999 +``` + +### Troubleshooting + +- **Out of memory**: Reduce context length (`-c 256`), lower rank, or use fewer target modules +- **Poor quality**: Increase rank, add more target modules, or train longer +- **Large adapter**: Reduce rank or limit target modules + +### Help + +Run with `--help` or `-h` to see all available parameters: +```sh +./build/bin/llama-finetune-lora --help +``` diff --git a/examples/training/finetune-lora.cpp b/examples/training/finetune-lora.cpp new file mode 100644 index 0000000000000..8e3a1026b6c91 --- /dev/null +++ b/examples/training/finetune-lora.cpp @@ -0,0 +1,262 @@ +#include "arg.h" +#include "common.h" +#include "log.h" +#include "llama.h" + +#include +#include +#include +#include +#include +#include + +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif + + +static uint32_t parse_lora_modules(const std::string& modules_str) { + if (modules_str.empty()) { + return LLAMA_LORA_TARGET_ATTN_Q | LLAMA_LORA_TARGET_ATTN_K | LLAMA_LORA_TARGET_ATTN_V | LLAMA_LORA_TARGET_ATTN_O; + } + + static const std::map module_map = { + {"attn_q", LLAMA_LORA_TARGET_ATTN_Q}, + {"attn_k", LLAMA_LORA_TARGET_ATTN_K}, + {"attn_v", LLAMA_LORA_TARGET_ATTN_V}, + {"attn_o", LLAMA_LORA_TARGET_ATTN_O}, + {"ffn_gate", LLAMA_LORA_TARGET_FFN_GATE}, + {"ffn_up", LLAMA_LORA_TARGET_FFN_UP}, + {"ffn_down", LLAMA_LORA_TARGET_FFN_DOWN}, + {"output", LLAMA_LORA_TARGET_OUTPUT}, + {"all", LLAMA_LORA_TARGET_ALL} + }; + + uint32_t target_modules = 0; + std::stringstream ss(modules_str); + std::string module; + + while (std::getline(ss, module, ',')) { + module.erase(0, module.find_first_not_of(" \t")); + module.erase(module.find_last_not_of(" \t") + 1); + + auto it = module_map.find(module); + if (it != module_map.end()) { + target_modules |= it->second; + LOG_INF("Added target module: %s\n", module.c_str()); + } else { + LOG_ERR("Unknown LoRA target module: %s\n", module.c_str()); + LOG_ERR("Available modules: attn_q, attn_k, attn_v, attn_o, ffn_gate, ffn_up, ffn_down, output, all\n"); + return 0; + } + } + + return target_modules; +} + +static void print_lora_usage() { + printf("\nLoRA Fine-tuning Parameters:\n"); + printf(" --lora-rank N LoRA rank (default: 8, range: 1-512)\n"); + printf(" --lora-alpha N LoRA alpha scaling factor (default: 16.0, range: 0.1-1000.0)\n"); + printf(" --lora-modules MODULES Target modules as comma-separated list (default: attn_q,attn_k,attn_v,attn_o)\n"); + printf(" Available modules: attn_q, attn_k, attn_v, attn_o, ffn_gate, ffn_up, ffn_down, output, all\n"); + printf(" Examples: \"attn_q,attn_v\" or \"all\" or \"attn_q,attn_k,attn_v,attn_o,ffn_gate,ffn_up,ffn_down\"\n"); + printf(" --output-adapter PATH Output path for trained adapter (default: auto-generated)\n"); + printf("\nExamples:\n"); + printf(" # Train with rank=16, alpha=32, all attention modules\n"); + printf(" %s -m model.gguf -f dataset.txt --lora-rank 16 --lora-alpha 32 --lora-modules attn_q,attn_k,attn_v,attn_o\n", "finetune-lora"); + printf("\n # Fine-tune existing adapter with all modules\n"); + printf(" %s -m model.gguf -f dataset.txt --lora existing.gguf --output-adapter improved.gguf\n", "finetune-lora"); + printf("\n"); +} + +int main(int argc, char ** argv) { + common_params params; + + int32_t lora_rank = 8; + float lora_alpha = 16.0f; + std::string lora_modules_str; + std::string output_adapter_path; + + params.escape = false; + + auto remove_arg_pair = [&](int i) { + for (int j = i; j < argc - 2; j++) { + argv[j] = argv[j + 2]; + } + argc -= 2; + }; + + for (int i = 1; i < argc - 1; i++) { + if (strcmp(argv[i], "--lora-rank") == 0) { + lora_rank = std::atoi(argv[i + 1]); + remove_arg_pair(i); + i--; + } else if (strcmp(argv[i], "--lora-alpha") == 0) { + lora_alpha = std::atof(argv[i + 1]); + remove_arg_pair(i); + i--; + } else if (strcmp(argv[i], "--lora-modules") == 0) { + lora_modules_str = argv[i + 1]; + remove_arg_pair(i); + i--; + } else if (strcmp(argv[i], "--output-adapter") == 0) { + output_adapter_path = argv[i + 1]; + remove_arg_pair(i); + i--; + } + } + + LOG_INF("Using LoRA parameters: rank=%d, alpha=%.1f\n", lora_rank, lora_alpha); + + for (int i = 1; i < argc; i++) { + if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "--help") == 0) { + print_lora_usage(); + } + } + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PERPLEXITY)) { + print_lora_usage(); + return 1; + } + + if (params.use_mmap) { + LOG_INF("%s: force disabling memory mapping because it would result in-read-only pointers to the weights\n", __func__); + params.use_mmap = false; + } + if (params.cache_type_k != GGML_TYPE_F32) { + LOG_INF("%s: force changing k cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__); + params.cache_type_k = GGML_TYPE_F32; + } + if (params.cache_type_v != GGML_TYPE_F32) { + LOG_INF("%s: force changing v cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__); + params.cache_type_v = GGML_TYPE_F32; + } + + common_init(); + llama_backend_init(); + llama_numa_init(params.numa); + + common_init_result llama_init = common_init_from_params(params); + llama_model_ptr & model = llama_init.model; + llama_context_ptr & ctx = llama_init.context; + + if (model == NULL) { + LOG_ERR("%s: unable to load model\n", __func__); + return 1; + } + + { + LOG_INF("\n"); + LOG_INF("%s\n", common_params_get_system_info(params).c_str()); + } + + uint32_t target_modules = parse_lora_modules(lora_modules_str); + if (target_modules == 0) { + return 1; + } + + struct llama_lora_training_params lora_params = { + /*target_modules =*/ target_modules, + /*rank =*/ lora_rank, + /*alpha =*/ lora_alpha, + /*dropout =*/ 0.0f, + /*init_std =*/ 0.02f, + }; + + bool has_existing_lora = !params.lora_adapters.empty(); + struct llama_adapter_lora * trained_adapter = nullptr; + + if (has_existing_lora) { + LOG_INF("Finetuning existing LoRA adapters\n"); + LOG_INF("Found %zu existing LoRA adapters to train\n", params.lora_adapters.size());\ + trained_adapter = params.lora_adapters[0].ptr; + if (!trained_adapter) { + LOG_ERR("Existing LoRA adapter is null\n"); + return 1; + } + } else { + LOG_INF("Target modules: Q=%s, K=%s, V=%s, O=%s, GATE=%s, UP=%s, DOWN=%s, OUTPUT=%s\n", + (lora_params.target_modules & LLAMA_LORA_TARGET_ATTN_Q) ? "yes" : "no", + (lora_params.target_modules & LLAMA_LORA_TARGET_ATTN_K) ? "yes" : "no", + (lora_params.target_modules & LLAMA_LORA_TARGET_ATTN_V) ? "yes" : "no", + (lora_params.target_modules & LLAMA_LORA_TARGET_ATTN_O) ? "yes" : "no", + (lora_params.target_modules & LLAMA_LORA_TARGET_FFN_GATE) ? "yes" : "no", + (lora_params.target_modules & LLAMA_LORA_TARGET_FFN_UP) ? "yes" : "no", + (lora_params.target_modules & LLAMA_LORA_TARGET_FFN_DOWN) ? "yes" : "no", + (lora_params.target_modules & LLAMA_LORA_TARGET_OUTPUT) ? "yes" : "no"); + + LOG_INF("LoRA configuration: rank=%d, alpha=%.1f (scaling=%.3f)\n", + lora_params.rank, lora_params.alpha, lora_params.alpha / lora_params.rank); + + trained_adapter = llama_lora_training_init(ctx.get(), model.get(), &lora_params); + if (!trained_adapter) { + LOG_ERR("%s: LoRA training initialization failed\n", __func__); + return 1; + } + } + + constexpr float val_split = 0.05f; + + std::vector tokens = common_tokenize(ctx.get(), params.prompt, true); + ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx.get(), tokens, llama_n_ctx(ctx.get())/2); + + struct ggml_opt_optimizer_params optimizer_params = ggml_opt_get_default_optimizer_params(nullptr); + optimizer_params.adamw.alpha = 1e-5f; // learning rate + + struct llama_opt_params lopt_params { + /*n_ctx_train =*/ 0, + /*param_filter =*/ llama_opt_param_filter_lora, + /*param_filter_ud =*/ nullptr, + /*get_opt_pars =*/ ggml_opt_get_constant_optimizer_params, + /*get_opt_pars_ud =*/ &optimizer_params, + }; + llama_opt_init(ctx.get(), model.get(), lopt_params); + + const int64_t idata_split = ggml_opt_dataset_ndata(dataset) * (1.0f - val_split); + + ggml_opt_result_t result_train = ggml_opt_result_init(); + ggml_opt_result_t result_eval = ggml_opt_result_init(); + + for (int epoch = 0; epoch < 2; ++epoch) { + llama_opt_epoch(ctx.get(), dataset, result_train, result_eval, idata_split, + ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar); + fprintf(stderr, "\n"); + + ggml_opt_result_reset(result_train); + ggml_opt_result_reset(result_eval); + } + ggml_opt_result_free(result_train); + ggml_opt_result_free(result_eval); + + std::string adapter_filename; + if (!output_adapter_path.empty()) { + adapter_filename = output_adapter_path; + } else if (has_existing_lora) { + adapter_filename = "finetuned-lora-adapter.gguf"; + LOG_INF("Finetuned existing lora adapter, saving as: %s\n", adapter_filename.c_str()); + } else { + adapter_filename = "trained-lora-adapter.gguf"; + LOG_INF("Saving new lora adapter: %s\n", adapter_filename.c_str()); + } + + if (trained_adapter) { + if (llama_lora_save_adapter(trained_adapter, adapter_filename.c_str(), model.get())) { + std::ifstream adapter_file(adapter_filename, std::ios::binary | std::ios::ate); + if (adapter_file.is_open()) { + std::streamsize adapter_size = adapter_file.tellg(); + LOG_INF("LoRA adapter saved: %s (%.2f MB)\n", + adapter_filename.c_str(), adapter_size / (1024.0 * 1024.0)); + adapter_file.close(); + } + } else { + LOG_ERR("Failed to save LoRA adapter\n"); + } + } else { + LOG_ERR("No trained adapter available for saving\n"); + } + + llama_backend_free(); + + return 0; +} diff --git a/examples/training/finetune.cpp b/examples/training/finetune.cpp index 23bede49b1362..764e0327d8777 100644 --- a/examples/training/finetune.cpp +++ b/examples/training/finetune.cpp @@ -93,4 +93,4 @@ int main(int argc, char ** argv) { llama_backend_free(); return 0; -} +} \ No newline at end of file diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp index c9daa4c39e83e..91b1004b5cf3c 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -442,7 +442,7 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st case GGML_OP_GET_ROWS_BACK: return src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16; case GGML_OP_OUT_PROD: - return (src0->type == GGML_TYPE_F32 || (ggml_is_quantized(src0->type) && src0->ne[2] == src1->ne[2] && src0->ne[3] == src1->ne[3])) && + return (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || (ggml_is_quantized(src0->type) && src0->ne[2] == src1->ne[2] && src0->ne[3] == src1->ne[3])) && src1->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; default: return true; diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 6581d27adde2e..b696c610b26b2 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -4498,6 +4498,107 @@ static void ggml_compute_forward_out_prod_f32( } } +static void ggml_compute_forward_out_prod_f16_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_ASSERT(ne0 == ne00); + GGML_ASSERT(ne1 == ne10); + GGML_ASSERT(ne2 == ne12); + GGML_ASSERT(ne3 == ne13); + + GGML_ASSERT(ne2 % ne02 == 0); + GGML_ASSERT(ne3 % ne03 == 0); + + // we don't support permuted src0 or src1 + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + // GGML_ASSERT(nb0 <= nb1); + // GGML_ASSERT(nb1 <= nb2); + // GGML_ASSERT(nb2 <= nb3); + + // nb01 >= nb00 - src0 is not transposed + // compute by src0 rows + + if (ith == 0) { + ggml_vec_set_f32(ne0*ne1*ne2*ne3, (float *)dst->data, 0); + } + ggml_barrier(params->threadpool); + + // dst[:,:,:,:] = 0 + // for i2,i3: + // for i1: + // for i01: + // for i0: + // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3] + + // parallelize by last three dimensions + + // total rows in dst + const int64_t nr = ne1*ne2*ne3; + + // rows per thread + const int64_t dr = (nr + nth - 1)/nth; + + // row range for this thread + const int64_t ir0 = dr*ith; + const int64_t ir1 = MIN(ir0 + dr, nr); + + // block-tiling attempt + const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32); + const int64_t blck_1 = 16; + + // dps == dst per src0, used for group query attention + const int64_t dps2 = ne2 / ne02; + const int64_t dps3 = ne3 / ne03; + + for (int64_t bir = ir0; bir < ir1; bir += blck_1) { + const int64_t bir1 = MIN(bir + blck_1, ir1); + for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) { + const int64_t bne01 = MIN(bi01 + blck_0, ne01); + for (int64_t ir = bir; ir < bir1; ++ir) { + // dst indices + const int64_t i3 = ir/(ne2*ne1); + const int64_t i2 = (ir - i3*ne2*ne1)/ne1; + const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1); + + const int64_t i02 = i2 / dps2; + const int64_t i03 = i3 / dps3; + + //const int64_t i10 = i1; + const int64_t i12 = i2; + const int64_t i13 = i3; + + for (int64_t i01 = bi01; i01 < bne01; ++i01) { + const int64_t i11 = i01; + + ggml_fp16_t * s0 = (ggml_fp16_t *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)); + float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); + float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); + + for (int i = 0; i < ne0; ++i) { + d[i] += GGML_CPU_FP16_TO_FP32(s0[i])*(*s1); + } + } + } + } + } +} + static void ggml_compute_forward_out_prod_q_f32( const ggml_compute_params * params, ggml_tensor * dst) { @@ -4620,9 +4721,8 @@ void ggml_compute_forward_out_prod( } break; case GGML_TYPE_F16: { - GGML_ABORT("fatal error"); // todo - // ggml_compute_forward_out_prod_f16_f32(params, dst); - } + ggml_compute_forward_out_prod_f16_f32(params, dst); + } break; case GGML_TYPE_F32: { ggml_compute_forward_out_prod_f32(params, dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 50a977c30762c..0894496552c3c 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3202,7 +3202,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g } } break; case GGML_OP_OUT_PROD: - return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32; + // return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32; + return op->type == GGML_TYPE_F32; case GGML_OP_GET_ROWS: { switch (op->src[0]->type) { diff --git a/ggml/src/ggml-cuda/out-prod.cu b/ggml/src/ggml-cuda/out-prod.cu index c9b2b699c6a55..48280d4749dd1 100644 --- a/ggml/src/ggml-cuda/out-prod.cu +++ b/ggml/src/ggml-cuda/out-prod.cu @@ -1,4 +1,5 @@ #include "out-prod.cuh" +#include "convert.cuh" #include @@ -8,10 +9,61 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { GGML_TENSOR_BINARY_OP_LOCALS - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(src1->type == GGML_TYPE_F32); + const bool src0_is_quantized = (src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16); + const bool src1_is_quantized = (src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16); + + // if (src0_is_quantized || src1_is_quantized) { + // printf("DEBUG: OUT_PROD with quantized tensors - src0_quantized=%d, src1_quantized=%d\n", + // src0_is_quantized, src1_is_quantized); + // fflush(stdout); + // } + + // GGML_ASSERT(src0->type == GGML_TYPE_F32); + // GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + // temp buffers + float * src0_f32 = nullptr; + float * src1_f32 = nullptr; + bool allocated_src0 = false; + bool allocated_src1 = false; + cudaStream_t stream = ctx.stream(); + + if (src0_is_quantized) { + const size_t src0_size = ggml_nelements(src0) * sizeof(float); + CUDA_CHECK(cudaMallocAsync(&src0_f32, src0_size, stream)); + allocated_src0 = true; + + // Dequantize + auto dequantize_fn = ggml_get_to_fp32_cuda(src0->type); + if (dequantize_fn) { + dequantize_fn(src0->data, src0_f32, ggml_nelements(src0), stream); + } else { + CUDA_CHECK(cudaFreeAsync(src0_f32, stream)); + GGML_ABORT("Unsupported quant type for src0"); + } + } else { + src0_f32 = (float *) src0->data; + } + + if (src1_is_quantized) { + const size_t src1_size = ggml_nelements(src1) * sizeof(float); + CUDA_CHECK(cudaMallocAsync(&src1_f32, src1_size, stream)); + allocated_src1 = true; + + auto dequantize_fn = ggml_get_to_fp32_cuda(src1->type); + if (dequantize_fn) { + dequantize_fn(src1->data, src1_f32, ggml_nelements(src0), stream); + } else { + CUDA_CHECK(cudaFreeAsync(src1_f32, stream)); + GGML_ABORT("Unsupported quant type for src1"); + } + } else { + src1_f32 = (float *) src1->data; + } + + GGML_ASSERT(ne01 == ne11); GGML_ASSERT(ne0 == ne00); GGML_ASSERT(ne1 == ne10); @@ -22,11 +74,14 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { GGML_ASSERT(ne2 == src1->ne[2]); GGML_ASSERT(ne3 == src1->ne[3]); - const float * src0_d = (const float *) src0->data; - const float * src1_d = (const float *) src1->data; + // const float * src0_d = (const float *) src0->data; + // const float * src1_d = (const float *) src1->data; + + // Use dequantized data + const float * src0_d = src0_f32; + const float * src1_d = src1_f32; float * dst_d = (float *) dst->data; - cudaStream_t stream = ctx.stream(); cublasHandle_t handle = ctx.cublas_handle(); const float alpha = 1.0f; @@ -34,19 +89,32 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { CUBLAS_CHECK(cublasSetStream(handle, stream)); - const int64_t lda = nb01 / sizeof(float); + // const int64_t lda = nb01 / sizeof(float); + const int64_t lda = allocated_src0 ? ne00 : (nb01 / sizeof(float)); const int64_t ldc = nb1 / sizeof(float); const bool src1_T = ggml_is_transposed(src1); const cublasOperation_t src1_cublas_op = src1_T ? CUBLAS_OP_N : CUBLAS_OP_T; - const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float); - GGML_ASSERT( (src1_T ? nb11 : nb10) == sizeof(float)); + // const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float); + const int64_t ldb = allocated_src1 ? + (src1_T ? ne10 : ne11) : + ((src1_T ? nb10 : nb11) / sizeof(float)); + + // GGML_ASSERT( (src1_T ? nb11 : nb10) == sizeof(float)); + // Only assert for non dequantized src1 + if (!allocated_src1) { + GGML_ASSERT((src1_T ? nb11 : nb10) == sizeof(float)); + } // data strides in dimensions 2/3 - const size_t s02 = nb02 / sizeof(float); - const size_t s03 = nb03 / sizeof(float); - const size_t s12 = nb12 / sizeof(float); - const size_t s13 = nb13 / sizeof(float); + // const size_t s02 = nb02 / sizeof(float); + // const size_t s03 = nb03 / sizeof(float); + // const size_t s12 = nb12 / sizeof(float); + // const size_t s13 = nb13 / sizeof(float); + const size_t s02 = allocated_src0 ? (ne00 * ne01) : nb02 / sizeof(float); + const size_t s03 = allocated_src0 ? (ne00 * ne01 * ne02): nb03 / sizeof(float); + const size_t s12 = allocated_src1 ? (ne10 * ne11) : nb12 / sizeof(float); + const size_t s13 = allocated_src1 ? (ne10 * ne11 * ne12) : nb13 / sizeof(float); const size_t s2 = nb2 / sizeof(float); const size_t s3 = nb3 / sizeof(float); @@ -65,4 +133,16 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { &beta, dst_d + i3 *s3 + i2 *s2, ldc)); } } + + if (allocated_src0) { + CUDA_CHECK(cudaFreeAsync(src0_f32, stream)); + // printf("DEBUG: Freed dequantized src0 buffer\n"); + } + if (allocated_src1) { + CUDA_CHECK(cudaFreeAsync(src1_f32, stream)); + // // printf("DEBUG: Freed dequantized src1 buffer\n"); + } + + // printf("DEBUG: CUDA OUT_PROD completed successfully\n"); + fflush(stdout); } diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 3019a545d58ed..3e03d37bbd3b0 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -464,6 +464,7 @@ struct vk_device_struct { vk_pipeline pipeline_leaky_relu_f32; vk_pipeline pipeline_silu_back_f32; vk_pipeline pipeline_diag_mask_inf_f32; + vk_pipeline pipeline_cross_entropy_loss_back_f32; vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16; vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512; vk_pipeline pipeline_soft_max_back_f32; @@ -473,6 +474,10 @@ struct vk_device_struct { vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16; vk_pipeline pipeline_argsort_f32; vk_pipeline pipeline_sum_rows_f32; + vk_pipeline pipeline_out_prod_f32; + vk_pipeline pipeline_out_prod_f16_f32; + vk_pipeline pipeline_out_prod_q4_0; + vk_pipeline pipeline_out_prod_q8_0; vk_pipeline pipeline_argmax_f32; vk_pipeline pipeline_count_equal_i32; vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16; @@ -569,8 +574,8 @@ struct vk_buffer_struct { } VK_LOG_DEBUG("~vk_buffer_struct(" << buffer << ", " << size << ")"); - device->device.freeMemory(device_memory); device->device.destroyBuffer(buffer); + device->device.freeMemory(device_memory); } }; @@ -2911,6 +2916,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_cross_entropy_loss_back_f32, "cross_entropy_loss_back_f32", cross_entropy_loss_back_f32_len, cross_entropy_loss_back_f32_data, "main", 4, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); @@ -2934,6 +2941,12 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); } + ggml_vk_create_pipeline(device, device->pipeline_out_prod_f32, "out_prod_f32", out_prod_f32_len, out_prod_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, { 0 }, 1); + ggml_vk_create_pipeline(device, device->pipeline_out_prod_q4_0, "out_prod_q4_0", out_prod_q4_0_len, out_prod_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, { 0 }, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_out_prod_q8_0, "out_prod_q8_0", out_prod_q8_0_len, out_prod_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, { 0 }, 1, true); + + ggml_vk_create_pipeline(device, device->pipeline_out_prod_f16_f32, "out_prod_f16_f32", out_prod_f16_f32_len, out_prod_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, { 0 }, 1); + ggml_vk_create_pipeline(device, device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); @@ -6691,6 +6704,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_diag_mask_inf_f32; } return nullptr; + case GGML_OP_CROSS_ENTROPY_LOSS_BACK: + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && src2->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_cross_entropy_loss_back_f32; + } + return nullptr; case GGML_OP_SOFT_MAX: GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); @@ -6745,6 +6763,16 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const } return nullptr; } + case GGML_OP_OUT_PROD: + if (dst->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { + if (src0->type == GGML_TYPE_F32) return ctx->device->pipeline_out_prod_f32; + if (src0->type == GGML_TYPE_Q4_0) return ctx->device->pipeline_out_prod_q4_0; + if (src0->type == GGML_TYPE_Q8_0) return ctx->device->pipeline_out_prod_q8_0; + } + if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_out_prod_f16_f32; + } + return nullptr; case GGML_OP_ARGSORT: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) { return ctx->device->pipeline_argsort_f32; @@ -6829,6 +6857,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) { switch (op) { case GGML_OP_CPY: case GGML_OP_GET_ROWS: + case GGML_OP_OUT_PROD: case GGML_OP_ADD: case GGML_OP_SUB: case GGML_OP_MUL: @@ -6915,7 +6944,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co } std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; std::cerr << "), " << ggml_op_name(op) << ", " << (dryrun ? "dryrun" : "") << ")"); - GGML_ASSERT(op == GGML_OP_GET_ROWS || op == GGML_OP_CPY || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT + GGML_ASSERT(op == GGML_OP_GET_ROWS || op == GGML_OP_CPY || op == GGML_OP_OUT_PROD || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT GGML_ASSERT(ggml_vk_op_supports_incontiguous(op) || ggml_vk_dim01_contiguous(src0)); // NOLINT GGML_ASSERT(dst->buffer != nullptr); const uint64_t ne00 = src0->ne[0]; @@ -7073,6 +7102,18 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co elements = { nr, 1, 1 }; } } break; + case GGML_OP_CROSS_ENTROPY_LOSS_BACK: + { + // For cross entropy loss back, we need one workgroup per row of logits (src1) + const uint32_t nr = ggml_nrows(src1); + if (nr > 262144) { + elements = { 512, 512, CEIL_DIV(nr, 262144) }; + } else if (nr > 512) { + elements = { 512, CEIL_DIV(nr, 512), 1 }; + } else { + elements = { nr, 1, 1 }; + } + } break; case GGML_OP_RMS_NORM: elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 }; break; @@ -7149,6 +7190,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co case GGML_OP_UPSCALE: case GGML_OP_UNARY: case GGML_OP_GLU: + case GGML_OP_OUT_PROD: case GGML_OP_CONV_2D_DW: { uint32_t ne = ggml_nelements(dst); @@ -7787,6 +7829,18 @@ static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& sub ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun); } +static void ggml_vk_cross_entropy_loss_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) { + const int64_t nclasses = src1->ne[0]; + const int64_t nrows = ggml_nrows(src1); + + ggml_vk_op_f32(ctx, subctx, src0, src1, src2, dst, GGML_OP_CROSS_ENTROPY_LOSS_BACK, { + (uint32_t)nclasses, + (uint32_t)nrows, + 0.0f, + 0.0f + }, dryrun); +} + static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { float * op_params = (float *)dst->op_params; @@ -7894,6 +7948,24 @@ static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subct ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); } +static void ggml_vk_out_prod(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + const int64_t r2 = src1->ne[2] / src0->ne[2]; + const int64_t r3 = src1->ne[3] / src0->ne[3]; + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_OUT_PROD, { + (uint32_t)ggml_nelements(dst), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, (float) r2, (int32_t) r3 + }, dryrun); +} + static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { const int32_t s0 = dst->op_params[0]; const int32_t s1 = dst->op_params[1]; @@ -9044,12 +9116,14 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_RMS_NORM_BACK: case GGML_OP_L2_NORM: case GGML_OP_DIAG_MASK_INF: + case GGML_OP_CROSS_ENTROPY_LOSS_BACK: case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX_BACK: case GGML_OP_ROPE: case GGML_OP_ROPE_BACK: case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: + case GGML_OP_OUT_PROD: case GGML_OP_ARGSORT: case GGML_OP_SUM: case GGML_OP_SUM_ROWS: @@ -9117,6 +9191,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_SOFT_MAX_BACK: case GGML_OP_ROPE: case GGML_OP_ROPE_BACK: + case GGML_OP_OUT_PROD: case GGML_OP_ARGSORT: case GGML_OP_SUM: case GGML_OP_SUM_ROWS: @@ -9156,6 +9231,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_GET_ROWS: ggml_vk_get_rows(ctx, compute_ctx, src0, src1, node, dryrun); + break; + case GGML_OP_OUT_PROD: + ggml_vk_out_prod(ctx, compute_ctx, src0, src1, node, dryrun); + break; case GGML_OP_ADD: ggml_vk_add(ctx, compute_ctx, src0, src1, node, dryrun); @@ -9280,6 +9359,14 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_DIAG_MASK_INF: ggml_vk_diag_mask_inf(ctx, compute_ctx, src0, node, dryrun); + break; + + case GGML_OP_CROSS_ENTROPY_LOSS_BACK: + // std::cerr << "*** GGML_VK_BUILD_GRAPH: CROSS_ENTROPY_LOSS_BACK case hit, calling ggml_vk_cross_entropy_loss_back" << std::endl; + // std::cout << "*** GGML_VK_BUILD_GRAPH: CROSS_ENTROPY_LOSS_BACK case hit, calling ggml_vk_cross_entropy_loss_back" << std::endl; + // fflush(stdout); fflush(stderr); + ggml_vk_cross_entropy_loss_back(ctx, compute_ctx, src0, src1, src2, node, dryrun); + break; case GGML_OP_SOFT_MAX: ggml_vk_soft_max(ctx, compute_ctx, src0, src1, node, dryrun); @@ -9444,6 +9531,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * case GGML_OP_RMS_NORM_BACK: case GGML_OP_L2_NORM: case GGML_OP_DIAG_MASK_INF: + case GGML_OP_CROSS_ENTROPY_LOSS_BACK: case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX_BACK: case GGML_OP_ROPE: @@ -9457,6 +9545,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * case GGML_OP_SUM: case GGML_OP_SUM_ROWS: case GGML_OP_ARGMAX: + case GGML_OP_OUT_PROD: case GGML_OP_COUNT_EQUAL: case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: @@ -10580,11 +10669,28 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_GROUP_NORM: case GGML_OP_L2_NORM: return ggml_is_contiguous(op->src[0]); + case GGML_OP_OUT_PROD: { + const ggml_type t0 = op->src[0]->type; + const ggml_type t1 = op->src[1]->type; + const ggml_type td = op->type; + if (td != GGML_TYPE_F32 || t1 != GGML_TYPE_F32) { + return false; + } + switch (t0) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q8_0: + return true; + default: + return false; + } + } case GGML_OP_ADD: case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: - return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && (op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16); case GGML_OP_SILU_BACK: @@ -10619,6 +10725,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return true; case GGML_OP_CONV_TRANSPOSE_1D: return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32; + case GGML_OP_CROSS_ENTROPY_LOSS_BACK: + return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->src[2]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; default: return false; } @@ -11030,6 +11138,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * tensor_clone = ggml_repeat_back(ggml_ctx, src_clone[0], tensor); } else if (tensor->op == GGML_OP_ADD) { tensor_clone = ggml_add(ggml_ctx, src_clone[0], src_clone[1]); + } else if (tensor->op == GGML_OP_OUT_PROD) { + tensor_clone = ggml_out_prod(ggml_ctx, src_clone[0], src_clone[1]); } else if (tensor->op == GGML_OP_ACC) { tensor_clone = ggml_acc(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]); } else if (tensor->op == GGML_OP_NORM) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/cross_entropy_loss_back.comp b/ggml/src/ggml-vulkan/vulkan-shaders/cross_entropy_loss_back.comp new file mode 100644 index 0000000000000..920279aee314f --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/cross_entropy_loss_back.comp @@ -0,0 +1,92 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable + +#include "generic_head.comp" +#include "types.comp" + +#define FLOAT_TYPE float + +layout(constant_id = 0) const uint BLOCK_SIZE = 32; +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; // Grad(scalar) +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; // logits => raw model outputs(unnormalized scored) +layout (binding = 2) readonly buffer C {C_TYPE data_c[];}; // true labels(one hot encoded) +layout (binding = 3) writeonly buffer D {D_TYPE data_d[];}; // output gradients + +shared FLOAT_TYPE vals[BLOCK_SIZE]; + +void main() { + const uint nclasses = p.KX; + const uint nrows = p.KY; + + const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + if (row >= nrows) { + return; + } + + const uint tid = gl_LocalInvocationID.x; + const uint warp_size = gl_WorkGroupSize.x; + + const uint logits_offset = row * nclasses; + const uint labels_offset = row * nclasses; + const uint dst_offset = row * nclasses; + + // Gradient scaling (grad / batch_size) + const FLOAT_TYPE d_by_nrows = FLOAT_TYPE(data_a[0]) / FLOAT_TYPE(nrows); + + // Get max value per thread + FLOAT_TYPE thread_max = FLOAT_TYPE(uintBitsToFloat(0xFF800000)); // -INFINITY + for (uint i = tid; i < nclasses; i += warp_size) { + FLOAT_TYPE val = FLOAT_TYPE(data_b[logits_offset + i]); + thread_max = max(thread_max, val); + } + + vals[tid] = thread_max; + barrier(); + + // Get global maximum for the row(batch) + [[unroll]] + for (uint s = warp_size / 2; s > 0; s >>= 1) { + if (tid < s) { + vals[tid] = max(vals[tid], vals[tid + s]); + } + barrier(); + } + + const FLOAT_TYPE row_max = vals[0]; + barrier(); + + // Compute sum of exp(logits - max) for softmax normalization + FLOAT_TYPE thread_sum = FLOAT_TYPE(0.0); + for (uint i = tid; i < nclasses; i += warp_size) { + FLOAT_TYPE val = FLOAT_TYPE(data_b[logits_offset + i]); + thread_sum += exp(val - row_max); + } + + vals[tid] = thread_sum; + barrier(); + + [[unroll]] + for (uint s = warp_size / 2; s > 0; s >>= 1) { + if (tid < s) { + vals[tid] += vals[tid + s]; + } + barrier(); + } + + const FLOAT_TYPE row_sum = vals[0]; + const FLOAT_TYPE sm_scale = FLOAT_TYPE(1.0) / row_sum; + barrier(); + + // Compute final gradients: (softmax - labels) * d_by_nrows + for (uint i = tid; i < nclasses; i += warp_size) { + FLOAT_TYPE logit = FLOAT_TYPE(data_b[logits_offset + i]); + FLOAT_TYPE softmax_val = exp(logit - row_max) * sm_scale; + + FLOAT_TYPE label = FLOAT_TYPE(data_c[labels_offset + i]); + + data_d[dst_offset + i] = D_TYPE((softmax_val - label) * d_by_nrows); + } +} \ No newline at end of file diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/out_prod.comp b/ggml/src/ggml-vulkan/vulkan-shaders/out_prod.comp new file mode 100644 index 0000000000000..31a7f40db694a --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/out_prod.comp @@ -0,0 +1,56 @@ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require + +#include "generic_binary_head.comp" +#include "types.comp" + +const uint num_threads = 256; + +layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; + +void get_dst_indices(uint idx, out uint i20, out uint i21, out uint i22, out uint i23) { + i23 = fastdiv(idx, (p.ne22*p.ne21*p.ne20)); + const uint i23_offset = i23 * p.ne22*p.ne21*p.ne20; + i22 = fastdiv((idx - i23_offset), (p.ne21*p.ne20)); + const uint i22_offset = i22*p.ne21*p.ne20; + i21 = (idx - i23_offset - i22_offset) / p.ne20; + i20 = idx - i23_offset - i22_offset - i21*p.ne20; +} + +void main() { + // num_threads * num_iter must equal 512 to match the wg_denoms and get_idx + const uint num_iter = 2; + + const uint broadcast2 = uint(p.param2); + const uint broadcast3 = p.param3; + + uint idx = get_idx(); + + [[unroll]] for (uint i = 0; i < num_iter; ++i) { + if (idx >= p.ne) { + continue; + } + + uint i0, i1, i2, i3; + get_dst_indices(idx, i0, i1, i2, i3); + + FLOAT_TYPE acc = FLOAT_TYPE(0.0); + + for (uint i01 = 0; i01 < p.ne01; ++i01) { + uint a_idx = src0_idx(i0, i01, i2 / broadcast2, i3 / broadcast3); + uint b_idx = src1_idx(i1, i01, i2, i3); + + FLOAT_TYPE a_val = FLOAT_TYPE(data_a[get_aoffset() + a_idx]); + FLOAT_TYPE b_val = FLOAT_TYPE(data_b[get_boffset() + b_idx]); + + acc += a_val * b_val; + } + + uint d_idx = dst_idx(i0, i1, i2, i3); + data_d[get_doffset() + d_idx] = D_TYPE(acc); + + idx += num_threads; + } +} + diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/out_prod_q4_0.comp b/ggml/src/ggml-vulkan/vulkan-shaders/out_prod_q4_0.comp new file mode 100644 index 0000000000000..1d83771b1d910 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/out_prod_q4_0.comp @@ -0,0 +1,57 @@ +#version 450 + +#include "types.comp" +#include "generic_binary_head.comp" +#include "dequant_funcs.comp" + +const uint num_threads = 256; +layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; + +void get_dst_indices(uint idx, out uint i20, out uint i21, out uint i22, out uint i23) { + i23 = fastdiv(idx, (p.ne22*p.ne21*p.ne20)); + const uint i23_offset = i23 * p.ne22*p.ne21*p.ne20; + i22 = fastdiv((idx - i23_offset), (p.ne21*p.ne20)); + const uint i22_offset = i22*p.ne21*p.ne20; + i21 = (idx - i23_offset - i22_offset) / p.ne20; + i20 = idx - i23_offset - i22_offset - i21*p.ne20; +} + +void main() { + // num_threads * num_iter must equal 512 to match the wg_denoms and get_idx + const uint num_iter = 2; + + const uint broadcast2 = uint(p.param2); + const uint broadcast3 = p.param3; + + uint idx = get_idx(); + + [[unroll]] for (uint it = 0; it < num_iter; ++it) { + if (idx < p.ne) { + uint i0, i1, i2, i3; + get_dst_indices(idx, i0, i1, i2, i3); + + float acc = 0.0f; + + for (uint k = 0; k < p.ne01; k += 1) { + const uint a_block_base = get_aoffset() + (i3 / broadcast3) * p.nb03 + (i2 / broadcast2) * p.nb02 + k * p.nb01; + const uint ib = a_block_base + (i0 / QUANT_K); + const uint iqs = i0 % (QUANT_K / QUANT_R); + const uint upper = (i0 % QUANT_K) / (QUANT_K / QUANT_R); + const uint lower = 1 - upper; + + const vec2 v = dequantize(ib, iqs, 0); + const vec2 dm = get_dm(ib, 0); + + const float a_val = (v.x * lower + v.y * upper) * dm.x + dm.y; + + const uint b_idx = src1_idx(i1, k, i2, i3); + const float b = data_b[get_boffset() + b_idx]; + acc += a_val * b; + } + + uint d_idx = dst_idx(i0, i1, i2, i3); + data_d[get_doffset() + d_idx] = acc; + } + idx += num_threads; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/out_prod_q8_0.comp b/ggml/src/ggml-vulkan/vulkan-shaders/out_prod_q8_0.comp new file mode 100644 index 0000000000000..58acaae127622 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/out_prod_q8_0.comp @@ -0,0 +1,54 @@ +#version 450 + +#include "types.comp" +#include "generic_binary_head.comp" +#include "dequant_funcs.comp" + +const uint num_threads = 256; +layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; + +void get_dst_indices(uint idx, out uint i20, out uint i21, out uint i22, out uint i23) { + i23 = fastdiv(idx, (p.ne22*p.ne21*p.ne20)); + const uint i23_offset = i23 * p.ne22*p.ne21*p.ne20; + i22 = fastdiv((idx - i23_offset), (p.ne21*p.ne20)); + const uint i22_offset = i22*p.ne21*p.ne20; + i21 = (idx - i23_offset - i22_offset) / p.ne20; + i20 = idx - i23_offset - i22_offset - i21*p.ne20; +} + +void main() { + // num_threads * num_iter must equal 512 to match the wg_denoms and get_idx + const uint num_iter = 2; + + const uint broadcast2 = uint(p.param2); + const uint broadcast3 = p.param3; + + uint idx = get_idx(); + + [[unroll]] for (uint it = 0; it < num_iter; ++it) { + if (idx < p.ne) { + uint i0, i1, i2, i3; + get_dst_indices(idx, i0, i1, i2, i3); + + float acc = 0.0f; + + for (uint k = 0; k < p.ne01; k += 1) { + const uint a_block_base = get_aoffset() + (i3 / broadcast3) * p.nb03 + (i2 / broadcast2) * p.nb02 + k * p.nb01; + const uint ib = a_block_base + (i0 / QUANT_K) * p.nb00; + const uint iqs = (i0 % QUANT_K) / QUANT_R; + + const vec2 v = dequantize(ib, iqs, 0); + const vec2 dm = get_dm(ib, 0); + const float a_val = v.x * dm.x + dm.y; + + const uint b_idx = src1_idx(i1, k, i2, i3); + const float b = data_b[get_boffset() + b_idx]; + acc += a_val * b; + } + + uint d_idx = dst_idx(i0, i1, i2, i3); + data_d[get_doffset() + d_idx] = acc; + } + idx += num_threads; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 809c0bd9bd305..eb66152e18784 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -613,6 +613,8 @@ void process_shaders() { string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("cross_entropy_loss_back_f32", "cross_entropy_loss_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"C_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("soft_max_f32", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}})); string_to_spv("soft_max_back_f32", "soft_max_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); @@ -633,6 +635,11 @@ void process_shaders() { string_to_spv("rope_vision_f16", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); + string_to_spv("out_prod_f32", "out_prod.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("out_prod_f16_f32", "out_prod.comp", merge_maps(base_dict, {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("out_prod_q4_0", "out_prod_q4_0.comp", merge_maps(base_dict, {{"DATA_A_Q4_0", "1"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("out_prod_q8_0", "out_prod_q8_0.comp", merge_maps(base_dict, {{"DATA_A_Q8_0", "1"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}}); string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}})); diff --git a/include/llama.h b/include/llama.h index 1c3a1cd1b4e7d..5accb65e5a0e3 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1442,6 +1442,44 @@ extern "C" { int64_t idata_split, ggml_opt_epoch_callback callback_train, ggml_opt_epoch_callback callback_eval); + + // LoRA training parameters + enum llama_lora_target_module { + LLAMA_LORA_TARGET_ATTN_Q = 1 << 0, + LLAMA_LORA_TARGET_ATTN_K = 1 << 1, + LLAMA_LORA_TARGET_ATTN_V = 1 << 2, + LLAMA_LORA_TARGET_ATTN_O = 1 << 3, + LLAMA_LORA_TARGET_FFN_GATE = 1 << 4, + LLAMA_LORA_TARGET_FFN_UP = 1 << 5, + LLAMA_LORA_TARGET_FFN_DOWN = 1 << 6, + LLAMA_LORA_TARGET_OUTPUT = 1 << 7, + LLAMA_LORA_TARGET_ALL = 0x1FF, + }; + + struct llama_lora_training_params { + uint32_t target_modules; + int32_t rank; + float alpha; + float dropout; + float init_std; + }; + + // Initialize LoRA training with the given parameters + // Creates LoRA tensors and adds them to the model context + LLAMA_API struct llama_adapter_lora * llama_lora_training_init( + struct llama_context * ctx, + struct llama_model * model, + const struct llama_lora_training_params * params + ); + + // LoRA parameter filter (returns true for LoRA tensors only) + LLAMA_API bool llama_opt_param_filter_lora(const struct ggml_tensor * tensor, void * userdata); + + LLAMA_API bool llama_lora_save_adapter( + const struct llama_adapter_lora * adapter, + const char * filename, + const struct llama_model * model + ); #ifdef __cplusplus } diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 8f9cd652447ab..6aaac7875203d 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -22,6 +22,7 @@ add_library(llama llama-io.cpp llama-kv-cache-unified.cpp llama-kv-cache-unified-iswa.cpp + llama-lora-training.cpp llama-memory.cpp llama-memory-hybrid.cpp llama-memory-recurrent.cpp diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 1af19caa39dab..3eeb7dae5fe03 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -2030,6 +2030,23 @@ void llama_context::opt_init(struct llama_model * model, struct llama_opt_params llama_set_param(reinterpret_cast(&layer)[i], param_filter, param_filter_ud); } } + + // Set LoRA params as trainable if any? + for (const auto & adapter_pair : loras) { + llama_adapter_lora * adapter = adapter_pair.first; + if (adapter) { + // Register lora tensors as params for training + for (const auto & tensor_pair : adapter->ab_map) { + const llama_adapter_lora_weight & weight = tensor_pair.second; + if (weight.a) { + llama_set_param(weight.a, param_filter, param_filter_ud); + } + if (weight.b) { + llama_set_param(weight.b, param_filter, param_filter_ud); + } + } + } + } } void llama_context::opt_epoch_iter( diff --git a/src/llama-lora-training.cpp b/src/llama-lora-training.cpp new file mode 100644 index 0000000000000..fc6750e27b4c2 --- /dev/null +++ b/src/llama-lora-training.cpp @@ -0,0 +1,359 @@ +#include "llama-lora-training.h" + +#include +#include +#include +#include +#include +#include + + +ggml_context * llama_lora_create_context(size_t mem_size) { + struct ggml_init_params init_params = { + /*.mem_size =*/ mem_size, + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + return ggml_init(init_params); +} + +bool llama_lora_validate_training_params(const struct llama_lora_training_params * params) { + if (!params) { + LLAMA_LOG_ERROR("LoRA training validation: params is null\n"); + return false; + } + + if (params->rank <= 0 || params->rank > 1024) { + LLAMA_LOG_ERROR("LoRA training validation: invalid rank %d (must be 1-1024)\n", params->rank); + return false; + } + + if (params->alpha <= 0.0f) { + LLAMA_LOG_ERROR("LoRA training validation: invalid alpha %f (must be > 0)\n", params->alpha); + return false; + } + + if (params->dropout < 0.0f || params->dropout > 1.0f) { + LLAMA_LOG_ERROR("LoRA training validation: invalid dropout %f (must be [0, 1])\n", params->dropout); + return false; + } + + if (params->init_std <= 0.0f || params->init_std > 1.0f) { + LLAMA_LOG_ERROR("LoRA training validation: invalid init_std %f (must be (0, 1])\n", params->init_std); + return false; + } + + if (params->target_modules == 0) { + LLAMA_LOG_ERROR("LoRA training validation: no target modules specified\n"); + return false; + } + + return true; +} + +bool llama_lora_create_tensor_pair( + struct ggml_context * lora_ctx, + const char * base_name, + const struct ggml_tensor * base_tensor, + int32_t rank, + struct ggml_tensor ** lora_a, + struct ggml_tensor ** lora_b) { + + if (!lora_ctx || !base_name || !base_tensor || !lora_a || !lora_b) { + return false; + } + + // Get base tensor dim + const int64_t d0 = base_tensor->ne[0]; // input dim + const int64_t d1 = base_tensor->ne[1]; // output dim + + char lora_a_name[256], lora_b_name[256]; + snprintf(lora_a_name, sizeof(lora_a_name), "%s.lora_a", base_name); + snprintf(lora_b_name, sizeof(lora_b_name), "%s.lora_b", base_name); + + // LoRA A: [d0, rank] - projects input to low rank + *lora_a = ggml_new_tensor_2d(lora_ctx, GGML_TYPE_F32, d0, rank); + ggml_set_name(*lora_a, lora_a_name); + + // LoRA B: [rank, d1] - projects from low rank to output + *lora_b = ggml_new_tensor_2d(lora_ctx, GGML_TYPE_F32, rank, d1); + ggml_set_name(*lora_b, lora_b_name); + + return true; +} + +static bool is_tensor_on_device(const struct ggml_tensor * tensor) { + return tensor->buffer && !ggml_backend_buffer_is_host(tensor->buffer); +} + +static void init_tensor_guassian(struct ggml_tensor * tensor, float std_dev) { + const size_t n_elements = ggml_nelements(tensor); + std::vector data(n_elements); + + std::random_device rd; + std::mt19937 gen(rd()); + std::normal_distribution dist(0.0f, std_dev); + + for (size_t i = 0; i < n_elements; i++) { + data[i] = dist(gen); + } + + if (is_tensor_on_device(tensor)) { + ggml_backend_tensor_set(tensor, data.data(), 0, n_elements * sizeof(float)); + } else { + std::copy(data.begin(), data.end(), (float *)tensor->data); + } +} + +static void init_tensor_zeros(struct ggml_tensor * tensor) { + const size_t n_elements = ggml_nelements(tensor); + + if (is_tensor_on_device(tensor)) { + std::vector zeros(n_elements, 0.0f); + ggml_backend_tensor_set(tensor, zeros.data(), 0, n_elements * sizeof(float)); + } else { + std::fill_n((float *)tensor->data, n_elements, 0.0f); + } +} + +void llama_lora_init_tensor_weights(struct ggml_tensor * lora_a, struct ggml_tensor * lora_b, float init_std) { + if (!lora_a || !lora_b) return; + + // LoRA initialization: A ~ N(0, init_std), B = 0 + init_tensor_guassian(lora_a, init_std); + init_tensor_zeros(lora_b); +} + +bool llama_lora_allocate_buffers( + struct llama_adapter_lora * adapter, + struct llama_model * model) { + + if (!adapter || !model) { + return false; + } + + std::map ctx_map; + + ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type(); // fallback to CPU + + // Find any layer tensor to determine the correct backend + for (const auto & tensor_pair : model->tensors_by_name) { + const std::string & name = tensor_pair.first; + struct ggml_tensor * tensor = tensor_pair.second; + + if (name.find("blk.") != std::string::npos && tensor && tensor->buffer) { + buft = ggml_backend_buffer_get_type(tensor->buffer); + break; + } + } + + if (adapter->ctxs.empty()) { + LLAMA_LOG_ERROR("No contexts found in adapter\n"); + return false; + } + ggml_context * lora_ctx = adapter->ctxs[0].get(); + + ggml_backend_buffer_ptr buf { ggml_backend_alloc_ctx_tensors_from_buft(lora_ctx, buft) }; + if (!buf) { + LLAMA_LOG_ERROR("Failed to allocate buffer for LoRA adapter\n"); + return false; + } + LLAMA_LOG_INFO("LoRA buffer size = %.2f MiB\n", ggml_backend_buffer_get_size(buf.get())/1024.0/1024.0); + adapter->bufs.emplace_back(std::move(buf)); + + return true; +} + +struct llama_adapter_lora * llama_lora_create_adapter( + struct llama_model * model, + const struct llama_lora_training_params * params) { + + // Create a new LoRA adapter instance + llama_adapter_lora * adapter = new llama_adapter_lora(); + try { + adapter->alpha = params->alpha; + + // Create LoRA tensors and populate ab_map + // Create GGML context for LoRA tensors + const size_t estimated_lora_mem = 256 * 1024 * 1024; // 256MB should be enough for most LoRA configs + ggml_context * lora_ctx = llama_lora_create_context(estimated_lora_mem); + if (!lora_ctx) { + throw std::runtime_error("Failed to create LoRA context"); + } + + adapter->ctxs.emplace_back(lora_ctx); + int created_count = 0; + + for (const auto & tensor_pair : model->tensors_by_name) { + const std::string & tensor_name = tensor_pair.first; + struct ggml_tensor * base_tensor = tensor_pair.second; + + if (!base_tensor) { + continue; + } + + bool should_create_lora = false; + if (tensor_name.find("blk.") != std::string::npos) { + if ((params->target_modules & LLAMA_LORA_TARGET_ATTN_Q) && tensor_name.find("attn_q") != std::string::npos) { + should_create_lora = true; + } else if ((params->target_modules & LLAMA_LORA_TARGET_ATTN_K) && tensor_name.find("attn_k") != std::string::npos) { + should_create_lora = true; + } else if ((params->target_modules & LLAMA_LORA_TARGET_ATTN_V) && tensor_name.find("attn_v") != std::string::npos) { + should_create_lora = true; + } else if ((params->target_modules & LLAMA_LORA_TARGET_ATTN_O) && tensor_name.find("attn_output") != std::string::npos) { + should_create_lora = true; + } else if ((params->target_modules & LLAMA_LORA_TARGET_FFN_GATE) && tensor_name.find("ffn_gate") != std::string::npos) { + should_create_lora = true; + } else if ((params->target_modules & LLAMA_LORA_TARGET_FFN_UP) && tensor_name.find("ffn_up") != std::string::npos) { + should_create_lora = true; + } else if ((params->target_modules & LLAMA_LORA_TARGET_FFN_DOWN) && tensor_name.find("ffn_down") != std::string::npos) { + should_create_lora = true; + } + } else if ((params->target_modules & LLAMA_LORA_TARGET_OUTPUT) && tensor_name.find("output") != std::string::npos) { + should_create_lora = true; + } + + if (should_create_lora && base_tensor->ne[1] > 0) { + struct ggml_tensor * lora_a = nullptr; + struct ggml_tensor * lora_b = nullptr; + + if (llama_lora_create_tensor_pair(lora_ctx, tensor_name.c_str(), base_tensor, params->rank, &lora_a, &lora_b)) { + if (!lora_a || !lora_b) { + throw std::runtime_error("Created null LoRA tensors for " + tensor_name); + } + created_count++; + adapter->ab_map[tensor_name] = llama_adapter_lora_weight(lora_a, lora_b); + } else { + throw std::runtime_error("Failed to create LoRA tensor pair for " + tensor_name); + } + } + } + + if (created_count == 0) { + throw std::runtime_error("No suitable tensors found for LoRA adaptation"); + } + + if (!llama_lora_allocate_buffers(adapter, model)) { + throw std::runtime_error("Failed to allocate LoRA buffers"); + } + + for (const auto & ab_pair : adapter->ab_map) { + const std::string & tensor_name = ab_pair.first; + const llama_adapter_lora_weight & weight = ab_pair.second; + + if (weight.a && weight.b && weight.a->data && weight.b->data) { + llama_lora_init_tensor_weights(weight.a, weight.b, params->init_std); + } else { + throw std::runtime_error("LoRA tensor initialization failed for " + tensor_name); + } + } + return adapter; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("Failed to create LoRA adapter: %s\n", err.what()); + delete adapter; + return nullptr; + } +} + +struct llama_adapter_lora * llama_lora_training_init( + struct llama_context * ctx, + struct llama_model * model, + const struct llama_lora_training_params * params) { + + if (!ctx || !model || !params) { + LLAMA_LOG_ERROR("LoRA training init: invalid parameters\n"); + return nullptr; + } + + if (!llama_lora_validate_training_params(params)) { + return nullptr; + } + + struct llama_adapter_lora * adapter = llama_lora_create_adapter(model, params); + if (!adapter) { + return nullptr; + } + + llama_clear_adapter_lora(ctx); + + if (llama_set_adapter_lora(ctx, adapter, 1.0f) < 0) { + LLAMA_LOG_ERROR("Failed to apply LoRA adapter to context\n"); + delete adapter; + return nullptr; + } + + LLAMA_LOG_INFO("LoRA adapter contains %zu tensor pairs and is now registered with context\n", adapter->ab_map.size()); + + return adapter; +} + +bool llama_opt_param_filter_lora(const struct ggml_tensor * tensor, void * userdata) { + (void) userdata; // Unused param + + if (!tensor || !tensor->name) { + return false; + } + + const char * name = tensor->name; + + // Check if tensor is LoRA A or B + // LoRA tensor naming convention: blk.{layer}.{module}.lora_a or .lora_b + if (strstr(name, ".lora_a") || strstr(name, ".lora_b")) { + LLAMA_LOG_DEBUG("LoRA filter: including trainable params '%s'\n", name); + return true; + } + + return false; +} + +bool llama_lora_save_adapter( + const struct llama_adapter_lora * adapter, + const char * filename, + const struct llama_model * model) { + + if (!adapter || !filename || !model) { + LLAMA_LOG_ERROR("llama_lora_save_adapter: invalid parameters\n"); + return false; + } + + struct gguf_context * gguf_ctx = gguf_init_empty(); + if (!gguf_ctx) { + LLAMA_LOG_ERROR("llama_lora_save_adapter: failed to create GGUF context\n"); + return false; + } + + std::string arch_name = model->arch_name(); + if (arch_name.empty()) { + LLAMA_LOG_ERROR("llama_lora_save_adapter: failed to get model architecture\n"); + gguf_free(gguf_ctx); + return false; + } + + gguf_set_val_str(gguf_ctx, "general.architecture", arch_name.c_str()); + gguf_set_val_str(gguf_ctx, "general.type", "adapter"); + gguf_set_val_str(gguf_ctx, "general.name", "LoRA Adapter"); + gguf_set_val_str(gguf_ctx, "adapter.type", "lora"); + gguf_set_val_f32(gguf_ctx, "adapter.lora.alpha", adapter->alpha); + + int tensor_count = 0; + for (const auto & kv : adapter->ab_map) { + const auto & lora_weight = kv.second; + + if (lora_weight.a && lora_weight.b) { + gguf_add_tensor(gguf_ctx, lora_weight.a); + gguf_add_tensor(gguf_ctx, lora_weight.b); + tensor_count += 2; + } + } + + bool success = gguf_write_to_file(gguf_ctx, filename, false); + if (success) { + LLAMA_LOG_INFO("Successfully saved LoRA adapter with %d tensors to: %s\n", + tensor_count, filename); + } else { + LLAMA_LOG_ERROR("Failed to write LoRA adapter to: %s\n", filename); + } + + gguf_free(gguf_ctx); + return success; +} diff --git a/src/llama-lora-training.h b/src/llama-lora-training.h new file mode 100644 index 0000000000000..ed777be7b36f7 --- /dev/null +++ b/src/llama-lora-training.h @@ -0,0 +1,34 @@ +#pragma once + +#include "llama.h" +#include "llama-model.h" +#include "llama-adapter.h" +#include "llama-impl.h" +#include "ggml.h" + + +bool llama_lora_validate_training_params(const struct llama_lora_training_params * params); + +ggml_context * llama_lora_create_context(size_t mem_size); + +bool llama_lora_create_tensor_pair( + struct ggml_context * lora_ctx, + const char * base_name, + const struct ggml_tensor * base_tensor, + int32_t rank, + struct ggml_tensor ** lora_a, + struct ggml_tensor ** lora_b); + +void llama_lora_init_tensor_weights( + struct ggml_tensor * lora_a, + struct ggml_tensor * lora_b, + float init_std); + +struct llama_adapter_lora * llama_lora_create_adapter( + struct llama_model * model, + const struct llama_lora_training_params * params); + +bool llama_lora_allocate_buffers( + struct llama_adapter_lora * adapter, + struct llama_model * model); +