@@ -2138,6 +2138,10 @@ void llama_context::opt_epoch_iter(
21382138 memory->clear (true );
21392139
21402140 for (uint32_t pos_ctx = 0 ; pos_ctx < n_ctx; pos_ctx += n_batch) {
2141+ // Check for early exit request before processing context batch
2142+ if (training_should_stop.load (std::memory_order_acquire)) {
2143+ return ;
2144+ }
21412145 batch.n_tokens = n_batch;
21422146 for (uint32_t pos_batch = 0 ; pos_batch < n_batch; ++pos_batch) {
21432147 batch.token [pos_batch] = tokens[pos_ctx + pos_batch];
@@ -2174,6 +2178,11 @@ void llama_context::opt_epoch_iter(
21742178
21752179 uint32_t pos_batch = 0 ;
21762180 do {
2181+ // Check for early exit request before processing ubatch
2182+ if (training_should_stop.load (std::memory_order_acquire)) {
2183+ break ;
2184+ }
2185+
21772186 const auto & ubatch = mctx->get_ubatch ();
21782187
21792188 n_outputs = ubatch.n_tokens ;
@@ -2263,6 +2272,11 @@ void llama_context::opt_epoch_iter(
22632272 ggml_free (ctx_compute_opt);
22642273
22652274 pos_batch += ubatch.n_tokens ;
2275+
2276+ // Check for early exit request after processing ubatch
2277+ if (training_should_stop.load (std::memory_order_acquire)) {
2278+ break ;
2279+ }
22662280 } while (mctx->next ());
22672281 }
22682282}
@@ -2275,6 +2289,9 @@ void llama_context::opt_epoch(
22752289 ggml_opt_epoch_callback callback_train,
22762290 ggml_opt_epoch_callback callback_eval,
22772291 int64_t resume_from_batch) {
2292+ // Reset stop flag at the start of each epoch to ensure clean state
2293+ training_should_stop.store (false , std::memory_order_release);
2294+
22782295 const uint32_t n_ctx = this ->n_ctx ();
22792296 const uint32_t n_batch = std::min (cparams.n_batch , n_ctx);
22802297 const uint32_t n_ubatch = std::min (cparams.n_ubatch , n_batch);
@@ -2295,6 +2312,11 @@ void llama_context::opt_epoch(
22952312 int64_t t_loop_start = ggml_time_us ();
22962313 int64_t ndata_in_loop = idata_split*ubatch_per_ctx;
22972314 for (; idata < idata_split; ++idata) {
2315+ // Check for early exit request before processing batch
2316+ if (training_should_stop.load (std::memory_order_acquire)) {
2317+ break ;
2318+ }
2319+
22982320 constexpr bool train = true ;
22992321 const int64_t idata_in_loop = idata*ubatch_per_ctx;
23002322
@@ -2306,11 +2328,21 @@ void llama_context::opt_epoch(
23062328 }
23072329 opt_epoch_iter (dataset, result_train, tokens, labels_sparse, masks_sparse, batch,
23082330 callback_train, train, idata_in_loop, ndata_in_loop, t_loop_start);
2331+
2332+ // Check again after iteration in case it was set during processing
2333+ if (training_should_stop.load (std::memory_order_acquire)) {
2334+ break ;
2335+ }
23092336 }
23102337
23112338 t_loop_start = ggml_time_us ();
23122339 ndata_in_loop = (ndata - idata_split)*ubatch_per_ctx;
23132340 for (; idata < ndata; ++idata) {
2341+ // Check for early exit request before processing batch
2342+ if (training_should_stop.load (std::memory_order_acquire)) {
2343+ break ;
2344+ }
2345+
23142346 constexpr bool train = false ;
23152347 const int64_t idata_in_loop = (idata - idata_split)*ubatch_per_ctx;
23162348
@@ -2321,12 +2353,20 @@ void llama_context::opt_epoch(
23212353 }
23222354 opt_epoch_iter (dataset, result_eval, tokens, labels_sparse, masks_sparse, batch,
23232355 callback_eval, train, idata_in_loop, ndata_in_loop, t_loop_start);
2356+
2357+ // Check again after iteration in case it was set during processing
2358+ if (training_should_stop.load (std::memory_order_acquire)) {
2359+ break ;
2360+ }
23242361 }
23252362
23262363 llama_batch_free (batch);
23272364}
23282365
23292366int64_t llama_context::opt_get_iter () {
2367+ if (!opt_ctx) {
2368+ return 0 ; // Return 0 if optimizer not initialized
2369+ }
23302370 return ggml_opt_get_iter (opt_ctx);
23312371}
23322372
@@ -2344,6 +2384,24 @@ bool llama_context::opt_load_state(const char* filename) {
23442384 return ggml_opt_load_state (opt_ctx, filename);
23452385}
23462386
2387+ void llama_context::opt_cleanup () {
2388+ if (opt_ctx) {
2389+ ggml_opt_free (opt_ctx);
2390+ opt_ctx = nullptr ;
2391+ should_load_optimizer_tensors = false ;
2392+ optimizer_tensors_loaded = false ;
2393+ pending_optimizer_checkpoint_path.clear ();
2394+ }
2395+ }
2396+
2397+ void llama_context::opt_request_stop () {
2398+ training_should_stop.store (true , std::memory_order_release);
2399+ }
2400+
2401+ void llama_context::opt_reset_stop () {
2402+ training_should_stop.store (false , std::memory_order_release);
2403+ }
2404+
23472405//
23482406// interface implementation
23492407//
@@ -2903,3 +2961,23 @@ void llama_opt_epoch(
29032961int64_t llama_opt_get_iter (struct llama_context * ctx) {
29042962 return ctx->opt_get_iter ();
29052963}
2964+
2965+ bool llama_opt_save_state (struct llama_context * ctx, const char * filename) {
2966+ return ctx->opt_save_state (filename);
2967+ }
2968+
2969+ bool llama_opt_load_state (struct llama_context * ctx, const char * filename) {
2970+ return ctx->opt_load_state (filename);
2971+ }
2972+
2973+ void llama_opt_cleanup (struct llama_context * ctx) {
2974+ ctx->opt_cleanup ();
2975+ }
2976+
2977+ void llama_opt_request_stop (struct llama_context * ctx) {
2978+ ctx->opt_request_stop ();
2979+ }
2980+
2981+ void llama_opt_reset_stop (struct llama_context * ctx) {
2982+ ctx->opt_reset_stop ();
2983+ }
0 commit comments