-
Notifications
You must be signed in to change notification settings - Fork 112
Open
Description
def check_reward_nonzero_std(args, samples: list[Sample], **kwargs):
rewards = [sample.get_reward_value(args) for sample in samples]
keep = torch.tensor(rewards, dtype=torch.float).std() > 0.0
return DynamicFilterOutput(
keep=keep,
reason=None if keep else f"zero_std_{round(rewards[0], 1)}",
)
There will be a floating-point precision issue, when the reward is Non-0/1 cases.
For example,
torch.tensor([0.1]*16, dtype=torch.float).std() > 0.0
>tensor(True)
torch.tensor([0.25]*16, dtype=torch.float).std() > 0.0
>tensor(False)
torch.tensor([0.1]*16, dtype=torch.float64).std() > 0.0
>tensor(False)
torch.tensor([0.1]*1024, dtype=torch.float64).std() > 0.0
>tensor(True)
Suggest using higher precision and a small epsilon for comparision
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels