Skip to content

Commit 0afd252

Browse files
committed
Merge branch 'main' into dev/fix_microbatch_loss_scale
2 parents 64bc002 + 9efaec8 commit 0afd252

File tree

24 files changed

+414
-81
lines changed

24 files changed

+414
-81
lines changed
184 KB
Loading
184 KB
Loading
138 KB
Loading
485 KB
Loading

docs/sphinx_doc/source_zh/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
tutorial/develop_algorithm.md
2121
tutorial/example_mix_algo.md
2222
tutorial/develop_operator.md
23+
tutorial/develop_selector.md
2324
tutorial/trinity_configs.md
2425
tutorial/synchronizer.md
2526

examples/grpo_vlm/vlm.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ buffer:
2121
taskset:
2222
name: geometry3k
2323
storage_type: file
24-
path: hiyouga/geometry3k
24+
path: ${oc.env:TRINITY_TASKSET_PATH,hiyouga/geometry3k}
2525
subset_name: 'default'
2626
split: 'train'
2727
format:

examples/mix_chord/mix_chord.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ buffer:
6262
name: SFT_data
6363
storage_type: file
6464
schema_type: sft
65-
path: ${oc.env:TRINITY_SFT_DATASET_PATH,open-r1/Mixture-of-Thoughts}
65+
path: ${oc.env:TRINITY_SFT_DATASET_PATH}
6666
split: 'train'
6767
format:
6868
prompt_type: messages

examples/mix_math/mix_math.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ buffer:
6161
name: math_sft
6262
storage_type: file
6363
schema_type: sft
64-
path: ${oc.env:TRINITY_SFT_DATASET_PATH,open-r1/Mixture-of-Thoughts}
64+
path: ${oc.env:TRINITY_SFT_DATASET_PATH}
6565
split: 'train'
6666
format:
6767
prompt_type: messages

examples/mix_vlm/README.md

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# MIX algorithm with VLM
2+
3+
This is an example of using the [MIX](../../docs/sphinx_doc/source/tutorial/example_mix_algo.md) algorithm with Qwen2.5-VL-3B-Instruct model.
4+
5+
> [!NOTE]
6+
> This feature is experimental and will be subject to change in future releases.
7+
8+
The specific requirements are:
9+
10+
```yaml
11+
vllm>=0.9.1,<0.10.0
12+
transformers<4.53.0
13+
qwen_vl_utils
14+
```
15+
16+
## Prepare the SFT Dataset
17+
We use the [geometry3k](https://huggingface.co/datasets/hiyouga/geometry3k) dataset for training; we generate the [SFT dataset](https://huggingface.co/datasets/datajuicer/geometry_sft) by prompting Qwen2.5-VL-32B-Instruct model on the validation set. Note that this dataset only showcases the format of SFT data in this example, as shown below:
18+
```json
19+
{
20+
"problem": "<image>Find $x$ so that $m || n$.",
21+
"response": "To determine the value of $ x $ ... Answer:\n\\[\n\\boxed{63}\n\\]",
22+
"images": [<image>]
23+
}
24+
```
25+
26+
The config file is located in [`mix_vlm.yaml`](mix_vlm.yaml). To get better performance, feel free to try out different algorithm hyperparameters!
27+
28+
## Run the Example
29+
30+
Run the following command to start the training:
31+
```bash
32+
trinity run --config examples/mix_vlm/mix_vlm.yaml
33+
```
34+
35+
The reward curve is shown below:
36+
![](../../docs/sphinx_doc/assets/mix_vlm_reward.png)

examples/mix_vlm/mix_vlm.yaml

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
project: "Trinity-RFT"
2+
name: "mix_vlm"
3+
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
4+
algorithm:
5+
algorithm_type: mix_chord
6+
repeat_times: 8
7+
optimizer:
8+
lr: 1e-6
9+
kl_loss_fn_args:
10+
kl_coef: 0.0
11+
entropy_loss_fn: mix
12+
sample_strategy_args:
13+
expert_data_ratio: 0.20
14+
policy_loss_fn_args:
15+
mu_warmup_steps: 200
16+
mu_decay_steps: 400
17+
mu_peak: 0.1
18+
mu_valley: 0.1
19+
enable_phi_function: false
20+
clip_range: 0.2
21+
sft_loss_agg_mode: "token-mean"
22+
use_dynamic_bsz: true
23+
ppo_mini_batch_size: 320 # 320 = 256 + 64
24+
ppo_micro_batch_size_per_gpu: 4
25+
ngpus_trainer: 4
26+
train_batch_size_expert: 64
27+
train_batch_size_usual: 256 # 32 batchsize * 8 repeat times
28+
model:
29+
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-1.5B-Instruct}
30+
max_response_tokens: 10240
31+
max_model_len: 11264
32+
cluster:
33+
node_num: 1
34+
gpu_per_node: 8
35+
buffer:
36+
total_epochs: 4
37+
batch_size: 32
38+
train_batch_size: 320
39+
explorer_input:
40+
taskset:
41+
name: geometry3k
42+
storage_type: file
43+
path: ${oc.env:TRINITY_TASKSET_PATH,hiyouga/geometry3k}
44+
subset_name: 'default'
45+
split: 'train'
46+
format:
47+
prompt_key: 'problem'
48+
response_key: 'answer'
49+
image_key: 'images'
50+
rollout_args:
51+
temperature: 1.0
52+
logprobs: 0
53+
workflow_args:
54+
with_think: true
55+
eval_tasksets: [] # you can add your own eval tasksets here
56+
default_workflow_type: 'simple_mm_workflow'
57+
default_reward_fn_type: 'math_boxed_reward'
58+
trainer_input:
59+
experience_buffer:
60+
name: experience_buffer
61+
storage_type: queue
62+
auxiliary_buffers:
63+
sft_dataset:
64+
total_epochs: 25
65+
name: geometry_sft
66+
storage_type: file
67+
schema_type: sft
68+
path: datajuicer/geometry_sft
69+
split: 'train'
70+
format:
71+
prompt_type: plaintext
72+
prompt_key: 'problem'
73+
response_key: 'response'
74+
image_key: 'images'
75+
explorer:
76+
eval_interval: 10
77+
runner_per_model: 8
78+
rollout_model:
79+
engine_num: 4
80+
tensor_parallel_size: 1
81+
enable_prefix_caching: false
82+
enforce_eager: true
83+
dtype: bfloat16
84+
seed: 42
85+
synchronizer:
86+
sync_method: 'nccl'
87+
sync_interval: 1
88+
sync_timeout: 1200
89+
trainer:
90+
save_interval: 50
91+
grad_clip: 1.0
92+
use_dynamic_bsz: true
93+
max_token_len_per_gpu: 11264
94+
ulysses_sequence_parallel_size: 2

0 commit comments

Comments
 (0)