Skip to content

Commit db27625

Browse files
dev-nidgianni-cor
authored andcommitted
Fixed ibatch Mismatch in llama_opt_epoch Resume
1 parent 10fd931 commit db27625

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

src/llama-context.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2318,7 +2318,13 @@ void llama_context::opt_epoch(
23182318
}
23192319

23202320
constexpr bool train = true;
2321-
const int64_t idata_in_loop = idata*ubatch_per_ctx;
2321+
// When resuming, adjust idata_in_loop to account for skipped batches.
2322+
// The callback expects ibatch to be relative to the start of the epoch (batch 0),
2323+
// not relative to the resume point. So if we resume from batch 2, the first
2324+
// callback should receive ibatch for batch 2, not batch 3.
2325+
// Since idata starts at resume_from_batch+1 when resuming, we subtract 1 to get
2326+
// the correct batch number. When not resuming, idata starts at 0, so we use idata directly.
2327+
const int64_t idata_in_loop = (resume_from_batch > 0) ? (idata - 1) * ubatch_per_ctx : idata * ubatch_per_ctx;
23222328

23232329
if (opt_loss_type == GGML_OPT_LOSS_TYPE_CROSS_ENTROPY_MASKED && ggml_opt_dataset_masks(dataset)) {
23242330
ggml_opt_dataset_get_batch_host_with_masks(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), masks_sparse.data(), idata);

0 commit comments

Comments
 (0)