1616class TestBall (unittest .TestCase ):
1717 @run_if_cuda
1818 def test_simple_gpu (self ):
19- a = (
20- torch .tensor (
21- [[[0 , 0 , 0 ], [1 , 0 , 0 ], [2 , 0 , 0 ]], [[0 , 0 , 0 ], [1 , 0 , 0 ], [2 , 0 , 0 ]]]
22- )
23- .to (torch .float )
24- .cuda ()
25- )
19+ a = torch .tensor ([[[0 , 0 , 0 ], [1 , 0 , 0 ], [2 , 0 , 0 ]], [[0 , 0 , 0 ], [1 , 0 , 0 ], [2 , 0 , 0 ]]]).to (torch .float ).cuda ()
2620 b = torch .tensor ([[[0 , 0 , 0 ]], [[3 , 0 , 0 ]]]).to (torch .float ).cuda ()
2721 idx , dist = ball_query (1.01 , 2 , a , b )
2822 torch .testing .assert_allclose (idx .cpu (), torch .tensor ([[[0 , 1 ]], [[2 , 2 ]]]))
29- torch .testing .assert_allclose (
30- dist .cpu (), torch .tensor ([[[0 , 1 ]], [[1 , - 1 ]]]).float ()
31- )
23+ torch .testing .assert_allclose (dist .cpu (), torch .tensor ([[[0 , 1 ]], [[1 , - 1 ]]]).float ())
3224
3325 def test_simple_cpu (self ):
34- a = torch .tensor (
35- [[[0 , 0 , 0 ], [1 , 0 , 0 ], [2 , 0 , 0 ]], [[0 , 0 , 0 ], [1 , 0 , 0 ], [2 , 0 , 0 ]]]
36- ).to (torch .float )
26+ a = torch .tensor ([[[0 , 0 , 0 ], [1 , 0 , 0 ], [2 , 0 , 0 ]], [[0 , 0 , 0 ], [1 , 0 , 0 ], [2 , 0 , 0 ]]]).to (torch .float )
3727 b = torch .tensor ([[[0 , 0 , 0 ]], [[3 , 0 , 0 ]]]).to (torch .float )
3828 idx , dist = ball_query (1.01 , 2 , a , b , sort = True )
3929 torch .testing .assert_allclose (idx , torch .tensor ([[[0 , 1 ]], [[2 , 2 ]]]))
4030 torch .testing .assert_allclose (dist , torch .tensor ([[[0 , 1 ]], [[1 , - 1 ]]]).float ())
4131
4232 a = torch .tensor ([[[0 , 0 , 0 ], [1 , 0 , 0 ], [1 , 1 , 0 ]]]).to (torch .float )
4333 idx , dist = ball_query (1.01 , 3 , a , a , sort = True )
44- torch .testing .assert_allclose (
45- idx , torch .tensor ([[[0 , 1 , 0 ], [1 , 0 , 2 ], [2 , 1 , 2 ]]])
46- )
34+ torch .testing .assert_allclose (idx , torch .tensor ([[[0 , 1 , 0 ], [1 , 0 , 2 ], [2 , 1 , 2 ]]]))
4735
4836 @run_if_cuda
4937 def test_larger_gpu (self ):
@@ -73,40 +61,30 @@ def test_cpu_gpu_equality(self):
7361class TestBallPartial (unittest .TestCase ):
7462 @run_if_cuda
7563 def test_simple_gpu (self ):
76- x = (
77- torch .tensor ([[10 , 0 , 0 ], [0.1 , 0 , 0 ], [0.2 , 0 , 0 ], [0.1 , 0 , 0 ]])
78- .to (torch .float )
79- .cuda ()
80- )
64+ x = torch .tensor ([[10 , 0 , 0 ], [0.1 , 0 , 0 ], [0.2 , 0 , 0 ], [0.1 , 0 , 0 ]]).to (torch .float ).cuda ()
8165 y = torch .tensor ([[0 , 0 , 0 ]]).to (torch .float ).cuda ()
8266 batch_x = torch .from_numpy (np .asarray ([0 , 0 , 0 , 1 ])).long ().cuda ()
8367 batch_y = torch .from_numpy (np .asarray ([0 ])).long ().cuda ()
8468
85- idx , dist2 = ball_query (
86- 0.2 , 4 , x , y , mode = "PARTIAL_DENSE" , batch_x = batch_x , batch_y = batch_y
87- )
69+ idx , dist2 = ball_query (0.2 , 4 , x , y , mode = "PARTIAL_DENSE" , batch_x = batch_x , batch_y = batch_y )
8870
8971 idx = idx .detach ().cpu ().numpy ()
9072 dist2 = dist2 .detach ().cpu ().numpy ()
9173
9274 idx_answer = np .asarray ([[1 , 2 , - 1 , - 1 ]])
93- dist2_answer = np .asarray ([[0.0100 ,0.04 ,- 1 ,- 1 ]]).astype (np .float32 )
75+ dist2_answer = np .asarray ([[0.0100 , 0.04 , - 1 , - 1 ]]).astype (np .float32 )
9476
9577 npt .assert_array_almost_equal (idx , idx_answer )
9678 npt .assert_array_almost_equal (dist2 , dist2_answer )
9779
9880 def test_simple_cpu (self ):
99- x = torch .tensor ([[10 , 0 , 0 ], [0.1 , 0 , 0 ], [10 , 0 , 0 ], [10.1 , 0 , 0 ]]).to (
100- torch .float
101- )
81+ x = torch .tensor ([[10 , 0 , 0 ], [0.1 , 0 , 0 ], [10 , 0 , 0 ], [10.1 , 0 , 0 ]]).to (torch .float )
10282 y = torch .tensor ([[0 , 0 , 0 ]]).to (torch .float )
10383
10484 batch_x = torch .from_numpy (np .asarray ([0 , 0 , 0 , 0 ])).long ()
10585 batch_y = torch .from_numpy (np .asarray ([0 ])).long ()
10686
107- idx , dist2 = ball_query (
108- 1.0 , 2 , x , y , mode = "PARTIAL_DENSE" , batch_x = batch_x , batch_y = batch_y
109- )
87+ idx , dist2 = ball_query (1.0 , 2 , x , y , mode = "PARTIAL_DENSE" , batch_x = batch_x , batch_y = batch_y )
11088
11189 idx = idx .detach ().cpu ().numpy ()
11290 dist2 = dist2 .detach ().cpu ().numpy ()
@@ -118,74 +96,28 @@ def test_simple_cpu(self):
11896 npt .assert_array_almost_equal (dist2 , dist2_answer )
11997
12098 def test_breaks (self ):
121- x = torch .tensor ([[10 , 0 , 0 ], [0.1 , 0 , 0 ], [10 , 0 , 0 ], [10.1 , 0 , 0 ]]).to (
122- torch .float
123- )
99+ x = torch .tensor ([[10 , 0 , 0 ], [0.1 , 0 , 0 ], [10 , 0 , 0 ], [10.1 , 0 , 0 ]]).to (torch .float )
124100 y = torch .tensor ([[0 , 0 , 0 ]]).to (torch .float )
125101
126102 batch_x = torch .from_numpy (np .asarray ([0 , 0 , 1 , 1 ])).long ()
127103 batch_y = torch .from_numpy (np .asarray ([0 ])).long ()
128104
129105 with self .assertRaises (RuntimeError ):
130- idx , dist2 = ball_query (
131- 1.0 , 2 , x , y , mode = "PARTIAL_DENSE" , batch_x = batch_x , batch_y = batch_y
132- )
106+ idx , dist2 = ball_query (1.0 , 2 , x , y , mode = "PARTIAL_DENSE" , batch_x = batch_x , batch_y = batch_y )
133107
134108 def test_random_cpu (self , cuda = False ):
135109 a = torch .randn (100 , 3 ).to (torch .float )
136110 b = torch .randn (50 , 3 ).to (torch .float )
137- batch_a = torch .tensor (
138- [0 for i in range (a .shape [0 ] // 2 )]
139- + [1 for i in range (a .shape [0 ] // 2 , a .shape [0 ])]
140- )
141- batch_b = torch .tensor (
142- [0 for i in range (b .shape [0 ] // 2 )]
143- + [1 for i in range (b .shape [0 ] // 2 , b .shape [0 ])]
144- )
111+ batch_a = torch .tensor ([0 for i in range (a .shape [0 ] // 2 )] + [1 for i in range (a .shape [0 ] // 2 , a .shape [0 ])])
112+ batch_b = torch .tensor ([0 for i in range (b .shape [0 ] // 2 )] + [1 for i in range (b .shape [0 ] // 2 , b .shape [0 ])])
145113 R = 1
146114
147- idx , dist = ball_query (
148- R ,
149- 15 ,
150- a ,
151- b ,
152- mode = "PARTIAL_DENSE" ,
153- batch_x = batch_a ,
154- batch_y = batch_b ,
155- sort = True ,
156- )
157- idx1 , dist = ball_query (
158- R ,
159- 15 ,
160- a ,
161- b ,
162- mode = "PARTIAL_DENSE" ,
163- batch_x = batch_a ,
164- batch_y = batch_b ,
165- sort = True ,
166- )
115+ idx , dist = ball_query (R , 15 , a , b , mode = "PARTIAL_DENSE" , batch_x = batch_a , batch_y = batch_b , sort = True ,)
116+ idx1 , dist = ball_query (R , 15 , a , b , mode = "PARTIAL_DENSE" , batch_x = batch_a , batch_y = batch_b , sort = True ,)
167117 torch .testing .assert_allclose (idx1 , idx )
168118 with self .assertRaises (AssertionError ):
169- idx , dist = ball_query (
170- R ,
171- 15 ,
172- a ,
173- b ,
174- mode = "PARTIAL_DENSE" ,
175- batch_x = batch_a ,
176- batch_y = batch_b ,
177- sort = False ,
178- )
179- idx1 , dist = ball_query (
180- R ,
181- 15 ,
182- a ,
183- b ,
184- mode = "PARTIAL_DENSE" ,
185- batch_x = batch_a ,
186- batch_y = batch_b ,
187- sort = False ,
188- )
119+ idx , dist = ball_query (R , 15 , a , b , mode = "PARTIAL_DENSE" , batch_x = batch_a , batch_y = batch_b , sort = False ,)
120+ idx1 , dist = ball_query (R , 15 , a , b , mode = "PARTIAL_DENSE" , batch_x = batch_a , batch_y = batch_b , sort = False ,)
189121 torch .testing .assert_allclose (idx1 , idx )
190122
191123 self .assertEqual (idx .shape [0 ], b .shape [0 ])
@@ -200,29 +132,19 @@ def test_random_cpu(self, cuda=False):
200132 if p >= 0 and p < len (batch_a ):
201133 assert p in idx3_sk [i ]
202134
135+ @run_if_cuda
203136 def test_random_gpu (self ):
204137 a = torch .randn (100 , 3 ).to (torch .float ).cuda ()
205138 b = torch .randn (50 , 3 ).to (torch .float ).cuda ()
206139 batch_a = torch .tensor (
207- [0 for i in range (a .shape [0 ] // 2 )]
208- + [1 for i in range (a .shape [0 ] // 2 , a .shape [0 ])]
140+ [0 for i in range (a .shape [0 ] // 2 )] + [1 for i in range (a .shape [0 ] // 2 , a .shape [0 ])]
209141 ).cuda ()
210142 batch_b = torch .tensor (
211- [0 for i in range (b .shape [0 ] // 2 )]
212- + [1 for i in range (b .shape [0 ] // 2 , b .shape [0 ])]
143+ [0 for i in range (b .shape [0 ] // 2 )] + [1 for i in range (b .shape [0 ] // 2 , b .shape [0 ])]
213144 ).cuda ()
214145 R = 1
215146
216- idx , dist = ball_query (
217- R ,
218- 15 ,
219- a ,
220- b ,
221- mode = "PARTIAL_DENSE" ,
222- batch_x = batch_a ,
223- batch_y = batch_b ,
224- sort = False ,
225- )
147+ idx , dist = ball_query (R , 15 , a , b , mode = "PARTIAL_DENSE" , batch_x = batch_a , batch_y = batch_b , sort = False ,)
226148
227149 # Comparison to see if we have the same result
228150 tree = KDTree (a .cpu ().detach ().numpy ())
0 commit comments