Skip to content

Commit e2df335

Browse files
committed
Move logger import
1 parent c389ba6 commit e2df335

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

trinity/algorithm/policy_loss_fn/gspo_policy_loss.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
from trinity.algorithm.policy_loss_fn.policy_loss_fn import PolicyLossFn
1111
from trinity.algorithm.utils import aggregate_loss, masked_mean
12-
from trinity.utils.log import get_logger
1312

1413

1514
class GSPOLossFn(PolicyLossFn):
@@ -33,6 +32,8 @@ def __init__(
3332
self.clip_range_high = _clip_range_high
3433

3534
if loss_agg_mode != "seq-mean-token-mean":
35+
from trinity.utils.log import get_logger
36+
3637
logger = get_logger(__name__)
3738
logger.warning(
3839
f"The original GSPO paper requires loss_agg_mode to be 'seq-mean-token-mean', but the current setting is '{loss_agg_mode}'."

0 commit comments

Comments
 (0)