Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions verl/workers/actor/dp_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,10 @@ def update_policy(self, data: DataProto):

on_policy = len(mini_batches) == 1 and self.config.ppo_epochs == 1

metrics = {}
metrics = {
"actor/pg_loss": 0.0,
"actor/kl_loss": 0.0,
}
for _ in range(self.config.ppo_epochs):
for batch_idx, mini_batch in enumerate(mini_batches):
if self.config.use_dynamic_bsz:
Expand Down Expand Up @@ -530,7 +533,7 @@ def update_policy(self, data: DataProto):
kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)

policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef
micro_batch_metrics["actor/kl_loss"] = kl_loss.detach().item() * loss_scale_factor
metrics["actor/kl_loss"] += kl_loss.detach().item() * loss_scale_factor
micro_batch_metrics["actor/kl_coef"] = self.config.kl_loss_coef

if self.config.use_dynamic_bsz:
Expand All @@ -543,7 +546,7 @@ def update_policy(self, data: DataProto):
else:
loss.backward()

micro_batch_metrics["actor/pg_loss"] = pg_loss.detach().item() * loss_scale_factor
metrics["actor/pg_loss"] += pg_loss.detach().item() * loss_scale_factor
append_to_dict(metrics, micro_batch_metrics)

grad_norm = self._optimizer_step()
Expand Down
6 changes: 4 additions & 2 deletions verl/workers/critic/dp_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,9 @@ def compute_values(self, data: DataProto) -> torch.Tensor:
def update_critic(self, data: DataProto):
# make sure we are in training mode
self.critic_module.train()
metrics = {}
metrics = {
"critic/vf_loss": 0.0,
}

select_keys = ["input_ids", "responses", "response_mask", "attention_mask", "position_ids", "values", "returns"]
has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys()
Expand Down Expand Up @@ -246,12 +248,12 @@ def update_critic(self, data: DataProto):

micro_batch_metrics.update(
{
"critic/vf_loss": vf_loss.detach().item() * loss_scale_factor,
"critic/vf_clipfrac": vf_clipfrac.detach().item(),
"critic/vpred_mean": masked_mean(vpreds, response_mask).detach().item(),
}
)

metrics["critic/vf_loss"] += vf_loss.detach().item() * loss_scale_factor
append_to_dict(metrics, micro_batch_metrics)

grad_norm = self._optimizer_step()
Expand Down
Loading