1616class TestBall (unittest .TestCase ):
1717 @run_if_cuda
1818 def test_simple_gpu (self ):
19- a = torch .tensor ([[[0 , 0 , 0 ], [1 , 0 , 0 ], [2 , 0 , 0 ]], [[0 , 0 , 0 ], [1 , 0 , 0 ], [2 , 0 , 0 ]]]).to (torch .float ).cuda ()
19+ a = (
20+ torch .tensor (
21+ [[[0 , 0 , 0 ], [1 , 0 , 0 ], [2 , 0 , 0 ]], [[0 , 0 , 0 ], [1 , 0 , 0 ], [2 , 0 , 0 ]]]
22+ )
23+ .to (torch .float )
24+ .cuda ()
25+ )
2026 b = torch .tensor ([[[0 , 0 , 0 ]], [[3 , 0 , 0 ]]]).to (torch .float ).cuda ()
2127 idx , dist = ball_query (1.01 , 2 , a , b )
2228 torch .testing .assert_allclose (idx .cpu (), torch .tensor ([[[0 , 1 ]], [[2 , 2 ]]]))
23- torch .testing .assert_allclose (dist .cpu (), torch .tensor ([[[0 , 1 ]], [[1 , - 1 ]]]).float ())
29+ torch .testing .assert_allclose (
30+ dist .cpu (), torch .tensor ([[[0 , 1 ]], [[1 , - 1 ]]]).float ()
31+ )
2432
2533 def test_simple_cpu (self ):
26- a = torch .tensor ([[[0 , 0 , 0 ], [1 , 0 , 0 ], [2 , 0 , 0 ]], [[0 , 0 , 0 ], [1 , 0 , 0 ], [2 , 0 , 0 ]]]).to (torch .float )
34+ a = torch .tensor (
35+ [[[0 , 0 , 0 ], [1 , 0 , 0 ], [2 , 0 , 0 ]], [[0 , 0 , 0 ], [1 , 0 , 0 ], [2 , 0 , 0 ]]]
36+ ).to (torch .float )
2737 b = torch .tensor ([[[0 , 0 , 0 ]], [[3 , 0 , 0 ]]]).to (torch .float )
2838 idx , dist = ball_query (1.01 , 2 , a , b , sort = True )
2939 torch .testing .assert_allclose (idx , torch .tensor ([[[0 , 1 ]], [[2 , 2 ]]]))
3040 torch .testing .assert_allclose (dist , torch .tensor ([[[0 , 1 ]], [[1 , - 1 ]]]).float ())
3141
3242 a = torch .tensor ([[[0 , 0 , 0 ], [1 , 0 , 0 ], [1 , 1 , 0 ]]]).to (torch .float )
3343 idx , dist = ball_query (1.01 , 3 , a , a , sort = True )
34- torch .testing .assert_allclose (idx , torch .tensor ([[[0 , 1 , 0 ], [1 , 0 , 2 ], [2 , 1 , 2 ]]]))
44+ torch .testing .assert_allclose (
45+ idx , torch .tensor ([[[0 , 1 , 0 ], [1 , 0 , 2 ], [2 , 1 , 2 ]]])
46+ )
3547
3648 @run_if_cuda
3749 def test_larger_gpu (self ):
@@ -61,33 +73,40 @@ def test_cpu_gpu_equality(self):
6173class TestBallPartial (unittest .TestCase ):
6274 @run_if_cuda
6375 def test_simple_gpu (self ):
64- x = torch .tensor ([[10 , 0 , 0 ], [0.1 , 0 , 0 ], [10 , 0 , 0 ], [0.1 , 0 , 0 ]]).to (torch .float ).cuda ()
76+ x = (
77+ torch .tensor ([[10 , 0 , 0 ], [0.1 , 0 , 0 ], [0.2 , 0 , 0 ], [0.1 , 0 , 0 ]])
78+ .to (torch .float )
79+ .cuda ()
80+ )
6581 y = torch .tensor ([[0 , 0 , 0 ]]).to (torch .float ).cuda ()
66- batch_x = torch .from_numpy (np .asarray ([0 , 0 , 1 , 1 ])).long ().cuda ()
82+ batch_x = torch .from_numpy (np .asarray ([0 , 0 , 0 , 1 ])).long ().cuda ()
6783 batch_y = torch .from_numpy (np .asarray ([0 ])).long ().cuda ()
6884
69- batch_x = torch .from_numpy (np .asarray ([0 , 0 , 1 , 1 ])).long ().cuda ()
70- batch_y = torch .from_numpy (np .asarray ([0 ])).long ().cuda ()
71-
72- idx , dist2 = ball_query (1.0 , 2 , x , y , mode = "PARTIAL_DENSE" , batch_x = batch_x , batch_y = batch_y )
85+ idx , dist2 = ball_query (
86+ 0.2 , 4 , x , y , mode = "PARTIAL_DENSE" , batch_x = batch_x , batch_y = batch_y
87+ )
7388
7489 idx = idx .detach ().cpu ().numpy ()
7590 dist2 = dist2 .detach ().cpu ().numpy ()
7691
77- idx_answer = np .asarray ([[1 , - 1 ]])
78- dist2_answer = np .asarray ([[0.0100 , - 1.0000 ]]).astype (np .float32 )
92+ idx_answer = np .asarray ([[1 , 2 , - 1 , - 1 ]])
93+ dist2_answer = np .asarray ([[0.0100 ,0.04 , - 1 , - 1 ]]).astype (np .float32 )
7994
8095 npt .assert_array_almost_equal (idx , idx_answer )
8196 npt .assert_array_almost_equal (dist2 , dist2_answer )
8297
8398 def test_simple_cpu (self ):
84- x = torch .tensor ([[10 , 0 , 0 ], [0.1 , 0 , 0 ], [10 , 0 , 0 ], [10.1 , 0 , 0 ]]).to (torch .float )
99+ x = torch .tensor ([[10 , 0 , 0 ], [0.1 , 0 , 0 ], [10 , 0 , 0 ], [10.1 , 0 , 0 ]]).to (
100+ torch .float
101+ )
85102 y = torch .tensor ([[0 , 0 , 0 ]]).to (torch .float )
86103
87104 batch_x = torch .from_numpy (np .asarray ([0 , 0 , 0 , 0 ])).long ()
88105 batch_y = torch .from_numpy (np .asarray ([0 ])).long ()
89106
90- idx , dist2 = ball_query (1.0 , 2 , x , y , mode = "PARTIAL_DENSE" , batch_x = batch_x , batch_y = batch_y )
107+ idx , dist2 = ball_query (
108+ 1.0 , 2 , x , y , mode = "PARTIAL_DENSE" , batch_x = batch_x , batch_y = batch_y
109+ )
91110
92111 idx = idx .detach ().cpu ().numpy ()
93112 dist2 = dist2 .detach ().cpu ().numpy ()
@@ -98,30 +117,75 @@ def test_simple_cpu(self):
98117 npt .assert_array_almost_equal (idx , idx_answer )
99118 npt .assert_array_almost_equal (dist2 , dist2_answer )
100119
101-
102120 def test_breaks (self ):
103- x = torch .tensor ([[10 , 0 , 0 ], [0.1 , 0 , 0 ], [10 , 0 , 0 ], [10.1 , 0 , 0 ]]).to (torch .float )
121+ x = torch .tensor ([[10 , 0 , 0 ], [0.1 , 0 , 0 ], [10 , 0 , 0 ], [10.1 , 0 , 0 ]]).to (
122+ torch .float
123+ )
104124 y = torch .tensor ([[0 , 0 , 0 ]]).to (torch .float )
105125
106126 batch_x = torch .from_numpy (np .asarray ([0 , 0 , 1 , 1 ])).long ()
107127 batch_y = torch .from_numpy (np .asarray ([0 ])).long ()
108-
128+
109129 with self .assertRaises (RuntimeError ):
110- idx , dist2 = ball_query (1.0 , 2 , x , y , mode = "PARTIAL_DENSE" , batch_x = batch_x , batch_y = batch_y )
130+ idx , dist2 = ball_query (
131+ 1.0 , 2 , x , y , mode = "PARTIAL_DENSE" , batch_x = batch_x , batch_y = batch_y
132+ )
111133
112- def test_random_cpu (self ):
134+ def test_random_cpu (self , cuda = False ):
113135 a = torch .randn (100 , 3 ).to (torch .float )
114136 b = torch .randn (50 , 3 ).to (torch .float )
115- batch_a = torch .tensor ([0 for i in range (a .shape [0 ] // 2 )] + [1 for i in range (a .shape [0 ] // 2 , a .shape [0 ])])
116- batch_b = torch .tensor ([0 for i in range (b .shape [0 ] // 2 )] + [1 for i in range (b .shape [0 ] // 2 , b .shape [0 ])])
137+ batch_a = torch .tensor (
138+ [0 for i in range (a .shape [0 ] // 2 )]
139+ + [1 for i in range (a .shape [0 ] // 2 , a .shape [0 ])]
140+ )
141+ batch_b = torch .tensor (
142+ [0 for i in range (b .shape [0 ] // 2 )]
143+ + [1 for i in range (b .shape [0 ] // 2 , b .shape [0 ])]
144+ )
117145 R = 1
118146
119- idx , dist = ball_query (R , 15 , a , b , mode = "PARTIAL_DENSE" , batch_x = batch_a , batch_y = batch_b , sort = True )
120- idx1 , dist = ball_query (R , 15 , a , b , mode = "PARTIAL_DENSE" , batch_x = batch_a , batch_y = batch_b , sort = True )
147+ idx , dist = ball_query (
148+ R ,
149+ 15 ,
150+ a ,
151+ b ,
152+ mode = "PARTIAL_DENSE" ,
153+ batch_x = batch_a ,
154+ batch_y = batch_b ,
155+ sort = True ,
156+ )
157+ idx1 , dist = ball_query (
158+ R ,
159+ 15 ,
160+ a ,
161+ b ,
162+ mode = "PARTIAL_DENSE" ,
163+ batch_x = batch_a ,
164+ batch_y = batch_b ,
165+ sort = True ,
166+ )
121167 torch .testing .assert_allclose (idx1 , idx )
122168 with self .assertRaises (AssertionError ):
123- idx , dist = ball_query (R , 15 , a , b , mode = "PARTIAL_DENSE" , batch_x = batch_a , batch_y = batch_b , sort = False )
124- idx1 , dist = ball_query (R , 15 , a , b , mode = "PARTIAL_DENSE" , batch_x = batch_a , batch_y = batch_b , sort = False )
169+ idx , dist = ball_query (
170+ R ,
171+ 15 ,
172+ a ,
173+ b ,
174+ mode = "PARTIAL_DENSE" ,
175+ batch_x = batch_a ,
176+ batch_y = batch_b ,
177+ sort = False ,
178+ )
179+ idx1 , dist = ball_query (
180+ R ,
181+ 15 ,
182+ a ,
183+ b ,
184+ mode = "PARTIAL_DENSE" ,
185+ batch_x = batch_a ,
186+ batch_y = batch_b ,
187+ sort = False ,
188+ )
125189 torch .testing .assert_allclose (idx1 , idx )
126190
127191 self .assertEqual (idx .shape [0 ], b .shape [0 ])
@@ -136,6 +200,38 @@ def test_random_cpu(self):
136200 if p >= 0 and p < len (batch_a ):
137201 assert p in idx3_sk [i ]
138202
203+ def test_random_gpu (self ):
204+ a = torch .randn (100 , 3 ).to (torch .float ).cuda ()
205+ b = torch .randn (50 , 3 ).to (torch .float ).cuda ()
206+ batch_a = torch .tensor (
207+ [0 for i in range (a .shape [0 ] // 2 )]
208+ + [1 for i in range (a .shape [0 ] // 2 , a .shape [0 ])]
209+ ).cuda ()
210+ batch_b = torch .tensor (
211+ [0 for i in range (b .shape [0 ] // 2 )]
212+ + [1 for i in range (b .shape [0 ] // 2 , b .shape [0 ])]
213+ ).cuda ()
214+ R = 1
215+
216+ idx , dist = ball_query (
217+ R ,
218+ 15 ,
219+ a ,
220+ b ,
221+ mode = "PARTIAL_DENSE" ,
222+ batch_x = batch_a ,
223+ batch_y = batch_b ,
224+ sort = False ,
225+ )
226+
227+ # Comparison to see if we have the same result
228+ tree = KDTree (a .cpu ().detach ().numpy ())
229+ idx3_sk = tree .query_radius (b .cpu ().detach ().numpy (), r = R )
230+ i = np .random .randint (len (batch_b ))
231+ for p in idx [i ].cpu ().detach ().numpy ():
232+ if p >= 0 and p < len (batch_a ):
233+ assert p in idx3_sk [i ]
234+
139235
140236if __name__ == "__main__" :
141237 unittest .main ()
0 commit comments