@@ -34,18 +34,16 @@ def test_simple(self):
3434 self .assertEqual (clusters , [[0 , 1 , 2 ], [4 , 5 , 6 ]])
3535
3636 def test_region_grow (self ):
37- clusters = region_grow (
37+ cluster_idx = region_grow (
3838 self .pos , self .labels , self .batch , radius = 2 , min_cluster_size = 1
3939 )
40- self .assertEqual (len (clusters [0 ]), 2 )
41- self .assertEqual (len (clusters [1 ]), 3 )
42- self .assertEqual (len (clusters [10 ]), 1 )
43- torch .testing .assert_allclose (clusters [0 ][0 ], torch .tensor ([0 , 1 ]))
44- torch .testing .assert_allclose (clusters [0 ][1 ], torch .tensor ([4 ]))
45- torch .testing .assert_allclose (clusters [1 ][0 ], torch .tensor ([2 ]))
46- torch .testing .assert_allclose (clusters [1 ][1 ], torch .tensor ([3 ]))
47- torch .testing .assert_allclose (clusters [1 ][2 ], torch .tensor ([5 , 6 ]))
48- torch .testing .assert_allclose (clusters [10 ][0 ], torch .tensor ([7 ]))
40+ self .assertEqual (len (cluster_idx ), 6 )
41+ torch .testing .assert_allclose (cluster_idx [0 ], torch .tensor ([0 , 1 ]))
42+ torch .testing .assert_allclose (cluster_idx [1 ], torch .tensor ([4 ]))
43+ torch .testing .assert_allclose (cluster_idx [2 ], torch .tensor ([2 ]))
44+ torch .testing .assert_allclose (cluster_idx [3 ], torch .tensor ([3 ]))
45+ torch .testing .assert_allclose (cluster_idx [4 ], torch .tensor ([5 , 6 ]))
46+ torch .testing .assert_allclose (cluster_idx [5 ], torch .tensor ([7 ]))
4947
5048
5149if __name__ == "__main__" :
0 commit comments