File tree Expand file tree Collapse file tree 4 files changed +5
-5
lines changed
torchtitan/experiments/rl Expand file tree Collapse file tree 4 files changed +5
-5
lines changed Original file line number Diff line number Diff line change @@ -6,7 +6,7 @@ This directory contains code for RL training using TorchTitan model definitions
66The integration consists of the following components:
77
881 . ** vLLM Model Wrapper** (` models/vllm_wrapper.py ` ): Adapts TorchTitan models for vLLM's inference engine
9- 2 . ** RL Training Loop** (` simple_grpo_sum_digits .py` ): GRPO-based RL training with Monarch actors
9+ 2 . ** RL Training Loop** (` tasks/sum_digits/simple_grpo .py` ): GRPO-based RL training with Monarch actors
10103 . ** Inference Script** (` inference_example.py ` ): Standalone inference using the vLLM engine
1111
1212
@@ -57,7 +57,7 @@ torchrun --nproc_per_node=2 torchtitan/experiments/rl/inference_example.py
5757
58586 . Run simple GRPO RL loop to learn sum digits task
5959``` bash
60- python torchtitan/experiments/rl/simple_grpo_sum_digits .py --module rl --config rl_grpo_qwen3_0_6b
60+ python torchtitan/experiments/rl/tasks/sum_digits/simple_grpo .py --module rl --config rl_grpo_qwen3_0_6b
6161```
6262
6363** NOTE:** If you downloaded your HF model to a different path than the one in step 4, specify it in your command with ` --hf_assets_path=<path_to_model_checkpoint> ` .
Original file line number Diff line number Diff line change 2020 VLLMGenerator ,
2121)
2222from torchtitan .experiments .rl .actors .trainer import PolicyTrainer
23- from torchtitan .experiments .rl .simple_grpo_sum_digits import RLTrainer
23+ from torchtitan .experiments .rl .tasks . sum_digits . simple_grpo import RLTrainer
2424from torchtitan .models .qwen3 import model_registry
2525
2626
Original file line number Diff line number Diff line change 1717The architecture mirrors monarch's grpo_actor.py but adapted for vLLM rollouts + TorchTitan training.
1818
1919Command to run:
20- python3 torchtitan/ experiments/rl/simple_grpo_sum_digits .py \
20+ python3 torchtitan. experiments.rl.tasks.sum_digits/simple_grpo .py \
2121 --module rl --config rl_grpo_qwen3_0_6b \
2222 --hf_assets_path=<path_to_model_checkpoint>
2323"""
4040from torchtitan .experiments .rl .actors .generator import VLLMGenerator
4141from torchtitan .experiments .rl .actors .grader import Grader
4242from torchtitan .experiments .rl .actors .trainer import PolicyTrainer
43- from torchtitan .experiments .rl .sum_digits import extract_answer , SumDigitsTask
43+ from torchtitan .experiments .rl .tasks . sum_digits . task import extract_answer , SumDigitsTask
4444from torchtitan .experiments .rl .types import Episode
4545from torchtitan .protocols .model_spec import ModelSpec
4646
File renamed without changes.
You can’t perform that action at this time.
0 commit comments