@@ -53,14 +53,41 @@ def linkage_merge_join(linkage_merge, left, right):
53
53
54
54
55
55
@numba .njit ()
56
- def mst_to_linkage_tree (sorted_mst , sample_weights = None ):
56
+ def mst_to_linkage_tree (sorted_mst ):
57
57
result = np .empty ((sorted_mst .shape [0 ], sorted_mst .shape [1 ] + 1 ))
58
58
59
59
n_samples = sorted_mst .shape [0 ] + 1
60
- if sample_weights is None :
61
- linkage_merge = create_linkage_merge_data (n_samples )
62
- else :
63
- linkage_merge = create_linkage_merge_data_w_sample_weights (sample_weights )
60
+ linkage_merge = create_linkage_merge_data (n_samples )
61
+
62
+ for index in range (sorted_mst .shape [0 ]):
63
+
64
+ left = np .intp (sorted_mst [index , 0 ])
65
+ right = np .intp (sorted_mst [index , 1 ])
66
+ delta = sorted_mst [index , 2 ]
67
+
68
+ left_component = linkage_merge_find (linkage_merge , left )
69
+ right_component = linkage_merge_find (linkage_merge , right )
70
+
71
+ if left_component > right_component :
72
+ result [index ][0 ] = left_component
73
+ result [index ][1 ] = right_component
74
+ else :
75
+ result [index ][1 ] = left_component
76
+ result [index ][0 ] = right_component
77
+
78
+ result [index ][2 ] = delta
79
+ result [index ][3 ] = linkage_merge .size [left_component ] + linkage_merge .size [right_component ]
80
+
81
+ linkage_merge_join (linkage_merge , left_component , right_component )
82
+
83
+ return result
84
+
85
+
86
+ @numba .njit ()
87
+ def mst_to_linkage_tree_w_sample_weights (sorted_mst , sample_weights ):
88
+ result = np .empty ((sorted_mst .shape [0 ], sorted_mst .shape [1 ] + 1 ))
89
+
90
+ linkage_merge = create_linkage_merge_data_w_sample_weights (sample_weights )
64
91
65
92
for index in range (sorted_mst .shape [0 ]):
66
93
0 commit comments