Skip to content

Commit 024597c

Browse files
committed
add lightning
1 parent 384588f commit 024597c

File tree

8 files changed

+280
-30
lines changed

8 files changed

+280
-30
lines changed

docs/docs/deploy/optimization.md

Lines changed: 22 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -117,51 +117,46 @@ veadk rl init --platform lightning --workspace veadk_rl_lightning_project
117117

118118
```bash
119119
cd veadk_rl_lightning_project
120-
veadk rl run --platform lightning --client
120+
python veadk_agent.py
121121
```
122122

123-
然后在终端2中运行以下命令,启动 server:
123+
然后在终端2中运行以下命令
124+
125+
- 首先重启 ray 集群:
126+
127+
```bash
128+
cd veadk_rl_lightning_project
129+
bash restart_ray.sh
130+
```
131+
132+
- 启动 server:
124133

125134
```bash
126135
cd veadk_rl_lightning_project
127-
veadk rl run --platform lightning --server
136+
bash train.sh
128137
```
129138

130139
#### 原理说明
131140

132141
生成后的项目结构如下,其中核心文件包括:
133142

134-
- agent_client: `examples/*/*_agent.py` 中定义了agent的rollout逻辑和reward规则
135-
- training_server: `examples/*/train.py` 定义了训练相关参数,用于启动训练服务器
143+
- agent_client: `*_agent.py` 中定义了agent的rollout逻辑和reward规则
144+
- training_server: `train.sh` 定义了训练相关参数,用于启动训练服务器
136145

137146
```shell
138147
veadk_rl_lightning_project
139-
├── agentligtning
140-
├── runner # 运行器:负责任务执行、调度、主流程管理
141-
├── tracer # 追踪模块:记录日志、链路追踪、调试信息
142-
├── trainer # 训练模块:支持模型训练、微调与评估逻辑
143-
├── verl # VERL强化学习组件
144-
└── server.py # 训练服务器
145-
└── examples # 示例项目,包含若干示例
146-
├── spider # 示例一:Spider 数据库问答任务
147-
├── sql_agent.py # sql agent的rollout逻辑和reward设定
148-
├── train.sh #训练服务器启动脚本,设定训练相关参数
149-
└── data # 数据集
150-
├── train.parquet # 训练数据集,需要为parquet格式
151-
└── eval.parquet # 评测数据集,需要为parquet格式
152-
├── rag # 示例二:RAG 应用示例
153-
├── rag_agent.py # rag agent的rollout逻辑和reward设定
154-
└── train.sh #训练服务器启动脚本,设定训练相关参数
155-
└── calc_x # 示例三:计算 agent 应用示例
156-
├── calc_agent.py # calculate agent的rollout逻辑和reward设定
157-
└── train.sh #训练服务器启动脚本,设定训练相关参数
158-
148+
├── data
149+
├── demo_train.parquet # 训练数据,必须为 parquet 格式
150+
├── demo_test.parquet # 测试数据,必须为 parquet 格式
151+
└── demo_calculate_agent.py # agent的rollout逻辑和reward设定
152+
└── train.sh # 训练服务器启动脚本,设定训练相关参数
153+
└── restart_ray.sh # 重启 ray 集群脚本
159154
```
160155

161156
#### 最佳实践案例
162157

163-
1. 脚手架中,基于 VeADK 的天气查询 Agent 进行强化学习优化
164-
2. 启动 client (veadk rl run --platform lightning --client) 与 server (veadk rl run --platform lightning --server),分别在终端1与终端2中运行以上命令
158+
1. 脚手架中,基于 VeADK 的算术 Agent 进行强化学习优化
159+
2. 启动 client (python demo_calculate_agent.py), 重启ray集群(bash restart_ray.sh), 最后启动训练服务器server (bash train.sh),分别在终端1与终端2中运行以上命令
165160

166161
![启动client](../assets/images/optimization/lightning_client.png)
167162

