Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion evals/redis_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,4 +129,4 @@ def analyze_episodes_with_positions(tag):
}

# Run the analysis
results = analyze_episodes_with_positions("Qwen2.5-7B-Instruct_vs_sft_qwen25_7b_bigtom_step_1500-bigtom_0402")
results = analyze_episodes_with_positions("grpo_direct_step_400_vs_sft_qwen25_7b_sft_round_1_bc_data_top_2_step_1500-0420")
7 changes: 1 addition & 6 deletions scripts/accelerate_config_ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,13 @@ compute_environment: LOCAL_MACHINE
debug: true
distributed_type: MULTI_GPU
downcast_bf16: 'no'
dynamo_config:
dynamo_backend: EAGER
dynamo_mode: default
dynamo_use_dynamic: false
dynamo_use_fullgraph: false
enable_cpu_affinity: true
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 6
num_processes: 8
rdzv_backend: static # Keep this unless running multi-node
same_network: true
tpu_env: []
Expand Down
2 changes: 1 addition & 1 deletion scripts/accelerate_config_rm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 5
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
Expand Down
4 changes: 2 additions & 2 deletions scripts/inference_rm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ CUDA_VISIBLE_DEVICES=9 python inference_rm.py \
--example_path "/data/haofeiy2/sotopia-rl/data/sotopia_pi_gpt4_rm_overfit.json"


CUDA_VISIBLE_DEVICES=8 python inference_rm.py \
CUDA_VISIBLE_DEVICES=5 python inference_rm.py \
--model_path "/mnt/data_from_server1/models/Qwen2.5-7B-Instruct" \
--adapter_path "/data/haofeiy2/sotopia-rl/rm_overfit_test/checkpoint-100" \
--adapter_path "/data/haofeiy2/sotopia-rl/rm_token_length/checkpoint-800" \
--template_path "/data/haofeiy2/sotopia-rl/evals/qwen2.5-7b.jinja" \
--example_path "/data/haofeiy2/sotopia-rl/data/sotopia_pi_gpt4_rm_overfit.json"
5 changes: 2 additions & 3 deletions scripts/train_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
help="Number of PPO epochs per update")
parser.add_argument("--learning_rate", type=float, default=5e-6,
help="Learning rate for optimizer")
parser.add_argument("--gamma", type=float, default=0.99,
parser.add_argument("--gamma", type=float, default=1.0,
help="Discount factor")
parser.add_argument("--lam", type=float, default=0.95,
help="GAE lambda for advantage estimation")
Expand Down Expand Up @@ -69,6 +69,5 @@
help="Use LoRA for training PPO")

