Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 60 additions & 26 deletions scripts/inference_rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,36 +12,55 @@ def parse_args():
parser = argparse.ArgumentParser(
description="Test a reward model with a template and example data"
)
parser.add_argument("--model_path", type=str, required=True, help="Path to base model or HF model name")
parser.add_argument("--adapter_path", type=str, required=True, help="Path to saved checkpoint directory")
parser.add_argument("--template_path", type=str, required=True, help="Path to Jinja template file")
parser.add_argument("--example_path", type=str, required=True, help="Path to example data JSON")
parser.add_argument(
"--model_path",
type=str,
required=True,
help="Path to base model or HF model name",
)
parser.add_argument(
"--adapter_path",
type=str,
required=True,
help="Path to saved checkpoint directory",
)
parser.add_argument(
"--template_path", type=str, required=True, help="Path to Jinja template file"
)
parser.add_argument(
"--example_path", type=str, required=True, help="Path to example data JSON"
)
return parser.parse_args()


def load_model_and_tokenizer(args):
print(f"Loading base model: {args.model_path}")
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
print("Using full precision model")
base_model = AutoModelForSequenceClassification.from_pretrained(
args.model_path,
torch_dtype=torch.float32, #important
model = AutoModelForSequenceClassification.from_pretrained(
args.adapter_path,
torch_dtype=torch.float32, # important
device_map="auto",
num_labels=1, # For regression task
pad_token_id=tokenizer.pad_token_id # very important to add this
pad_token_id=tokenizer.pad_token_id, # very important to add this
)

adapter_path = os.path.join(args.adapter_path, 'adapter_model')
if os.path.exists(adapter_path + '.safetensors') or os.path.exists(adapter_path + '.bin'):
print(f"Loading adapter from: {args.adapter_path}")
model = PeftModelForSequenceClassification.from_pretrained(base_model, args.adapter_path)
else:
print(f"No adapter found at {adapter_path}, using base model")
model = base_model
def print_named_parameters(model, keyword="score"):
for name, param in model.named_parameters():
if keyword in name:
print(
f"{name}: mean={param.data.mean():.4f}, std={param.data.std():.4f}"
)
else:
print("did not load score_weights")

print_named_parameters(model, keyword="score")

model.eval()

return model, tokenizer


def load_template(template_path):
template_dir = os.path.dirname(template_path)
template_file = os.path.basename(template_path)
Expand All @@ -50,29 +69,42 @@ def load_template(template_path):
template_dir = "."

env = Environment(loader=FileSystemLoader(template_dir))
env.filters['tojson'] = lambda obj: json.dumps(obj)
env.filters["tojson"] = lambda obj: json.dumps(obj)
return env.get_template(template_file)

def evaluate_prompt(model, tokenizer, prompt):
inputs = tokenizer(prompt, return_tensors="pt", truncation=True)

def evaluate_prompt(model, tokenizer, prompt, index=None):
print(f"\n[DEBUG] Prompt [{index}]:")
print(prompt)
print("[DEBUG] Decoded Input IDs:")
encoded = tokenizer(prompt, return_tensors="pt", truncation=True)
print(tokenizer.decode(encoded["input_ids"][0], skip_special_tokens=False))

# Check input length
print(f"[DEBUG] Input length: {encoded['input_ids'].shape[-1]} tokens")
# inputs = tokenizer(prompt, return_tensors="pt", truncation=True)

device = next(model.parameters()).device

inputs = {k: v.to(device) for k, v in inputs.items()}
inputs = {k: v.to(device) for k, v in encoded.items()}

with torch.no_grad():
outputs = model(**inputs)

# Get reward score directly from the logits
reward = outputs.logits.squeeze().cpu().item()
# reward = outputs.logits.squeeze().cpu().item()
logits = outputs.logits.squeeze()
print(f"[DEBUG] Raw logits: {logits}")
reward = logits.cpu().item()
return reward


def main():
args = parse_args()

model, tokenizer = load_model_and_tokenizer(args)

with open(args.example_path, 'r') as f:
with open(args.example_path, "r") as f:
example_data = json.load(f)

template = load_template(args.template_path)
Expand All @@ -81,20 +113,22 @@ def main():

rendered_prompt = template.render(
messages=[
{"role": "user", "content": example['input']},
{"role": "assistant", "content": example['output']},
{"role": "user", "content": example["input"]},
{"role": "assistant", "content": example["output"]},
],
add_generation_prompt=False
add_generation_prompt=False,
)

reward = evaluate_prompt(model, tokenizer, rendered_prompt)
gth_reward = example.get('value')
# reward = evaluate_prompt(model, tokenizer, rendered_prompt)
reward = evaluate_prompt(model, tokenizer, rendered_prompt, index=i + 1)
gth_reward = example.get("value")

print(f"REWARD SCORE: {reward:.6f}")
if gth_reward is not None:
print(f"GTH REWARD: {gth_reward:.6f}")
else:
print("GTH REWARD: Not available")


if __name__ == "__main__":
main()
133 changes: 97 additions & 36 deletions sotopia_rl/rm_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
import torch.distributed as dist
from jinja2 import Environment, FileSystemLoader
from peft import LoraConfig, get_peft_model
from peft import LoraConfig, get_peft_model, PeftModelForSequenceClassification
from torch.nn import MSELoss
from torch.utils.data import random_split
from transformers import (
Expand All @@ -13,13 +13,15 @@
TrainingArguments,
)
from accelerate import Accelerator
from typing import Optional

import wandb
from sotopia_rl.data import RMDataset
import torch._dynamo

torch._dynamo.config.suppress_errors = True

os.environ['NCCL_P2P_DISABLE'] = '1'
os.environ["NCCL_P2P_DISABLE"] = "1"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"


Expand All @@ -33,30 +35,44 @@ def __init__(self, args, accelerator, **kwargs):
train_dataset, eval_dataset = self.setup_dataset(tokenizer)

# Initialize wandb only on the main process
if self.accelerator.is_main_process:
wandb.init(
project=args.wandb_project,
name=args.wandb_run_name,
config={k: v for k, v in vars(args).items() if isinstance(v, (int, float, str))}
)
# if self.accelerator.is_main_process:
wandb.init(
project=args.wandb_project,
name=args.wandb_run_name,
config={
k: v for k, v in vars(args).items() if isinstance(v, (int, float, str))
},
)

peft_config = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
target_modules=args.target_modules.split(",")
target_modules=args.target_modules.split(","),
)

