1010SUPPORTED_FLOAT_DTYPES = {torch .float32 , torch .float64 }
1111
1212
13- @pytest .mark .parametrize ('dtype,device' , product (SUPPORTED_FLOAT_DTYPES , devices ))
13+ @pytest .mark .parametrize ('dtype,device' ,
14+ product (SUPPORTED_FLOAT_DTYPES , devices ))
1415def test_log_softmax (dtype , device ):
15- src = tensor ([0.25 , 0 , 0.25 , - 2.1 , 3.2 , 7 , - 1 , float ('-inf' )], dtype , device )
16+ src = tensor ([0.25 , 0 , 0.25 , - 2.1 , 3.2 , 7 , - 1 , float ('-inf' )],
17+ dtype , device )
1618 index = tensor ([0 , 1 , 0 , 1 , 1 , 2 , 4 , 4 ], torch .long , device )
1719
1820 out = scatter_log_softmax (src , index )
1921
2022 # Expected results per index
2123 idx0 = [np .log (0.5 ), np .log (0.5 )]
22- idx1 = torch .log_softmax (torch .tensor ([0.0 , - 2.1 , 3.2 ], dtype = dtype ), dim = - 1 ).tolist ()
24+ idx1 = torch .log_softmax (
25+ torch .tensor ([0.0 , - 2.1 , 3.2 ], dtype = dtype ),
26+ dim = - 1 ).tolist ()
2327 idx2 = 0.0 # Single element, has logprob=0
2428 # index=3 is empty. Should not matter.
2529 idx4 = [0.0 , float ('-inf' )] # log_softmax with -inf preserves the -inf
@@ -31,16 +35,20 @@ def test_log_softmax(dtype, device):
3135 )
3236
3337
34- @pytest .mark .parametrize ('dtype,device' , product (SUPPORTED_FLOAT_DTYPES , devices ))
38+ @pytest .mark .parametrize ('dtype,device' ,
39+ product (SUPPORTED_FLOAT_DTYPES , devices ))
3540def test_softmax (dtype , device ):
36- src = tensor ([0.25 , 0 , 0.25 , - 2.1 , 3.2 , 7 , - 1 , float ('-inf' )], dtype , device )
41+ src = tensor ([0.25 , 0 , 0.25 , - 2.1 , 3.2 , 7 , - 1 , float ('-inf' )],
42+ dtype , device )
3743 index = tensor ([0 , 1 , 0 , 1 , 1 , 2 , 4 , 4 ], torch .long , device )
3844
3945 out = scatter_softmax (src , index )
4046
4147 # Expected results per index
4248 idx0 = [0.5 , 0.5 ]
43- idx1 = torch .softmax (torch .tensor ([0.0 , - 2.1 , 3.2 ], dtype = dtype ), dim = - 1 ).tolist ()
49+ idx1 = torch .softmax (
50+ torch .tensor ([0.0 , - 2.1 , 3.2 ], dtype = dtype ),
51+ dim = - 1 ).tolist ()
4452 idx2 = 1 # Single element, has prob=1
4553 # index=3 is empty. Should not matter.
4654 idx4 = [1.0 , 0.0 ] # softmax with -inf yields zero probability
0 commit comments