Skip to content

Commit 107d16e

Browse files
committed
Updating code to enable mid-epoch cancellation
1 parent e432569 commit 107d16e

File tree

3 files changed

+71
-0
lines changed

3 files changed

+71
-0
lines changed

include/llama.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1412,6 +1412,14 @@ extern "C" {
14121412
// Call this before calling llama_opt_init() again on the same context
14131413
LLAMA_API void llama_opt_cleanup(struct llama_context * lctx);
14141414

1415+
// Request early exit from training epoch (thread-safe)
1416+
// Call this from a callback or another thread to stop training after the current batch
1417+
LLAMA_API void llama_opt_request_stop(struct llama_context * lctx);
1418+
1419+
// Reset the stop flag to allow training to continue
1420+
// Call this before resuming training after a pause
1421+
LLAMA_API void llama_opt_reset_stop(struct llama_context * lctx);
1422+
14151423
// LoRA training parameters
14161424
enum llama_lora_target_module {
14171425
LLAMA_LORA_TARGET_ATTN_Q = 1 << 0,

src/llama-context.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2138,6 +2138,10 @@ void llama_context::opt_epoch_iter(
21382138
memory->clear(true);
21392139

21402140
for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
2141+
// Check for early exit request before processing context batch
2142+
if (training_should_stop.load(std::memory_order_acquire)) {
2143+
return;
2144+
}
21412145
batch.n_tokens = n_batch;
21422146
for (uint32_t pos_batch = 0; pos_batch < n_batch; ++pos_batch) {
21432147
batch.token [pos_batch] = tokens[pos_ctx + pos_batch];
@@ -2174,6 +2178,11 @@ void llama_context::opt_epoch_iter(
21742178

21752179
uint32_t pos_batch = 0;
21762180
do {
2181+
// Check for early exit request before processing ubatch
2182+
if (training_should_stop.load(std::memory_order_acquire)) {
2183+
break;
2184+
}
2185+
21772186
const auto & ubatch = mctx->get_ubatch();
21782187

21792188
n_outputs = ubatch.n_tokens;
@@ -2263,6 +2272,11 @@ void llama_context::opt_epoch_iter(
22632272
ggml_free(ctx_compute_opt);
22642273

22652274
pos_batch += ubatch.n_tokens;
2275+
2276+
// Check for early exit request after processing ubatch
2277+
if (training_should_stop.load(std::memory_order_acquire)) {
2278+
break;
2279+
}
22662280
} while (mctx->next());
22672281
}
22682282
}
@@ -2275,6 +2289,9 @@ void llama_context::opt_epoch(
22752289
ggml_opt_epoch_callback callback_train,
22762290
ggml_opt_epoch_callback callback_eval,
22772291
int64_t resume_from_batch) {
2292+
// Reset stop flag at the start of each epoch to ensure clean state
2293+
training_should_stop.store(false, std::memory_order_release);
2294+
22782295
const uint32_t n_ctx = this->n_ctx();
22792296
const uint32_t n_batch = std::min(cparams.n_batch, n_ctx);
22802297
const uint32_t n_ubatch = std::min(cparams.n_ubatch, n_batch);
@@ -2295,6 +2312,11 @@ void llama_context::opt_epoch(
22952312
int64_t t_loop_start = ggml_time_us();
22962313
int64_t ndata_in_loop = idata_split*ubatch_per_ctx;
22972314
for (; idata < idata_split; ++idata) {
2315+
// Check for early exit request before processing batch
2316+
if (training_should_stop.load(std::memory_order_acquire)) {
2317+
break;
2318+
}
2319+
22982320
constexpr bool train = true;
22992321
const int64_t idata_in_loop = idata*ubatch_per_ctx;
23002322

@@ -2306,11 +2328,21 @@ void llama_context::opt_epoch(
23062328
}
23072329
opt_epoch_iter(dataset, result_train, tokens, labels_sparse, masks_sparse, batch,
23082330
callback_train, train, idata_in_loop, ndata_in_loop, t_loop_start);
2331+
2332+
// Check again after iteration in case it was set during processing
2333+
if (training_should_stop.load(std::memory_order_acquire)) {
2334+
break;
2335+
}
23092336
}
23102337

23112338
t_loop_start = ggml_time_us();
23122339
ndata_in_loop = (ndata - idata_split)*ubatch_per_ctx;
23132340
for (; idata < ndata; ++idata) {
2341+
// Check for early exit request before processing batch
2342+
if (training_should_stop.load(std::memory_order_acquire)) {
2343+
break;
2344+
}
2345+
23142346
constexpr bool train = false;
23152347
const int64_t idata_in_loop = (idata - idata_split)*ubatch_per_ctx;
23162348

@@ -2321,6 +2353,11 @@ void llama_context::opt_epoch(
23212353
}
23222354
opt_epoch_iter(dataset, result_eval, tokens, labels_sparse, masks_sparse, batch,
23232355
callback_eval, train, idata_in_loop, ndata_in_loop, t_loop_start);
2356+
2357+
// Check again after iteration in case it was set during processing
2358+
if (training_should_stop.load(std::memory_order_acquire)) {
2359+
break;
2360+
}
23242361
}
23252362

23262363
llama_batch_free(batch);
@@ -2357,6 +2394,14 @@ void llama_context::opt_cleanup() {
23572394
}
23582395
}
23592396

2397+
void llama_context::opt_request_stop() {
2398+
training_should_stop.store(true, std::memory_order_release);
2399+
}
2400+
2401+
void llama_context::opt_reset_stop() {
2402+
training_should_stop.store(false, std::memory_order_release);
2403+
}
2404+
23602405
//
23612406
// interface implementation
23622407
//
@@ -2928,3 +2973,11 @@ bool llama_opt_load_state(struct llama_context * ctx, const char* filename) {
29282973
void llama_opt_cleanup(struct llama_context * ctx) {
29292974
ctx->opt_cleanup();
29302975
}
2976+
2977+
void llama_opt_request_stop(struct llama_context * ctx) {
2978+
ctx->opt_request_stop();
2979+
}
2980+
2981+
void llama_opt_reset_stop(struct llama_context * ctx) {
2982+
ctx->opt_reset_stop();
2983+
}

src/llama-context.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "ggml-cpp.h"
99
#include "ggml-opt.h"
1010

11+
#include <atomic>
1112
#include <map>
1213
#include <vector>
1314

@@ -169,6 +170,12 @@ struct llama_context {
169170

170171
// Clean up optimizer context to free memory and allow reinitialization
171172
void opt_cleanup();
173+
174+
// Request early exit from training epoch (thread-safe)
175+
void opt_request_stop();
176+
177+
// Reset the stop flag to allow training to continue
178+
void opt_reset_stop();
172179

173180
void opt_epoch_iter(
174181
ggml_opt_dataset_t dataset,
@@ -280,6 +287,9 @@ struct llama_context {
280287
bool should_load_optimizer_tensors = false;
281288
bool optimizer_tensors_loaded = false;
282289
ggml_opt_loss_type opt_loss_type = GGML_OPT_LOSS_TYPE_CROSS_ENTROPY;
290+
291+
// early exit flag for training epochs (thread-safe)
292+
std::atomic<bool> training_should_stop{false};
283293

284294
ggml_threadpool_t threadpool = nullptr;
285295
ggml_threadpool_t threadpool_batch = nullptr;

0 commit comments

Comments
 (0)