-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcluster.py
More file actions
100 lines (78 loc) · 3.45 KB
/
cluster.py
File metadata and controls
100 lines (78 loc) · 3.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch
import numpy as np
def constrained_agglomerative_clustering(
probs: torch.Tensor,
edge_index: torch.Tensor,
demands: torch.Tensor,
capacity: float,
prob_threshold: float = 0.5
):
"""
Greedy Constrained Agglomerative Clustering based on GNN edge probabilities.
Args:
probs (Tensor): [num_edges] probabilities from GNN.
edge_index (Tensor): [2, num_edges] matching indices.
demands (Tensor): [num_nodes] array of demands.
capacity (float): Maximum allowed capacity per route.
prob_threshold (float): Minimum probability required to merge clusters.
Returns:
List[List[int]]: A list of clusters, where each cluster is a list of node IDs.
"""
num_nodes = len(demands)
# 1. Initialize clusters: Each node (except depot 0) is its own cluster
# Node -> Cluster ID mapping
node_to_cluster = {i: i for i in range(1, num_nodes)}
# Cluster ID -> List of Nodes and Total Demand
clusters = {i: {"nodes": [i], "demand": float(demands[i])} for i in range(1, num_nodes)}
# 2. Sort edges by probability descending
sorted_indices = torch.argsort(probs, descending=True)
# 3. Greedily merge clusters
for idx in sorted_indices:
p = probs[idx].item()
# Stop merging if probability is too low
if p < prob_threshold:
break
u, v = int(edge_index[0, idx]), int(edge_index[1, idx])
# We don't cluster the depot
if u == 0 or v == 0:
continue
c_u = node_to_cluster[u]
c_v = node_to_cluster[v]
# If they are already in the same cluster, skip
if c_u == c_v:
continue
# Check capacity constraint
demand_u = clusters[c_u]["demand"]
demand_v = clusters[c_v]["demand"]
if demand_u + demand_v <= capacity:
# Merge cluster V into cluster U
clusters[c_u]["nodes"].extend(clusters[c_v]["nodes"])
clusters[c_u]["demand"] += demand_v
# Update mappings
for node in clusters[c_v]["nodes"]:
node_to_cluster[node] = c_u
# Remove depleted cluster
del clusters[c_v]
# Format the output into a list of routes
routes = [cluster_info["nodes"] for cluster_info in clusters.values()]
return routes
if __name__ == "__main__":
# Simple Synthetic Test
print("Testing Constrained Agglomerative Clustering...")
# Needs 4 nodes (0 is depot, 1, 2, 3 are customers)
demands = torch.tensor([0.0, 10.0, 15.0, 20.0])
capacity = 30.0
# Edges: 1-2 (high prob), 2-3 (high prob), 1-3 (low prob)
edge_index = torch.tensor([
[1, 2, 1],
[2, 3, 3]
])
probs = torch.tensor([0.9, 0.8, 0.1])
# Expected behavior:
# 1. Edge 1-2 (prob 0.9): Node 1 (dem 10) + Node 2 (dem 15) = 25 <= 30. Merge!
# 2. Edge 2-3 (prob 0.8): Node 2 is in cluster {1,2} (dem 25). Cluster {3} (dem 20). 25 + 20 = 45 > 30. REJECT!
# 3. Edge 1-3 (prob 0.1): Below threshold. Stop.
routes = constrained_agglomerative_clustering(probs, edge_index, demands, capacity)
print(f"Final Routes: {routes}")
assert set(tuple(sorted(r)) for r in routes) == {(1, 2), (3,)}, "Clustering logic failed!"
print("Test Passed: Constrained Clustering correctly blocks over-capacity merges!")