File tree Expand file tree Collapse file tree 2 files changed +6
-4
lines changed
trinity/algorithm/policy_loss_fn Expand file tree Collapse file tree 2 files changed +6
-4
lines changed Original file line number Diff line number Diff line change 99
1010from trinity .algorithm .policy_loss_fn .policy_loss_fn import PolicyLossFn
1111from trinity .algorithm .utils import aggregate_loss , masked_mean
12-
1312from trinity .utils .log import get_logger
1413
1514logger = get_logger (__name__ )
1615
16+
1717class GSPOLossFn (PolicyLossFn ):
1818 def __init__ (
1919 self ,
@@ -35,7 +35,9 @@ def __init__(
3535 self .clip_range_high = _clip_range_high
3636
3737 if loss_agg_mode != "seq-mean-token-mean" :
38- logger .warning (f"The original GSPO paper requires loss_agg_mode to be 'seq-mean-token-mean', but the current setting is '{ loss_agg_mode } '." )
38+ logger .warning (
39+ f"The original GSPO paper requires loss_agg_mode to be 'seq-mean-token-mean', but the current setting is '{ loss_agg_mode } '."
40+ )
3941 # loss_agg_mode = "seq-mean-token-mean"
4042 self .loss_agg_mode = loss_agg_mode
4143
Original file line number Diff line number Diff line change 66import torch
77
88from trinity .algorithm .policy_loss_fn .policy_loss_fn import PolicyLossFn
9- from trinity .algorithm .utils import masked_mean
9+ from trinity .algorithm .utils import aggregate_loss , masked_mean
1010
1111
1212class RECPolicyLossFn (PolicyLossFn ):
@@ -123,7 +123,7 @@ def __call__( # type: ignore
123123
124124 if self .clip_mode == "gspo-one-side" :
125125 # [EXPERIMENTAL] specialized for gspo-style rec variant for now
126- pg_loss = masked_loss (
126+ pg_loss = aggregate_loss (
127127 values = pg_losses ,
128128 mask = action_mask ,
129129 loss_agg_mode = "seq-mean-token-mean" ,
You can’t perform that action at this time.
0 commit comments