|
| 1 | +# Backend Protocol |
| 2 | + |
| 3 | +> **Module**: `rllm.experimental.protocol` |
| 4 | +
|
| 5 | +The `BackendProtocol` is the abstract interface that decouples the |
| 6 | +[Unified Trainer](unified-trainer.md) from any specific training infrastructure. |
| 7 | +By implementing this protocol, you can plug in any model-serving, optimization, and |
| 8 | +checkpointing system while reusing the trainer's episode generation, data |
| 9 | +transformation, rejection sampling, and logging machinery. |
| 10 | + |
| 11 | +--- |
| 12 | + |
| 13 | +## Class Signature |
| 14 | + |
| 15 | +```python |
| 16 | +class BackendProtocol(ABC, Generic[TDataset, TBatch]): |
| 17 | + name: str = "base_backend" |
| 18 | + requires_loop: bool = False |
| 19 | + |
| 20 | + def __init__(self, config: DictConfig, **kwargs): ... |
| 21 | +``` |
| 22 | + |
| 23 | +The two type parameters let you declare your backend-specific types: |
| 24 | + |
| 25 | +- `TDataset` -- the iterable type returned by `get_dataloader` (e.g. `torch.utils.data.DataLoader`) |
| 26 | +- `TBatch` -- the batch type consumed by your pipeline methods (e.g. `list[tinker.Datum]`) |
| 27 | + |
| 28 | +--- |
| 29 | + |
| 30 | +## What a Backend Provides |
| 31 | + |
| 32 | +A backend implementation is responsible for four categories of functionality: |
| 33 | + |
| 34 | +``` |
| 35 | +BackendProtocol |
| 36 | + | |
| 37 | + |-- Setup & teardown |
| 38 | + | init_rollout_engine() -- create the RolloutEngine for inference |
| 39 | + | validate_config() -- check backend-specific config |
| 40 | + | get_dataloader() -- wrap a Dataset into an iterable |
| 41 | + | shutdown() -- release resources |
| 42 | + | |
| 43 | + |-- Pipeline methods (called per batch, in order) |
| 44 | + | generate_episodes() -- stage 1: run workflows |
| 45 | + | transform_to_backend_batch()-- stage 4: convert to native format |
| 46 | + | process_backend_batch() -- stage 5: forward/backward pass |
| 47 | + | compute_advantages() -- stage 6: advantage computation |
| 48 | + | update_policy() -- stage 7: optimizer step |
| 49 | + | |
| 50 | + |-- Lifecycle hooks (optional overrides) |
| 51 | + | on_train_start / on_train_end |
| 52 | + | on_epoch_start / on_epoch_end |
| 53 | + | on_batch_start / on_batch_end |
| 54 | + | on_validation_start / on_validation_end |
| 55 | +``` |
| 56 | + |
| 57 | +--- |
| 58 | + |
| 59 | +## Setup Methods |
| 60 | + |
| 61 | +### `init_rollout_engine(**kwargs) -> RolloutEngine` |
| 62 | + |
| 63 | +Called once during trainer initialization. The backend must create and return a |
| 64 | +`RolloutEngine` that the workflow engine will use for model inference. The trainer |
| 65 | +passes the parsed config objects as keyword arguments: |
| 66 | + |
| 67 | +```python |
| 68 | +def init_rollout_engine(self, **kwargs) -> RolloutEngine: |
| 69 | + cf_config = kwargs.get("cf_config") # CompactFilteringConfig |
| 70 | + transform_config = kwargs.get("transform_config") # TransformConfig |
| 71 | + rs_config = kwargs.get("rs_config") # RejectionSamplingConfig |
| 72 | + algorithm_config = kwargs.get("algorithm_config") # AlgorithmConfig |
| 73 | + # ... create and return your engine |
| 74 | +``` |
| 75 | + |
| 76 | +### `validate_config() -> None` |
| 77 | + |
| 78 | +Called during trainer initialization to validate backend-specific configuration. |
| 79 | +Raise or warn on invalid settings. |
| 80 | + |
| 81 | +### `get_dataloader(dataset, trainer_state) -> TDataset` |
| 82 | + |
| 83 | +Called at the start of each epoch (training) and at each validation round. Use |
| 84 | +`trainer_state.is_training` to distinguish between training and validation and |
| 85 | +return the appropriate dataloader. |
| 86 | + |
| 87 | +### `shutdown() -> None` |
| 88 | + |
| 89 | +Called when the trainer is torn down. Release GPU memory, close connections, etc. |
| 90 | + |
| 91 | +--- |
| 92 | + |
| 93 | +## Pipeline Methods |
| 94 | + |
| 95 | +These are called by the trainer in a fixed order during each training batch. |
| 96 | +The `TrainerState` object is the shared mutable context throughout a batch. |
| 97 | + |
| 98 | +### Stage 1: `generate_episodes(batch, agent_workflow_engine, is_validation) -> list[Episode]` |
| 99 | + |
| 100 | +Produce episodes by running workflows on the input batch. A typical implementation: |
| 101 | + |
| 102 | +1. Prepares the batch (e.g. repeat each task `group_size` times for GRPO) |
| 103 | +2. Sets the current model on the rollout engine |
| 104 | +3. Delegates to `agent_workflow_engine.execute_tasks(...)` |
| 105 | + |
| 106 | +### Stage 4: `transform_to_backend_batch(trainer_state) -> TBatch` |
| 107 | + |
| 108 | +Convert the framework's `TrajectoryGroup` objects into your backend-native format. |
| 109 | +This is a sync method since it is typically pure data transformation. |
| 110 | + |
| 111 | +Some backends defer transformation to `process_backend_batch` and return a |
| 112 | +placeholder here. |
| 113 | + |
| 114 | +### Stage 5: `process_backend_batch(trainer_state) -> None` |
| 115 | + |
| 116 | +The main computational stage. Common operations: |
| 117 | + |
| 118 | +- Run a forward pass to compute training logprobs |
| 119 | +- Run a backward pass to compute gradients |
| 120 | +- Store results in `trainer_state.backend_batch` and `trainer_state.extra_info` |
| 121 | + |
| 122 | +This method updates `trainer_state` in place (no return value). |
| 123 | + |
| 124 | +### Stage 6: `compute_advantages(trainer_state, algorithm_config) -> None` |
| 125 | + |
| 126 | +Compute per-step advantages and store them on the `Step` objects within each |
| 127 | +trajectory. The base class provides a default implementation using rLLM-native |
| 128 | +advantage estimators (GRPO, REINFORCE): |
| 129 | + |
| 130 | +```python |
| 131 | +# Default implementation in BackendProtocol |
| 132 | +async def compute_advantages(self, trainer_state, algorithm_config, **kwargs): |
| 133 | + adv_metrics = collect_reward_and_advantage_from_trajectory_groups( |
| 134 | + trainer_state.trajectory_groups, algorithm_config |
| 135 | + ) |
| 136 | + trainer_state.metrics.update(adv_metrics) |
| 137 | +``` |
| 138 | + |
| 139 | +**Pre-computed advantages:** If advantages are already set on the `Step` objects |
| 140 | +(e.g. computed during episode generation via a workflow decorator), the default |
| 141 | +implementation detects this and skips re-computation. |
| 142 | + |
| 143 | +### Stage 7: `update_policy(trainer_state) -> None` |
| 144 | + |
| 145 | +Run the optimizer step to update model weights. Some backends fuse this into |
| 146 | +`process_backend_batch` and make `update_policy` a no-op (see |
| 147 | +[Flexible Stage Organization](#flexible-stage-organization) below). |
| 148 | + |
| 149 | +--- |
| 150 | + |
| 151 | +## Lifecycle Hooks |
| 152 | + |
| 153 | +All hooks are `async def` methods with default no-op implementations. Override only |
| 154 | +what you need. |
| 155 | + |
| 156 | +### Training hooks |
| 157 | + |
| 158 | +``` |
| 159 | +on_train_start(state) -- called once before the first epoch |
| 160 | + | |
| 161 | + | on_epoch_start(state) -- called at the start of each epoch |
| 162 | + | | |
| 163 | + | | on_batch_start(state) -- called before each batch pipeline |
| 164 | + | | [... 8-stage pipeline ...] |
| 165 | + | | on_batch_end(state) -- called after pipeline, before logging |
| 166 | + | | |
| 167 | + | | (repeat for each batch) |
| 168 | + | | |
| 169 | + | on_epoch_end(state) -- called at the end of each epoch |
| 170 | + | |
| 171 | + | (repeat for each epoch) |
| 172 | + | |
| 173 | +on_train_end(state) -- called once after all epochs |
| 174 | +``` |
| 175 | + |
| 176 | +### Validation hooks |
| 177 | + |
| 178 | +``` |
| 179 | +on_validation_start(state) -> bool -- return False to skip validation |
| 180 | + [... validation loop ...] |
| 181 | +on_validation_end(state) |
| 182 | +``` |
| 183 | + |
| 184 | +### Common uses for hooks |
| 185 | + |
| 186 | +| Hook | Common use | |
| 187 | +|------|------------| |
| 188 | +| `on_train_start` | Initialize training client, load checkpoint, set initial `global_step` | |
| 189 | +| `on_batch_end` | Save checkpoint, update sampling client, compute derived metrics, print metrics table | |
| 190 | +| `on_train_end` | Save final checkpoint | |
| 191 | +| `on_validation_start` | Toggle model to eval mode; return `False` to skip | |
| 192 | +| `on_validation_end` | Toggle model back to train mode | |
| 193 | + |
| 194 | +**Important:** `on_batch_end` runs **after** the pipeline but **before** |
| 195 | +`logger.log(...)`. This makes it the right place to inject derived metrics |
| 196 | +(e.g. KL divergence, learning rate) into `trainer_state.metrics`. |
| 197 | + |
| 198 | +--- |
| 199 | + |
| 200 | +## Flexible Stage Organization |
| 201 | + |
| 202 | +The protocol defines stages 4-7 as separate methods, but backends are free to |
| 203 | +redistribute work across them. The trainer always calls them in the same order -- |
| 204 | +it is the backend's responsibility to decide what each stage does internally. |
| 205 | + |
| 206 | +### Example: TinkerBackend's fused mode |
| 207 | + |
| 208 | +The `TinkerBackend` demonstrates this flexibility. When |
| 209 | +`fuse_forward_backward_and_optim_step` is enabled: |
| 210 | + |
| 211 | +``` |
| 212 | + Default (non-fused) Fused |
| 213 | + ------------------- ----- |
| 214 | +transform_to_backend returns [] placeholder same |
| 215 | +process_backend_batch forward + backward forward + backward + optim step |
| 216 | +compute_advantages stores algorithm_config same |
| 217 | +update_policy optimizer step no-op (already done) |
| 218 | +``` |
| 219 | + |
| 220 | +Both modes produce the same end result, but the fused path reduces round-trips |
| 221 | +to the training server. The trainer does not need to know which path is active -- |
| 222 | +it simply calls all four methods in order. |
| 223 | + |
| 224 | +### Example: Pre-computed advantages (OPSD) |
| 225 | + |
| 226 | +For On-Policy Self-Distillation, advantages are computed during episode generation |
| 227 | +(stage 1) via a workflow decorator. By the time `compute_advantages` (stage 6) runs, |
| 228 | +every `Step` already has its `.advantage` field set. The default implementation in |
| 229 | +`BackendProtocol` detects this and skips re-computation, collecting only metrics. |
| 230 | + |
| 231 | +This means OPSD can work with the standard `TinkerBackend` -- no custom backend |
| 232 | +subclass is needed. |
| 233 | + |
| 234 | +--- |
| 235 | + |
| 236 | +## Implementing a Custom Backend |
| 237 | + |
| 238 | +### Step 1: Subclass `BackendProtocol` |
| 239 | + |
| 240 | +```python |
| 241 | +from rllm.experimental.protocol import BackendProtocol |
| 242 | + |
| 243 | +class MyBackend(BackendProtocol[MyDataLoader, MyBatch]): |
| 244 | + name = "my_backend" |
| 245 | + |
| 246 | + def __init__(self, config, **kwargs): |
| 247 | + super().__init__(config, **kwargs) |
| 248 | + # ... initialize your infrastructure |
| 249 | +``` |
| 250 | + |
| 251 | +### Step 2: Implement required methods |
| 252 | + |
| 253 | +At minimum, you must implement these abstract methods: |
| 254 | + |
| 255 | +```python |
| 256 | +# Setup |
| 257 | +def init_rollout_engine(self, **kwargs) -> RolloutEngine: ... |
| 258 | +def validate_config(self) -> None: ... |
| 259 | +def get_dataloader(self, dataset, trainer_state) -> MyDataLoader: ... |
| 260 | +def shutdown(self) -> None: ... |
| 261 | + |
| 262 | +# Pipeline |
| 263 | +async def generate_episodes(self, batch, agent_workflow_engine, is_validation=False, **kwargs) -> list[Episode]: ... |
| 264 | +def transform_to_backend_batch(self, trainer_state, **kwargs) -> MyBatch: ... |
| 265 | +async def process_backend_batch(self, trainer_state, **kwargs) -> None: ... |
| 266 | +async def compute_advantages(self, trainer_state, algorithm_config, **kwargs) -> None: ... |
| 267 | +async def update_policy(self, trainer_state, **kwargs) -> None: ... |
| 268 | +``` |
| 269 | + |
| 270 | +### Step 3: Override lifecycle hooks as needed |
| 271 | + |
| 272 | +```python |
| 273 | +async def on_train_start(self, trainer_state): |
| 274 | + # Load checkpoint, initialize model |
| 275 | + ... |
| 276 | + |
| 277 | +async def on_batch_end(self, trainer_state): |
| 278 | + # Save checkpoint, update metrics |
| 279 | + ... |
| 280 | +``` |
| 281 | + |
| 282 | +### Step 4: Wire it up |
| 283 | + |
| 284 | +```python |
| 285 | +trainer = UnifiedTrainer( |
| 286 | + backend_cls=MyBackend, |
| 287 | + config=config, |
| 288 | + workflow_class=MyWorkflow, |
| 289 | + train_dataset=train_ds, |
| 290 | + val_dataset=val_ds, |
| 291 | +) |
| 292 | +trainer.fit() |
| 293 | +``` |
| 294 | + |
| 295 | +--- |
| 296 | + |
| 297 | +## Reference: TinkerBackend |
| 298 | + |
| 299 | +The `TinkerBackend` (`rllm.experimental.tinker.tinker_backend`) is the primary |
| 300 | +production backend. It serves as a comprehensive reference implementation. Below |
| 301 | +is a summary of how it implements each part of the protocol. |
| 302 | + |
| 303 | +### Setup |
| 304 | + |
| 305 | +| Method | Implementation | |
| 306 | +|---|---| |
| 307 | +| `init_rollout_engine` | Creates a `TinkerPolicyTrainer` and a `TinkerEngine` (rollout engine backed by a Tinker sampling server) | |
| 308 | +| `validate_config` | Warns if sampling temperature/top_p deviate from 1.0 | |
| 309 | +| `get_dataloader` | Returns a `torch.utils.data.DataLoader` with backend-specific batch sizes | |
| 310 | +| `shutdown` | Delegates to parent (no-op currently) | |
| 311 | + |
| 312 | +### Pipeline |
| 313 | + |
| 314 | +| Stage | Method | Implementation | |
| 315 | +|---|---|---| |
| 316 | +| 1 | `generate_episodes` | Builds an interleaved batch (`N` repeats per task for GRPO grouping), sets the sampling client on the rollout engine, and calls `agent_workflow_engine.execute_tasks(...)` | |
| 317 | +| 4 | `transform_to_backend_batch` | Returns an empty list (placeholder). The actual datum construction is deferred to stage 5 | |
| 318 | +| 5 | `process_backend_batch` | Converts trajectory groups to Tinker `Datum` objects, runs forward-backward, stores training logprobs. Optionally fuses the optimizer step | |
| 319 | +| 6 | `compute_advantages` | Stores the `AlgorithmConfig` for use by stage 5's datum construction (advantage computation is embedded in the forward-backward call) | |
| 320 | +| 7 | `update_policy` | Runs the optimizer step (or no-op if fused into stage 5) | |
| 321 | + |
| 322 | +### Lifecycle hooks |
| 323 | + |
| 324 | +| Hook | Implementation | |
| 325 | +|---|---| |
| 326 | +| `on_train_start` | Initializes the training client, loads checkpoint, sets `trainer_state.global_step` from the checkpoint's batch index | |
| 327 | +| `on_train_end` | Saves final checkpoint if not already saved | |
| 328 | +| `on_batch_end` | Saves sampler checkpoint, updates `self.sampling_client`, injects `optim/lr` and KL/entropy metrics into `trainer_state.metrics`, prints metrics table | |
| 329 | +| `on_epoch_start/end` | Logging only | |
| 330 | +| `on_validation_start/end` | Toggles `trainer_state.is_training` flag | |
| 331 | + |
| 332 | +### Key patterns to note |
| 333 | + |
| 334 | +1. **Deferred transformation.** `transform_to_backend_batch` returns a placeholder; |
| 335 | + the real work happens in `process_backend_batch`. This is valid because the trainer |
| 336 | + only checks `trainer_state.has_backend_batch` *after* `process_backend_batch` runs. |
| 337 | + |
| 338 | +2. **Checkpoint-driven sampling client.** The `sampling_client` (used by workflows |
| 339 | + for inference) is updated in `on_batch_end` after each checkpoint save. This |
| 340 | + ensures workflows always sample from the latest policy. |
| 341 | + |
| 342 | +3. **Metrics injection in `on_batch_end`.** Since `on_batch_end` runs after the |
| 343 | + pipeline but before `logger.log(...)`, it is the natural place to compute derived |
| 344 | + metrics (KL divergence, entropy, learning rate) and add them to |
| 345 | + `trainer_state.metrics`. |
| 346 | + |
| 347 | +--- |
| 348 | + |
| 349 | +## Data Flow Summary |
| 350 | + |
| 351 | +``` |
| 352 | + Dataset |
| 353 | + | |
| 354 | + v |
| 355 | + get_dataloader() --> batch |
| 356 | + | |
| 357 | + v |
| 358 | + generate_episodes(batch) --> list[Episode] |
| 359 | + | |
| 360 | + v |
| 361 | + [framework] transform to TrajectoryGroups |
| 362 | + | |
| 363 | + v |
| 364 | + [framework] rejection sampling & filtering |
| 365 | + | |
| 366 | + v |
| 367 | + transform_to_backend_batch() --> TBatch (stored in trainer_state.backend_batch) |
| 368 | + | |
| 369 | + v |
| 370 | + process_backend_batch() -- forward/backward, logprobs |
| 371 | + | |
| 372 | + v |
| 373 | + compute_advantages() -- advantage computation |
| 374 | + | |
| 375 | + v |
| 376 | + update_policy() -- optimizer step |
| 377 | + | |
| 378 | + v |
| 379 | + [framework] visualization, metrics collection |
| 380 | + | |
| 381 | + v |
| 382 | + on_batch_end() -- checkpoint, derived metrics |
| 383 | + | |
| 384 | + v |
| 385 | + logger.log(metrics) -- wandb / tracking |
| 386 | +``` |
0 commit comments