Skip to content

Commit a02a9d4

Browse files
committed
Fix returns_to_go shape in CalQL
1 parent 4e7bba7 commit a02a9d4

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

d3rlpy/algos/qlearning/torch/cal_ql_impl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,4 @@ def _compute_policy_is_values(
1818
value_obs=value_obs,
1919
returns_to_go=returns_to_go,
2020
)
21-
return torch.maximum(values, returns_to_go), log_probs
21+
return torch.maximum(values, returns_to_go.view(1, -1, 1)), log_probs

0 commit comments

Comments
 (0)