Skip to content

Commit 2f1e020

Browse files
committed
Merge branch 'main' into dev/fix_microbatch_loss_scale
2 parents 25754ca + cd82bfc commit 2f1e020

File tree

127 files changed

+3175
-1380
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

127 files changed

+3175
-1380
lines changed

.github/workflows/unittest.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,13 @@ jobs:
9797
fi
9898
fi
9999
100+
- name: Clean checkpoint dir
101+
working-directory: trinity-${{ github.run_id }}/.github/workflows/docker
102+
if: always()
103+
run: |
104+
docker compose exec trinity-node-1 rm -rf /mnt/checkpoints/*
105+
continue-on-error: true
106+
100107
- name: Upload test results
101108
if: env.tests_run == 'true' || failure()
102109
uses: actions/upload-artifact@v4

benchmark/config/countdown-template.yaml

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,19 +35,18 @@ buffer:
3535
rollout_args:
3636
temperature: 1.0
3737
logprobs: 0
38+
default_workflow_type: math_workflow
39+
default_reward_fn_type: countdown_reward
3840
eval_tasksets: []
39-
default_workflow_type: math_workflow
40-
default_reward_fn_type: countdown_reward
41-
system_prompt: null
42-
reply_prefix: null
4341
trainer_input:
4442
experience_buffer:
4543
name: experience_buffer
4644
storage_type: queue
47-
use_priority_queue: true
48-
replay_buffer_kwargs:
45+
replay_buffer:
46+
enable: true
4947
priority_fn: linear_decay
50-
decay: 0.1
48+
priority_fn_args:
49+
decay: 0.1
5150
explorer:
5251
runner_per_model: 8
5352
max_timeout: 900

benchmark/config/gsm8k-template.yaml

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,19 +40,18 @@ buffer:
4040
rollout_args:
4141
temperature: 1.0
4242
logprobs: 0
43+
default_workflow_type: math_workflow
44+
default_reward_fn_type: math_reward
4345
eval_tasksets: []
44-
default_workflow_type: math_workflow
45-
default_reward_fn_type: math_reward
46-
system_prompt: null
47-
reply_prefix: null
4846
trainer_input:
4947
experience_buffer:
5048
name: experience_buffer
5149
storage_type: queue
52-
use_priority_queue: true
53-
replay_buffer_kwargs:
50+
replay_buffer:
51+
enable: true
5452
priority_fn: linear_decay
55-
decay: 0.1
53+
priority_fn_args:
54+
decay: 0.1
5655
explorer:
5756
runner_per_model: 8
5857
max_timeout: 900
@@ -79,7 +78,7 @@ trainer:
7978
enable_preview: true
8079
grad_clip: 1.0
8180
use_dynamic_bsz: true
82-
ppo_max_token_len_per_gpu: 10240
81+
max_token_len_per_gpu: 10240
8382
ulysses_sequence_parallel_size: 1
8483
monitor:
8584
monitor_type: wandb

docs/sphinx_doc/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ Welcome to Trinity-RFT's documentation!
2121
tutorial/develop_algorithm.md
2222
tutorial/example_mix_algo.md
2323
tutorial/develop_operator.md
24+
tutorial/develop_selector.md
2425
tutorial/trinity_configs.md
2526
tutorial/synchronizer.md
2627

docs/sphinx_doc/source/tutorial/develop_operator.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1+
(Operators)=
12
## Operator Development Guide
23

34
### Step 0: Basic Concepts of Operator Module
45

56
In Trinity-RFT, the operator module is responsible for processing experience data in the buffer module. It supports existing data processing capabilities from [Data-Juicer](https://github.com/modelscope/data-juicer) naturally, and allows developers to implement their own operators as well.
67
By customizing operators, developers can implement various data processing functionalities, such as data augmentation, filtering, and transformation. You can even implement advantages/returns calculation as operators, as shown in {ref}`Algorithms <Algorithms>` section.
78

8-
- **DataJuicerOperator** ({class}`trinity.data.operators.DataJuicerOperator`): The operator that wraps the data processing operators from Data-Juicer. It provides a simple interface for developers to list the Data-Juicer operators they want to use. The full list of Data-Juicer operators can be found [here](https://modelscope.github.io/data-juicer/en/main/docs/Operators.html).
9-
- **ExperienceOperator** ({class}`trinity.data.operators.ExperienceOperator`): The base class for all operators used in experience data processing. It defines the interface and common functionalities that all operators should have. Each operator processes a batch of experience data and returns the processed data with metrics for logging.
10-
- **ExperiencePipeline** ({class}`trinity.data.pipelines.ExperiencePipeline`): The experience data processing pipeline that manages a sequence of operators. It takes raw experiences from the `Explorer`, passes them through each operator in the pipeline, and writes the final processed experiences into the input buffer of the `Trainer`.
9+
- **DataJuicerOperator** ({class}`trinity.buffer.operators.DataJuicerOperator`): The operator that wraps the data processing operators from Data-Juicer. It provides a simple interface for developers to list the Data-Juicer operators they want to use. The full list of Data-Juicer operators can be found [here](https://modelscope.github.io/data-juicer/en/main/docs/Operators.html).
10+
- **ExperienceOperator** ({class}`trinity.buffer.operators.ExperienceOperator`): The base class for all operators used in experience data processing. It defines the interface and common functionalities that all operators should have. Each operator processes a batch of experience data and returns the processed data with metrics for logging.
11+
- **ExperiencePipeline** ({class}`trinity.buffer.pipelines.ExperiencePipeline`): The experience data processing pipeline that manages a sequence of operators. It takes raw experiences from the `Explorer`, passes them through each operator in the pipeline, and writes the final processed experiences into the input buffer of the `Trainer`.
1112

1213
```{note}
1314
Except for `ExperiencePipeline`, Trinity-RFT also provides `TaskPipeline` for task data processing.
@@ -55,7 +56,7 @@ class RewardFilter(ExperienceOperator):
5556
return filtered_exps, metrics
5657
```
5758

58-
After implementation, you need to register this module through {class}`trinity.data.operators.EXPERIENCE_OPERATORS`. Once registered, the module can be configured in the configuration file using the registered name.
59+
After implementation, you need to register this module through {class}`trinity.buffer.operators.EXPERIENCE_OPERATORS`. Once registered, the module can be configured in the configuration file using the registered name.
5960

6061
### Step 2: Use Your Operator
6162

docs/sphinx_doc/source/tutorial/develop_overview.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ The table below lists the main functions of each extension interface, its target
1111
| `Workflow` | Agent Application Developers | Enhance agent's ability to complete tasks in a specified environment | [🔗](./develop_workflow.md) |
1212
| `Algorithm` | RL Algorithm Researchers | Design new RL algorithms | [🔗](./develop_algorithm.md) |
1313
| `Operator` | Data Engineers | Design new data cleaning and augmentation strategies | [🔗](./develop_operator.md) |
14+
| `Selector` | Data Engineers | Design new task selection strategies | [🔗](./develop_selector.md) |
1415

1516
```{tip}
1617
Trinity-RFT provides a modular development approach, allowing you to flexibly add custom modules without modifying the framework code.
Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
# 🧪 Experimental: Task Selection & Scheduling System
2+
3+
```{note}
4+
This module is currently in **experimental status**. Interfaces may change in future versions.
5+
This document describes the functionality and intended usage of the system.
6+
```
7+
8+
9+
10+
## Overview
11+
12+
This system enables **intelligent, adaptive task sampling** from multiple datasets (called *tasksets*) during exploration. It consists of two core components:
13+
14+
1. **`Selector`** – Controls how individual samples are selected *within* each taskset.
15+
2. **`TasksetScheduler`** – Manages *which* tasksets contribute to each batch and coordinates their sampling.
16+
17+
Together, they support advanced training strategies such as:
18+
- Curriculum learning (easy → hard)
19+
- Multi-task interleaving or mixing
20+
- Difficulty-aware sampling
21+
- Adaptive data selection based on model performance
22+
23+
These capabilities allow you to train models more efficiently by focusing on informative or challenging examples.
24+
25+
26+
27+
## Module 1: Selector – Customizable Data Selection
28+
29+
A `Selector` determines **which tasks (samples) to select** from its associated dataset (`Taskset`). Beyond basic strategies like sequential or random access, it supports **adaptive algorithms** that adjust sampling based on feedback—such as sample difficulty, model confidence, or reward signals.
30+
31+
### Built-in Selectors
32+
33+
| Selector Type | Description |
34+
|---------------|-------------|
35+
| `sequential` | Returns samples in fixed order (0, 1, ..., N). |
36+
| `shuffle` | Shuffles the dataset once per epoch; then iterates sequentially. |
37+
| `random` | Randomly samples without replacement within each batch. Independent across batches. |
38+
| `offline_easy2hard` | Sorts samples by pre-defined features (e.g., loss, length), serving easier ones first, progressing to harder ones. |
39+
| `difficulty_based` *(custom example)* | Dynamically selects samples near a target difficulty level using probabilistic modeling. |
40+
41+
You can also **implement your own custom selector** to enable adaptive or curriculum-based learning.
42+
43+
44+
45+
### ✅ Step 1: Implement a Custom Selector
46+
47+
To create a new selector, inherit from `BaseSelector` and implement the following methods:
48+
49+
#### Required Methods
50+
51+
| Method | Purpose |
52+
|-------|--------|
53+
| `get_indices(batch_size: int, return_extra_info=False) -> List[int]` | Return a list of sample indices to read next. |
54+
| `update(indices: List[int], values: List[float])` | Update internal state using feedback (e.g., rewards, losses). |
55+
| `state_dict() -> Dict` | Serialize current state for checkpointing. |
56+
| `load_state_dict(state_dict: Dict)` | Restore state from a saved dictionary. |
57+
58+
#### Example: `DifficultyBasedSelector`
59+
60+
This selector focuses on samples whose predicted performance is closest to a target (e.g., 90% success rate), effectively choosing "just right" difficulty tasks.
61+
62+
```python
63+
@SELECTORS.register_module("difficulty_based")
64+
class DifficultyBasedSelector(BaseSelector):
65+
def __init__(self, data_source, config: TaskSelectorConfig) -> None:
66+
super().__init__(data_source, config)
67+
self.logger = get_logger("difficulty_based_selector")
68+
69+
# Build difficulty estimator using two input features (e.g., correctness, uncertainty)
70+
self.diff_estimator = self.build_diff_estimator(
71+
data_source.dataset, config.feature_keys, config.kwargs
72+
)
73+
self.current_index = 0
74+
self.seed = config.seed
75+
76+
# Configuration parameters
77+
self.do_sample = config.kwargs.get("do_sample", False)
78+
self.target_reward = config.kwargs.get("target_reward", 1.0)
79+
self.tau = config.kwargs.get("tau", 1.0)
80+
81+
# ... detailed implementation
82+
83+
def get_indices(self, batch_size, return_extra_info=False):
84+
# Compute scores based on proximity to target reward
85+
sampling_scores = self.get_scores()
86+
sampling_scores = torch.from_numpy(sampling_scores)
87+
88+
if self.tau == 0:
89+
# Greedy: take top-k highest scoring samples
90+
selected_indices = torch.topk(sampling_scores, batch_size).indices
91+
else:
92+
# Stochastic: sample via softmax with temperature scaling
93+
sampling_logits = sampling_scores / self.tau
94+
sampling_logits -= sampling_logits.max() # Stability
95+
sampling_probabilities = torch.softmax(sampling_logits, dim=0)
96+
rng = torch.Generator().manual_seed(self.seed + self.current_index)
97+
selected_indices = torch.multinomial(
98+
sampling_probabilities,
99+
batch_size,
100+
replacement=False,
101+
generator=rng,
102+
)
103+
104+
self.current_index += batch_size
105+
106+
if return_extra_info:
107+
# Optional debugging info
108+
extra_info = {
109+
"indices": selected_indices.tolist(),
110+
"scores": sampling_scores[selected_indices].tolist(),
111+
# ... other metadata
112+
}
113+
return selected_indices, extra_info
114+
else:
115+
return selected_indices
116+
117+
def update(self, indices: List[int], values: List[float]) -> None:
118+
# Update difficulty model with observed rewards
119+
self.diff_estimator.update(indices, values)
120+
121+
def state_dict(self) -> Dict:
122+
return {"current_index": self.current_index}
123+
124+
def load_state_dict(self, state_dict: Dict) -> None:
125+
self.current_index = state_dict.get("current_index", 0)
126+
```
127+
128+
> 🔁 After defining your class, use `@SELECTORS.register_module("your_name")` so it can be referenced by name in configs.
129+
130+
131+
132+
### ✅ Step 2: Implement a Feedback Operator
133+
134+
For adaptive selectors like `DifficultyBasedSelector`, you need to provide runtime feedback (e.g., task rewards). This is done via an **Experience Operator** that processes rollouts and computes metrics.
135+
136+
> 📚 See the {ref}`Operator Development Guide<Operators>` for more on building custom experience processors.
137+
138+
The operator must output a metric under the key `trinity.common.constants.SELECTOR_METRIC`, structured as:
139+
140+
```python
141+
{
142+
SELECTOR_METRIC: {
143+
0: { # taskset_id
144+
"indices": [10, 25, 43],
145+
"values": [0.8, 0.6, 0.9] # e.g., average reward
146+
},
147+
1: { ... }
148+
}
149+
}
150+
```
151+
152+
#### Example: Pass Rate Calculator
153+
154+
```python
155+
@EXPERIENCE_OPERATORS.register_module("pass_rate_calculator")
156+
class PassRateCalculator(ExperienceOperator):
157+
def __init__(self, **kwargs):
158+
pass
159+
160+
def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]:
161+
raw_metric = defaultdict(lambda: defaultdict(list))
162+
163+
for exp in exps:
164+
task_index = exp.info["task_index"]
165+
assert "taskset_id" in task_index and "index" in task_index
166+
raw_metric[task_index["taskset_id"]][task_index["index"]].append(exp.reward)
167+
168+
metric = {}
169+
for taskset_id, task_metrics in raw_metric.items():
170+
indices = []
171+
reward_means = []
172+
for idx, rewards in task_metrics.items():
173+
indices.append(idx)
174+
reward_means.append(float(np.mean(rewards)))
175+
metric[taskset_id] = {
176+
"indices": indices,
177+
"values": reward_means,
178+
}
179+
180+
return exps, {SELECTOR_METRIC: metric}
181+
```
182+
183+
This operator calculates the average reward per task and passes it back to the corresponding selector for updating difficulty estimates.
184+
185+
186+
187+
### ✅ Step 3: Update Configuration
188+
189+
After implementing your selector and operator, register them in the config file.
190+
191+
#### Add the Operator to the Pipeline
192+
193+
```yaml
194+
data_processor:
195+
experience_pipeline:
196+
operators:
197+
- name: pass_rate_calculator # Must match @register_module name
198+
```
199+
200+
#### Configure the Taskset with Your Selector
201+
202+
```yaml
203+
buffer:
204+
explorer_input:
205+
tasksets:
206+
- name: my_taskset
207+
storage_type: file
208+
path: ./path/to/tasks
209+
task_selector:
210+
selector_type: difficulty_based # Matches @register_module name
211+
feature_keys: ["correct", "uncertainty"]
212+
kwargs:
213+
m: 16
214+
lamb: 0.2
215+
rho: 0.2
216+
target_reward: 0.9
217+
tau: 0.5
218+
do_sample: true
219+
```
220+
221+
> 💡 You can define multiple tasksets, each with its own selector type and configuration.
222+
223+
224+
225+
## Module 2: TasksetScheduler – Multi-Taskset Orchestration
226+
227+
The `TasksetScheduler` manages **how different tasksets are interleaved or mixed** during training.
228+
229+
### Key Features
230+
231+
- Supports **multiple tasksets** simultaneously.
232+
- Balances sampling proportionally to dataset sizes.
233+
- **Shuffles taskset access order** at the start of each epoch.
234+
- Enables **curriculum-style** or **interleaved multi-task training**.
235+
- Fully **checkpointable**: resumes exactly where it left off.
236+
- Integrates with any registered `Selector`.
237+
238+
### How It Works
239+
240+
At each training step:
241+
1. Determines which tasksets should contribute to the current batch.
242+
2. Queries each taskset’s selector to get specific sample indices.
243+
3. Reads the actual data asynchronously.
244+
4. Tags each task with `"taskset_id"` for downstream routing or analysis.
245+
246+
Epochs are defined based on total data volume and batch size:
247+
```python
248+
steps_per_epoch = total_samples // batch_size
249+
```
250+
251+
At the beginning of each epoch, the scheduler reshuffles the sequence of taskset accesses to introduce variability.
252+
253+
254+
255+
## Summary
256+
257+
With these components, you can:
258+
- Use simple strategies like random or sequential sampling.
259+
- Design **adaptive curricula** using custom selectors.
260+
- Combine multiple datasets intelligently.
261+
- Optimize training efficiency by focusing on high-value samples.
262+
263+
By combining smart `Selectors` with the flexible `TasksetScheduler`, you gain fine-grained control over what your model sees—and when.

0 commit comments

Comments
 (0)