Skip to content

Commit 8c94140

Browse files
committed
fix precommit
1 parent a967040 commit 8c94140

File tree

9 files changed

+28
-35
lines changed

9 files changed

+28
-35
lines changed

rllm/experimental/fully_async/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,4 +165,4 @@ async def chat_completion(
165165
return message, output
166166

167167
async def close(self):
168-
await self.client.aclose()
168+
await self.client.aclose()

rllm/experimental/fully_async/fully_async_trainer.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,24 +20,24 @@
2020
import ray
2121
from omegaconf import OmegaConf
2222
from tqdm import tqdm
23-
24-
from rllm.experimental.fully_async.message_queue import MessageQueueClient
25-
from rllm.experimental.fully_async.metric_utils import MetricsAggregator, ValidateMetrics
26-
from rllm.experimental.fully_async.utils import (
27-
assemble_batch_from_trajectory_group_ls,
28-
compute_grpo_outcome_advantage,
29-
reduce_metrics_with_flatten,
30-
)
3123
from verl import DataProto
3224
from verl.experimental.fully_async_policy.ray_trainer import FullyAsyncRayPPOTrainer
3325
from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup
3426
from verl.trainer.ppo import core_algos
3527
from verl.trainer.ppo.core_algos import agg_loss
3628
from verl.trainer.ppo.ray_trainer import ResourcePoolManager, apply_kl_penalty, compute_response_mask
37-
from verl.trainer.ppo.utils import Role, WorkerType, need_critic, need_reference_policy, need_reward_model
29+
from verl.trainer.ppo.utils import Role, WorkerType, need_critic, need_reference_policy
3830
from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path, should_save_ckpt_esi
3931
from verl.utils.debug import marked_timer
4032

33+
from rllm.experimental.fully_async.message_queue import MessageQueueClient
34+
from rllm.experimental.fully_async.metric_utils import MetricsAggregator, ValidateMetrics
35+
from rllm.experimental.fully_async.utils import (
36+
assemble_batch_from_trajectory_group_ls,
37+
compute_grpo_outcome_advantage,
38+
reduce_metrics_with_flatten,
39+
)
40+
4141

4242
@ray.remote(num_cpus=10)
4343
class FullyAsyncTrainer(FullyAsyncRayPPOTrainer):
@@ -637,4 +637,4 @@ def compute_old_log_prob(batch):
637637

638638
actor_output_metrics = reduce_metrics_with_flatten(actor_output.meta_info["metrics"])
639639
metrics.update(actor_output_metrics)
640-
return batch, {}
640+
return batch, {}

rllm/experimental/fully_async/inference_manager.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import asyncio
1615
import subprocess
1716

1817
import ray
19-
2018
from verl.experimental.fully_async_policy.ray_trainer import FullyAsyncRayPPOTrainer
2119
from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup
2220
from verl.trainer.ppo.ray_trainer import ResourcePoolManager
@@ -163,4 +161,4 @@ def launch_router(self, port: int = 30000):
163161
return self.router_url
164162

165163
async def clear_kv_cache(self):
166-
await self.async_rollout_manager.clear_kv_cache()
164+
await self.async_rollout_manager.clear_kv_cache()

rllm/experimental/fully_async/message_queue.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,4 +215,4 @@ def put_sample_sync(self, sample: Any) -> bool:
215215

216216
def get_sample_sync(self) -> Any | None:
217217
"""Get single sample from queue (sync - deprecated, use get_sample instead)"""
218-
return ray.get(self.queue_actor.get_sample.remote())
218+
return ray.get(self.queue_actor.get_sample.remote())

rllm/experimental/fully_async/message_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import json
66
import re
77

8-
98
# Regex for thinking content: <think>...</think>
109
THINK_PATTERN = re.compile(r"<think>(.*?)</think>", re.DOTALL)
1110

@@ -73,4 +72,4 @@ def build_tool_message(tool_name: str, tool_output: str, tool_call_id: str | Non
7372
message = {"role": "tool", "name": tool_name, "content": tool_output}
7473
if tool_call_id:
7574
message["tool_call_id"] = tool_call_id
76-
return message
75+
return message

rllm/experimental/fully_async/metric_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import time
2323
from collections import defaultdict
2424
from dataclasses import dataclass
25-
from typing import Any, Optional
25+
from typing import Any
2626

2727
import numpy as np
2828
import torch
@@ -33,9 +33,9 @@ class ValidateMetrics:
3333
"""Metrics for validation"""
3434

3535
timing_raw: dict[str, Any]
36-
metrics: Optional[dict[str, Any]] = None
37-
global_steps: Optional[int] = None
38-
param_version: Optional[int] = None
36+
metrics: dict[str, Any] | None = None
37+
global_steps: int | None = None
38+
param_version: int | None = None
3939

4040

4141
class MetricsAggregator:
@@ -225,4 +225,4 @@ def get_current_stats(self) -> dict[str, Any]:
225225
"metric_count": len(self.metric_values),
226226
"total_samples": sum(self.sample_counts),
227227
"metric_names": list(self.metric_values.keys()),
228-
}
228+
}

