We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent c389ba6 commit e2df335Copy full SHA for e2df335
trinity/algorithm/policy_loss_fn/gspo_policy_loss.py
@@ -9,7 +9,6 @@
9
10
from trinity.algorithm.policy_loss_fn.policy_loss_fn import PolicyLossFn
11
from trinity.algorithm.utils import aggregate_loss, masked_mean
12
-from trinity.utils.log import get_logger
13
14
15
class GSPOLossFn(PolicyLossFn):
@@ -33,6 +32,8 @@ def __init__(
33
32
self.clip_range_high = _clip_range_high
34
35
if loss_agg_mode != "seq-mean-token-mean":
+ from trinity.utils.log import get_logger
36
+
37
logger = get_logger(__name__)
38
logger.warning(
39
f"The original GSPO paper requires loss_agg_mode to be 'seq-mean-token-mean', but the current setting is '{loss_agg_mode}'."
0 commit comments