Skip to content

Commit 82a43ff

Browse files
authored
[refactor] step zero definition (PrimeIntellect-ai#528)
* start step at 0 which makes async level 0 sync * fix: seed shouldnt have -1 * train step 0 start for trainer * keep async level around * DEBUG * fix: +1 * +1 * Revert "DEBUG" This reverts commit 6ba48b4b3f3958a741b8ad13b5dd2279e0ea31aa. * docs
1 parent 2ffcb6f commit 82a43ff

File tree

4 files changed

+24
-13
lines changed

4 files changed

+24
-13
lines changed

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,17 @@ To run fast tests, use the inverse of the `slow` marker:
220220
uv run pytest -v -m "not slow"
221221
```
222222

223+
### Checkpoint, rollout, step numbering and async level
224+
At each step `n`, all artifacts (e.g., checkpoint, rollout, gradient) are tagged with the same step number.
225+
- Step 0:
226+
- Uses checkpoint 0 on rollout 0 to compute gradient 0.
227+
- Then computes checkpoint 1 as: `ckpt 1 = ckpt 0 - grad 0`
228+
229+
In general, the model used for generating rollouts at step `n` is from `ckpt[n - async_level]`.
230+
231+
- When async_level = 0, the rollout and gradient are based on the same model version.
232+
This is equivalent to synchronous on-policy training.
233+
223234
## Citation
224235

225236
*TBD*

src/zeroband/training/ckpt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
@dataclass
1818
class TrainingProgress:
19-
step: int = 1
19+
step: int = 0
2020
total_tokens: int = 0
2121
total_samples: int = 0
2222

src/zeroband/training/orchestrator/orchestrator.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -106,19 +106,19 @@ async def orchestrate(config: OrchestratorConfig, setup_queue: Queue | None = No
106106
total_tokens, total_samples = 0, 0
107107
ckpt_step = 0
108108
last_eval_step = -1
109-
epoch = 0
109+
epoch = -1
110110

111-
for step in range(1, int(max_steps) + 1):
111+
for step in range(int(max_steps)):
112112
# Check if we need to start a new epoch
113-
epoch_step = (step - 1) % steps_per_epoch
113+
epoch_step = step % steps_per_epoch
114114
if epoch_step == 0:
115115
epoch += 1
116116
logger.info(f"Starting epoch {epoch}")
117117
# Reshuffle dataset at the beginning of each epoch
118-
dataset = dataset.shuffle(seed=(config.seed or 0) + epoch - 1)
118+
dataset = dataset.shuffle(seed=(config.seed or 0) + epoch)
119119

120120
logger.debug(
121-
f"Orchestrator step {step} (epoch: {epoch}, epoch_step: {epoch_step + 1}/{steps_per_epoch}, checkpoint step: {ckpt_step})"
121+
f"Orchestrator step {step} (epoch: {epoch}, epoch_step: {epoch_step}/{steps_per_epoch}, checkpoint step: {ckpt_step})"
122122
)
123123
step_start_time = time.time()
124124

@@ -131,15 +131,14 @@ async def orchestrate(config: OrchestratorConfig, setup_queue: Queue | None = No
131131
batch_messages = [[{"role": "user", "content": prompt}] for prompt in prompts]
132132

133133
# Optionally, wait for the next checkpoint to be available
134-
async_level = step - 1 - ckpt_step # How many steps training ahead
135134
wait_for_weight_ckpt_time, reload_weights_time = 0, 0
136-
if async_level > config.async_level:
137-
ckpt_step = step - 1 - config.async_level
135+
if step - ckpt_step > config.async_level:
138136
logger.debug(
139-
f"Hit async barrier because step {step} is {async_level} (>{config.async_level}) steps ahead of checkpoint step {ckpt_step}."
137+
f"Hit async barrier because step {step} is {step - ckpt_step} (>{config.async_level}) steps ahead of checkpoint step {ckpt_step}."
140138
)
141139

142140
# Wait for the checkpoint to be available
141+
ckpt_step = step - config.async_level
143142
logger.debug(f"Waiting for weight checkpoint for step {ckpt_step}")
144143
wait_for_weight_ckpt_start_time = time.time()
145144
wait_for_weight_checkpoint(config.weights.path, ckpt_step)

src/zeroband/training/train.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ def train(config: TrainingConfig):
101101
logger.info(f"Initializing shardcast from {envs.SHARDCAST_OUTPUT_DIR}")
102102
shardcast.initialize(
103103
envs.SHARDCAST_OUTPUT_DIR,
104+
# +1 to ensure to not delete current checkpoint when async_level=0
104105
max_distribution_folders=config.async_level + 1,
105106
)
106107

@@ -171,7 +172,7 @@ def train(config: TrainingConfig):
171172
if config.recompute_logprobs:
172173
logger.debug("Recomputing logprobs")
173174
compute_logprobs_start_time = time.time()
174-
og_infer_step = progress.step - 1 - config.async_level # -1 because we haven't updated the model yet
175+
og_infer_step = progress.step - config.async_level
175176
infer_step = max(og_infer_step, 0)
176177

177178
# Wake up the logprob model from CPU
@@ -268,7 +269,7 @@ def train(config: TrainingConfig):
268269
optimizer.zero_grad()
269270

270271
# Save the weight checkpoint
271-
step_path = Path(config.weights.path) / f"step_{progress.step}"
272+
step_path = Path(config.weights.path) / f"step_{progress.step + 1}"
272273
save_weights_start_time = time.time()
273274
model_path = save_weight_checkpoint(model, tokenizer, step_path, async_save=config.weights.save_async)
274275
active_weight_checkpoint_paths.append(step_path)
@@ -310,7 +311,7 @@ def train(config: TrainingConfig):
310311
if config.recompute_logprobs:
311312
logger.debug("Offloading updated model to CPU")
312313
reshard_module(logprob_model)
313-
tensor_offloaded_repository[progress.step] = copy_model_to_cpu(model)
314+
tensor_offloaded_repository[progress.step + 1] = copy_model_to_cpu(model)
314315

315316
# Compute step metrics
316317
num_local_tokens = micro_batch_size * seq_len * num_micro_batches

0 commit comments

Comments
 (0)