Skip to content

Commit 0675e19

Browse files
FirerozestAnGjIa520
authored andcommitted
polish(fir): polish softmax (opendilab#394)
1 parent 35972a3 commit 0675e19

File tree

1 file changed

+2
-15
lines changed

1 file changed

+2
-15
lines changed

lzero/policy/scaling_transform.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,19 +30,6 @@ def scalar_transform(x: torch.Tensor, epsilon: float = 0.001, delta: float = 1.)
3030
return output
3131

3232

33-
def ensure_softmax(logits, dim=1):
34-
"""
35-
Overview:
36-
Ensure that the input tensor is normalized along the specified dimension.
37-
Arguments:
38-
- logits (:obj:`torch.Tensor`): The input tensor.
39-
- dim (:obj:`int`): The dimension along which to normalize the input tensor.
40-
Returns:
41-
- output (:obj:`torch.Tensor`): The normalized tensor.
42-
"""
43-
return torch.softmax(logits, dim=dim)
44-
45-
4633
def inverse_scalar_transform(
4734
logits: torch.Tensor,
4835
scalar_support: DiscreteSupport,
@@ -58,7 +45,7 @@ def inverse_scalar_transform(
5845
- https://arxiv.org/pdf/1805.11593.pdf Appendix A: Proposition A.2
5946
"""
6047
if categorical_distribution:
61-
value_probs = ensure_softmax(logits, dim=1)
48+
value_probs = torch.softmax(logits, dim=1)
6249
value_support = scalar_support.arange
6350

6451
value_support = value_support.to(device=value_probs.device)
@@ -94,7 +81,7 @@ def __init__(
9481

9582
def __call__(self, logits: torch.Tensor, epsilon: float = 0.001) -> torch.Tensor:
9683
if self.categorical_distribution:
97-
value_probs = ensure_softmax(logits, dim=1)
84+
value_probs = torch.softmax(logits, dim=1)
9885
value = value_probs.mul_(self.value_support).sum(1, keepdim=True)
9986
else:
10087
value = logits

0 commit comments

Comments
 (0)