base_model = AutoModelForSequenceClassification.from_pretrained(
args.model_name,
num_labels=1,
torch_dtype='auto',
torch_dtype="auto",
)

model = get_peft_model(base_model, peft_config)
model.config.pad_token_id = tokenizer.pad_token_id # important to set the config pad_token_id

model = self.accelerator.prepare_model(model)
# self.model = get_peft_model(base_model, peft_config)
self.model = PeftModelForSequenceClassification(base_model, peft_config)
param_check = self.model.base_model.model.score.weight
print(
f"whether score.weight is trainable: {param_check.requires_grad}, shape: {param_check.shape}"
)
print(f"mean={param_check.data.mean():.4f}, std={param_check.data.std():.4f}")
count = 0
for name, param in self.model.named_parameters():
if param.requires_grad:
count += 1
print(f"{name} shape={param.shape}")
print(f"Total trainable parameters: {count}")
self.model.config.pad_token_id = (
tokenizer.pad_token_id
) # important to set the config pad_token_id

# Set up the TrainingArguments with DeepSpeed support
training_args = TrainingArguments(
Expand All @@ -66,6 +82,7 @@ def __init__(self, args, accelerator, **kwargs):
num_train_epochs=args.num_epochs,
logging_steps=1,
save_steps=args.evaluation_steps,
save_strategy="steps",
eval_steps=args.evaluation_steps,
logging_dir="./logs",
gradient_accumulation_steps=args.accumulation_steps,
Expand All @@ -78,26 +95,49 @@ def __init__(self, args, accelerator, **kwargs):
ddp_find_unused_parameters=False,
)

collate_fn = train_dataset.dataset.collate_fn if hasattr(train_dataset, 'dataset') else None
print("Trainable parameters:")

