Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4,783 changes: 2,149 additions & 2,634 deletions poetry.lock

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@ pandas-stubs = "^2.2.0"
msgpack = "^1.0.8"
setuptools = "^74.1.1"
torch = "^2.4.0"
trl = "*"
trl = ">=0.16.0"
peft = ">=0.11.1"
datasets = "2.20.0"
sentencepiece = "*"
bitsandbytes = "*"
wandb = "*"
accelerate = "*"
accelerate = "0.34.2"
deepspeed = "*"

ruff = {version = "*", optional = true}
vllm = {version = "=0.6.2", optional = true}
Expand Down
23 changes: 15 additions & 8 deletions scripts/accelerate_config.yaml
Original file line number Diff line number Diff line change
@@ -1,21 +1,28 @@
compute_environment: LOCAL_MACHINE
debug: true
distributed_type: MULTI_GPU
downcast_bf16: 'no'
debug: false
deepspeed_config:
deepspeed_multinode_launcher: standard
offload_optimizer_device: cpu
offload_param_device: cpu
zero3_init_flag: true
zero3_save_16bit_model: true
zero_stage: 3
communication_data_type: "bfloat16"
dynamo_config:
dynamo_backend: EAGER
dynamo_mode: default
dynamo_use_dynamic: false
dynamo_use_fullgraph: false
enable_cpu_affinity: true
gpu_ids: all
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 3
rdzv_backend: static # Keep this unless running multi-node
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
tpu_use_sudo: false
use_cpu: false
13 changes: 8 additions & 5 deletions scripts/train_ppo.sh
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
CUDA_VISIBLE_DEVICES=1,7,8 accelerate launch \
CUDA_VISIBLE_DEVICES=0,1,2,4 accelerate launch \
--config_file /data/haofeiy2/sotopia-rl/scripts/accelerate_config.yaml \
--main_process_port 29511 \
/data/haofeiy2/sotopia-rl/scripts/train_ppo.py \
--model_name /mnt/data_from_server1/models/Qwen2.5-7B-Instruct \
--model_name /data/models/Qwen2.5-7B-Instruct \
--policy_adapter_path /data/haofeiy2/sotopia-rl/sft_qwen25_7b_sft_round_1_bc_data_top_2/checkpoint-1500 \
--ref_adapter_path /data/haofeiy2/sotopia-rl/sft_qwen25_7b_sft_round_1_bc_data_top_2/checkpoint-1500 \
--reward_adapter_path /data/haofeiy2/sotopia-rl/rm_reward_direct_default_without_that_n_error_as_the_end/checkpoint-4480 \
--value_adapter_path /data/haofeiy2/sotopia-rl/rm_reward_direct_default_without_that_n_error_as_the_end/checkpoint-4480 \
--learning_rate 1e-5 \
--learning_rate 5e-6 \
--max_length 4096 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 1 \
--local_rollout_forward_batch_size 1 \
--num_mini_batches 1 \
--ppo_data_path /data/haofeiy2/sotopia-rl/data/sotopia_pi_round1_qwen_sft_all_with_instruct_string.json \
--ppo_data_path /data/haofeiy2/sotopia-rl/data/sotopia_pi_round1_qwen_sft_pi_with_instruct_string.json \
--template_path /data/haofeiy2/sotopia-rl/evals/qwen2.5-7b.jinja \
--num_ppo_epochs 2 \
--num_train_epochs 5 \
--gamma 0.99 \
--lam 0.95 \
--output_dir /data/haofeiy2/sotopia-rl/ppo_origin_qwen25_7b_reward_direct_default_no_goal_gpt-4o_without_goal_leak_with_sft_self_play_data_use_sotopia_pi_full_data_0408
--output_dir /data/haofeiy2/sotopia-rl/ppo_origin_qwen25_7b_reward_direct_default_no_goal_gpt-4o_without_goal_leak_with_sft_self_play_data_use_sotopia_pi_filtered_data_0411 \
--use_lora_train_ppo
2 changes: 1 addition & 1 deletion scripts/train_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
parser.add_argument("--num_epochs", type=int, default=3, help="Number of training epochs")
parser.add_argument("--sft_data_path", type=str, required=True, help="Path to SFT data")
parser.add_argument("--template_path", type=str, required=True, help="Path to the Jinja template file")
parser.add_argument("--max_length", type=int, default=4096, help="Max sequence length")
parser.add_argument("--max_length", type=int, default=3000, help="Max sequence length")
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay")
parser.add_argument("--evaluation_steps", type=int, default=100, help="Evaluation interval in steps")
parser.add_argument("--accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
Expand Down
37 changes: 12 additions & 25 deletions sotopia_rl/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,8 @@ def __init__(self, args, accelerator: Accelerator):
self._setup_tokenizer()
self._setup_dataset()
self._create_quantization_config()

self._setup_generation_models()
self._setup_classification_models()

self.policy, self.ref_policy, self.reward_model, self.value_model = self.accelerator.prepare(
self.policy, self.ref_policy, self.reward_model, self.value_model
)
self.policy = self.accelerator.unwrap_model(self.policy)
self.ref_policy = self.accelerator.unwrap_model(self.ref_policy)
self.reward_model = self.accelerator.unwrap_model(self.reward_model)
self.value_model = self.accelerator.unwrap_model(self.value_model)

self._setup_ppo_trainer()

def save_model(self, output_dir: str, _internal_call: bool = False):
Expand Down Expand Up @@ -90,19 +80,19 @@ def _setup_dataset(self):
print(f"Dataset split: {len(self.train_dataset)} train, {len(self.val_dataset)} validation")

def _create_quantization_config(self):
self.quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
self.model_dtype = torch.bfloat16
self.bit_quant_config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
)


