Skip to content

Commit 0d5bad2

Browse files
committed
Split adv fn into separate files, and other update (TODO: update yaml configs and config manager)
1 parent 3742f06 commit 0d5bad2

File tree

13 files changed

+313
-172
lines changed

13 files changed

+313
-172
lines changed

tests/template/config.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@ algorithm:
88
policy_loss_fn: ppo
99
policy_loss_fn_args:
1010
clip_range: 0.2
11+
advantage_fn_type: ppo_adv_fn
12+
advantage_fn_args:
13+
gamma: 1.0
14+
lam: 1.0
15+
1116
model:
1217
model_path: ''
1318
max_prompt_tokens: 2048

trinity/algorithm/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from trinity.algorithm.advantage_fn.advantage_fn import ADVANTAGE_FN, AdvantageFn
1+
from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn
22
from trinity.algorithm.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
33

44
__all__ = [
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from trinity.algorithm.advantage_fn.advantage_fn import ADVANTAGE_FN, AdvantageFn
2+
from trinity.algorithm.advantage_fn.grpo_advantage import GRPOAdvantageFn
3+
from trinity.algorithm.advantage_fn.opmd_advantage import OPMDAdvantageFn
4+
from trinity.algorithm.advantage_fn.ppo_advantage import PPOAdvantageFn
5+
from trinity.algorithm.advantage_fn.reinforce_plus_plus_advantage import (
6+
REINFORCEPLUSPLUSAdvantageFn,
7+
)
8+
from trinity.algorithm.advantage_fn.remax_advantage import REMAXAdvantageFn
9+
from trinity.algorithm.advantage_fn.rloo_advantage import RLOOAdvantageFn
10+
11+
__all__ = [
12+
"ADVANTAGE_FN",
13+
"AdvantageFn",
14+
"PPOAdvantageFn",
15+
"GRPOAdvantageFn",
16+
"REINFORCEPLUSPLUSAdvantageFn",
17+
"REMAXAdvantageFn",
18+
"RLOOAdvantageFn",
19+
"OPMDAdvantageFn",
20+
]
Lines changed: 8 additions & 159 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
from abc import ABC, abstractmethod
22
from typing import Any, Dict, Tuple
33

4-
from verl import DataProto
5-
6-
from trinity.trainer.verl import core_algos
74
from trinity.utils.registry import Registry
85

96
ADVANTAGE_FN = Registry("advantage_fn")
@@ -19,162 +16,14 @@ def __call__(self, exps: Any, **kwargs: Dict) -> Tuple[Any, Dict]:
1916
kwargs (`Dict`): The step-level parameters for calculating advantages.
2017
2118
Returns:
22-
`Any`: The experiences with advantages.
19+
`DataProto`: The experiences with advantages.
2320
`Dict`: The metrics for logging.
2421
"""
2522

26-
27-
@ADVANTAGE_FN.register("ppo_adv_fn")
28-
class PPOAdvantageFn(AdvantageFn):
29-
"""PPO's GAE advantage computation"""
30-
31-
def __init__(self, **kwargs):
32-
self.gamma = kwargs.get("gamma")
33-
self.lam = kwargs.get("lam")
34-
35-
def __call__(self, exps: DataProto, **kwargs) -> Tuple[DataProto, Dict]:
36-
"""Adapted from compute_advantage_ppo in ray_trainer.py"""
37-
38-
advantages, returns = core_algos.compute_gae_advantage_return(
39-
token_level_rewards=exps.batch["token_level_rewards"],
40-
values=exps.batch["values"],
41-
eos_mask=exps.batch["response_mask"],
42-
gamma=self.gamma,
43-
lam=self.lam,
44-
)
45-
exps.batch["advantages"] = advantages
46-
exps.batch["returns"] = returns
47-
48-
metrics = {
49-
"abc": "xyz", # TODO: add meaningful metrics
50-
}
51-
52-
return exps, metrics
53-
54-
55-
@ADVANTAGE_FN.register("grpo_adv_fn")
56-
class GRPOAdvantageFn(AdvantageFn):
57-
"""GRPO advantage computation"""
58-
59-
def __init__(self, **kwargs):
60-
pass
61-
62-
def __call__(self, exps: DataProto, **kwargs) -> Tuple[DataProto, Dict]:
63-
"""Adapted from compute_advantage_ppo in ray_trainer.py"""
64-
65-
advantages, returns = core_algos.compute_grpo_outcome_advantage(
66-
token_level_rewards=exps.batch["token_level_rewards"],
67-
eos_mask=exps.batch["response_mask"],
68-
index=exps.non_tensor_batch["uid"],
69-
)
70-
exps.batch["advantages"] = advantages
71-
exps.batch["returns"] = returns
72-
73-
metrics = {
74-
"abc": "xyz", # TODO: add meaningful metrics
75-
}
76-
77-
return exps, metrics
78-
79-
80-
@ADVANTAGE_FN.register("reinforceplusplus_adv_fn")
81-
class REINFORCEPLUSPLUSAdvantageFn(AdvantageFn):
82-
"""REINFORCE++ advantage computation"""
83-
84-
def __init__(self, **kwargs):
85-
self.gamma = kwargs.get("gamma")
86-
87-
def __call__(self, exps: DataProto, **kwargs) -> Tuple[DataProto, Dict]:
88-
"""Adapted from compute_advantage_ppo in ray_trainer.py"""
89-
90-
advantages, returns = core_algos.compute_reinforce_plus_plus_outcome_advantage(
91-
token_level_rewards=exps.batch["token_level_rewards"],
92-
eos_mask=exps.batch["response_mask"],
93-
gamma=self.gamma,
94-
)
95-
exps.batch["advantages"] = advantages
96-
exps.batch["returns"] = returns
97-
98-
metrics = {
99-
"abc": "xyz", # TODO: add meaningful metrics
100-
}
101-
102-
return exps, metrics
103-
104-
105-
@ADVANTAGE_FN.register("remax_adv_fn")
106-
class REMAXAdvantageFn(AdvantageFn):
107-
"""REMAX advantage computation"""
108-
109-
def __init__(self, **kwargs):
110-
pass
111-
112-
def __call__(self, exps: DataProto, **kwargs) -> Tuple[DataProto, Dict]:
113-
"""Adapted from compute_advantage_ppo in ray_trainer.py"""
114-
115-
advantages, returns = core_algos.compute_remax_outcome_advantage(
116-
token_level_rewards=exps.batch["token_level_rewards"],
117-
reward_baselines=exps.batch["reward_baselines"],
118-
eos_mask=exps.batch["response_mask"],
119-
)
120-
exps.batch["advantages"] = advantages
121-
exps.batch["returns"] = returns
122-
123-
metrics = {
124-
"abc": "xyz", # TODO: add meaningful metrics
125-
}
126-
127-
return exps, metrics
128-
129-
130-
@ADVANTAGE_FN.register("rloo_adv_fn")
131-
class RLOOAdvantageFn(AdvantageFn):
132-
"""RLOO advantage computation"""
133-
134-
def __init__(self, **kwargs):
135-
pass
136-
137-
def __call__(self, exps: DataProto, **kwargs) -> Tuple[DataProto, Dict]:
138-
"""Adapted from compute_advantage_ppo in ray_trainer.py"""
139-
140-
advantages, returns = core_algos.compute_rloo_outcome_advantage(
141-
token_level_rewards=exps.batch["token_level_rewards"],
142-
eos_mask=exps.batch["response_mask"],
143-
index=exps.non_tensor_batch["uid"],
144-
)
145-
exps.batch["advantages"] = advantages
146-
exps.batch["returns"] = returns
147-
148-
metrics = {
149-
"abc": "xyz", # TODO: add meaningful metrics
150-
}
151-
152-
return exps, metrics
153-
154-
155-
@ADVANTAGE_FN.register("opmd_adv_fn")
156-
class OPMDAdvantageFn(AdvantageFn):
157-
"""OPMD advantage computation"""
158-
159-
def __init__(self, **kwargs):
160-
pass
161-
162-
def __call__(self, exps: DataProto, **kwargs) -> Tuple[DataProto, Dict]:
163-
"""Adapted from compute_advantage_opmd in ray_trainer.py"""
164-
165-
advantages, returns = core_algos.compute_opmd_outcome_advantage(
166-
token_level_rewards=exps.batch["token_level_rewards"],
167-
eos_mask=exps.batch["response_mask"],
168-
# TODO: check consistency with exps.batch["attention_mask"][:, -response_length:] in original implementation
169-
index=exps.non_tensor_batch["uid"],
170-
opmd_baseline="mean",
171-
tau=1.0,
172-
)
173-
exps.batch["advantages"] = advantages
174-
exps.batch["returns"] = returns
175-
176-
metrics = {
177-
"abc": "xyz", # TODO: add meaningful metrics
178-
}
179-
180-
return exps, metrics
23+
@classmethod
24+
@abstractmethod
25+
def default_args(cls) -> Dict:
26+
"""
27+
Returns:
28+
`Dict`: The default init arguments for the advantage function.
29+
"""
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
"""GRPO advantage computation
2+
3+
Adapted from compute_advantage_ppo in original ray_trainer.py
4+
"""
5+
6+
from typing import Dict, Tuple
7+
8+
from verl import DataProto
9+
10+
from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn
11+
from trinity.trainer.verl import core_algos
12+
13+
14+
@ADVANTAGE_FN.register_module("grpo_adv_fn")
15+
class GRPOAdvantageFn(AdvantageFn):
16+
"""GRPO advantage computation"""
17+
18+
def __init__(self) -> None:
19+
pass
20+
21+
def __call__(
22+
self,
23+
exps: DataProto,
24+
**kwargs,
25+
) -> Tuple[DataProto, Dict]:
26+
advantages, returns = core_algos.compute_grpo_outcome_advantage(
27+
token_level_rewards=exps.batch["token_level_rewards"],
28+
eos_mask=exps.batch["response_mask"],
29+
index=exps.non_tensor_batch["uid"],
30+
)
31+
exps.batch["advantages"] = advantages
32+
exps.batch["returns"] = returns
33+
34+
metrics = {
35+
# TODO: add meaningful metrics
36+
}
37+
38+
return exps, metrics
39+
40+
@classmethod
41+
def default_args(cls) -> Dict:
42+
return {}
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
"""OPMD advantage computation
2+
3+
Adapted from compute_advantage_opmd in original ray_trainer.py
4+
"""
5+
6+
from typing import Dict, Tuple
7+
8+
from verl import DataProto
9+
10+
from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn
11+
from trinity.trainer.verl import core_algos
12+
13+
14+
@ADVANTAGE_FN.register_module("opmd_adv_fn")
15+
class OPMDAdvantageFn(AdvantageFn):
16+
"""OPMD advantage computation"""
17+
18+
def __init__(self) -> None:
19+
pass
20+
21+
def __call__(
22+
self,
23+
exps: DataProto,
24+
**kwargs,
25+
) -> Tuple[DataProto, Dict]:
26+
advantages, returns = core_algos.compute_opmd_outcome_advantage(
27+
token_level_rewards=exps.batch["token_level_rewards"],
28+
eos_mask=exps.batch["response_mask"],
29+
# TODO (yanxi): check consistency with exps.batch["attention_mask"][:, -response_length:] in original implementation
30+
index=exps.non_tensor_batch["uid"],
31+
opmd_baseline="mean",
32+
tau=1.0,
33+
)
34+
exps.batch["advantages"] = advantages
35+
exps.batch["returns"] = returns
36+
37+
metrics = {
38+
# TODO: add meaningful metrics
39+
}
40+
41+
return exps, metrics
42+
43+
@classmethod
44+
def default_args(cls) -> Dict:
45+
return {}
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
"""PPO's GAE advantage computation
2+
3+
Adapted from compute_advantage_ppo in original ray_trainer.py
4+
"""
5+
6+
from typing import Dict, Tuple
7+
8+
from verl import DataProto
9+
10+
from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn
11+
from trinity.trainer.verl import core_algos
12+
13+
14+
@ADVANTAGE_FN.register_module("ppo_adv_fn")
15+
class PPOAdvantageFn(AdvantageFn):
16+
def __init__(
17+
self,
18+
gamma: float = 1.0,
19+
lam: float = 1.0,
20+
) -> None:
21+
self.gamma = gamma
22+
self.lam = lam
23+
24+
def __call__(
25+
self,
26+
exps: DataProto,
27+
**kwargs,
28+
) -> Tuple[DataProto, Dict]:
29+
advantages, returns = core_algos.compute_gae_advantage_return(
30+
token_level_rewards=exps.batch["token_level_rewards"],
31+
values=exps.batch["values"],
32+
eos_mask=exps.batch["response_mask"],
33+
gamma=self.gamma,
34+
lam=self.lam,
35+
)
36+
exps.batch["advantages"] = advantages
37+
exps.batch["returns"] = returns
38+
39+
metrics = {
40+
# TODO: add meaningful metrics
41+
}
42+
43+
return exps, metrics
44+
45+
@classmethod
46+
def default_args(cls) -> Dict:
47+
return {
48+
"gamma": 1.0,
49+
"lam": 1.0,
50+
}

0 commit comments

Comments
 (0)