11from abc import ABC , abstractmethod
22from typing import Any , Dict , Tuple
33
4- from verl import DataProto
5-
6- from trinity .trainer .verl import core_algos
74from trinity .utils .registry import Registry
85
96ADVANTAGE_FN = Registry ("advantage_fn" )
@@ -19,162 +16,14 @@ def __call__(self, exps: Any, **kwargs: Dict) -> Tuple[Any, Dict]:
1916 kwargs (`Dict`): The step-level parameters for calculating advantages.
2017
2118 Returns:
22- `Any `: The experiences with advantages.
19+ `DataProto `: The experiences with advantages.
2320 `Dict`: The metrics for logging.
2421 """
2522
26-
27- @ADVANTAGE_FN .register ("ppo_adv_fn" )
28- class PPOAdvantageFn (AdvantageFn ):
29- """PPO's GAE advantage computation"""
30-
31- def __init__ (self , ** kwargs ):
32- self .gamma = kwargs .get ("gamma" )
33- self .lam = kwargs .get ("lam" )
34-
35- def __call__ (self , exps : DataProto , ** kwargs ) -> Tuple [DataProto , Dict ]:
36- """Adapted from compute_advantage_ppo in ray_trainer.py"""
37-
38- advantages , returns = core_algos .compute_gae_advantage_return (
39- token_level_rewards = exps .batch ["token_level_rewards" ],
40- values = exps .batch ["values" ],
41- eos_mask = exps .batch ["response_mask" ],
42- gamma = self .gamma ,
43- lam = self .lam ,
44- )
45- exps .batch ["advantages" ] = advantages
46- exps .batch ["returns" ] = returns
47-
48- metrics = {
49- "abc" : "xyz" , # TODO: add meaningful metrics
50- }
51-
52- return exps , metrics
53-
54-
55- @ADVANTAGE_FN .register ("grpo_adv_fn" )
56- class GRPOAdvantageFn (AdvantageFn ):
57- """GRPO advantage computation"""
58-
59- def __init__ (self , ** kwargs ):
60- pass
61-
62- def __call__ (self , exps : DataProto , ** kwargs ) -> Tuple [DataProto , Dict ]:
63- """Adapted from compute_advantage_ppo in ray_trainer.py"""
64-
65- advantages , returns = core_algos .compute_grpo_outcome_advantage (
66- token_level_rewards = exps .batch ["token_level_rewards" ],
67- eos_mask = exps .batch ["response_mask" ],
68- index = exps .non_tensor_batch ["uid" ],
69- )
70- exps .batch ["advantages" ] = advantages
71- exps .batch ["returns" ] = returns
72-
73- metrics = {
74- "abc" : "xyz" , # TODO: add meaningful metrics
75- }
76-
77- return exps , metrics
78-
79-
80- @ADVANTAGE_FN .register ("reinforceplusplus_adv_fn" )
81- class REINFORCEPLUSPLUSAdvantageFn (AdvantageFn ):
82- """REINFORCE++ advantage computation"""
83-
84- def __init__ (self , ** kwargs ):
85- self .gamma = kwargs .get ("gamma" )
86-
87- def __call__ (self , exps : DataProto , ** kwargs ) -> Tuple [DataProto , Dict ]:
88- """Adapted from compute_advantage_ppo in ray_trainer.py"""
89-
90- advantages , returns = core_algos .compute_reinforce_plus_plus_outcome_advantage (
91- token_level_rewards = exps .batch ["token_level_rewards" ],
92- eos_mask = exps .batch ["response_mask" ],
93- gamma = self .gamma ,
94- )
95- exps .batch ["advantages" ] = advantages
96- exps .batch ["returns" ] = returns
97-
98- metrics = {
99- "abc" : "xyz" , # TODO: add meaningful metrics
100- }
101-
102- return exps , metrics
103-
104-
105- @ADVANTAGE_FN .register ("remax_adv_fn" )
106- class REMAXAdvantageFn (AdvantageFn ):
107- """REMAX advantage computation"""
108-
109- def __init__ (self , ** kwargs ):
110- pass
111-
112- def __call__ (self , exps : DataProto , ** kwargs ) -> Tuple [DataProto , Dict ]:
113- """Adapted from compute_advantage_ppo in ray_trainer.py"""
114-
115- advantages , returns = core_algos .compute_remax_outcome_advantage (
116- token_level_rewards = exps .batch ["token_level_rewards" ],
117- reward_baselines = exps .batch ["reward_baselines" ],
118- eos_mask = exps .batch ["response_mask" ],
119- )
120- exps .batch ["advantages" ] = advantages
121- exps .batch ["returns" ] = returns
122-
123- metrics = {
124- "abc" : "xyz" , # TODO: add meaningful metrics
125- }
126-
127- return exps , metrics
128-
129-
130- @ADVANTAGE_FN .register ("rloo_adv_fn" )
131- class RLOOAdvantageFn (AdvantageFn ):
132- """RLOO advantage computation"""
133-
134- def __init__ (self , ** kwargs ):
135- pass
136-
137- def __call__ (self , exps : DataProto , ** kwargs ) -> Tuple [DataProto , Dict ]:
138- """Adapted from compute_advantage_ppo in ray_trainer.py"""
139-
140- advantages , returns = core_algos .compute_rloo_outcome_advantage (
141- token_level_rewards = exps .batch ["token_level_rewards" ],
142- eos_mask = exps .batch ["response_mask" ],
143- index = exps .non_tensor_batch ["uid" ],
144- )
145- exps .batch ["advantages" ] = advantages
146- exps .batch ["returns" ] = returns
147-
148- metrics = {
149- "abc" : "xyz" , # TODO: add meaningful metrics
150- }
151-
152- return exps , metrics
153-
154-
155- @ADVANTAGE_FN .register ("opmd_adv_fn" )
156- class OPMDAdvantageFn (AdvantageFn ):
157- """OPMD advantage computation"""
158-
159- def __init__ (self , ** kwargs ):
160- pass
161-
162- def __call__ (self , exps : DataProto , ** kwargs ) -> Tuple [DataProto , Dict ]:
163- """Adapted from compute_advantage_opmd in ray_trainer.py"""
164-
165- advantages , returns = core_algos .compute_opmd_outcome_advantage (
166- token_level_rewards = exps .batch ["token_level_rewards" ],
167- eos_mask = exps .batch ["response_mask" ],
168- # TODO: check consistency with exps.batch["attention_mask"][:, -response_length:] in original implementation
169- index = exps .non_tensor_batch ["uid" ],
170- opmd_baseline = "mean" ,
171- tau = 1.0 ,
172- )
173- exps .batch ["advantages" ] = advantages
174- exps .batch ["returns" ] = returns
175-
176- metrics = {
177- "abc" : "xyz" , # TODO: add meaningful metrics
178- }
179-
180- return exps , metrics
23+ @classmethod
24+ @abstractmethod
25+ def default_args (cls ) -> Dict :
26+ """
27+ Returns:
28+ `Dict`: The default init arguments for the advantage function.
29+ """
0 commit comments