@@ -14,24 +14,24 @@ def test_simple_gpu(self):
1414 a = torch .tensor ([[[0 , 0 , 0 ], [1 , 0 , 0 ], [2 , 0 , 0 ]], [[0 , 0 , 0 ], [1 , 0 , 0 ], [2 , 0 , 0 ]]]).to (torch .float ).cuda ()
1515 b = torch .tensor ([[[0 , 0 , 0 ]], [[3 , 0 , 0 ]]]).to (torch .float ).cuda ()
1616 idx , dist = ball_query (1.01 , 2 , a , b )
17- torch .testing .assert_allclose (idx .long (). cpu (), torch .tensor ([[[0 , 1 ]], [[2 , 2 ]]]))
17+ torch .testing .assert_allclose (idx .cpu (), torch .tensor ([[[0 , 1 ]], [[2 , 2 ]]]))
1818 torch .testing .assert_allclose (dist .cpu (), torch .tensor ([[[0 , 1 ]], [[1 , - 1 ]]]).float ())
1919
2020 def test_simple_cpu (self ):
2121 a = torch .tensor ([[[0 , 0 , 0 ], [1 , 0 , 0 ], [2 , 0 , 0 ]], [[0 , 0 , 0 ], [1 , 0 , 0 ], [2 , 0 , 0 ]]]).to (torch .float )
2222 b = torch .tensor ([[[0 , 0 , 0 ]], [[3 , 0 , 0 ]]]).to (torch .float )
2323 idx , dist = ball_query (1.01 , 2 , a , b )
24- torch .testing .assert_allclose (idx . long () , torch .tensor ([[[0 , 1 ]], [[2 , 2 ]]]))
24+ torch .testing .assert_allclose (idx , torch .tensor ([[[0 , 1 ]], [[2 , 2 ]]]))
2525 torch .testing .assert_allclose (dist , torch .tensor ([[[0 , 1 ]], [[1 , - 1 ]]]).float ())
2626
2727 a = torch .tensor ([[[0 , 0 , 0 ], [1 , 0 , 0 ], [1 , 1 , 0 ]]]).to (torch .float )
2828 idx , dist = ball_query (1.01 , 3 , a , a )
29- torch .testing .assert_allclose (idx . long (), torch .tensor ([[[0 , 1 , 0 ],[1 ,0 , 2 ],[2 ,1 , 2 ]]]))
29+ torch .testing .assert_allclose (idx , torch .tensor ([[[0 , 1 , 0 ], [1 , 0 , 2 ], [2 , 1 , 2 ]]]))
3030
3131 @run_if_cuda
3232 def test_larger_gpu (self ):
3333 a = torch .randn (32 , 4096 , 3 ).to (torch .float ).cuda ()
34- idx ,dist = ball_query (1 , 64 , a , a )
34+ idx , dist = ball_query (1 , 64 , a , a )
3535 self .assertGreaterEqual (idx .min (), 0 )
3636
3737 @run_if_cuda
0 commit comments