rllm/experimental/fully_async/param_sync.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
import ray
1919
from ray.util.collective import collective
20-
2120
from verl.utils.device import get_nccl_backend
2221

2322
logger = logging.getLogger(__name__)
@@ -152,7 +151,7 @@ def sync_weights(self, version, validate=False, global_steps=0):
152151
# Update staleness tracking - subtracts consumed samples from enqueued count
153152
# This must be called AFTER resume so continue_event can be set if there's capacity
154153
ray.get(self.rollout_executor.update_staleness_tracking.remote())
155-
print(f"[ParameterSynchronizer] update_staleness_tracking completed", flush=True)
154+
print("[ParameterSynchronizer] update_staleness_tracking completed", flush=True)
156155

157156
pause_time = time.time()
158157

@@ -197,4 +196,4 @@ def rollout_executor_save_checkpoint(self, local_global_step_folder: str):
197196
if not hasattr(self, "rollout_executor") or self.rollout_executor is None:
198197
raise RuntimeError("rollout_executor is not set; call set_rollout_executor() before saving checkpoint")
199198
print(f"[ParameterSynchronizer] Triggering RolloutExecutor checkpoint save at {local_global_step_folder} ...")
200-
return ray.get(self.rollout_executor.save_checkpoint.remote(local_global_step_folder))
199+
return ray.get(self.rollout_executor.save_checkpoint.remote(local_global_step_folder))

rllm/experimental/fully_async/protocol.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from dataclasses import dataclass, field
2-
from typing import Optional
1+
from dataclasses import dataclass
32

43

54
@dataclass
@@ -175,4 +174,4 @@ def merge(self):
175174

176175
@dataclass
177176
class TrajectoryGroup:
178-
trajectories: list[Trajectory]
177+
trajectories: list[Trajectory]

rllm/experimental/fully_async/runner.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,16 @@
2222

2323
import ray
2424
from omegaconf import OmegaConf
25+
from verl.experimental.fully_async_policy.fully_async_main import create_resource_pool_manager, create_role_worker_mapping
26+
from verl.trainer.ppo.utils import Role
27+
from verl.utils.fs import copy_to_local
2528

26-
from rllm.experimental.fully_async.inference_manager import InferenceManager
2729
from rllm.experimental.fully_async.fully_async_trainer import FullyAsyncTrainer
30+
from rllm.experimental.fully_async.inference_manager import InferenceManager
2831
from rllm.experimental.fully_async.message_queue import MessageQueue, MessageQueueClient
2932
from rllm.experimental.fully_async.param_sync import ParameterSynchronizer
30-
from rllm.experimental.fully_async.protocol import Trajectory
3133
from rllm.experimental.fully_async.rollout_executor import RolloutExecutor
3234
from rllm.experimental.fully_async.utils import calculate_max_concurrency
33-
from verl.experimental.fully_async_policy.fully_async_main import create_resource_pool_manager, create_role_worker_mapping
34-
from verl.trainer.ppo.ray_trainer import ResourcePoolManager
35-
from verl.trainer.ppo.utils import Role, need_reference_policy
36-
from verl.utils.fs import copy_to_local
3735

3836

3937
def create_task_runner_with_rollout_fn(rollout_fn, val_rollout_fn=None):
@@ -294,4 +292,4 @@ def train(self):
294292
task_runner_class = create_task_runner_with_rollout_fn(self.rollout_fn, self.val_rollout_fn)
295293
run_ppo(self.config, task_runner_class=task_runner_class)
296294

297-
print(f"total time: {time.time() - start_time:.2f} seconds")
295+
print(f"total time: {time.time() - start_time:.2f} seconds")

0 commit comments

Comments
 (0)