Skip to content

Commit ecddeae

Browse files
authored
[Feat] Important updates to the experimental unified trainer (#398)
* feat: support lr scheduling for Tinkter backend and initiateon-policy self-distillation * feat: add opsd tinker backend & utils * fix: rename and add math opsd script * deprecate the 'per-step' stepwise mode with warning * feat: support pre-computation of advantages * refactor osdp: remove backend and add decorator * some cleanups * fix lr schedule and add docs
1 parent 6dba5c1 commit ecddeae

File tree

21 files changed

+1242
-173
lines changed

21 files changed

+1242
-173
lines changed
Lines changed: 386 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,386 @@
1+
# Backend Protocol
2+
3+
> **Module**: `rllm.experimental.protocol`
4+
5+
The `BackendProtocol` is the abstract interface that decouples the
6+
[Unified Trainer](unified-trainer.md) from any specific training infrastructure.
7+
By implementing this protocol, you can plug in any model-serving, optimization, and
8+
checkpointing system while reusing the trainer's episode generation, data
9+
transformation, rejection sampling, and logging machinery.
10+
11+
---
12+
13+
## Class Signature
14+
15+
```python
16+
class BackendProtocol(ABC, Generic[TDataset, TBatch]):
17+
name: str = "base_backend"
18+
requires_loop: bool = False
19+
20+
def __init__(self, config: DictConfig, **kwargs): ...
21+
```
22+
23+
The two type parameters let you declare your backend-specific types:
24+
25+
- `TDataset` -- the iterable type returned by `get_dataloader` (e.g. `torch.utils.data.DataLoader`)
26+
- `TBatch` -- the batch type consumed by your pipeline methods (e.g. `list[tinker.Datum]`)
27+
28+
---
29+
30+
## What a Backend Provides
31+
32+
A backend implementation is responsible for four categories of functionality:
33+
34+
```
35+
BackendProtocol
36+
|
37+
|-- Setup & teardown
38+
| init_rollout_engine() -- create the RolloutEngine for inference
39+
| validate_config() -- check backend-specific config
40+
| get_dataloader() -- wrap a Dataset into an iterable
41+
| shutdown() -- release resources
42+
|
43+
|-- Pipeline methods (called per batch, in order)
44+
| generate_episodes() -- stage 1: run workflows
45+
| transform_to_backend_batch()-- stage 4: convert to native format
46+
| process_backend_batch() -- stage 5: forward/backward pass
47+
| compute_advantages() -- stage 6: advantage computation
48+
| update_policy() -- stage 7: optimizer step
49+
|
50+
|-- Lifecycle hooks (optional overrides)
51+
| on_train_start / on_train_end
52+
| on_epoch_start / on_epoch_end
53+
| on_batch_start / on_batch_end
54+
| on_validation_start / on_validation_end
55+
```
56+
57+
---
58+
59+
## Setup Methods
60+
61+
### `init_rollout_engine(**kwargs) -> RolloutEngine`
62+
63+
Called once during trainer initialization. The backend must create and return a
64+
`RolloutEngine` that the workflow engine will use for model inference. The trainer
65+
passes the parsed config objects as keyword arguments:
66+
67+
```python
68+
def init_rollout_engine(self, **kwargs) -> RolloutEngine:
69+
cf_config = kwargs.get("cf_config") # CompactFilteringConfig
70+
transform_config = kwargs.get("transform_config") # TransformConfig
71+
rs_config = kwargs.get("rs_config") # RejectionSamplingConfig
72+
algorithm_config = kwargs.get("algorithm_config") # AlgorithmConfig
73+
# ... create and return your engine
74+
```
75+
76+
### `validate_config() -> None`
77+
78+
Called during trainer initialization to validate backend-specific configuration.
79+
Raise or warn on invalid settings.
80+
81+
### `get_dataloader(dataset, trainer_state) -> TDataset`
82+
83+
Called at the start of each epoch (training) and at each validation round. Use
84+
`trainer_state.is_training` to distinguish between training and validation and
85+
return the appropriate dataloader.
86+
87+
### `shutdown() -> None`
88+
89+
Called when the trainer is torn down. Release GPU memory, close connections, etc.
90+
91+
---
92+
93+
## Pipeline Methods
94+
95+
These are called by the trainer in a fixed order during each training batch.
96+
The `TrainerState` object is the shared mutable context throughout a batch.
97+
98+
### Stage 1: `generate_episodes(batch, agent_workflow_engine, is_validation) -> list[Episode]`
99+
100+
Produce episodes by running workflows on the input batch. A typical implementation:
101+
102+
1. Prepares the batch (e.g. repeat each task `group_size` times for GRPO)
103+
2. Sets the current model on the rollout engine
104+
3. Delegates to `agent_workflow_engine.execute_tasks(...)`
105+
106+
### Stage 4: `transform_to_backend_batch(trainer_state) -> TBatch`
107+
108+
Convert the framework's `TrajectoryGroup` objects into your backend-native format.
109+
This is a sync method since it is typically pure data transformation.
110+
111+
Some backends defer transformation to `process_backend_batch` and return a
112+
placeholder here.
113+
114+
### Stage 5: `process_backend_batch(trainer_state) -> None`
115+
116+
The main computational stage. Common operations:
117+
118+
- Run a forward pass to compute training logprobs
119+
- Run a backward pass to compute gradients
120+
- Store results in `trainer_state.backend_batch` and `trainer_state.extra_info`
121+
122+
This method updates `trainer_state` in place (no return value).
123+
124+
### Stage 6: `compute_advantages(trainer_state, algorithm_config) -> None`
125+
126+
Compute per-step advantages and store them on the `Step` objects within each
127+
trajectory. The base class provides a default implementation using rLLM-native
128+
advantage estimators (GRPO, REINFORCE):
129+
130+
```python
131+
# Default implementation in BackendProtocol
132+
async def compute_advantages(self, trainer_state, algorithm_config, **kwargs):
133+
adv_metrics = collect_reward_and_advantage_from_trajectory_groups(
134+
trainer_state.trajectory_groups, algorithm_config
135+
)
136+
trainer_state.metrics.update(adv_metrics)
137+
```
138+
139+
**Pre-computed advantages:** If advantages are already set on the `Step` objects
140+
(e.g. computed during episode generation via a workflow decorator), the default
141+
implementation detects this and skips re-computation.
142+
143+
### Stage 7: `update_policy(trainer_state) -> None`
144+
145+
Run the optimizer step to update model weights. Some backends fuse this into
146+
`process_backend_batch` and make `update_policy` a no-op (see
147+
[Flexible Stage Organization](#flexible-stage-organization) below).
148+
149+
---
150+
151+
## Lifecycle Hooks
152+
153+
All hooks are `async def` methods with default no-op implementations. Override only
154+
what you need.
155+
156+
### Training hooks
157+
158+
```
159+
on_train_start(state) -- called once before the first epoch
160+
|
161+
| on_epoch_start(state) -- called at the start of each epoch
162+
| |
163+
| | on_batch_start(state) -- called before each batch pipeline
164+
| | [... 8-stage pipeline ...]
165+
| | on_batch_end(state) -- called after pipeline, before logging
166+
| |
167+
| | (repeat for each batch)
168+
| |
169+
| on_epoch_end(state) -- called at the end of each epoch
170+
|
171+
| (repeat for each epoch)
172+
|
173+
on_train_end(state) -- called once after all epochs
174+
```
175+
176+
### Validation hooks
177+
178+
```
179+
on_validation_start(state) -> bool -- return False to skip validation
180+
[... validation loop ...]
181+
on_validation_end(state)
182+
```
183+
184+
### Common uses for hooks
185+
186+
| Hook | Common use |
187+
|------|------------|
188+
| `on_train_start` | Initialize training client, load checkpoint, set initial `global_step` |
189+
| `on_batch_end` | Save checkpoint, update sampling client, compute derived metrics, print metrics table |
190+
| `on_train_end` | Save final checkpoint |
191+
| `on_validation_start` | Toggle model to eval mode; return `False` to skip |
192+
| `on_validation_end` | Toggle model back to train mode |
193+
194+
**Important:** `on_batch_end` runs **after** the pipeline but **before**
195+
`logger.log(...)`. This makes it the right place to inject derived metrics
196+
(e.g. KL divergence, learning rate) into `trainer_state.metrics`.
197+
198+
---
199+
200+
## Flexible Stage Organization
201+
202+
The protocol defines stages 4-7 as separate methods, but backends are free to
203+
redistribute work across them. The trainer always calls them in the same order --
204+
it is the backend's responsibility to decide what each stage does internally.
205+
206+
### Example: TinkerBackend's fused mode
207+
208+
The `TinkerBackend` demonstrates this flexibility. When
209+
`fuse_forward_backward_and_optim_step` is enabled:
210+
211+
```
212+
Default (non-fused) Fused
213+
------------------- -----
214+
transform_to_backend returns [] placeholder same
215+
process_backend_batch forward + backward forward + backward + optim step
216+
compute_advantages stores algorithm_config same
217+
update_policy optimizer step no-op (already done)
218+
```
219+
220+
Both modes produce the same end result, but the fused path reduces round-trips
221+
to the training server. The trainer does not need to know which path is active --
222+
it simply calls all four methods in order.
223+
224+
### Example: Pre-computed advantages (OPSD)
225+
226+
For On-Policy Self-Distillation, advantages are computed during episode generation
227+
(stage 1) via a workflow decorator. By the time `compute_advantages` (stage 6) runs,
228+
every `Step` already has its `.advantage` field set. The default implementation in
229+
`BackendProtocol` detects this and skips re-computation, collecting only metrics.
230+
231+
This means OPSD can work with the standard `TinkerBackend` -- no custom backend
232+
subclass is needed.
233+
234+
---
235+
236+
## Implementing a Custom Backend
237+
238+
### Step 1: Subclass `BackendProtocol`
239+
240+
```python
241+
from rllm.experimental.protocol import BackendProtocol
242+
243+
class MyBackend(BackendProtocol[MyDataLoader, MyBatch]):
244+
name = "my_backend"
245+
246+
def __init__(self, config, **kwargs):
247+
super().__init__(config, **kwargs)
248+
# ... initialize your infrastructure
249+
```
250+
251+
### Step 2: Implement required methods
252+
253+
At minimum, you must implement these abstract methods:
254+
255+
```python
256+
# Setup
257+
def init_rollout_engine(self, **kwargs) -> RolloutEngine: ...
258+
def validate_config(self) -> None: ...
259+
def get_dataloader(self, dataset, trainer_state) -> MyDataLoader: ...
260+
def shutdown(self) -> None: ...
261+
262+
# Pipeline
263+
async def generate_episodes(self, batch, agent_workflow_engine, is_validation=False, **kwargs) -> list[Episode]: ...
264+
def transform_to_backend_batch(self, trainer_state, **kwargs) -> MyBatch: ...
265+
async def process_backend_batch(self, trainer_state, **kwargs) -> None: ...
266+
async def compute_advantages(self, trainer_state, algorithm_config, **kwargs) -> None: ...
267+
async def update_policy(self, trainer_state, **kwargs) -> None: ...
268+
```
269+
270+
### Step 3: Override lifecycle hooks as needed
271+
272+
```python
273+
async def on_train_start(self, trainer_state):
274+
# Load checkpoint, initialize model
275+
...
276+
277+
async def on_batch_end(self, trainer_state):
278+
# Save checkpoint, update metrics
279+
...
280+
```
281+
282+
### Step 4: Wire it up
283+
284+
```python
285+
trainer = UnifiedTrainer(
286+
backend_cls=MyBackend,
287+
config=config,
288+
workflow_class=MyWorkflow,
289+
train_dataset=train_ds,
290+
val_dataset=val_ds,
291+
)
292+
trainer.fit()
293+
```
294+
295+
---
296+
297+
## Reference: TinkerBackend
298+
299+
The `TinkerBackend` (`rllm.experimental.tinker.tinker_backend`) is the primary
300+
production backend. It serves as a comprehensive reference implementation. Below
301+
is a summary of how it implements each part of the protocol.
302+
303+
### Setup
304+
305+
| Method | Implementation |
306+
|---|---|
307+
| `init_rollout_engine` | Creates a `TinkerPolicyTrainer` and a `TinkerEngine` (rollout engine backed by a Tinker sampling server) |
308+
| `validate_config` | Warns if sampling temperature/top_p deviate from 1.0 |
309+
| `get_dataloader` | Returns a `torch.utils.data.DataLoader` with backend-specific batch sizes |
310+
| `shutdown` | Delegates to parent (no-op currently) |
311+
312+
### Pipeline
313+
314+
| Stage | Method | Implementation |
315+
|---|---|---|
316+
| 1 | `generate_episodes` | Builds an interleaved batch (`N` repeats per task for GRPO grouping), sets the sampling client on the rollout engine, and calls `agent_workflow_engine.execute_tasks(...)` |
317+
| 4 | `transform_to_backend_batch` | Returns an empty list (placeholder). The actual datum construction is deferred to stage 5 |
318+
| 5 | `process_backend_batch` | Converts trajectory groups to Tinker `Datum` objects, runs forward-backward, stores training logprobs. Optionally fuses the optimizer step |
319+
| 6 | `compute_advantages` | Stores the `AlgorithmConfig` for use by stage 5's datum construction (advantage computation is embedded in the forward-backward call) |
320+
| 7 | `update_policy` | Runs the optimizer step (or no-op if fused into stage 5) |
321+
322+
### Lifecycle hooks
323+
324+
| Hook | Implementation |
325+
|---|---|
326+
| `on_train_start` | Initializes the training client, loads checkpoint, sets `trainer_state.global_step` from the checkpoint's batch index |
327+
| `on_train_end` | Saves final checkpoint if not already saved |
328+
| `on_batch_end` | Saves sampler checkpoint, updates `self.sampling_client`, injects `optim/lr` and KL/entropy metrics into `trainer_state.metrics`, prints metrics table |
329+
| `on_epoch_start/end` | Logging only |
330+
| `on_validation_start/end` | Toggles `trainer_state.is_training` flag |
331+
332+
### Key patterns to note
333+
334+
1. **Deferred transformation.** `transform_to_backend_batch` returns a placeholder;
335+
the real work happens in `process_backend_batch`. This is valid because the trainer
336+
only checks `trainer_state.has_backend_batch` *after* `process_backend_batch` runs.
337+
338+
2. **Checkpoint-driven sampling client.** The `sampling_client` (used by workflows
339+
for inference) is updated in `on_batch_end` after each checkpoint save. This
340+
ensures workflows always sample from the latest policy.
341+
342+
3. **Metrics injection in `on_batch_end`.** Since `on_batch_end` runs after the
343+
pipeline but before `logger.log(...)`, it is the natural place to compute derived
344+
metrics (KL divergence, entropy, learning rate) and add them to
345+
`trainer_state.metrics`.
346+
347+
---
348+
349+
## Data Flow Summary
350+
351+
```
352+
Dataset
353+
|
354+
v
355+
get_dataloader() --> batch
356+
|
357+
v
358+
generate_episodes(batch) --> list[Episode]
359+
|
360+
v
361+
[framework] transform to TrajectoryGroups
362+
|
363+
v
364+
[framework] rejection sampling & filtering
365+
|
366+
v
367+
transform_to_backend_batch() --> TBatch (stored in trainer_state.backend_batch)
368+
|
369+
v
370+
process_backend_batch() -- forward/backward, logprobs
371+
|
372+
v
373+
compute_advantages() -- advantage computation
374+
|
375+
v
376+
update_policy() -- optimizer step
377+
|
378+
v
379+
[framework] visualization, metrics collection
380+
|
381+
v
382+
on_batch_end() -- checkpoint, derived metrics
383+
|
384+
v
385+
logger.log(metrics) -- wandb / tracking
386+
```

0 commit comments

Comments
 (0)