@@ -30,19 +30,6 @@ def scalar_transform(x: torch.Tensor, epsilon: float = 0.001, delta: float = 1.)
3030 return output
3131
3232
33- def ensure_softmax (logits , dim = 1 ):
34- """
35- Overview:
36- Ensure that the input tensor is normalized along the specified dimension.
37- Arguments:
38- - logits (:obj:`torch.Tensor`): The input tensor.
39- - dim (:obj:`int`): The dimension along which to normalize the input tensor.
40- Returns:
41- - output (:obj:`torch.Tensor`): The normalized tensor.
42- """
43- return torch .softmax (logits , dim = dim )
44-
45-
4633def inverse_scalar_transform (
4734 logits : torch .Tensor ,
4835 scalar_support : DiscreteSupport ,
@@ -58,7 +45,7 @@ def inverse_scalar_transform(
5845 - https://arxiv.org/pdf/1805.11593.pdf Appendix A: Proposition A.2
5946 """
6047 if categorical_distribution :
61- value_probs = ensure_softmax (logits , dim = 1 )
48+ value_probs = torch . softmax (logits , dim = 1 )
6249 value_support = scalar_support .arange
6350
6451 value_support = value_support .to (device = value_probs .device )
@@ -94,7 +81,7 @@ def __init__(
9481
9582 def __call__ (self , logits : torch .Tensor , epsilon : float = 0.001 ) -> torch .Tensor :
9683 if self .categorical_distribution :
97- value_probs = ensure_softmax (logits , dim = 1 )
84+ value_probs = torch . softmax (logits , dim = 1 )
9885 value = value_probs .mul_ (self .value_support ).sum (1 , keepdim = True )
9986 else :
10087 value = logits
0 commit comments