def _setup_generation_models(self):
base_gen_ref = AutoModelForCausalLM.from_pretrained(
self.args.model_name,
torch_dtype='auto',
torch_dtype=self.model_dtype,
quantization_config=self.quant_config,
device_map=get_kbit_device_map(),
)
self.ref_policy = PeftModelForCausalLM.from_pretrained(
base_gen_ref,
Expand All @@ -114,9 +104,8 @@ def _setup_generation_models(self):
if self.args.use_lora_train_ppo:
base_gen_policy = AutoModelForCausalLM.from_pretrained(
self.args.model_name,
torch_dtype='auto',
torch_dtype=self.model_dtype,
quantization_config=self.quant_config,
device_map=get_kbit_device_map(),
)
self.policy = PeftModelForCausalLM.from_pretrained(
base_gen_policy,
Expand All @@ -127,7 +116,7 @@ def _setup_generation_models(self):
else:
self.policy = AutoModelForCausalLM.from_pretrained(
self.args.model_name,
torch_dtype='auto',
torch_dtype=self.model_dtype,
)

requires_grad_num = 0
Expand All @@ -146,10 +135,9 @@ def _setup_generation_models(self):
def _setup_classification_models(self):
base_reward_model = AutoModelForSequenceClassification.from_pretrained(
self.args.model_name,
torch_dtype='auto',
torch_dtype=self.model_dtype,
num_labels=1,
quantization_config=self.quant_config,
device_map=get_kbit_device_map(),
)
self.reward_model = PeftModelForSequenceClassification.from_pretrained(
base_reward_model,
Expand All @@ -164,10 +152,9 @@ def _setup_classification_models(self):
if self.args.use_lora_train_ppo:
base_value_model = AutoModelForSequenceClassification.from_pretrained(
self.args.model_name,
torch_dtype='auto',
torch_dtype=self.model_dtype,
num_labels=1,
quantization_config=self.quant_config,
device_map=get_kbit_device_map(),
)
self.value_model = PeftModelForSequenceClassification.from_pretrained(
base_value_model,
Expand All @@ -178,7 +165,7 @@ def _setup_classification_models(self):
else:
self.value_model = AutoModelForSequenceClassification.from_pretrained(
self.args.model_name,
torch_dtype='auto',
torch_dtype=self.model_dtype,
num_labels=1,
)

Expand Down
1 change: 0 additions & 1 deletion sotopia_rl/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import wandb
from sotopia_rl.data import SFTDataset


os.environ['NCCL_P2P_DISABLE'] = '1'

class SotopiaSFTTrainer:
Expand Down
Loading