Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -1404,6 +1404,22 @@ extern "C" {
ggml_opt_epoch_callback callback_eval,
int64_t resume_from_batch);

// Optimizer state persistence
LLAMA_API bool llama_opt_save_state(struct llama_context * lctx, const char* filename);
LLAMA_API bool llama_opt_load_state(struct llama_context * lctx, const char* filename);

// Clean up optimizer context to free memory and allow reinitialization
// Call this before calling llama_opt_init() again on the same context
LLAMA_API void llama_opt_cleanup(struct llama_context * lctx);

// Request early exit from training epoch (thread-safe)
// Call this from a callback or another thread to stop training after the current batch
LLAMA_API void llama_opt_request_stop(struct llama_context * lctx);

// Reset the stop flag to allow training to continue
// Call this before resuming training after a pause
LLAMA_API void llama_opt_reset_stop(struct llama_context * lctx);

// LoRA training parameters
enum llama_lora_target_module {
LLAMA_LORA_TARGET_ATTN_Q = 1 << 0,
Expand Down
78 changes: 78 additions & 0 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2138,6 +2138,10 @@ void llama_context::opt_epoch_iter(
memory->clear(true);

for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
// Check for early exit request before processing context batch
if (training_should_stop.load(std::memory_order_acquire)) {
return;
}
batch.n_tokens = n_batch;
for (uint32_t pos_batch = 0; pos_batch < n_batch; ++pos_batch) {
batch.token [pos_batch] = tokens[pos_ctx + pos_batch];
Expand Down Expand Up @@ -2174,6 +2178,11 @@ void llama_context::opt_epoch_iter(

uint32_t pos_batch = 0;
do {
// Check for early exit request before processing ubatch
if (training_should_stop.load(std::memory_order_acquire)) {
break;
}

const auto & ubatch = mctx->get_ubatch();

n_outputs = ubatch.n_tokens;
Expand Down Expand Up @@ -2263,6 +2272,11 @@ void llama_context::opt_epoch_iter(
ggml_free(ctx_compute_opt);

pos_batch += ubatch.n_tokens;

// Check for early exit request after processing ubatch
if (training_should_stop.load(std::memory_order_acquire)) {
break;
}
} while (mctx->next());
}
}
Expand All @@ -2275,6 +2289,9 @@ void llama_context::opt_epoch(
ggml_opt_epoch_callback callback_train,
ggml_opt_epoch_callback callback_eval,
int64_t resume_from_batch) {
// Reset stop flag at the start of each epoch to ensure clean state
training_should_stop.store(false, std::memory_order_release);

const uint32_t n_ctx = this->n_ctx();
const uint32_t n_batch = std::min(cparams.n_batch, n_ctx);
const uint32_t n_ubatch = std::min(cparams.n_ubatch, n_batch);
Expand All @@ -2295,6 +2312,11 @@ void llama_context::opt_epoch(
int64_t t_loop_start = ggml_time_us();
int64_t ndata_in_loop = idata_split*ubatch_per_ctx;
for (; idata < idata_split; ++idata) {
// Check for early exit request before processing batch
if (training_should_stop.load(std::memory_order_acquire)) {
break;
}

constexpr bool train = true;
const int64_t idata_in_loop = idata*ubatch_per_ctx;

Expand All @@ -2306,11 +2328,21 @@ void llama_context::opt_epoch(
}
opt_epoch_iter(dataset, result_train, tokens, labels_sparse, masks_sparse, batch,
callback_train, train, idata_in_loop, ndata_in_loop, t_loop_start);

// Check again after iteration in case it was set during processing
if (training_should_stop.load(std::memory_order_acquire)) {
break;
}
}

t_loop_start = ggml_time_us();
ndata_in_loop = (ndata - idata_split)*ubatch_per_ctx;
for (; idata < ndata; ++idata) {
// Check for early exit request before processing batch
if (training_should_stop.load(std::memory_order_acquire)) {
break;
}

constexpr bool train = false;
const int64_t idata_in_loop = (idata - idata_split)*ubatch_per_ctx;

Expand All @@ -2321,12 +2353,20 @@ void llama_context::opt_epoch(
}
opt_epoch_iter(dataset, result_eval, tokens, labels_sparse, masks_sparse, batch,
callback_eval, train, idata_in_loop, ndata_in_loop, t_loop_start);

// Check again after iteration in case it was set during processing
if (training_should_stop.load(std::memory_order_acquire)) {
break;
}
}

llama_batch_free(batch);
}

int64_t llama_context::opt_get_iter() {
if (!opt_ctx) {
return 0; // Return 0 if optimizer not initialized
}
return ggml_opt_get_iter(opt_ctx);
}

Expand All @@ -2344,6 +2384,24 @@ bool llama_context::opt_load_state(const char* filename) {
return ggml_opt_load_state(opt_ctx, filename);
}

void llama_context::opt_cleanup() {
if (opt_ctx) {
ggml_opt_free(opt_ctx);
opt_ctx = nullptr;
should_load_optimizer_tensors = false;
optimizer_tensors_loaded = false;
pending_optimizer_checkpoint_path.clear();
}
}

void llama_context::opt_request_stop() {
training_should_stop.store(true, std::memory_order_release);
}

void llama_context::opt_reset_stop() {
training_should_stop.store(false, std::memory_order_release);
}

//
// interface implementation
//
Expand Down Expand Up @@ -2903,3 +2961,23 @@ void llama_opt_epoch(
int64_t llama_opt_get_iter(struct llama_context * ctx) {
return ctx->opt_get_iter();
}

bool llama_opt_save_state(struct llama_context * ctx, const char* filename) {
return ctx->opt_save_state(filename);
}

bool llama_opt_load_state(struct llama_context * ctx, const char* filename) {
return ctx->opt_load_state(filename);
}

void llama_opt_cleanup(struct llama_context * ctx) {
ctx->opt_cleanup();
}

void llama_opt_request_stop(struct llama_context * ctx) {
ctx->opt_request_stop();
}

void llama_opt_reset_stop(struct llama_context * ctx) {
ctx->opt_reset_stop();
}
13 changes: 13 additions & 0 deletions src/llama-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "ggml-cpp.h"
#include "ggml-opt.h"

#include <atomic>
#include <map>
#include <vector>

Expand Down Expand Up @@ -166,6 +167,15 @@ struct llama_context {
// Optimizer state persistence
bool opt_save_state(const char* filename);
bool opt_load_state(const char* filename);

// Clean up optimizer context to free memory and allow reinitialization
void opt_cleanup();

// Request early exit from training epoch (thread-safe)
void opt_request_stop();

// Reset the stop flag to allow training to continue
void opt_reset_stop();

void opt_epoch_iter(
ggml_opt_dataset_t dataset,
Expand Down Expand Up @@ -277,6 +287,9 @@ struct llama_context {
bool should_load_optimizer_tensors = false;
bool optimizer_tensors_loaded = false;
ggml_opt_loss_type opt_loss_type = GGML_OPT_LOSS_TYPE_CROSS_ENTROPY;

// early exit flag for training epochs (thread-safe)
std::atomic<bool> training_should_stop{false};

ggml_threadpool_t threadpool = nullptr;
ggml_threadpool_t threadpool_batch = nullptr;
Expand Down
Loading