Skip to content

Commit 25a31f1

Browse files
authored
Merge pull request #3359 from square/fast_topo_graph
[swift] Improve performance of topological sort algorithm for partitions
2 parents b41991a + 18fe665 commit 25a31f1

File tree

3 files changed

+165
-12
lines changed

3 files changed

+165
-12
lines changed

wire-schema/src/commonMain/kotlin/com/squareup/wire/schema/DirectedAcyclicGraph.kt

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -45,20 +45,40 @@ internal class DirectedAcyclicGraph<N>(
4545
}
4646

4747
fun topologicalOrder(): List<N> {
48-
val seen = LinkedHashSet<N>() // Insertion order is important to produce the final list!
49-
val queue = ArrayDeque<N>().apply { addAll(seeds) }
48+
val incomingEdges = mutableMapOf<N, Int>()
49+
for (vertex in nodes) {
50+
if (vertex !in incomingEdges) {
51+
incomingEdges[vertex] = 0
52+
}
53+
for (edge in edges(vertex)) {
54+
incomingEdges[edge] = (incomingEdges[edge] ?: 0) + 1
55+
}
56+
}
57+
58+
val queue = ArrayDeque<N>()
59+
for ((vertex, edges) in incomingEdges) {
60+
if (edges == 0) queue += vertex
61+
}
62+
63+
val result = mutableListOf<N>()
64+
5065
while (queue.isNotEmpty()) {
51-
val currentName = queue.removeFirst()
52-
val currentDependencies = edges(currentName).toList()
53-
if (seen.containsAll(currentDependencies)) {
54-
seen += currentName
55-
queue += incomingEdges(currentName)
56-
} else {
57-
// All dependencies have not been seen so move to the back of the line to try again later.
58-
queue += currentName
66+
val vertex = queue.removeFirst()
67+
result += vertex
68+
69+
for (edge in edges(vertex)) {
70+
incomingEdges[edge] = (incomingEdges[edge] ?: 0) - 1
71+
if (incomingEdges[edge] == 0) {
72+
queue += edge
73+
}
5974
}
6075
}
61-
return seen.toList()
76+
77+
check(result.size == incomingEdges.size) {
78+
"Graph contains a cycle, topological sort not possible!"
79+
}
80+
81+
return result
6282
}
6383

6484
fun transitiveNodes(node: N): Set<N> {

wire-schema/src/commonMain/kotlin/com/squareup/wire/schema/PartitionedSchema.kt

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,29 @@ internal class PartitionedSchema(
3434
)
3535
}
3636

37+
internal fun computeDepths(modules: Map<String, Module>): Map<String, Int> {
38+
val memo = mutableMapOf<String, Int>()
39+
40+
fun dfs(name: String): Int {
41+
memo[name]?.let { return it }
42+
val depth = 1 + (modules.getValue(name).dependencies.maxOfOrNull { dfs(it) } ?: 0)
43+
memo[name] = depth
44+
return depth
45+
}
46+
47+
return modules.keys.associateWith { dfs(it) }
48+
}
49+
3750
internal fun Schema.partition(modules: Map<String, Module>): PartitionedSchema {
3851
val moduleGraph = DirectedAcyclicGraph(modules.keys) { modules.getValue(it).dependencies }
3952

4053
val errors = mutableListOf<String>()
4154
val partitions = mutableMapOf<String, Partition>()
42-
for (moduleName in moduleGraph.topologicalOrder()) {
55+
val topoGraph = moduleGraph.topologicalOrder()
56+
val depths = computeDepths(modules)
57+
val ordered = topoGraph.sortedBy { depths.getValue(it) }
58+
59+
for (moduleName in ordered) {
4360
val module = modules.getValue(moduleName)
4461

4562
val upstreamTypes = buildMap {
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
/*
2+
* Copyright (C) 2025 Square, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package com.squareup.wire.schema
17+
18+
import assertk.assertThat
19+
import assertk.assertions.containsExactly
20+
import assertk.assertions.hasMessage
21+
import assertk.assertions.isEqualTo
22+
import kotlin.test.Test
23+
import kotlin.test.fail
24+
25+
class DirectedAcyclicGraphTest {
26+
27+
@Test
28+
fun singleNode() {
29+
val nodes = listOf("A")
30+
val graph = DirectedAcyclicGraph(nodes) { emptyList() }
31+
val order = graph.topologicalOrder()
32+
assertThat(order).containsExactly("A")
33+
}
34+
35+
@Test
36+
fun simpleLinearDag() {
37+
// A -> B -> C
38+
val nodes = listOf("A", "B", "C")
39+
val edges = mapOf(
40+
"A" to listOf("B"),
41+
"B" to listOf("C"),
42+
"C" to emptyList(),
43+
)
44+
val graph = DirectedAcyclicGraph(nodes) { edges[it].orEmpty() }
45+
val order = graph.topologicalOrder()
46+
assertThat(order).isEqualTo(listOf("A", "B", "C"))
47+
}
48+
49+
@Test
50+
fun simpleNonLinearDag() {
51+
// A
52+
// C -> B
53+
val nodes = listOf("A", "B", "C")
54+
val edges = mapOf(
55+
"A" to emptyList(),
56+
"B" to emptyList(),
57+
"C" to listOf("B"),
58+
)
59+
val graph = DirectedAcyclicGraph(nodes) { edges[it].orEmpty() }
60+
val order = graph.topologicalOrder()
61+
assertThat(order).isEqualTo(listOf("A", "C", "B"))
62+
}
63+
64+
@Test
65+
fun branchingDag() {
66+
// A
67+
// / \
68+
// B C
69+
// \ /
70+
// D
71+
val nodes = listOf("A", "B", "C", "D")
72+
val edges = mapOf(
73+
"A" to listOf("B", "C"),
74+
"B" to listOf("D"),
75+
"C" to listOf("D"),
76+
"D" to emptyList(),
77+
)
78+
val graph = DirectedAcyclicGraph(nodes) { edges[it].orEmpty() }
79+
val order = graph.topologicalOrder()
80+
assertThat(order).isEqualTo(listOf("A", "B", "C", "D"))
81+
}
82+
83+
@Test
84+
fun multipleRoots() {
85+
// A -> C
86+
// B -> C
87+
val nodes = listOf("A", "B", "C")
88+
val edges = mapOf(
89+
"A" to listOf("C"),
90+
"B" to listOf("C"),
91+
"C" to emptyList(),
92+
)
93+
val graph = DirectedAcyclicGraph(nodes) { edges[it].orEmpty() }
94+
val order = graph.topologicalOrder()
95+
assertThat(order).isEqualTo(listOf("A", "B", "C"))
96+
}
97+
98+
@Test
99+
fun cycleThrowsError() {
100+
// A -> B -> C -> A
101+
val nodes = listOf("A", "B", "C")
102+
val edges = mapOf(
103+
"A" to listOf("B"),
104+
"B" to listOf("C"),
105+
"C" to listOf("A"),
106+
)
107+
val graph = DirectedAcyclicGraph(nodes) { edges[it].orEmpty() }
108+
109+
try {
110+
graph.topologicalOrder()
111+
fail()
112+
} catch (expected: IllegalStateException) {
113+
assertThat(expected).hasMessage("Graph contains a cycle, topological sort not possible!")
114+
}
115+
}
116+
}

0 commit comments

Comments
 (0)