Skip to content

Commit 643cfdb

Browse files
authored
feat: add templates & cli_rl (#339)
* add templates & cli_rl * add readme * update ark_rk readme
1 parent ed2ff9f commit 643cfdb

File tree

13 files changed

+793
-0
lines changed

13 files changed

+793
-0
lines changed

veadk/cli/cli.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from veadk.cli.cli_uploadevalset import uploadevalset
2727
from veadk.cli.cli_update import update
2828
from veadk.cli.cli_clean import clean
29+
from veadk.cli.cli_rl import rl_group
2930
from veadk.version import VERSION
3031

3132

@@ -53,6 +54,7 @@ def veadk():
5354
veadk.add_command(uploadevalset)
5455
veadk.add_command(update)
5556
veadk.add_command(clean)
57+
veadk.add_command(rl_group)
5658

5759
if __name__ == "__main__":
5860
veadk()

veadk/cli/cli_rl.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import click
16+
import shutil
17+
import sys
18+
from pathlib import Path
19+
20+
21+
def get_rl_template_root() -> Path:
22+
"""Get absolute path of RL scaffold template root (cli/templates/rl/)"""
23+
current_file = Path(__file__).resolve()
24+
cli_dir = current_file.parent
25+
rl_template_root = cli_dir / "templates" / "rl"
26+
return rl_template_root
27+
28+
29+
@click.group(name="rl", help="RL related commands")
30+
def rl_group():
31+
pass
32+
33+
34+
@rl_group.command(
35+
name="init", help="Initialize RL scaffold project (specify platform/workspace)"
36+
)
37+
@click.option(
38+
"--platform",
39+
"-p",
40+
required=True,
41+
type=click.Choice(["ark"], case_sensitive=False),
42+
help="Scaffold platform type (only support for now: ark)",
43+
)
44+
@click.option(
45+
"--workspace", "-w", required=True, type=str, help="Target workspace directory name"
46+
)
47+
@click.option(
48+
"--overwrite",
49+
"-f",
50+
is_flag=True,
51+
help="Force overwrite existing workspace (default: false)",
52+
)
53+
def rl_init(platform: str, workspace: str, overwrite: bool):
54+
"""
55+
Initialize RL scaffold project for ark platform
56+
Example: veadk rl init --platform ark --workspace veadk_rl_ark_project
57+
"""
58+
# Locate template directory
59+
rl_template_root = get_rl_template_root()
60+
platform_template_dir = rl_template_root / platform.lower()
61+
62+
# Validate template directory
63+
if not platform_template_dir.exists():
64+
click.secho(f"Error: Scaffold template for {platform} not found!", fg="red")
65+
click.secho(f" Expected path: {platform_template_dir}", fg="yellow")
66+
click.secho(
67+
f" Supported platforms: {[d.name for d in rl_template_root.glob('*') if d.is_dir()]}",
68+
fg="blue",
69+
)
70+
sys.exit(1)
71+
72+
# Target workspace path
73+
target_workspace = Path.cwd() / workspace
74+
75+
# Handle existing directory
76+
if target_workspace.exists():
77+
if not overwrite:
78+
click.secho(
79+
f"\nWarning: Target directory {target_workspace} already exists!",
80+
fg="yellow",
81+
)
82+
if not click.confirm("Overwrite?"):
83+
click.secho("Operation cancelled", fg="red")
84+
sys.exit(0)
85+
shutil.rmtree(target_workspace)
86+
click.secho(f"Cleared existing directory: {target_workspace}", fg="green")
87+
88+
# Copy scaffold files
89+
try:
90+
shutil.copytree(
91+
src=platform_template_dir,
92+
dst=target_workspace,
93+
ignore=None,
94+
dirs_exist_ok=False,
95+
)
96+
click.secho("\nRL scaffold initialized successfully!", fg="green")
97+
click.secho(f" - Project path: {target_workspace.absolute()}", fg="green")
98+
except PermissionError:
99+
click.secho(
100+
f"Error: Permission denied to write to {target_workspace}", fg="red"
101+
)
102+
sys.exit(1)
103+
except Exception as e:
104+
click.secho(f"Error: Failed to copy scaffold - {str(e)}", fg="red")
105+
sys.exit(1)
106+
107+
108+
if __name__ == "__main__":
109+
rl_group()
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
3.10
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# 基于方舟平台强化学习
2+
方舟 RL 将强化学习过程进行了一定程度的封装,降低了复杂度。用户主要关注 rollout 中的 agent 逻辑、奖励函数的构建、训练样本的选择即可。
3+
VeADK 与方舟平台 Agent RL 集成,用户使用 VeADK 提供的脚手架,可以开发 VeADK Agent,然后提交任务到方舟平台进行强化学习优化。
4+
## 准备工作
5+
在你的终端中运行以下命令,初始化一个强化学习项目:
6+
```shell
7+
veadk rl init --platform ark --workspace veadk_rl_ark_project
8+
```
9+
该命令会在当前目录下创建一个名为 `veadk_rl_ark_project` 的文件夹,其中包含了一个基本的强化学习项目结构。
10+
然后在终端中运行以下命令,提交任务到方舟平台:
11+
```shell
12+
cd veadk_rl_ark_project
13+
veadk rl submit --platform ark
14+
```
15+
## 原理说明
16+
生成后的项目结构如下,其中核心文件包括:
17+
- 数据集: `data/*.jsonl`
18+
- `/plugins`文件夹下的rollout和reward:
19+
- rollout :用以规定agent的工作流,`raw_async_veadk_rollout.py`提供了使用在方舟rl中使用veadk agent的示例,
20+
- reward:给出强化学习所需的奖励值,在`random_reward.py`给出了示例
21+
- `job.py``job.yaml`:用以配置训练参数,并指定需要使用的rollout和reward
22+
```shell
23+
veadk_rl_ark_project
24+
├── data
25+
├── *.jsonl # 训练数据
26+
└── plugins
27+
├── async_weather_rollout.py #
28+
├── config.yaml.example # VeADK agent 配置信息示例
29+
├── random_reward.py # reward规则设定
30+
├── raw_async_veadk_rollout.py # rollout工作流设定
31+
├── raw_rollout.py #
32+
└── test_utils.py #
33+
└── weather_rollout.py #
34+
├── job.py # 任务提交代码
35+
├── job.yaml # 任务配置
36+
├── test_agent.py # VeFaaS 测试脚本
37+
```
38+
## 运行
39+
```bash
40+
ark create mcj -f job.yaml
41+
```
42+
43+
```bash
44+
python job.py
45+
```
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[plugin.package]
2+
include=["*.py"]
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{"messages":[{"role":"user","content":"将一个故事然后告诉我北京的天气怎么样"}],"thinking":{"type": "enabled"}}
2+
{"messages":[{"role":"user","content":"上海的天气怎么样"}],"thinking":{"type": "enabled"}}
3+
{"messages":[{"role":"user","content":"上海和北京的天气怎么样"}],"thinking":{"type": "enabled"}}

veadk/cli/templates/rl/ark/job.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from ark_sdk.resources.model_customization_job import ModelCustomizationJob
16+
from ark_sdk.resources.pipeline_plugin import GRPOPipeline, PipelinePluginWrapper
17+
from ark_sdk.types.model_customization_job import (
18+
ModelReference,
19+
FoundationModelReference,
20+
TrainingDataset,
21+
Data,
22+
)
23+
24+
from plugins.random_reward import random_reward_fn
25+
from plugins.raw_async_veadk_rollout import demo_veadk_rollout
26+
27+
if __name__ == "__main__":
28+
mcj = ModelCustomizationJob(
29+
name="sdk-job",
30+
model_reference=ModelReference(
31+
foundation_model=FoundationModelReference(
32+
name="doubao-seed-1-6-flash", model_version="250615"
33+
)
34+
),
35+
hyperparameters={
36+
"batch_size": "32",
37+
"clip_ratio_high": "0.2",
38+
"clip_ratio_low": "0.2",
39+
"kl_coefficient": "0.001",
40+
"loss_agg_mode": "seq-mean-token-mean",
41+
"lr": "0.000001",
42+
"lr_warmup_steps": "5",
43+
"max_new_tokens": "1024",
44+
"num_generations": "8",
45+
"num_iterations_per_batch": "2",
46+
"save_every_n_steps": "10",
47+
"temperature": "1.0",
48+
"test_every_n_steps": "5",
49+
"test_num_generations": "1",
50+
"test_top_p": "1",
51+
"top_p": "1",
52+
"num_steps": "10",
53+
},
54+
data=Data(
55+
training_set=TrainingDataset(
56+
local_files=[
57+
"./data/mcj_rollout_test_dataset.jsonl",
58+
]
59+
)
60+
),
61+
custom_rl_pipeline=GRPOPipeline(
62+
graders=[
63+
PipelinePluginWrapper(
64+
plugin=random_reward_fn, envs={"foo": "bar"}, weight=0.5
65+
),
66+
],
67+
rollout=PipelinePluginWrapper(
68+
plugin=demo_veadk_rollout, envs={"foo": "bar"}
69+
),
70+
),
71+
)
72+
73+
mcj.submit()
74+
print(f"Job submitted. view job at {mcj.url}")
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
name: sdk-job
2+
customization_type: GRPO
3+
model_reference:
4+
foundation_model:
5+
name: doubao-seed-1-6-flash
6+
model_version: '250615'
7+
hyperparameters:
8+
batch_size: '128'
9+
clip_ratio_high: '0.2'
10+
clip_ratio_low: '0.2'
11+
kl_coefficient: '0.001'
12+
loss_agg_mode: seq-mean-token-mean
13+
lr: '0.000001'
14+
lr_warmup_steps: '5'
15+
max_new_tokens: '1024'
16+
num_generations: '8'
17+
num_iterations_per_batch: '2'
18+
save_every_n_steps: '10'
19+
temperature: '1.0'
20+
test_every_n_steps: '5'
21+
test_num_generations: '1'
22+
test_top_p: '1'
23+
top_p: '1'
24+
num_steps: '20'
25+
custom_rl_pipeline:
26+
graders:
27+
- plugin:
28+
name: random_reward
29+
python_func: plugins.random_reward:random_reward_fn
30+
envs:
31+
foo: bar
32+
weight: 0.5
33+
rollout:
34+
plugin:
35+
name: demo_veadk_rollout
36+
python_func: plugins.async_weather_rollout:demo_rollout
37+
envs:
38+
foo: bar
39+
40+
data:
41+
training_set:
42+
local_files:
43+
- ./data/mcj_rollout_test_dataset.jsonl
44+
save_model_limit: 1
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import random
16+
from typing import List
17+
from ark_sdk.resources.pipeline_plugin import group_grader
18+
from ark_sdk.types.pipeline_plugin.pipeline_plugin import PluginStatus, PluginContext
19+
from ark_sdk.types.pipeline_plugin.rollout import Trajectory, ChatCompletionSample
20+
from ark_sdk.types.pipeline_plugin import (
21+
Runtime,
22+
PluginInstance,
23+
GroupGraderResult,
24+
)
25+
26+
27+
@group_grader(
28+
name="randaom_reward",
29+
runtime=Runtime(
30+
instance=PluginInstance.CPU1MEM2,
31+
max_concurrency=100,
32+
timeout=300,
33+
),
34+
)
35+
def random_reward_fn(
36+
context: PluginContext,
37+
sample: ChatCompletionSample,
38+
trajectories: List[Trajectory],
39+
) -> GroupGraderResult:
40+
"""
41+
奖励函数:返回随机奖励
42+
43+
参数:
44+
- trajectories: 完整的对话历史
45+
- sample: 样本数据,包含标准答案的字典
46+
47+
返回:
48+
- list[float]: 奖励分数列表,每个分数对应一个候选回复(1.0表示完全匹配,0.0表示不匹配)
49+
50+
依赖:
51+
- 数据集里的字典字段 extra 内需要携带 answer 字段。
52+
"""
53+
rewards = [
54+
t.extra["reward"] if (t.extra and "reward" in t.extra) else random.random()
55+
for t in trajectories
56+
]
57+
return GroupGraderResult(
58+
rewards=rewards, status=PluginStatus.SUCCESS, error="", metrics={}
59+
)

0 commit comments

Comments
 (0)