forked from agentscope-ai/Trinity-RFT
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrainer.py
More file actions
156 lines (129 loc) · 5.19 KB
/
trainer.py
File metadata and controls
156 lines (129 loc) · 5.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# -*- coding: utf-8 -*-
"""
Trainer Class
This file is modified from verl.trainer.main_ppo.py
And is a reproduction code of Jiayi-Pan/TinyZero.
Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
"""
from abc import ABC, abstractmethod
from typing import Tuple
import ray
from trinity.buffer import get_buffer_reader
from trinity.common.config import Config, TrainerConfig
from trinity.common.constants import AlgorithmType
from trinity.common.experience import Experiences
from trinity.utils.log import get_logger
@ray.remote(name="trainer")
class Trainer:
"""Consume the experience and train the model."""
def __init__(self, config: Config) -> None:
self.config = config
self.logger = get_logger(__name__)
self.train_buffer = get_buffer_reader(
self.config.buffer.train_dataset, # type: ignore
self.config.buffer,
)
self.sft_warmup_buffer = (
get_buffer_reader(
self.config.buffer.sft_warmup_dataset, # type: ignore
self.config.buffer,
)
if self.config.trainer.sft_warmup_iteration > 0
else None
)
self.engine = get_trainer_wrapper(config.trainer)
def prepare(self) -> None:
"""Prepare the trainer."""
self.engine.prepare()
def train(self, algo_type: AlgorithmType = AlgorithmType.PPO):
"""Train the model."""
while True:
train_status, _ = self.train_iteration(algo_type)
if not train_status:
break
def train_step(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tuple[bool, int]:
"""Train one step. Each step contains `sync_iteration_interval` iteration.
Returns:
train_status: Whether to continue training.
train_iter_num: The number of training iterations"""
for _ in range(self.config.synchronizer.sync_iteration_interval):
train_status, train_iter_num = self.train_iteration(algo_type)
if not train_status:
return False, train_iter_num
self.logger.info("Trainer finished.")
return True, train_iter_num
def train_iteration(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tuple[bool, int]:
"""Train one iteration.
Args:
algo_type (AlgorithmType): The type of data to be used for training.
Defaults to AlgorithmType.PPO.
Returns:
bool: Whether to continue training.
"""
self.engine.set_mode(algo_type)
if algo_type.is_sft():
exps = self.sft_warmup_buffer.read()
return self.engine.train_sft_iteration(
Experiences.gather_experiences(
exps,
pad_token_id=self.config.buffer.pad_token_id, # type: ignore
)
)
else:
exps = self.train_buffer.read()
if algo_type.is_rft():
return self.engine.train_rft_iteration(
Experiences.gather_experiences(
exps,
pad_token_id=self.config.buffer.pad_token_id, # type: ignore
)
)
elif algo_type.is_dpo():
return self.engine.train_dpo_iteration(
Experiences.gather_dpo_experiences(
exps,
pad_token_id=self.config.buffer.pad_token_id, # type: ignore
)
)
else:
raise ValueError(f"Unsupported algorithm type: {algo_type}")
def sync_weight(self) -> None:
"""Sync the model weight."""
if self.config.synchronizer.sync_method == "online":
self.engine.sync_weight()
def log_finalize(self, step: int) -> None:
"""Commit the logging results to wandb"""
self.engine.logger.log({"dummy_log_trainer": step}, step=step, commit=True)
class TrainEngineWrapper(ABC):
"""A wrapper class to wrap various training engines."""
@abstractmethod
def prepare(self) -> None:
"""Do some preparation before training started."""
@abstractmethod
def train_rft_iteration(self, experiences) -> Tuple[bool, int]:
"""Train on the RFT data."""
@abstractmethod
def train_sft_iteration(self, experiences) -> Tuple[bool, int]:
"""Train on the SFT data."""
@abstractmethod
def train_dpo_iteration(self, experiences) -> Tuple[bool, int]:
"""Train on the DPO data."""
@abstractmethod
def save_checkpoint(self) -> None:
"""Save the checkpoint."""
@abstractmethod
def sync_weight(self) -> None:
"""Sync the model weight."""
@abstractmethod
def set_mode(self, algo_type: AlgorithmType) -> None:
"""Set training mode."""
@abstractmethod
def shutdown(self) -> None:
"""Shutdown the engine."""
def get_trainer_wrapper(config: TrainerConfig) -> TrainEngineWrapper:
"""Get a trainer wrapper."""
if config.trainer_type == "verl":
from trinity.trainer.verl_trainer import VerlPPOTrainerWrapper
return VerlPPOTrainerWrapper(config)
else:
raise NotImplementedError