@@ -2138,6 +2138,10 @@ void llama_context::opt_epoch_iter(
21382138 memory->clear (true );
21392139
21402140 for (uint32_t pos_ctx = 0 ; pos_ctx < n_ctx; pos_ctx += n_batch) {
2141+ // Check for early exit request before processing context batch
2142+ if (training_should_stop.load (std::memory_order_acquire)) {
2143+ return ;
2144+ }
21412145 batch.n_tokens = n_batch;
21422146 for (uint32_t pos_batch = 0 ; pos_batch < n_batch; ++pos_batch) {
21432147 batch.token [pos_batch] = tokens[pos_ctx + pos_batch];
@@ -2174,6 +2178,11 @@ void llama_context::opt_epoch_iter(
21742178
21752179 uint32_t pos_batch = 0 ;
21762180 do {
2181+ // Check for early exit request before processing ubatch
2182+ if (training_should_stop.load (std::memory_order_acquire)) {
2183+ break ;
2184+ }
2185+
21772186 const auto & ubatch = mctx->get_ubatch ();
21782187
21792188 n_outputs = ubatch.n_tokens ;
@@ -2263,6 +2272,11 @@ void llama_context::opt_epoch_iter(
22632272 ggml_free (ctx_compute_opt);
22642273
22652274 pos_batch += ubatch.n_tokens ;
2275+
2276+ // Check for early exit request after processing ubatch
2277+ if (training_should_stop.load (std::memory_order_acquire)) {
2278+ break ;
2279+ }
22662280 } while (mctx->next ());
22672281 }
22682282}
@@ -2275,6 +2289,9 @@ void llama_context::opt_epoch(
22752289 ggml_opt_epoch_callback callback_train,
22762290 ggml_opt_epoch_callback callback_eval,
22772291 int64_t resume_from_batch) {
2292+ // Reset stop flag at the start of each epoch to ensure clean state
2293+ training_should_stop.store (false , std::memory_order_release);
2294+
22782295 const uint32_t n_ctx = this ->n_ctx ();
22792296 const uint32_t n_batch = std::min (cparams.n_batch , n_ctx);
22802297 const uint32_t n_ubatch = std::min (cparams.n_ubatch , n_batch);
@@ -2295,6 +2312,11 @@ void llama_context::opt_epoch(
22952312 int64_t t_loop_start = ggml_time_us ();
22962313 int64_t ndata_in_loop = idata_split*ubatch_per_ctx;
22972314 for (; idata < idata_split; ++idata) {
2315+ // Check for early exit request before processing batch
2316+ if (training_should_stop.load (std::memory_order_acquire)) {
2317+ break ;
2318+ }
2319+
22982320 constexpr bool train = true ;
22992321 const int64_t idata_in_loop = idata*ubatch_per_ctx;
23002322
@@ -2306,11 +2328,21 @@ void llama_context::opt_epoch(
23062328 }
23072329 opt_epoch_iter (dataset, result_train, tokens, labels_sparse, masks_sparse, batch,
23082330 callback_train, train, idata_in_loop, ndata_in_loop, t_loop_start);
2331+
2332+ // Check again after iteration in case it was set during processing
2333+ if (training_should_stop.load (std::memory_order_acquire)) {
2334+ break ;
2335+ }
23092336 }
23102337
23112338 t_loop_start = ggml_time_us ();
23122339 ndata_in_loop = (ndata - idata_split)*ubatch_per_ctx;
23132340 for (; idata < ndata; ++idata) {
2341+ // Check for early exit request before processing batch
2342+ if (training_should_stop.load (std::memory_order_acquire)) {
2343+ break ;
2344+ }
2345+
23142346 constexpr bool train = false ;
23152347 const int64_t idata_in_loop = (idata - idata_split)*ubatch_per_ctx;
23162348
@@ -2321,6 +2353,11 @@ void llama_context::opt_epoch(
23212353 }
23222354 opt_epoch_iter (dataset, result_eval, tokens, labels_sparse, masks_sparse, batch,
23232355 callback_eval, train, idata_in_loop, ndata_in_loop, t_loop_start);
2356+
2357+ // Check again after iteration in case it was set during processing
2358+ if (training_should_stop.load (std::memory_order_acquire)) {
2359+ break ;
2360+ }
23242361 }
23252362
23262363 llama_batch_free (batch);
@@ -2357,6 +2394,14 @@ void llama_context::opt_cleanup() {
23572394 }
23582395}
23592396
2397+ void llama_context::opt_request_stop () {
2398+ training_should_stop.store (true , std::memory_order_release);
2399+ }
2400+
2401+ void llama_context::opt_reset_stop () {
2402+ training_should_stop.store (false , std::memory_order_release);
2403+ }
2404+
23602405//
23612406// interface implementation
23622407//
@@ -2928,3 +2973,11 @@ bool llama_opt_load_state(struct llama_context * ctx, const char* filename) {
29282973void llama_opt_cleanup (struct llama_context * ctx) {
29292974 ctx->opt_cleanup ();
29302975}
2976+
2977+ void llama_opt_request_stop (struct llama_context * ctx) {
2978+ ctx->opt_request_stop ();
2979+ }
2980+
2981+ void llama_opt_reset_stop (struct llama_context * ctx) {
2982+ ctx->opt_reset_stop ();
2983+ }
0 commit comments