We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 542e4ac commit be4e78eCopy full SHA for be4e78e
trinity/algorithm/policy_loss_fn/gspo_policy_loss.py
@@ -11,8 +11,6 @@
11
from trinity.algorithm.utils import aggregate_loss, masked_mean
12
from trinity.utils.log import get_logger
13
14
-logger = get_logger(__name__)
15
-
16
17
class GSPOLossFn(PolicyLossFn):
18
def __init__(
@@ -35,6 +33,7 @@ def __init__(
35
33
self.clip_range_high = _clip_range_high
36
34
37
if loss_agg_mode != "seq-mean-token-mean":
+ logger = get_logger(__name__)
38
logger.warning(
39
f"The original GSPO paper requires loss_agg_mode to be 'seq-mean-token-mean', but the current setting is '{loss_agg_mode}'."
40
)
0 commit comments