Skip to content

Commit 9809238

Browse files
authored
Merge pull request #61 from dev-nid/temp-latest-finetuning
QVAC -7948 Added opt cleanup function and mid-epoch cancellation.
2 parents 0cf1ef7 + 8a2d710 commit 9809238

File tree

3 files changed

+112
-6
lines changed

3 files changed

+112
-6
lines changed

include/llama.h

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1403,7 +1403,23 @@ extern "C" {
14031403
ggml_opt_epoch_callback callback_train,
14041404
ggml_opt_epoch_callback callback_eval,
14051405
int64_t resume_from_batch);
1406-
1406+
1407+
// Optimizer state persistence
1408+
LLAMA_API bool llama_opt_save_state(struct llama_context * lctx, const char * filename);
1409+
LLAMA_API bool llama_opt_load_state(struct llama_context * lctx, const char * filename);
1410+
1411+
// Clean up optimizer context to free memory and allow reinitialization
1412+
// Call this before calling llama_opt_init() again on the same context
1413+
LLAMA_API void llama_opt_cleanup(struct llama_context * lctx);
1414+
1415+
// Request early exit from training epoch (thread-safe)
1416+
// Call this from a callback or another thread to stop training after the current batch
1417+
LLAMA_API void llama_opt_request_stop(struct llama_context * lctx);
1418+
1419+
// Reset the stop flag to allow training to continue
1420+
// Call this before resuming training after a pause
1421+
LLAMA_API void llama_opt_reset_stop(struct llama_context * lctx);
1422+
14071423
// LoRA training parameters
14081424
enum llama_lora_target_module {
14091425
LLAMA_LORA_TARGET_ATTN_Q = 1 << 0,

src/llama-context.cpp

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2138,6 +2138,10 @@ void llama_context::opt_epoch_iter(
21382138
memory->clear(true);
21392139

21402140
for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
2141+
// Check for early exit request before processing context batch
2142+
if (training_should_stop.load(std::memory_order_acquire)) {
2143+
return;
2144+
}
21412145
batch.n_tokens = n_batch;
21422146
for (uint32_t pos_batch = 0; pos_batch < n_batch; ++pos_batch) {
21432147
batch.token [pos_batch] = tokens[pos_ctx + pos_batch];
@@ -2174,6 +2178,11 @@ void llama_context::opt_epoch_iter(
21742178

21752179
uint32_t pos_batch = 0;
21762180
do {
2181+
// Check for early exit request before processing ubatch
2182+
if (training_should_stop.load(std::memory_order_acquire)) {
2183+
break;
2184+
}
2185+
21772186
const auto & ubatch = mctx->get_ubatch();
21782187

21792188
n_outputs = ubatch.n_tokens;
@@ -2263,6 +2272,11 @@ void llama_context::opt_epoch_iter(
22632272
ggml_free(ctx_compute_opt);
22642273

22652274
pos_batch += ubatch.n_tokens;
2275+
2276+
// Check for early exit request after processing ubatch
2277+
if (training_should_stop.load(std::memory_order_acquire)) {
2278+
break;
2279+
}
22662280
} while (mctx->next());
22672281
}
22682282
}
@@ -2275,6 +2289,9 @@ void llama_context::opt_epoch(
22752289
ggml_opt_epoch_callback callback_train,
22762290
ggml_opt_epoch_callback callback_eval,
22772291
int64_t resume_from_batch) {
2292+
// Reset stop flag at the start of each epoch to ensure clean state
2293+
training_should_stop.store(false, std::memory_order_release);
2294+
22782295
const uint32_t n_ctx = this->n_ctx();
22792296
const uint32_t n_batch = std::min(cparams.n_batch, n_ctx);
22802297
const uint32_t n_ubatch = std::min(cparams.n_ubatch, n_batch);
@@ -2295,6 +2312,11 @@ void llama_context::opt_epoch(
22952312
int64_t t_loop_start = ggml_time_us();
22962313
int64_t ndata_in_loop = idata_split*ubatch_per_ctx;
22972314
for (; idata < idata_split; ++idata) {
2315+
// Check for early exit request before processing batch
2316+
if (training_should_stop.load(std::memory_order_acquire)) {
2317+
break;
2318+
}
2319+
22982320
constexpr bool train = true;
22992321
const int64_t idata_in_loop = idata*ubatch_per_ctx;
23002322

@@ -2306,11 +2328,21 @@ void llama_context::opt_epoch(
23062328
}
23072329
opt_epoch_iter(dataset, result_train, tokens, labels_sparse, masks_sparse, batch,
23082330
callback_train, train, idata_in_loop, ndata_in_loop, t_loop_start);
2331+
2332+
// Check again after iteration in case it was set during processing
2333+
if (training_should_stop.load(std::memory_order_acquire)) {
2334+
break;
2335+
}
23092336
}
23102337

23112338
t_loop_start = ggml_time_us();
23122339
ndata_in_loop = (ndata - idata_split)*ubatch_per_ctx;
23132340
for (; idata < ndata; ++idata) {
2341+
// Check for early exit request before processing batch
2342+
if (training_should_stop.load(std::memory_order_acquire)) {
2343+
break;
2344+
}
2345+
23142346
constexpr bool train = false;
23152347
const int64_t idata_in_loop = (idata - idata_split)*ubatch_per_ctx;
23162348

@@ -2321,12 +2353,20 @@ void llama_context::opt_epoch(
23212353
}
23222354
opt_epoch_iter(dataset, result_eval, tokens, labels_sparse, masks_sparse, batch,
23232355
callback_eval, train, idata_in_loop, ndata_in_loop, t_loop_start);
2356+
2357+
// Check again after iteration in case it was set during processing
2358+
if (training_should_stop.load(std::memory_order_acquire)) {
2359+
break;
2360+
}
23242361
}
23252362

23262363
llama_batch_free(batch);
23272364
}
23282365

23292366
int64_t llama_context::opt_get_iter() {
2367+
if (!opt_ctx) {
2368+
return 0; // Return 0 if optimizer not initialized
2369+
}
23302370
return ggml_opt_get_iter(opt_ctx);
23312371
}
23322372

@@ -2344,6 +2384,24 @@ bool llama_context::opt_load_state(const char* filename) {
23442384
return ggml_opt_load_state(opt_ctx, filename);
23452385
}
23462386

2387+
void llama_context::opt_cleanup() {
2388+
if (opt_ctx) {
2389+
ggml_opt_free(opt_ctx);
2390+
opt_ctx = nullptr;
2391+
should_load_optimizer_tensors = false;
2392+
optimizer_tensors_loaded = false;
2393+
pending_optimizer_checkpoint_path.clear();
2394+
}
2395+
}
2396+
2397+
void llama_context::opt_request_stop() {
2398+
training_should_stop.store(true, std::memory_order_release);
2399+
}
2400+
2401+
void llama_context::opt_reset_stop() {
2402+
training_should_stop.store(false, std::memory_order_release);
2403+
}
2404+
23472405
//
23482406
// interface implementation
23492407
//
@@ -2903,3 +2961,23 @@ void llama_opt_epoch(
29032961
int64_t llama_opt_get_iter(struct llama_context * ctx) {
29042962
return ctx->opt_get_iter();
29052963
}
2964+
2965+
bool llama_opt_save_state(struct llama_context * ctx, const char * filename) {
2966+
return ctx->opt_save_state(filename);
2967+
}
2968+
2969+
bool llama_opt_load_state(struct llama_context * ctx, const char * filename) {
2970+
return ctx->opt_load_state(filename);
2971+
}
2972+
2973+
void llama_opt_cleanup(struct llama_context * ctx) {
2974+
ctx->opt_cleanup();
2975+
}
2976+
2977+
void llama_opt_request_stop(struct llama_context * ctx) {
2978+
ctx->opt_request_stop();
2979+
}
2980+
2981+
void llama_opt_reset_stop(struct llama_context * ctx) {
2982+
ctx->opt_reset_stop();
2983+
}

src/llama-context.h

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
#pragma once
22

3-
#include "llama.h"
4-
#include "llama-cparams.h"
5-
#include "llama-graph.h"
6-
#include "llama-adapter.h"
7-
83
#include "ggml-cpp.h"
94
#include "ggml-opt.h"
5+
#include "llama-adapter.h"
6+
#include "llama-cparams.h"
7+
#include "llama-graph.h"
8+
#include "llama.h"
109

10+
#include <atomic>
1111
#include <map>
1212
#include <vector>
1313

@@ -167,6 +167,15 @@ struct llama_context {
167167
bool opt_save_state(const char* filename);
168168
bool opt_load_state(const char* filename);
169169

170+
// Clean up optimizer context to free memory and allow reinitialization
171+
void opt_cleanup();
172+
173+
// Request early exit from training epoch (thread-safe)
174+
void opt_request_stop();
175+
176+
// Reset the stop flag to allow training to continue
177+
void opt_reset_stop();
178+
170179
void opt_epoch_iter(
171180
ggml_opt_dataset_t dataset,
172181
ggml_opt_result_t result,
@@ -278,6 +287,9 @@ struct llama_context {
278287
bool optimizer_tensors_loaded = false;
279288
ggml_opt_loss_type opt_loss_type = GGML_OPT_LOSS_TYPE_CROSS_ENTROPY;
280289

290+
// early exit flag for training epochs (thread-safe)
291+
std::atomic<bool> training_should_stop{ false };
292+
281293
ggml_threadpool_t threadpool = nullptr;
282294
ggml_threadpool_t threadpool_batch = nullptr;
283295

0 commit comments

Comments
 (0)