Skip to content

Commit ac5f8f4

Browse files
authored
Add taskset scheduler (agentscope-ai#326)
1 parent 9f1719e commit ac5f8f4

File tree

31 files changed

+1758
-76
lines changed

31 files changed

+1758
-76
lines changed

docs/sphinx_doc/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ Welcome to Trinity-RFT's documentation!
2121
tutorial/develop_algorithm.md
2222
tutorial/example_mix_algo.md
2323
tutorial/develop_operator.md
24+
tutorial/develop_selector.md
2425
tutorial/trinity_configs.md
2526
tutorial/synchronizer.md
2627

docs/sphinx_doc/source/tutorial/develop_operator.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
(Operators)=
12
## Operator Development Guide
23

34
### Step 0: Basic Concepts of Operator Module

docs/sphinx_doc/source/tutorial/develop_overview.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ The table below lists the main functions of each extension interface, its target
1111
| `Workflow` | Agent Application Developers | Enhance agent's ability to complete tasks in a specified environment | [🔗](./develop_workflow.md) |
1212
| `Algorithm` | RL Algorithm Researchers | Design new RL algorithms | [🔗](./develop_algorithm.md) |
1313
| `Operator` | Data Engineers | Design new data cleaning and augmentation strategies | [🔗](./develop_operator.md) |
14+
| `Selector` | Data Engineers | Design new task selection strategies | [🔗](./develop_selector.md) |
1415

1516
```{tip}
1617
Trinity-RFT provides a modular development approach, allowing you to flexibly add custom modules without modifying the framework code.
Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
2+
# 🧪 Experimental: Task Selection & Scheduling System
3+
4+
```{note}
5+
This module is currently in **experimental status**. Interfaces may change in future versions.
6+
This document describes the functionality and intended usage of the system.
7+
```
8+
9+
10+
11+
## Overview
12+
13+
This system enables **intelligent, adaptive task sampling** from multiple datasets (called *tasksets*) during exploration. It consists of two core components:
14+
15+
1. **`Selector`** – Controls how individual samples are selected *within* each taskset.
16+
2. **`TasksetScheduler`** – Manages *which* tasksets contribute to each batch and coordinates their sampling.
17+
18+
Together, they support advanced training strategies such as:
19+
- Curriculum learning (easy → hard)
20+
- Multi-task interleaving or mixing
21+
- Difficulty-aware sampling
22+
- Adaptive data selection based on model performance
23+
24+
These capabilities allow you to train models more efficiently by focusing on informative or challenging examples.
25+
26+
27+
28+
## Module 1: Selector – Customizable Data Selection
29+
30+
A `Selector` determines **which tasks (samples) to select** from its associated dataset (`Taskset`). Beyond basic strategies like sequential or random access, it supports **adaptive algorithms** that adjust sampling based on feedback—such as sample difficulty, model confidence, or reward signals.
31+
32+
### Built-in Selectors
33+
34+
| Selector Type | Description |
35+
|---------------|-------------|
36+
| `sequential` | Returns samples in fixed order (0, 1, ..., N). |
37+
| `shuffle` | Shuffles the dataset once per epoch; then iterates sequentially. |
38+
| `random` | Randomly samples without replacement within each batch. Independent across batches. |
39+
| `offline_easy2hard` | Sorts samples by pre-defined features (e.g., loss, length), serving easier ones first, progressing to harder ones. |
40+
| `difficulty_based` *(custom example)* | Dynamically selects samples near a target difficulty level using probabilistic modeling. |
41+
42+
You can also **implement your own custom selector** to enable adaptive or curriculum-based learning.
43+
44+
45+
46+
### ✅ Step 1: Implement a Custom Selector
47+
48+
To create a new selector, inherit from `BaseSelector` and implement the following methods:
49+
50+
#### Required Methods
51+
52+
| Method | Purpose |
53+
|-------|--------|
54+
| `get_indices(batch_size: int, return_extra_info=False) -> List[int]` | Return a list of sample indices to read next. |
55+
| `update(indices: List[int], values: List[float])` | Update internal state using feedback (e.g., rewards, losses). |
56+
| `state_dict() -> Dict` | Serialize current state for checkpointing. |
57+
| `load_state_dict(state_dict: Dict)` | Restore state from a saved dictionary. |
58+
59+
#### Example: `DifficultyBasedSelector`
60+
61+
This selector focuses on samples whose predicted performance is closest to a target (e.g., 90% success rate), effectively choosing "just right" difficulty tasks.
62+
63+
```python
64+
@SELECTORS.register_module("difficulty_based")
65+
class DifficultyBasedSelector(BaseSelector):
66+
def __init__(self, data_source, config: TaskSelectorConfig) -> None:
67+
super().__init__(data_source, config)
68+
self.logger = get_logger("difficulty_based_selector")
69+
70+
# Build difficulty estimator using two input features (e.g., correctness, uncertainty)
71+
self.diff_estimator = self.build_diff_estimator(
72+
data_source.dataset, config.feature_keys, config.kwargs
73+
)
74+
self.current_index = 0
75+
self.seed = config.seed
76+
77+
# Configuration parameters
78+
self.do_sample = config.kwargs.get("do_sample", False)
79+
self.target_reward = config.kwargs.get("target_reward", 1.0)
80+
self.tau = config.kwargs.get("tau", 1.0)
81+
82+
# ... detailed implementation
83+
84+
def get_indices(self, batch_size, return_extra_info=False):
85+
# Compute scores based on proximity to target reward
86+
sampling_scores = self.get_scores()
87+
sampling_scores = torch.from_numpy(sampling_scores)
88+
89+
if self.tau == 0:
90+
# Greedy: take top-k highest scoring samples
91+
selected_indices = torch.topk(sampling_scores, batch_size).indices
92+
else:
93+
# Stochastic: sample via softmax with temperature scaling
94+
sampling_logits = sampling_scores / self.tau
95+
sampling_logits -= sampling_logits.max() # Stability
96+
sampling_probabilities = torch.softmax(sampling_logits, dim=0)
97+
rng = torch.Generator().manual_seed(self.seed + self.current_index)
98+
selected_indices = torch.multinomial(
99+
sampling_probabilities,
100+
batch_size,
101+
replacement=False,
102+
generator=rng,
103+
)
104+
105+
self.current_index += batch_size
106+
107+
if return_extra_info:
108+
# Optional debugging info
109+
extra_info = {
110+
"indices": selected_indices.tolist(),
111+
"scores": sampling_scores[selected_indices].tolist(),
112+
# ... other metadata
113+
}
114+
return selected_indices, extra_info
115+
else:
116+
return selected_indices
117+
118+
def update(self, indices: List[int], values: List[float]) -> None:
119+
# Update difficulty model with observed rewards
120+
self.diff_estimator.update(indices, values)
121+
122+
def state_dict(self) -> Dict:
123+
return {"current_index": self.current_index}
124+
125+
def load_state_dict(self, state_dict: Dict) -> None:
126+
self.current_index = state_dict.get("current_index", 0)
127+
```
128+
129+
> 🔁 After defining your class, use `@SELECTORS.register_module("your_name")` so it can be referenced by name in configs.
130+
131+
132+
133+
### ✅ Step 2: Implement a Feedback Operator
134+
135+
For adaptive selectors like `DifficultyBasedSelector`, you need to provide runtime feedback (e.g., task rewards). This is done via an **Experience Operator** that processes rollouts and computes metrics.
136+
137+
> 📚 See the {ref}`Operator Development Guide<Operators>` for more on building custom experience processors.
138+
139+
The operator must output a metric under the key `trinity.common.constants.SELECTOR_METRIC`, structured as:
140+
141+
```python
142+
{
143+
SELECTOR_METRIC: {
144+
0: { # taskset_id
145+
"indices": [10, 25, 43],
146+
"values": [0.8, 0.6, 0.9] # e.g., average reward
147+
},
148+
1: { ... }
149+
}
150+
}
151+
```
152+
153+
#### Example: Pass Rate Calculator
154+
155+
```python
156+
@EXPERIENCE_OPERATORS.register_module("pass_rate_calculator")
157+
class PassRateCalculator(ExperienceOperator):
158+
def __init__(self, **kwargs):
159+
pass
160+
161+
def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]:
162+
raw_metric = defaultdict(lambda: defaultdict(list))
163+
164+
for exp in exps:
165+
task_index = exp.info["task_index"]
166+
assert "taskset_id" in task_index and "index" in task_index
167+
raw_metric[task_index["taskset_id"]][task_index["index"]].append(exp.reward)
168+
169+
metric = {}
170+
for taskset_id, task_metrics in raw_metric.items():
171+
indices = []
172+
reward_means = []
173+
for idx, rewards in task_metrics.items():
174+
indices.append(idx)
175+
reward_means.append(float(np.mean(rewards)))
176+
metric[taskset_id] = {
177+
"indices": indices,
178+
"values": reward_means,
179+
}
180+
181+
return exps, {SELECTOR_METRIC: metric}
182+
```
183+
184+
This operator calculates the average reward per task and passes it back to the corresponding selector for updating difficulty estimates.
185+
186+
187+
188+
### ✅ Step 3: Update Configuration
189+
190+
After implementing your selector and operator, register them in the config file.
191+
192+
#### Add the Operator to the Pipeline
193+
194+
```yaml
195+
data_processor:
196+
experience_pipeline:
197+
operators:
198+
- name: pass_rate_calculator # Must match @register_module name
199+
```
200+
201+
#### Configure the Taskset with Your Selector
202+
203+
```yaml
204+
buffer:
205+
explorer_input:
206+
tasksets:
207+
- name: my_taskset
208+
storage_type: file
209+
path: ./path/to/tasks
210+
task_selector:
211+
selector_type: difficulty_based # Matches @register_module name
212+
feature_keys: ["correct", "uncertainty"]
213+
kwargs:
214+
m: 16
215+
lamb: 0.2
216+
rho: 0.2
217+
target_reward: 0.9
218+
tau: 0.5
219+
do_sample: true
220+
```
221+
222+
> 💡 You can define multiple tasksets, each with its own selector type and configuration.
223+
224+
225+
226+
## Module 2: TasksetScheduler – Multi-Taskset Orchestration
227+
228+
The `TasksetScheduler` manages **how different tasksets are interleaved or mixed** during training.
229+
230+
### Key Features
231+
232+
- Supports **multiple tasksets** simultaneously.
233+
- Balances sampling proportionally to dataset sizes.
234+
- **Shuffles taskset access order** at the start of each epoch.
235+
- Enables **curriculum-style** or **interleaved multi-task training**.
236+
- Fully **checkpointable**: resumes exactly where it left off.
237+
- Integrates with any registered `Selector`.
238+
239+
### How It Works
240+
241+
At each training step:
242+
1. Determines which tasksets should contribute to the current batch.
243+
2. Queries each taskset’s selector to get specific sample indices.
244+
3. Reads the actual data asynchronously.
245+
4. Tags each task with `"taskset_id"` for downstream routing or analysis.
246+
247+
Epochs are defined based on total data volume and batch size:
248+
```python
249+
steps_per_epoch = total_samples // batch_size
250+
```
251+
252+
At the beginning of each epoch, the scheduler reshuffles the sequence of taskset accesses to introduce variability.
253+
254+
255+
256+
## Summary
257+
258+
With these components, you can:
259+
- Use simple strategies like random or sequential sampling.
260+
- Design **adaptive curricula** using custom selectors.
261+
- Combine multiple datasets intelligently.
262+
- Optimize training efficiency by focusing on high-value samples.
263+
264+
By combining smart `Selectors` with the flexible `TasksetScheduler`, you gain fine-grained control over what your model sees—and when.

docs/sphinx_doc/source_zh/tutorial/develop_overview.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ Trinity-RFT 将 RL 训练过程拆分为了三个模块:**Explorer**、**Train
1111
| `Workflow` | 智能体应用开发者 | 提升 Agent 在指定环境中完成任务的能力 | [🔗](./develop_workflow.md) |
1212
| `Algorithm` | RL 算法研究者 | 设计新的 RL 算法 | [🔗](./develop_algorithm.md) |
1313
| `Operator` | 数据工程师 | 设计新的数据清洗、增强策略 | [🔗](./develop_operator.md) |
14+
| `Selector` | 数据工程师 | 设计新的数据选择策略 | [🔗](./develop_selector.md) |
1415

1516
```{tip}
1617
Trinity-RFT 提供了插件化的开发方式,可以在不修改框架代码的前提下,灵活地添加自定义模块。

0 commit comments

Comments
 (0)