Skip to content

Commit 896db9b

Browse files
authored
[rollout] fix: mlflow consecutive slashes (verl-project#4446)
### What does this PR do? MLFlow would not work with metrics that have // in its item name, it will yield error like so: ``` File "/usr/local/lib/python3.12/dist-packages/mlflow/tracking/client.py", line 2511, in log_batch return self._tracking_client.log_batch( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/mlflow/telemetry/track.py", line 30, in wrapper result = func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/mlflow/tracking/_tracking_service/client.py", line 581, in log_batch self.store.log_batch(run_id=run_id, metrics=metrics_batch, params=[], tags=[]) File "/usr/local/lib/python3.12/dist-packages/mlflow/store/tracking/rest_store.py", line 906, in log_batch self._call_endpoint(LogBatch, req_body) File "/usr/local/lib/python3.12/dist-packages/mlflow/store/tracking/rest_store.py", line 208, in _call_endpoint return call_endpoint( ^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/mlflow/utils/rest_utils.py", line 596, in call_endpoint response = verify_rest_response(response, endpoint) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/mlflow/utils/rest_utils.py", line 315, in verify_rest_response raise RestException(json.loads(response.text)) mlflow.exceptions.RestException: INVALID_PARAMETER_VALUE: Invalid value "val-aux//reward/mean_at_1" for parameter 'metrics[0].name' supplied: Names may be treated as files in certain cases, and must not resolve to other names when treated as such. This name would resolve to 'val-aux/reward/mean_at_1' ``` ### Test Added testing for this behavior into `TestMlflowLoggingAdapter`. ### Design & Code Changes Used regular expression to parse and substituted multiple slashes pattern
1 parent aee5aa8 commit 896db9b

File tree

2 files changed

+26
-3
lines changed

2 files changed

+26
-3
lines changed

tests/utils/test_mlflow_key_sanitization.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,24 +20,44 @@
2020

2121
class TestMlflowLoggingAdapter(unittest.TestCase):
2222
def test_sanitize_key_and_warning(self):
23+
"""Test key sanitization for invalid characters and consecutive slashes with warnings."""
2324
adapter = _MlflowLoggingAdapter()
24-
data = {"valid_key": 1.0, "invalid@key!": 2.0, "another/valid-key": 3.0, "bad key#": 4.0}
25+
data = {
26+
"valid_key": 1.0,
27+
"invalid@key!": 2.0,
28+
"another/valid-key": 3.0,
29+
"bad key#": 4.0,
30+
"val-aux//reward/mean_at_1": 5.0,
31+
"val-core///acc/best_at_5": 6.0,
32+
"metric////with/many////slashes": 7.0,
33+
}
2534
# Patch mlflow.log_metrics to capture the metrics actually sent
2635
with (
2736
patch("mlflow.log_metrics") as mock_log_metrics,
2837
patch.object(adapter, "logger") as mock_logger,
2938
):
3039
adapter.log(data, step=5)
31-
# Check that keys are sanitized
40+
# Check that invalid characters are sanitized
3241
sent_metrics = mock_log_metrics.call_args[1]["metrics"]
3342
self.assertIn("invalid_at_key_", sent_metrics) # @ becomes _at_, ! becomes _
3443
self.assertIn("bad key_", sent_metrics) # # becomes _, space remains
3544
self.assertNotIn("invalid@key!", sent_metrics)
3645
self.assertNotIn("bad key#", sent_metrics)
37-
# Check that a warning was logged for each sanitized key
46+
# Check that consecutive slashes are collapsed to single slashes
47+
self.assertIn("val-aux/reward/mean_at_1", sent_metrics)
48+
self.assertIn("val-core/acc/best_at_5", sent_metrics)
49+
self.assertIn("metric/with/many/slashes", sent_metrics)
50+
self.assertNotIn("val-aux//reward/mean_at_1", sent_metrics)
51+
self.assertNotIn("val-core///acc/best_at_5", sent_metrics)
52+
# Check that warnings were logged for all sanitized keys
3853
warning_msgs = [str(call) for call in mock_logger.warning.call_args_list]
54+
# Warnings for invalid characters
3955
self.assertTrue(any("invalid@key!" in msg and "invalid_at_key_" in msg for msg in warning_msgs))
4056
self.assertTrue(any("bad key#" in msg and "bad key_" in msg for msg in warning_msgs))
57+
# Warnings for consecutive slashes
58+
self.assertTrue(any("val-aux//reward/mean_at_1" in msg for msg in warning_msgs))
59+
self.assertTrue(any("val-core///acc/best_at_5" in msg for msg in warning_msgs))
60+
self.assertTrue(any("metric////with/many////slashes" in msg for msg in warning_msgs))
4161

4262

4363
if __name__ == "__main__":

verl/utils/tracking.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,13 +278,16 @@ def __init__(self):
278278
self._invalid_chars_pattern = re.compile(
279279
r"[^/\w.\- :]"
280280
) # Allowed: slashes, alphanumerics, underscores, periods, dashes, colons, and spaces.
281+
self._consecutive_slashes_pattern = re.compile(r"/+")
281282

282283
def log(self, data, step):
283284
import mlflow
284285

285286
def sanitize_key(key):
286287
# First replace @ with _at_ for backward compatibility
287288
sanitized = key.replace("@", "_at_")
289+
# Replace consecutive slashes with a single slash (MLflow treats them as file paths)
290+
sanitized = self._consecutive_slashes_pattern.sub("/", sanitized)
288291
# Then replace any other invalid characters with _
289292
sanitized = self._invalid_chars_pattern.sub("_", sanitized)
290293
if sanitized != key:

0 commit comments

Comments
 (0)