veadk/cli/cli_rl.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ def rl_group():
3838
"--platform",
3939
"-p",
4040
required=True,
41-
type=click.Choice(["ark"], case_sensitive=False),
42-
help="Scaffold platform type (only support for now: ark)",
41+
type=click.Choice(["ark", "lightning"], case_sensitive=False),
42+
help="Scaffold platform type (supported: ark, lightning)",
4343
)
4444
@click.option(
4545
"--workspace", "-w", required=True, type=str, help="Target workspace directory name"
@@ -52,8 +52,9 @@ def rl_group():
5252
)
5353
def rl_init(platform: str, workspace: str, overwrite: bool):
5454
"""
55-
Initialize RL scaffold project for ark platform
55+
Initialize RL scaffold project for ark or lightning platform
5656
Example: veadk rl init --platform ark --workspace veadk_rl_ark_project
57+
Example: veadk rl init --platform lightning --workspace veadk_rl_lightning_project
5758
"""
5859
# Locate template directory
5960
rl_template_root = get_rl_template_root()
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Agent Lightning
2+
3+
Agent Lightning 提供了灵活且可扩展的框架,实现了智能体(client)和训练(server)的完全解耦。
4+
VeADK 与 Agent Lightning 集成,用户使用 VeADK 提供的脚手架,可以开发 VeADK Agent,然后运行 client 与 server 进行强化学习优化。
5+
6+
## 准备工作
7+
8+
在你的终端中运行以下命令,初始化一个 Agent Lightning 项目:
9+
10+
```bash
11+
veadk rl init --platform lightning --workspace veadk_rl_lightning_project
12+
```
13+
14+
该命令会在当前目录下创建一个名为 `veadk_rl_lightning_project` 的文件夹,其中包含了一个基本的基于 VeADK 和 Agent Lightning 的强化学习项目结构。
15+
然后在终端1中运行以下命令,启动 client:
16+
17+
```bash
18+
cd veadk_rl_lightning_project
19+
python veadk_agent.py
20+
```
21+
22+
然后在终端2中运行以下命令
23+
24+
- 首先重启 ray 集群:
25+
26+
```bash
27+
cd veadk_rl_lightning_project
28+
bash restart_ray.sh
29+
```
30+
31+
- 启动 server:
32+
33+
```bash
34+
cd veadk_rl_lightning_project
35+
bash train.sh
36+
```
37+
38+
## 原理说明
39+
40+
生成后的项目结构如下,其中核心文件包括:
41+
42+
- agent_client: `*_agent.py` 中定义了agent的rollout逻辑和reward规则
43+
- training_server: `train.sh` 定义了训练相关参数,用于启动训练服务器
44+
45+
```shell
46+
veadk_rl_lightning_project
47+
├── data
48+
├── demo_train.parquet # 训练数据,必须为 parquet 格式
49+
├── demo_test.parquet # 测试数据,必须为 parquet 格式
50+
└── demo_calculate_agent.py # agent的rollout逻辑和reward设定
51+
└── train.sh # 训练服务器启动脚本,设定训练相关参数
52+
└── restart_ray.sh # 重启 ray 集群脚本
53+
```
54+
55+
## 最佳实践案例
56+
57+
1. 脚手架中,基于 VeADK 的算术 Agent 进行强化学习优化
58+
2. 启动 client (python demo_calculate_agent.py), 重启ray集群(bash restart_ray.sh), 最后启动训练服务器server (bash train.sh),分别在终端1与终端2中运行以上命令
8.36 KB
Binary file not shown.
109 KB
Binary file not shown.
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# Copyright (c) Microsoft. All rights reserved.
2+
3+
import math
4+
import re
5+
import string
6+
import sympy
7+
from typing import Any, cast
8+
from veadk.agent import Agent
9+
from veadk.runner import Runner
10+
from veadk.memory.short_term_memory import ShortTermMemory
11+
from agentlightning import (
12+
LLM,
13+
LitAgent,
14+
NamedResources,
15+
Trainer,
16+
reward,
17+
)
18+
19+
20+
def normalize_option(option: str) -> str:
21+
"""
22+
>>> normalize_option(" (A) \n")
23+
'A'
24+
"""
25+
return re.sub(r"(\s+|\(|\))", "", option)
26+
27+
28+
def is_option_result(result: str) -> bool:
29+
"""
30+
>>> is_option_result(" A) \n")
31+
True
32+
>>> is_option_result(" 23/7 ")
33+
False
34+
"""
35+
return normalize_option(result) in list(string.ascii_letters)
36+
37+
38+
def float_eval(input_str: str) -> float:
39+
if " = around " in input_str:
40+
input_str = input_str.split(" = around ")[0]
41+
expr = sympy.parse_expr(input_str, evaluate=True)
42+
return float(expr.evalf())
43+
44+
45+
def scalar_are_results_same(pred_result: str, true_result: str, rel_tol: float) -> bool:
46+
pred_result = str(pred_result) if pred_result is not None else "" # type: ignore
47+
true_result = str(true_result) if true_result is not None else "" # type: ignore
48+
49+
if pred_result.strip() == true_result.strip():
50+
return True
51+
52+
if is_option_result(true_result):
53+
# The task is to select correct option
54+
true_result = normalize_option(true_result)
55+
pred_result = normalize_option(pred_result)
56+
return pred_result == true_result
57+
58+
# The task is to calculate the result as a number
59+
try:
60+
pred_float = float_eval(pred_result)
61+
true_float = float_eval(true_result)
62+
return math.isclose(pred_float, true_float, rel_tol=rel_tol)
63+
except Exception:
64+
pass
65+
66+
return False
67+
68+
69+
@reward
70+
async def eval(prediction: str, ground_truth: str) -> float:
71+
return float(scalar_are_results_same(prediction, ground_truth, 1e-2))
72+
73+
74+
class CalcAgent(LitAgent[Any]):
75+
async def training_rollout_async(
76+
self, task: Any, rollout_id: str, resources: NamedResources
77+
) -> Any: # type: ignore
78+
llm: LLM = cast(LLM, resources.get("main_llm"))
79+
calc_agent = Agent(
80+
name="CalcAgent",
81+
description="An agent that can perform calculations to answer questions.",
82+
instruction="You are a helpful assistant that can perform mathematical calculations to answer questions accurately.",
83+
model_provider="openai",
84+
model=llm.model,
85+
api_base=llm.endpoint,
86+
api_key="",
87+
)
88+
runner = Runner(
89+
agent=calc_agent,
90+
short_term_memory=ShortTermMemory(),
91+
app_name="calc_agent",
92+
user_id="veadk_default_user",
93+
)
94+
try:
95+
output_format = "Output the answer when you are ready. The answer should be surrounded by three sharps (`###`), in the form of ### ANSWER: <answer> ###."
96+
prompt = task["question"] + " " + output_format
97+
result = await runner.run(
98+
session_id=rollout_id,
99+
messages=prompt,
100+
)
101+
# evaluate
102+
answer = re.search(
103+
r"###\s*ANSWER:\s*(.+?)(\s*###|$)", result.messages[-1].content
104+
) # type: ignore
105+
if answer:
106+
answer = answer.group(1)
107+
else:
108+
answer = result.messages[-1].content # type: ignore
109+
except Exception as e:
110+
print("Failure:", str(e))
111+
answer = "None"
112+
reward = await eval(
113+
answer, str(task["result"])
114+
) # reward is tracked with the decorator # type: ignore
115+
print(
116+
"answer: {} ground_truth: {} reward: {}".format(
117+
answer, task["result"], reward
118+
)
119+
) # type: ignore
120+
121+
async def validation_rollout_async(
122+
self, task: Any, rollout_id: str, resources: NamedResources
123+
) -> Any: # type: ignore
124+
llm: LLM = cast(LLM, resources.get("main_llm"))
125+
resources = {
126+
"main_llm": LLM(
127+
endpoint=llm.endpoint,
128+
model=llm.model,
129+
sampling_parameters={"temperature": 0},
130+
)
131+
}
132+
return await self.training_rollout_async(task, rollout_id, resources)
133+
134+
135+
if __name__ == "__main__":
136+
Trainer(n_workers=10).fit(CalcAgent(), "http://localhost:9999/")
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#!/bin/bash
2+
3+
set -ex
4+
5+
ray stop -v --force --grace-period 60
6+
ps aux
7+
env RAY_DEBUG=legacy HYDRA_FULL_ERROR=1 VLLM_USE_V1=1 ray start --head --dashboard-host=0.0.0.0 --port 6380 --dashboard-port 8266
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
#!/bin/bash
2+
3+
set -e
4+
5+
export N_GPUS=1
6+
export BASE_MODEL=Qwen/Qwen2.5-1.5B-Instruct
7+
export DATA_DIR=data
8+
export ROLLOUT_TP_SIZE=1
9+
export EXPERIMENT_NAME=calc_x
10+
export PROJECT_NAME=AgentLightning
11+
12+
echo "Starting training script..."
13+
14+
python -m agentlightning.verl \
15+
algorithm.adv_estimator=grpo \
16+
data.train_files=${DATA_DIR}/train.parquet \
17+
data.val_files=${DATA_DIR}/test.parquet \
18+
actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP_SIZE \
19+
trainer.n_gpus_per_node=${N_GPUS} \
20+
data.train_batch_size=32 \
21+
actor_rollout_ref.rollout.n=4 \
22+
actor_rollout_ref.actor.ppo_mini_batch_size=32 \
23+
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
24+
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \
25+
actor_rollout_ref.rollout.multi_turn.format=hermes \
26+
actor_rollout_ref.model.path=${BASE_MODEL} \
27+
data.max_prompt_length=4096 \
28+
data.max_response_length=2048 \
29+
data.truncation='error' \
30+
trainer.val_before_train=True \
31+
actor_rollout_ref.actor.optim.lr=1e-6 \
32+
actor_rollout_ref.model.use_remove_padding=True \
33+
actor_rollout_ref.actor.use_kl_loss=False \
34+
actor_rollout_ref.actor.kl_loss_coef=0.000 \
35+
actor_rollout_ref.actor.entropy_coeff=0 \
36+
actor_rollout_ref.actor.clip_ratio_low=0.2 \
37+
actor_rollout_ref.actor.clip_ratio_high=0.3 \
38+
actor_rollout_ref.model.enable_gradient_checkpointing=True \
39+
actor_rollout_ref.actor.fsdp_config.param_offload=True \
40+
actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
41+
actor_rollout_ref.rollout.name=vllm \
42+
actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \
43+
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \
44+
actor_rollout_ref.ref.fsdp_config.param_offload=True \
45+
algorithm.use_kl_in_reward=False \
46+
trainer.critic_warmup=0 \
47+
trainer.logger=['console','wandb'] \
48+
trainer.project_name=${PROJECT_NAME} \
49+
trainer.experiment_name=${EXPERIMENT_NAME} \
50+
trainer.nnodes=1 \
51+
trainer.save_freq=256 \
52+
trainer.test_freq=32 \
53+
trainer.total_epochs=2 $@

0 commit comments

Comments
 (0)