Skip to content

Commit 4526cf1

Browse files
authored
feat(cli): add lightning (#411)
* add lightning * add license
1 parent 39d51b3 commit 4526cf1

File tree

8 files changed

+292
-30
lines changed

8 files changed

+292
-30
lines changed

docs/docs/deploy/optimization.md

Lines changed: 22 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -117,51 +117,46 @@ veadk rl init --platform lightning --workspace veadk_rl_lightning_project
117117

118118
```bash
119119
cd veadk_rl_lightning_project
120-
veadk rl run --platform lightning --client
120+
python veadk_agent.py
121121
```
122122

123-
然后在终端2中运行以下命令,启动 server:
123+
然后在终端2中运行以下命令
124+
125+
- 首先重启 ray 集群:
126+
127+
```bash
128+
cd veadk_rl_lightning_project
129+
bash restart_ray.sh
130+
```
131+
132+
- 启动 server:
124133

125134
```bash
126135
cd veadk_rl_lightning_project
127-
veadk rl run --platform lightning --server
136+
bash train.sh
128137
```
129138

130139
#### 原理说明
131140

132141
生成后的项目结构如下,其中核心文件包括:
133142

134-
- agent_client: `examples/*/*_agent.py` 中定义了agent的rollout逻辑和reward规则
135-
- training_server: `examples/*/train.py` 定义了训练相关参数,用于启动训练服务器
143+
- agent_client: `*_agent.py` 中定义了agent的rollout逻辑和reward规则
144+
- training_server: `train.sh` 定义了训练相关参数,用于启动训练服务器
136145

137146
```shell
138147
veadk_rl_lightning_project
139-
├── agentligtning
140-
├── runner # 运行器:负责任务执行、调度、主流程管理
141-
├── tracer # 追踪模块:记录日志、链路追踪、调试信息
142-
├── trainer # 训练模块:支持模型训练、微调与评估逻辑
143-
├── verl # VERL强化学习组件
144-
└── server.py # 训练服务器
145-
└── examples # 示例项目,包含若干示例
146-
├── spider # 示例一:Spider 数据库问答任务
147-
├── sql_agent.py # sql agent的rollout逻辑和reward设定
148-
├── train.sh #训练服务器启动脚本,设定训练相关参数
149-
└── data # 数据集
150-
├── train.parquet # 训练数据集,需要为parquet格式
151-
└── eval.parquet # 评测数据集,需要为parquet格式
152-
├── rag # 示例二:RAG 应用示例
153-
├── rag_agent.py # rag agent的rollout逻辑和reward设定
154-
└── train.sh #训练服务器启动脚本,设定训练相关参数
155-
└── calc_x # 示例三:计算 agent 应用示例
156-
├── calc_agent.py # calculate agent的rollout逻辑和reward设定
157-
└── train.sh #训练服务器启动脚本,设定训练相关参数
158-
148+
├── data
149+
├── demo_train.parquet # 训练数据,必须为 parquet 格式
150+
├── demo_test.parquet # 测试数据,必须为 parquet 格式
151+
└── demo_calculate_agent.py # agent的rollout逻辑和reward设定
152+
└── train.sh # 训练服务器启动脚本,设定训练相关参数
153+
└── restart_ray.sh # 重启 ray 集群脚本
159154
```
160155

161156
#### 最佳实践案例
162157

163-
1. 脚手架中,基于 VeADK 的天气查询 Agent 进行强化学习优化
164-
2. 启动 client (veadk rl run --platform lightning --client) 与 server (veadk rl run --platform lightning --server),分别在终端1与终端2中运行以上命令
158+
1. 脚手架中,基于 VeADK 的算术 Agent 进行强化学习优化
159+
2. 启动 client (python demo_calculate_agent.py), 重启ray集群(bash restart_ray.sh), 最后启动训练服务器server (bash train.sh),分别在终端1与终端2中运行以上命令
165160

166161
![启动client](../assets/images/optimization/lightning_client.png)
167162

veadk/cli/cli_rl.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ def rl_group():
3838
"--platform",
3939
"-p",
4040
required=True,
41-
type=click.Choice(["ark"], case_sensitive=False),
42-
help="Scaffold platform type (only support for now: ark)",
41+
type=click.Choice(["ark", "lightning"], case_sensitive=False),
42+
help="Scaffold platform type (supported: ark, lightning)",
4343
)
4444
@click.option(
4545
"--workspace", "-w", required=True, type=str, help="Target workspace directory name"
@@ -52,8 +52,9 @@ def rl_group():
5252
)
5353
def rl_init(platform: str, workspace: str, overwrite: bool):
5454
"""
55-
Initialize RL scaffold project for ark platform
55+
Initialize RL scaffold project for ark or lightning platform
5656
Example: veadk rl init --platform ark --workspace veadk_rl_ark_project
57+
Example: veadk rl init --platform lightning --workspace veadk_rl_lightning_project
5758
"""
5859
# Locate template directory
5960
rl_template_root = get_rl_template_root()
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Agent Lightning
2+
3+
Agent Lightning 提供了灵活且可扩展的框架,实现了智能体(client)和训练(server)的完全解耦。
4+
VeADK 与 Agent Lightning 集成,用户使用 VeADK 提供的脚手架,可以开发 VeADK Agent,然后运行 client 与 server 进行强化学习优化。
5+
6+
## 准备工作
7+
8+
在你的终端中运行以下命令,初始化一个 Agent Lightning 项目:
9+
10+
```bash
11+
veadk rl init --platform lightning --workspace veadk_rl_lightning_project
12+
```
13+
14+
该命令会在当前目录下创建一个名为 `veadk_rl_lightning_project` 的文件夹,其中包含了一个基本的基于 VeADK 和 Agent Lightning 的强化学习项目结构。
15+
然后在终端1中运行以下命令,启动 client:
16+
17+
```bash
18+
cd veadk_rl_lightning_project
19+
python veadk_agent.py
20+
```
21+
22+
然后在终端2中运行以下命令
23+
24+
- 首先重启 ray 集群:
25+
26+
```bash
27+
cd veadk_rl_lightning_project
28+
bash restart_ray.sh
29+
```
30+
31+
- 启动 server:
32+
33+
```bash
34+
cd veadk_rl_lightning_project
35+
bash train.sh
36+
```
37+
38+
## 原理说明
39+
40+
生成后的项目结构如下,其中核心文件包括:
41+
42+
- agent_client: `*_agent.py` 中定义了agent的rollout逻辑和reward规则
43+
- training_server: `train.sh` 定义了训练相关参数,用于启动训练服务器
44+
45+
```shell
46+
veadk_rl_lightning_project
47+
├── data
48+
├── demo_train.parquet # 训练数据,必须为 parquet 格式
49+
├── demo_test.parquet # 测试数据,必须为 parquet 格式
50+
└── demo_calculate_agent.py # agent的rollout逻辑和reward设定
51+
└── train.sh # 训练服务器启动脚本,设定训练相关参数
52+
└── restart_ray.sh # 重启 ray 集群脚本
53+
```
54+
55+
## 最佳实践案例
56+
57+
1. 脚手架中,基于 VeADK 的算术 Agent 进行强化学习优化
58+
2. 启动 client (python demo_calculate_agent.py), 重启ray集群(bash restart_ray.sh), 最后启动训练服务器server (bash train.sh),分别在终端1与终端2中运行以上命令
8.36 KB
Binary file not shown.
109 KB
Binary file not shown.
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import math
16+
import re
17+
import string
18+
import sympy
19+
from typing import Any, cast
20+
from veadk.agent import Agent
21+
from veadk.runner import Runner
22+
from veadk.memory.short_term_memory import ShortTermMemory
23+
from agentlightning import (
24+
LLM,
25+
LitAgent,
26+
NamedResources,
27+
Trainer,
28+
reward,
29+
)
30+
31+
32+
def normalize_option(option: str) -> str:
33+
"""
34+
>>> normalize_option(" (A) \n")
35+
'A'
36+
"""
37+
return re.sub(r"(\s+|\(|\))", "", option)
38+
39+
40+
def is_option_result(result: str) -> bool:
41+
"""
42+
>>> is_option_result(" A) \n")
43+
True
44+
>>> is_option_result(" 23/7 ")
45+
False
46+
"""
47+
return normalize_option(result) in list(string.ascii_letters)
48+
49+
50+
def float_eval(input_str: str) -> float:
51+
if " = around " in input_str:
52+
input_str = input_str.split(" = around ")[0]
53+
expr = sympy.parse_expr(input_str, evaluate=True)
54+
return float(expr.evalf())
55+
56+
57+
def scalar_are_results_same(pred_result: str, true_result: str, rel_tol: float) -> bool:
58+
pred_result = str(pred_result) if pred_result is not None else "" # type: ignore
59+
true_result = str(true_result) if true_result is not None else "" # type: ignore
60+
61+
if pred_result.strip() == true_result.strip():
62+
return True
63+
64+
if is_option_result(true_result):
65+
# The task is to select correct option
66+
true_result = normalize_option(true_result)
67+
pred_result = normalize_option(pred_result)
68+
return pred_result == true_result
69+
70+
# The task is to calculate the result as a number
71+
try:
72+
pred_float = float_eval(pred_result)
73+
true_float = float_eval(true_result)
74+
return math.isclose(pred_float, true_float, rel_tol=rel_tol)
75+
except Exception:
76+
pass
77+
78+
return False
79+
80+
81+
@reward
82+
async def eval(prediction: str, ground_truth: str) -> float:
83+
return float(scalar_are_results_same(prediction, ground_truth, 1e-2))
84+
85+
86+
class CalcAgent(LitAgent[Any]):
87+
async def training_rollout_async(
88+
self, task: Any, rollout_id: str, resources: NamedResources
89+
) -> Any: # type: ignore
90+
llm: LLM = cast(LLM, resources.get("main_llm"))
91+
calc_agent = Agent(
92+
name="CalcAgent",
93+
description="An agent that can perform calculations to answer questions.",
94+
instruction="You are a helpful assistant that can perform mathematical calculations to answer questions accurately.",
95+
model_provider="openai",
96+
model=llm.model,
97+
api_base=llm.endpoint,
98+
api_key="",
99+
)
100+
runner = Runner(
101+
agent=calc_agent,
102+
short_term_memory=ShortTermMemory(),
103+
app_name="calc_agent",
104+
user_id="veadk_default_user",
105+
)
106+
try:
107+
output_format = "Output the answer when you are ready. The answer should be surrounded by three sharps (`###`), in the form of ### ANSWER: <answer> ###."
108+
prompt = task["question"] + " " + output_format
109+
result = await runner.run(
110+
session_id=rollout_id,
111+
messages=prompt,
112+
)
113+
# evaluate
114+
answer = re.search(
115+
r"###\s*ANSWER:\s*(.+?)(\s*###|$)", result.messages[-1].content
116+
) # type: ignore
117+
if answer:
118+
answer = answer.group(1)
119+
else:
120+
answer = result.messages[-1].content # type: ignore
121+
except Exception as e:
122+
print("Failure:", str(e))
123+
answer = "None"
124+
reward = await eval(
125+
answer, str(task["result"])
126+
) # reward is tracked with the decorator # type: ignore
127+
print(
128+
"answer: {} ground_truth: {} reward: {}".format(
129+
answer, task["result"], reward
130+
)
131+
) # type: ignore
132+
133+
async def validation_rollout_async(
134+
self, task: Any, rollout_id: str, resources: NamedResources
135+
) -> Any: # type: ignore
136+
llm: LLM = cast(LLM, resources.get("main_llm"))
137+
resources = {
138+
"main_llm": LLM(
139+
endpoint=llm.endpoint,
140+
model=llm.model,
141+
sampling_parameters={"temperature": 0},
142+
)
143+
}
144+
return await self.training_rollout_async(task, rollout_id, resources)
145+
146+
147+
if __name__ == "__main__":
148+
Trainer(n_workers=10).fit(CalcAgent(), "http://localhost:9999/")
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#!/bin/bash
2+
3+
set -ex
4+
5+
ray stop -v --force --grace-period 60
6+
ps aux
7+
env RAY_DEBUG=legacy HYDRA_FULL_ERROR=1 VLLM_USE_V1=1 ray start --head --dashboard-host=0.0.0.0 --port 6380 --dashboard-port 8266
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
#!/bin/bash
2+
3+
set -e
4+
5+
export N_GPUS=1
6+
export BASE_MODEL=Qwen/Qwen2.5-1.5B-Instruct
7+
export DATA_DIR=data
8+
export ROLLOUT_TP_SIZE=1
9+
export EXPERIMENT_NAME=calc_x
10+
export PROJECT_NAME=AgentLightning
11+
12+
echo "Starting training script..."
13+
14+
python -m agentlightning.verl \
15+
algorithm.adv_estimator=grpo \
16+
data.train_files=${DATA_DIR}/train.parquet \
17+
data.val_files=${DATA_DIR}/test.parquet \
18+
actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP_SIZE \
19+
trainer.n_gpus_per_node=${N_GPUS} \
20+
data.train_batch_size=32 \
21+
actor_rollout_ref.rollout.n=4 \
22+
actor_rollout_ref.actor.ppo_mini_batch_size=32 \
23+
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
24+
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \
25+
actor_rollout_ref.rollout.multi_turn.format=hermes \
26+
actor_rollout_ref.model.path=${BASE_MODEL} \
27+
data.max_prompt_length=4096 \
28+
data.max_response_length=2048 \
29+
data.truncation='error' \
30+
trainer.val_before_train=True \
31+
actor_rollout_ref.actor.optim.lr=1e-6 \
32+
actor_rollout_ref.model.use_remove_padding=True \
33+
actor_rollout_ref.actor.use_kl_loss=False \
34+
actor_rollout_ref.actor.kl_loss_coef=0.000 \
35+
actor_rollout_ref.actor.entropy_coeff=0 \
36+
actor_rollout_ref.actor.clip_ratio_low=0.2 \
37+
actor_rollout_ref.actor.clip_ratio_high=0.3 \
38+
actor_rollout_ref.model.enable_gradient_checkpointing=True \
39+
actor_rollout_ref.actor.fsdp_config.param_offload=True \
40+
actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
41+
actor_rollout_ref.rollout.name=vllm \
42+
actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \
43+
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \
44+
actor_rollout_ref.ref.fsdp_config.param_offload=True \
45+
algorithm.use_kl_in_reward=False \
46+
trainer.critic_warmup=0 \
47+
trainer.logger=['console','wandb'] \
48+
trainer.project_name=${PROJECT_NAME} \
49+
trainer.experiment_name=${EXPERIMENT_NAME} \
50+
trainer.nnodes=1 \
51+
trainer.save_freq=256 \
52+
trainer.test_freq=32 \
53+
trainer.total_epochs=2 $@

0 commit comments

Comments
 (0)