Skip to content

Commit 9322d4b

Browse files
authored
Add previously failing case, introduce stochastic gradient descent optimizer, make scale-invariant via initial hop distance calculation (#15)
* notes from testing * wip * compute hops to end and use as distance measure * use hop distance as heuristic * make parameters easier to load for gradient descent * introduce lr velocity * give sense of progress * wip * remove clamping * gradient descent patches * parameter adjustment * stochastic optimization and new hyperparameters * format
1 parent b7fa8dc commit 9322d4b

14 files changed

+14716
-68
lines changed

fixtures/jumper-graph-solver/jumper-graph-solver03.fixture.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ export default () => {
2323
innerColChannelPointCount: 3,
2424
innerRowChannelPointCount: 3,
2525
regionsBetweenPads: true,
26-
orientation
26+
orientation,
2727
})
2828

2929
const graphWithConnections = createProblemFromBaseGraph({
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import { GenericSolverDebugger } from "@tscircuit/solver-utils/react"
2+
import { JumperGraphSolver } from "lib/JumperGraphSolver/JumperGraphSolver"
3+
import type { JPort, JRegion } from "lib/index"
4+
import inputData from "../../tests/jumper-graph-solver/jumper-graph-solver05-input.json"
5+
6+
export default () => {
7+
return (
8+
<GenericSolverDebugger
9+
createSolver={() =>
10+
new JumperGraphSolver({
11+
inputGraph: {
12+
regions: inputData.graph.regions as JRegion[],
13+
ports: inputData.graph.ports as unknown as JPort[],
14+
},
15+
inputConnections: inputData.connections,
16+
})
17+
}
18+
/>
19+
)
20+
}

lib/HyperGraphSolver.ts

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,14 @@ export class HyperGraphSolver<
285285
*/
286286
routeSolvedHook(solvedRoute: SolvedRoute) {}
287287

288+
/**
289+
* OPTIONALLY OVERRIDE THIS
290+
*
291+
* You can override this to perform actions when a new route begins, e.g.
292+
* you may want to log or track which connection is being processed.
293+
*/
294+
routeStartedHook(connection: Connection) {}
295+
288296
ripSolvedRoute(solvedRoute: SolvedRoute) {
289297
for (const port of solvedRoute.path.map((candidate) => candidate.port)) {
290298
port.ripCount = (port.ripCount ?? 0) + 1
@@ -305,6 +313,7 @@ export class HyperGraphSolver<
305313
this.currentEndRegion = this.currentConnection.endRegion
306314
this.candidateQueue = new PriorityQueue<Candidate>()
307315
this.visitedPointsForCurrentConnection.clear()
316+
this.routeStartedHook(this.currentConnection)
308317
for (const port of this.currentConnection.startRegion.ports) {
309318
this.candidateQueue.enqueue({
310319
port,

lib/JumperGraphSolver/JumperGraphSolver.ts

Lines changed: 61 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,25 @@ import { distance } from "@tscircuit/math-utils"
1414
import { computeDifferentNetCrossings } from "./computeDifferentNetCrossings"
1515
import { computeCrossingAssignments } from "./computeCrossingAssignments"
1616

17+
export const JUMPER_GRAPH_SOLVER_DEFAULTS = {
18+
portUsagePenalty: 0.06393718451067248,
19+
portUsagePenaltySq: 0.06194817180037216,
20+
crossingPenalty: 6.0761550028071145,
21+
crossingPenaltySq: 0.1315528159128946,
22+
ripCost: 40.00702225250195,
23+
greedyMultiplier: 0.4316469416682083,
24+
}
25+
1726
export class JumperGraphSolver extends HyperGraphSolver<JRegion, JPort> {
1827
UNIT_OF_COST = "distance"
1928

20-
portUsagePenalty = 0.197
21-
portUsagePenaltySq = 0
22-
crossingPenalty = 6.007
23-
crossingPenaltySq = 0.111
24-
override ripCost = 40
29+
portUsagePenalty = JUMPER_GRAPH_SOLVER_DEFAULTS.portUsagePenalty
30+
portUsagePenaltySq = JUMPER_GRAPH_SOLVER_DEFAULTS.portUsagePenaltySq
31+
crossingPenalty = JUMPER_GRAPH_SOLVER_DEFAULTS.crossingPenalty
32+
crossingPenaltySq = JUMPER_GRAPH_SOLVER_DEFAULTS.crossingPenaltySq
33+
override ripCost = JUMPER_GRAPH_SOLVER_DEFAULTS.ripCost
2534
baseMaxIterations = 4000
26-
additionalMaxIterationsPerConnection = 2000
35+
additionalMaxIterationsPerConnection = 4000
2736

2837
constructor(input: {
2938
inputGraph: HyperGraph | SerializedHyperGraph
@@ -35,7 +44,7 @@ export class JumperGraphSolver extends HyperGraphSolver<JRegion, JPort> {
3544
additionalMaxIterationsPerConnection?: number
3645
}) {
3746
super({
38-
greedyMultiplier: 0.45,
47+
greedyMultiplier: JUMPER_GRAPH_SOLVER_DEFAULTS.greedyMultiplier,
3948
rippingEnabled: true,
4049
...input,
4150
})
@@ -50,10 +59,52 @@ export class JumperGraphSolver extends HyperGraphSolver<JRegion, JPort> {
5059
this.MAX_ITERATIONS =
5160
this.baseMaxIterations +
5261
input.inputConnections.length * this.additionalMaxIterationsPerConnection
62+
63+
this.populateDistanceToEndMaps()
64+
}
65+
66+
private populateDistanceToEndMaps() {
67+
// Get all unique end regions from connections
68+
const endRegions = new Set(this.connections.map((c) => c.endRegion))
69+
70+
// For each end region, compute hop distances from all ports using BFS
71+
for (const endRegion of endRegions) {
72+
const regionDistanceMap = new Map<string, number>()
73+
const queue: Array<{ region: JRegion; distance: number }> = []
74+
75+
regionDistanceMap.set(endRegion.regionId, 0)
76+
queue.push({ region: endRegion as JRegion, distance: 0 })
77+
78+
while (queue.length > 0) {
79+
const { region, distance: dist } = queue.shift()!
80+
81+
for (const port of region.ports) {
82+
const otherRegion = (
83+
port.region1 === region ? port.region2 : port.region1
84+
) as JRegion
85+
if (!regionDistanceMap.has(otherRegion.regionId)) {
86+
regionDistanceMap.set(otherRegion.regionId, dist + 1)
87+
queue.push({ region: otherRegion, distance: dist + 1 })
88+
}
89+
}
90+
}
91+
92+
// Populate each port's distanceToEndMap for this end region
93+
for (const port of this.graph.ports) {
94+
if (!port.distanceToEndMap) {
95+
port.distanceToEndMap = {}
96+
}
97+
const d1 = regionDistanceMap.get(port.region1.regionId) ?? Infinity
98+
const d2 = regionDistanceMap.get(port.region2.regionId) ?? Infinity
99+
port.distanceToEndMap[endRegion.regionId] = Math.min(d1, d2)
100+
}
101+
}
53102
}
54103

55104
override estimateCostToEnd(port: JPort): number {
56-
return distance(port.d, this.currentEndRegion!.d.center)
105+
const endRegionId = this.currentEndRegion!.regionId
106+
const hopDistance = port.distanceToEndMap![endRegionId]!
107+
return hopDistance
57108
}
58109
override getPortUsagePenalty(port: JPort): number {
59110
const ripCount = port.ripCount ?? 0
@@ -84,6 +135,8 @@ export class JumperGraphSolver extends HyperGraphSolver<JRegion, JPort> {
84135

85136
override routeSolvedHook(solvedRoute: SolvedRoute) {}
86137

138+
override routeStartedHook(connection: Connection) {}
139+
87140
override visualize(): GraphicsObject {
88141
return visualizeJumperGraphSolver(this)
89142
}

lib/types.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@ export type RegionPort = {
1616
* ports that are likely to block off connections
1717
*/
1818
ripCount?: number
19+
20+
/**
21+
* Optionally can be used by solvers to keep track of the distance to
22+
* each end era.
23+
*/
24+
distanceToEndMap?: Record<RegionId, number>
1925
}
2026

2127
export type Region = {
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import { JumperGraphSolver } from "../../lib/JumperGraphSolver/JumperGraphSolver"
2+
import type { Parameters, EvaluationResult, SampleConfig } from "./types"
3+
import type { PregeneratedProblem } from "./problem-generator"
4+
5+
/**
6+
* Evaluates parameters on pregenerated problems.
7+
* Uses structuredClone to avoid mutating the original problem data.
8+
*/
9+
export function evaluateParametersOnProblems(
10+
params: Parameters,
11+
problemsByConfig: Map<number, PregeneratedProblem[]>,
12+
configs: SampleConfig[],
13+
progressLabel?: string,
14+
): EvaluationResult {
15+
let totalRouted = 0
16+
let totalConnections = 0
17+
let solvedCount = 0
18+
19+
for (let i = 0; i < configs.length; i++) {
20+
const config = configs[i]
21+
if (progressLabel) {
22+
process.stdout.write(`\r${progressLabel}: ${i + 1}/${configs.length}`)
23+
}
24+
25+
const problems = problemsByConfig.get(config.seed)
26+
if (!problems) {
27+
throw new Error(`No pregenerated problems for seed ${config.seed}`)
28+
}
29+
30+
let bestRoutedFraction = 0
31+
let solved = false
32+
33+
for (const { problem } of problems) {
34+
// Use structuredClone to avoid mutating the original problem
35+
const clonedProblem = structuredClone(problem)
36+
const totalConns = clonedProblem.connections.length
37+
38+
const solver = new JumperGraphSolver({
39+
inputGraph: {
40+
regions: clonedProblem.regions,
41+
ports: clonedProblem.ports,
42+
},
43+
inputConnections: clonedProblem.connections,
44+
portUsagePenalty: params.portUsagePenalty,
45+
crossingPenalty: params.crossingPenalty,
46+
ripCost: params.ripCost,
47+
})
48+
49+
// Apply additional parameters that aren't in constructor
50+
;(solver as any).portUsagePenaltySq = params.portUsagePenaltySq
51+
;(solver as any).crossingPenaltySq = params.crossingPenaltySq
52+
;(solver as any).greedyMultiplier = params.greedyMultiplier
53+
54+
solver.solve()
55+
56+
const routedFraction = solver.solvedRoutes.length / totalConns
57+
58+
if (solver.solved) {
59+
solved = true
60+
bestRoutedFraction = 1.0
61+
break
62+
} else if (routedFraction > bestRoutedFraction) {
63+
bestRoutedFraction = routedFraction
64+
}
65+
}
66+
67+
// Estimate total connections from numCrossings
68+
const estimatedConns = Math.ceil(
69+
(1 + Math.sqrt(1 + 8 * config.numCrossings)) / 2,
70+
)
71+
totalConnections += estimatedConns
72+
totalRouted += bestRoutedFraction * estimatedConns
73+
74+
if (solved) {
75+
solvedCount++
76+
}
77+
}
78+
79+
if (progressLabel) {
80+
process.stdout.write("\r" + " ".repeat(50) + "\r")
81+
}
82+
83+
return {
84+
continuousScore: totalRouted / totalConnections,
85+
successRate: solvedCount / configs.length,
86+
totalRouted,
87+
totalConnections,
88+
}
89+
}

0 commit comments

Comments
 (0)