args = parser.parse_args()
accelerator = Accelerator()
trainer = SotopiaPPOTrainer(args, accelerator)
trainer = SotopiaPPOTrainer(args)
trainer.train()
37 changes: 19 additions & 18 deletions scripts/train_ppo.sh
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
CUDA_VISIBLE_DEVICES=1,7,8 accelerate launch \
--config_file /data/haofeiy2/sotopia-rl/scripts/accelerate_config.yaml \
--main_process_port 29511 \
/data/haofeiy2/sotopia-rl/scripts/train_ppo.py \
--model_name /mnt/data_from_server1/models/Qwen2.5-7B-Instruct \
--policy_adapter_path /data/haofeiy2/sotopia-rl/sft_qwen25_7b_sft_round_1_bc_data_top_2/checkpoint-1500 \
--ref_adapter_path /data/haofeiy2/sotopia-rl/sft_qwen25_7b_sft_round_1_bc_data_top_2/checkpoint-1500 \
--reward_adapter_path /data/haofeiy2/sotopia-rl/rm_reward_direct_default_without_that_n_error_as_the_end/checkpoint-4480 \
--value_adapter_path /data/haofeiy2/sotopia-rl/rm_reward_direct_default_without_that_n_error_as_the_end/checkpoint-4480 \
--learning_rate 1e-5 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch \
--config_file /data/disk0/sotopia-rl/scripts/accelerate_config_ppo.yaml \
--main_process_port 29529 \
/data/disk0/sotopia-rl/scripts/train_ppo.py \
--model_name /data/disk0/models/Qwen2.5-7B-Instruct \
--policy_adapter_path /data/disk0/sotopia-rl/sft_qwen25_7b_sft_round_1_bc_data_top_2/checkpoint-1500 \
--ref_adapter_path /data/disk0/sotopia-rl/sft_qwen25_7b_sft_round_1_bc_data_top_2/checkpoint-1500 \
--reward_adapter_path /data/disk0/sotopia-rl/rm_token_length_normalized/checkpoint-350 \
--value_adapter_path /data/disk0/sotopia-rl/ppo_token_length_normalized_checkpoint_510_value_adapter \
--learning_rate 5e-5 \
--per_device_train_batch_size 6 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 1 \
--num_mini_batches 1 \
--ppo_data_path /data/haofeiy2/sotopia-rl/data/sotopia_pi_round1_qwen_sft_all_with_instruct_string.json \
--template_path /data/haofeiy2/sotopia-rl/evals/qwen2.5-7b.jinja \
--num_ppo_epochs 2 \
--ppo_data_path /data/disk0/sotopia-rl/data/sotopia_pi_round1_qwen_sft_all_with_instruct_string.json \
--template_path /data/disk0/sotopia-rl/evals/qwen2.5-7b.jinja \
--num_train_epochs 5 \
--gamma 0.99 \
--lam 0.95 \
--output_dir /data/haofeiy2/sotopia-rl/ppo_origin_qwen25_7b_reward_direct_default_no_goal_gpt-4o_without_goal_leak_with_sft_self_play_data_use_sotopia_pi_full_data_0408
--max_length 4096 \
--num_ppo_epochs 2 \
--gamma 1.00 \
--use_lora_train_ppo \
--output_dir /data/disk0/sotopia-rl/ppo_token_length_normalized
18 changes: 9 additions & 9 deletions scripts/train_rm.sh
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
CUDA_VISIBLE_DEVICES=5,6,7,8,9 accelerate launch \
--config_file /data/haofeiy2/sotopia-rl/scripts/accelerate_config_rm.yaml \
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch \
--config_file /data/disk0/sotopia-rl/scripts/accelerate_config_rm.yaml \
--main_process_port 29500 \
/data/haofeiy2/sotopia-rl/scripts/train_rm.py \
--model_name /mnt/data_from_server1/models/Qwen2.5-7B-Instruct \
/data/disk0/sotopia-rl/scripts/train_rm.py \
--model_name /data/disk0/models/Qwen2.5-7B-Instruct \
--learning_rate 1e-5 \
--max_length 4096 \
--train_batch_size 1 \
--train_batch_size 4 \
--val_batch_size 1 \
--accumulation_steps 8 \
--accumulation_steps 2 \
--num_epochs 30 \
--evaluation_steps 50 \
--reward_data_path /data/haofeiy2/sotopia-rl/data/sotopia_pi_bc_episodes_reward_token_length.json \
--template_path /data/haofeiy2/sotopia-rl/evals/qwen2.5-7b.jinja \
--checkpoint_dir /data/haofeiy2/sotopia-rl/rm_token_length
--reward_data_path /data/disk0/sotopia-rl/data/sotopia_pi_bc_episodes_reward_token_length_binary.json \
--template_path /data/disk0/sotopia-rl/evals/qwen2.5-7b.jinja \
--checkpoint_dir /data/disk0/sotopia-rl/rm_token_length_binary

CUDA_VISIBLE_DEVICES=5,6,7,8,9 accelerate launch \
--config_file /data/haofeiy2/sotopia-rl/scripts/accelerate_config_rm.yaml \
Expand Down
164 changes: 164 additions & 0 deletions sotopia_rl/grpo_trainer_math.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import os
import torch
import wandb
from datasets import load_dataset
from torch.utils.data import random_split
from transformers import (
AutoModelForCausalLM,
AutoModelForSequenceClassification,
AutoTokenizer,
BitsAndBytesConfig,
GenerationConfig,
)
from accelerate import PartialState
from peft import PeftModelForCausalLM, PeftModelForSequenceClassification
from jinja2 import Environment, FileSystemLoader
from trl import get_kbit_device_map, GRPOConfig, GRPOTrainer
from accelerate import Accelerator
from sotopia_rl.data import GRPODataset
from functools import partial
from typing import List

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ['NCCL_P2P_DISABLE'] = '1'
os.environ['TORCH_DISTRIBUTED_DEBUG'] = 'DETAIL'

SIMPLE_CHAT_TEMPLATE = "{% for message in messages %}{{message['role'].capitalize() + ': ' + message['content'] + '\n\n'}}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"


from transformers import GPTNeoXForCausalLM