collate_fn = (
train_dataset.dataset.collate_fn
if hasattr(train_dataset, "dataset")
else None
)

super().__init__(
model=model,
model=self.model,
args=training_args,
processing_class=tokenizer,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=collate_fn,
**kwargs
**kwargs,
)
self.loss_fn = MSELoss()

if args.checkpoint_path:
self.load_lora_checkpoint(args.checkpoint_path)
self.model.base_model.model.score.weight.register_hook(
lambda grad: print(
f"[HOOK] Gradient norm for score.weight: {grad.norm().item()}"
)
)
self.model.base_model.model.model.layers[
3
].self_attn.v_proj.lora_B.default.weight.register_hook(
lambda grad: print(
f"[HOOK] base_model.model.model.layers.3.self_attn.v_proj.lora_B.default.weight: {grad.norm().item()}"
)
)
for name, param in self.model.named_parameters():
if param.requires_grad:
print(f"{name} shape={param.shape}")

def setup_dataset(self, tokenizer):
env = Environment(loader=FileSystemLoader("/".join(self.args.template_path.split("/")[:-1])))
env = Environment(
loader=FileSystemLoader("/".join(self.args.template_path.split("/")[:-1]))
)
template = env.get_template(self.args.template_path.split("/")[-1])
dataset = RMDataset(self.args.reward_data_path, tokenizer, template, self.args.max_length)
dataset = RMDataset(
self.args.reward_data_path, tokenizer, template, self.args.max_length
)

if self.accelerator.is_main_process:
print(f"dataset: {len(dataset)}")
Expand All @@ -107,7 +147,9 @@ def setup_dataset(self, tokenizer):

# Use deterministic splitter with seed to ensure same split across processes
generator = torch.Generator().manual_seed(42)
train_dataset, val_dataset = random_split(dataset, [train_size, val_size], generator=generator)
train_dataset, val_dataset = random_split(
dataset, [train_size, val_size], generator=generator
)
return train_dataset, val_dataset

def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
Expand All @@ -117,21 +159,40 @@ def compute_loss(self, model, inputs, return_outputs=False, **kwargs):

outputs = model(input_ids, attention_mask=attention_masks)
predicted_rewards = outputs.logits.squeeze(-1) # Shape: (batch_size,)
print(">>> Predicted:", predicted_rewards.detach().cpu().numpy())
print(">>> GroundTruth:", true_rewards.detach().cpu().numpy())
loss = self.loss_fn(predicted_rewards, true_rewards)

return (loss, outputs) if return_outputs else loss

def save_lora_checkpoint(self, output_dir=None, **kwargs):
if self.accelerator.is_main_process:
self.model.save_pretrained(output_dir)

def load_lora_checkpoint(self, checkpoint_path):
adapter_model_path = os.path.join(checkpoint_path, 'adapter_model.safetensors')
peft_config = LoraConfig.from_pretrained(checkpoint_path)

if os.path.exists(adapter_model_path):
self.model.load_adapter(checkpoint_path, adapter_name='lora', peft_config=peft_config)
else:
if self.accelerator.is_main_process:
print(f"No adapter model found at {adapter_model_path}.")

def save_model(
self, output_dir: Optional[str] = None, _internal_call: bool = False
):
"""
Override the default save_model to save merged full model (LoRA merged + score.weight).
"""
print("[Custom save_model] Called.")
if output_dir is None:
output_dir = self.args.output_dir

if not self.accelerator.is_main_process:
return
print("[Custom save_model] Saving model...")

os.makedirs(output_dir, exist_ok=True)

try:
from copy import deepcopy

param_check = self.model.base_model.model.score.weight
print(
f"mean={param_check.data.mean():.4f}, std={param_check.data.std():.4f}"
)
model_copy = deepcopy(self.model)
merged_model = model_copy.merge_and_unload()
merged_model.save_pretrained(output_dir)
if hasattr(self, "tokenizer"):
self.tokenizer.save_pretrained(output_dir)
print(f"[Custom save_model] Full model saved to: {output_dir}")
except Exception as e:
print(f"[Custom save_model] Failed to merge and save full model: {str(e)}")
Loading