Skip to content

Commit eb25885

Browse files
committed
Replace container/heap with native implementation.
1 parent d79be35 commit eb25885

File tree

3 files changed

+83
-49
lines changed

3 files changed

+83
-49
lines changed

heap.go

Lines changed: 75 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,67 +1,109 @@
11
package graph
22

3-
import (
4-
"container/heap"
5-
)
6-
73
type prioQueue struct {
84
heap []int // vertices in heap order
95
index []int // index of each vertex in the heap
106
cost []int64
117
}
128

13-
func emptyQueue(cost []int64) *prioQueue {
9+
func emptyPrioQueue(cost []int64) *prioQueue {
1410
return &prioQueue{
1511
index: make([]int, len(cost)),
1612
cost: cost,
1713
}
1814
}
1915

20-
func newQueue(cost []int64) *prioQueue {
16+
func newPrioQueue(cost []int64) *prioQueue {
2117
n := len(cost)
22-
h := &prioQueue{
18+
q := &prioQueue{
2319
heap: make([]int, n),
2420
index: make([]int, n),
2521
cost: cost,
2622
}
27-
for i := range h.heap {
28-
h.heap[i] = i
29-
h.index[i] = i
23+
for i := range q.heap {
24+
q.heap[i] = i
25+
q.index[i] = i
3026
}
31-
return h
27+
return q
3228
}
3329

34-
func (m *prioQueue) Len() int { return len(m.heap) }
30+
// Len returns the number of elements in the queue.
31+
func (q *prioQueue) Len() int {
32+
return len(q.heap)
33+
}
3534

36-
func (m *prioQueue) Less(i, j int) bool {
37-
return m.cost[m.heap[i]] < m.cost[m.heap[j]]
35+
// Push pushes v onto the queue.
36+
// The time complexity is O(log n) where n = q.Len().
37+
func (q *prioQueue) Push(v int) {
38+
n := q.Len()
39+
q.heap = append(q.heap, v)
40+
q.index[v] = n
41+
q.up(n)
3842
}
3943

40-
func (m *prioQueue) Swap(i, j int) {
41-
m.heap[i], m.heap[j] = m.heap[j], m.heap[i]
42-
m.index[m.heap[i]] = i
43-
m.index[m.heap[j]] = j
44+
// Pop removes the minimum element from the queue and returns it.
45+
// The time complexity is O(log n) where n = q.Len().
46+
func (q *prioQueue) Pop() int {
47+
n := q.Len() - 1
48+
q.swap(0, n)
49+
q.down(0, n)
50+
51+
v := q.heap[n]
52+
q.index[v] = -1
53+
q.heap = q.heap[:n]
54+
return v
4455
}
4556

46-
func (pq *prioQueue) Push(x interface{}) {
47-
n := len(pq.heap)
48-
v := x.(int)
49-
pq.heap = append(pq.heap, v)
50-
pq.index[v] = n
57+
// Contains tells whether v is in the queue.
58+
func (q *prioQueue) Contains(v int) bool {
59+
return q.index[v] >= 0
5160
}
5261

53-
func (m *prioQueue) Pop() interface{} {
54-
n := len(m.heap) - 1
55-
v := m.heap[n]
56-
m.index[v] = -1
57-
m.heap = m.heap[:n]
58-
return v
62+
// Fix re-establishes the ordering after the cost for v has changed.
63+
// The time complexity is O(log n) where n = q.Len().
64+
func (q *prioQueue) Fix(v int) {
65+
if i := q.index[v]; !q.down(i, q.Len()) {
66+
q.up(i)
67+
}
5968
}
6069

61-
func (m *prioQueue) Update(v int) {
62-
heap.Fix(m, m.index[v])
70+
func (q *prioQueue) less(i, j int) bool {
71+
return q.cost[q.heap[i]] < q.cost[q.heap[j]]
6372
}
6473

65-
func (m *prioQueue) Contains(v int) bool {
66-
return m.index[v] >= 0
74+
func (q *prioQueue) swap(i, j int) {
75+
q.heap[i], q.heap[j] = q.heap[j], q.heap[i]
76+
q.index[q.heap[i]] = i
77+
q.index[q.heap[j]] = j
78+
}
79+
80+
func (q *prioQueue) up(j int) {
81+
for {
82+
i := (j - 1) / 2 // parent
83+
if i == j || !q.less(j, i) {
84+
break
85+
}
86+
q.swap(i, j)
87+
j = i
88+
}
89+
}
90+
91+
func (q *prioQueue) down(i0, n int) bool {
92+
i := i0
93+
for {
94+
j1 := 2*i + 1
95+
if j1 >= n || j1 < 0 { // j1 < 0 after int overflow
96+
break
97+
}
98+
j := j1 // left child
99+
if j2 := j1 + 1; j2 < n && q.less(j2, j1) {
100+
j = j2 // = 2*i + 2 // right child
101+
}
102+
if !q.less(j, i) {
103+
break
104+
}
105+
q.swap(i, j)
106+
i = j
107+
}
108+
return i > i0
67109
}

mst.go

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
11
package graph
22

3-
import (
4-
"container/heap"
5-
)
6-
73
// MST computes a minimum spanning tree for each connected component
84
// of an undirected weighted graph.
95
// The forest of spanning trees is returned as a slice of parent pointers:
@@ -22,13 +18,13 @@ func MST(g Iterator) (parent []int) {
2218
}
2319

2420
// Prim's algorithm
25-
Q := newQueue(cost)
21+
Q := newPrioQueue(cost)
2622
for Q.Len() > 0 {
27-
v := heap.Pop(Q).(int)
23+
v := Q.Pop()
2824
g.Visit(v, func(w int, c int64) (skip bool) {
2925
if Q.Contains(w) && c < cost[w] {
3026
cost[w] = c
31-
Q.Update(w)
27+
Q.Fix(w)
3228
parent[w] = v
3329
}
3430
return

path.go

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
11
package graph
22

3-
import (
4-
"container/heap"
5-
)
6-
73
// ShortestPath computes a shortest path from v to w.
84
// Only edges with non-negative costs are included.
95
// The number dist is the length of the path, or -1 if w cannot be reached.
@@ -44,10 +40,10 @@ func ShortestPaths(g Iterator, v int) (parent []int, dist []int64) {
4440
dist[v] = 0
4541

4642
// Dijkstra's algorithm
47-
Q := emptyQueue(dist)
48-
heap.Push(Q, v)
43+
Q := emptyPrioQueue(dist)
44+
Q.Push(v)
4945
for Q.Len() > 0 {
50-
v = heap.Pop(Q).(int)
46+
v := Q.Pop()
5147
g.Visit(v, func(w int, d int64) (skip bool) {
5248
if d < 0 {
5349
return
@@ -56,10 +52,10 @@ func ShortestPaths(g Iterator, v int) (parent []int, dist []int64) {
5652
switch {
5753
case dist[w] == -1:
5854
dist[w], parent[w] = alt, v
59-
heap.Push(Q, w)
55+
Q.Push(w)
6056
case alt < dist[w]:
6157
dist[w], parent[w] = alt, v
62-
Q.Update(w)
58+
Q.Fix(w)
6359
}
6460
return
6561
})

0 commit comments

Comments
 (0)