class PatchedGPTNeoXForCausalLM(GPTNeoXForCausalLM):
def forward(self, *args, logits_to_keep=None, **kwargs):
return super().forward(*args, **kwargs)

class SotopiaGRPOTrainer:
def __init__(self, args, accelerator: Accelerator):
self.args = args
self.accelerator = accelerator

self._init_wandb()
self._setup_tokenizer()
self._setup_dataset()
self._create_quantization_config()

self._setup_grpo_trainer()

def save_model(self, output_dir: str, _internal_call: bool = False):
self.model.save_pretrained(output_dir)
self.tokenizer.save_pretrained(output_dir)
print(f"Saved PEFT model to {output_dir}")

self.grpo_trainer.save_model = save_model.__get__(self.grpo_trainer, type(self.grpo_trainer))

def _init_wandb(self):
wandb.init(
project=self.args.wandb_project,
name=self.args.wandb_run_name,
config={k: v for k, v in vars(self.args).items() if isinstance(v, (int, float, str))}
)

def _setup_tokenizer(self):
self.tokenizer = AutoTokenizer.from_pretrained("/data/disk0/models/EleutherAI_pythia-1b-deduped__sft__tldr")
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
self.tokenizer.pad_token_id = self.tokenizer.convert_tokens_to_ids('[PAD]')
if self.tokenizer.chat_template is None:
self.tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE


def _setup_dataset(self):
from datasets import load_dataset

dataset = load_dataset("trl-internal-testing/tldr-preference-sft-trl-style")
print("processing")
train_dataset = dataset["train"]
eval_dataset = dataset["validation"]

def prepare_dataset(dataset, tokenizer):
def tokenize(element):
input_ids = tokenizer.apply_chat_template(
element["messages"][:1],
padding=False,
add_generation_prompt=True,
)
return {"input_ids": input_ids, "lengths": len(input_ids), "prompt": element["messages"][:1]}

return dataset.map(
tokenize,
remove_columns=dataset.column_names,
num_proc=4,
)

with PartialState().local_main_process_first():
train_dataset = prepare_dataset(train_dataset, self.tokenizer)
if eval_dataset is not None:
eval_dataset = prepare_dataset(eval_dataset, self.tokenizer)
train_dataset = train_dataset.filter(lambda x: x["lengths"] <= 512, num_proc=4)
if eval_dataset is not None:
eval_dataset = eval_dataset.filter(lambda x: x["lengths"] <= 512, num_proc=4)

assert train_dataset[0]["input_ids"][-1] != self.tokenizer.eos_token_id, "The last token should not be an EOS token"

self.train_dataset = train_dataset
self.val_dataset = eval_dataset
print(f"Dataset loaded and processed: {len(self.train_dataset)} train, {len(self.val_dataset or [])} validation")

def _create_quantization_config(self):
self.quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)

def _setup_grpo_trainer(self):
num_processes = self.accelerator.num_processes
global_batch_size = self.args.per_device_train_batch_size * num_processes

num_generations = 4 # manually chosen value
print(f"Using num_generations = {num_generations} (global_batch_size = {global_batch_size})")

policy_model = AutoModelForCausalLM.from_pretrained(
"/data/disk0/models/EleutherAI_pythia-1b-deduped__sft__tldr",
torch_dtype='auto',
num_labels=1,
)

reward_model = AutoModelForSequenceClassification.from_pretrained(
"/data/disk0/models/EleutherAI_pythia-1b-deduped__reward__tldr",
torch_dtype='auto',
num_labels=1,
)

training_args = GRPOConfig(
logging_steps = 1,
report_to = "wandb",
per_device_train_batch_size=self.args.per_device_train_batch_size,
per_device_eval_batch_size=self.args.per_device_eval_batch_size,
gradient_accumulation_steps=self.args.gradient_accumulation_steps,
num_train_epochs=self.args.num_train_epochs,
learning_rate=self.args.learning_rate,
output_dir=self.args.output_dir,
save_steps=self.args.save_steps,
num_generations=num_generations
)

self.grpo_trainer = GRPOTrainer(
args=training_args,
model=policy_model,
reward_funcs=reward_model,
processing_class=self.tokenizer,
reward_processing_classes=self.tokenizer,
train_dataset=self.train_dataset,
eval_dataset=self.val_dataset,
)
print("GRPOtrainer setup complete")

def train(self):
try:
print("Starting GRPO training...")
train_stats = self.grpo_trainer.train()
return train_stats
except Exception as e:
print(f"Training error: {str(e)}")
raise
Loading
Loading