Skip to content

Commit 542e4ac

Browse files
committed
Fix pre-commit and rec
1 parent 5b77c08 commit 542e4ac

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

trinity/algorithm/policy_loss_fn/gspo_policy_loss.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99

1010
from trinity.algorithm.policy_loss_fn.policy_loss_fn import PolicyLossFn
1111
from trinity.algorithm.utils import aggregate_loss, masked_mean
12-
1312
from trinity.utils.log import get_logger
1413

1514
logger = get_logger(__name__)
1615

16+
1717
class GSPOLossFn(PolicyLossFn):
1818
def __init__(
1919
self,
@@ -35,7 +35,9 @@ def __init__(
3535
self.clip_range_high = _clip_range_high
3636

3737
if loss_agg_mode != "seq-mean-token-mean":
38-
logger.warning(f"The original GSPO paper requires loss_agg_mode to be 'seq-mean-token-mean', but the current setting is '{loss_agg_mode}'.")
38+
logger.warning(
39+
f"The original GSPO paper requires loss_agg_mode to be 'seq-mean-token-mean', but the current setting is '{loss_agg_mode}'."
40+
)
3941
# loss_agg_mode = "seq-mean-token-mean"
4042
self.loss_agg_mode = loss_agg_mode
4143

trinity/algorithm/policy_loss_fn/rec_policy_loss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch
77

88
from trinity.algorithm.policy_loss_fn.policy_loss_fn import PolicyLossFn
9-
from trinity.algorithm.utils import masked_mean
9+
from trinity.algorithm.utils import aggregate_loss, masked_mean
1010

1111

1212
class RECPolicyLossFn(PolicyLossFn):
@@ -123,7 +123,7 @@ def __call__( # type: ignore
123123

124124
if self.clip_mode == "gspo-one-side":
125125
# [EXPERIMENTAL] specialized for gspo-style rec variant for now
126-
pg_loss = masked_loss(
126+
pg_loss = aggregate_loss(
127127
values=pg_losses,
128128
mask=action_mask,
129129
loss_agg_mode="seq-mean-token-mean",

0 commit comments

Comments
 (0)