Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 140 additions & 10 deletions lib/JumperGraphSolver/JumperGraphSolver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,23 @@ import { visualizeJumperGraphSolver } from "./visualizeJumperGraphSolver"
import { distance } from "@tscircuit/math-utils"
import { computeDifferentNetCrossings } from "./computeDifferentNetCrossings"
import { computeCrossingAssignments } from "./computeCrossingAssignments"
import { computeSameNetCrossings } from "./computeSameNetCrossings"

export const JUMPER_GRAPH_SOLVER_DEFAULTS = {
portUsagePenalty: 0.06393718451067248,
portUsagePenaltySq: 0.06194817180037216,
crossingPenalty: 6.0761550028071145,
crossingPenaltySq: 0.1315528159128946,
ripCost: 40.00702225250195,
greedyMultiplier: 0.4316469416682083,
portUsagePenalty: 0.03365268465229554,
portUsagePenaltySq: 0.001,
crossingPenalty: 6.160693673577123,
crossingPenaltySq: 0.06126198189275256,
ripCost: 39.97123937131205,
greedyMultiplier: 0.5293456817395028,
// New tunable parameters
hopDistanceMultiplier: 0.9401216030689439,
ripCountExponent: 1.9553157504245895,
crossingExponent: 1.8768621932810199,
sameNetCrossingPenalty: 0.5282307574660395,
congestionRadius: 0.4188103484385634,
congestionPenaltyMultiplier: 0.12244797149538383,
connectionOrderWeight: 0.04368861909045532,
}

export class JumperGraphSolver extends HyperGraphSolver<JRegion, JPort> {
Expand All @@ -34,6 +43,16 @@ export class JumperGraphSolver extends HyperGraphSolver<JRegion, JPort> {
baseMaxIterations = 4000
additionalMaxIterationsPerConnection = 4000

// New tunable parameters
hopDistanceMultiplier = JUMPER_GRAPH_SOLVER_DEFAULTS.hopDistanceMultiplier
ripCountExponent = JUMPER_GRAPH_SOLVER_DEFAULTS.ripCountExponent
crossingExponent = JUMPER_GRAPH_SOLVER_DEFAULTS.crossingExponent
sameNetCrossingPenalty = JUMPER_GRAPH_SOLVER_DEFAULTS.sameNetCrossingPenalty
congestionRadius = JUMPER_GRAPH_SOLVER_DEFAULTS.congestionRadius
congestionPenaltyMultiplier =
JUMPER_GRAPH_SOLVER_DEFAULTS.congestionPenaltyMultiplier
connectionOrderWeight = JUMPER_GRAPH_SOLVER_DEFAULTS.connectionOrderWeight

constructor(input: {
inputGraph: HyperGraph | SerializedHyperGraph
inputConnections: (Connection | SerializedConnection)[]
Expand All @@ -42,6 +61,14 @@ export class JumperGraphSolver extends HyperGraphSolver<JRegion, JPort> {
crossingPenalty?: number
baseMaxIterations?: number
additionalMaxIterationsPerConnection?: number
// New tunable parameters
hopDistanceMultiplier?: number
ripCountExponent?: number
crossingExponent?: number
sameNetCrossingPenalty?: number
congestionRadius?: number
congestionPenaltyMultiplier?: number
connectionOrderWeight?: number
}) {
super({
greedyMultiplier: JUMPER_GRAPH_SOLVER_DEFAULTS.greedyMultiplier,
Expand All @@ -56,11 +83,55 @@ export class JumperGraphSolver extends HyperGraphSolver<JRegion, JPort> {
input.additionalMaxIterationsPerConnection ??
this.additionalMaxIterationsPerConnection

// Initialize new tunable parameters
this.hopDistanceMultiplier =
input.hopDistanceMultiplier ?? this.hopDistanceMultiplier
this.ripCountExponent = input.ripCountExponent ?? this.ripCountExponent
this.crossingExponent = input.crossingExponent ?? this.crossingExponent
this.sameNetCrossingPenalty =
input.sameNetCrossingPenalty ?? this.sameNetCrossingPenalty
this.congestionRadius = input.congestionRadius ?? this.congestionRadius
this.congestionPenaltyMultiplier =
input.congestionPenaltyMultiplier ?? this.congestionPenaltyMultiplier
this.connectionOrderWeight =
input.connectionOrderWeight ?? this.connectionOrderWeight

this.MAX_ITERATIONS =
this.baseMaxIterations +
input.inputConnections.length * this.additionalMaxIterationsPerConnection

this.populateDistanceToEndMaps()

// Sort connections by estimated difficulty if connectionOrderWeight is set
if (this.connectionOrderWeight !== 0) {
this.sortConnectionsByDifficulty()
}
}

/**
* Sort connections by estimated difficulty (hop distance).
* Positive connectionOrderWeight: easier (shorter) connections first
* Negative connectionOrderWeight: harder (longer) connections first
*/
private sortConnectionsByDifficulty() {
const getConnectionDifficulty = (conn: Connection): number => {
// Estimate difficulty as minimum hop distance between start and end
let minHops = Infinity
for (const port of conn.startRegion.ports) {
const hops = (port as JPort).distanceToEndMap?.[conn.endRegion.regionId]
if (hops !== undefined && hops < minHops) {
minHops = hops
}
}
return minHops === Infinity ? 0 : minHops
}

this.unprocessedConnections.sort((a, b) => {
const diffA = getConnectionDifficulty(a)
const diffB = getConnectionDifficulty(b)
// Positive weight: easier first (ascending), Negative: harder first (descending)
return (diffA - diffB) * Math.sign(this.connectionOrderWeight)
})
}

private populateDistanceToEndMaps() {
Expand Down Expand Up @@ -104,19 +175,78 @@ export class JumperGraphSolver extends HyperGraphSolver<JRegion, JPort> {
override estimateCostToEnd(port: JPort): number {
const endRegionId = this.currentEndRegion!.regionId
const hopDistance = port.distanceToEndMap![endRegionId]!
return hopDistance
return hopDistance * this.hopDistanceMultiplier
}

override getPortUsagePenalty(port: JPort): number {
const ripCount = port.ripCount ?? 0
return ripCount * this.portUsagePenalty + ripCount * this.portUsagePenaltySq
// Linear term + polynomial term with configurable exponent
const linearPenalty = ripCount * this.portUsagePenalty
const polynomialPenalty =
Math.pow(ripCount, this.ripCountExponent) * this.portUsagePenaltySq

// Congestion penalty: penalize ports near heavily-used ports
const congestionPenalty = this.computeCongestionPenalty(port)

return linearPenalty + polynomialPenalty + congestionPenalty
}

/**
* Compute congestion penalty based on nearby port usage.
* Ports within congestionRadius of heavily-used ports get penalized.
*/
private computeCongestionPenalty(port: JPort): number {
if (this.congestionRadius <= 0 || this.congestionPenaltyMultiplier <= 0) {
return 0
}

let totalNearbyRipCount = 0
const portX = port.d.x
const portY = port.d.y

for (const otherPort of this.graph.ports) {
if (otherPort === port) continue

const dx = (otherPort as JPort).d.x - portX
const dy = (otherPort as JPort).d.y - portY
const dist = Math.sqrt(dx * dx + dy * dy)

if (dist <= this.congestionRadius) {
// Weight by inverse distance (closer = more penalty)
const weight = 1 - dist / this.congestionRadius
totalNearbyRipCount += ((otherPort as JPort).ripCount ?? 0) * weight
}
}

return totalNearbyRipCount * this.congestionPenaltyMultiplier
}

override computeIncreasedRegionCostIfPortsAreUsed(
region: JRegion,
port1: JPort,
port2: JPort,
): number {
const crossings = computeDifferentNetCrossings(region, port1, port2)
return crossings * this.crossingPenalty + crossings * this.crossingPenaltySq
// Different-net crossings (the main penalty)
const differentNetCrossings = computeDifferentNetCrossings(
region,
port1,
port2,
)
const differentNetPenalty =
differentNetCrossings * this.crossingPenalty +
Math.pow(differentNetCrossings, this.crossingExponent) *
this.crossingPenaltySq

// Same-net crossings (smaller penalty, but still worth avoiding)
const sameNetCrossings = computeSameNetCrossings(
region,
port1,
port2,
this.currentConnection!.mutuallyConnectedNetworkId,
)
const sameNetPenalty = sameNetCrossings * this.sameNetCrossingPenalty

return differentNetPenalty + sameNetPenalty
}

override getRipsRequiredForPortUsage(
Expand Down
56 changes: 56 additions & 0 deletions lib/JumperGraphSolver/computeSameNetCrossings.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import type { JPort, JRegion } from "./jumper-types"
import { perimeterT, chordsCross } from "./perimeterChordUtils"

/**
* Compute the number of crossings between a new port pair and existing
* assignments in the region that belong to the SAME network.
*
* Uses the circle/perimeter mapping approach: two connections MUST cross
* if their boundary points interleave around the perimeter.
*/
export function computeSameNetCrossings(
region: JRegion,
port1: JPort,
port2: JPort,
currentNetworkId: string,
): number {
const { minX: xmin, maxX: xmax, minY: ymin, maxY: ymax } = region.d.bounds

// Map the new port pair to perimeter coordinates
const t1 = perimeterT(port1.d, xmin, xmax, ymin, ymax)
const t2 = perimeterT(port2.d, xmin, xmax, ymin, ymax)
const newChord: [number, number] = [t1, t2]

// Count crossings with existing assignments from the same network
let crossings = 0
const assignments = region.assignments ?? []

for (const assignment of assignments) {
// Only count same-network crossings
if (assignment.connection.mutuallyConnectedNetworkId !== currentNetworkId) {
continue
}

const existingT1 = perimeterT(
(assignment.regionPort1 as JPort).d,
xmin,
xmax,
ymin,
ymax,
)
const existingT2 = perimeterT(
(assignment.regionPort2 as JPort).d,
xmin,
xmax,
ymin,
ymax,
)
const existingChord: [number, number] = [existingT1, existingT2]

if (chordsCross(newChord, existingChord)) {
crossings++
}
}

return crossings
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,14 @@ const PARAM_SCALES: Parameters = {
crossingPenaltySq: 1,
ripCost: 50,
greedyMultiplier: 1,
// New parameters - scales chosen based on expected magnitude
hopDistanceMultiplier: 1,
ripCountExponent: 1,
crossingExponent: 1,
sameNetCrossingPenalty: 1,
congestionRadius: 1,
congestionPenaltyMultiplier: 0.1,
connectionOrderWeight: 1,
}

/**
Expand Down Expand Up @@ -169,6 +177,13 @@ function formatGradient(gradient: Parameters): string {
`d_crossSq=${gradient.crossingPenaltySq.toFixed(4)}`,
`d_rip=${gradient.ripCost.toFixed(6)}`,
`d_greedy=${gradient.greedyMultiplier.toFixed(4)}`,
`d_hopMult=${gradient.hopDistanceMultiplier.toFixed(4)}`,
`d_ripExp=${gradient.ripCountExponent.toFixed(4)}`,
`d_crossExp=${gradient.crossingExponent.toFixed(4)}`,
`d_sameNet=${gradient.sameNetCrossingPenalty.toFixed(4)}`,
`d_congRad=${gradient.congestionRadius.toFixed(4)}`,
`d_congMult=${gradient.congestionPenaltyMultiplier.toFixed(4)}`,
`d_connOrd=${gradient.connectionOrderWeight.toFixed(4)}`,
].join(", ")
}

Expand Down
30 changes: 30 additions & 0 deletions scripts/hyper-parameter-optimization/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@ export interface Parameters {
crossingPenaltySq: number
ripCost: number
greedyMultiplier: number
// New tunable parameters
hopDistanceMultiplier: number
ripCountExponent: number
crossingExponent: number
sameNetCrossingPenalty: number
congestionRadius: number
congestionPenaltyMultiplier: number
connectionOrderWeight: number
}

export const PARAM_KEYS: (keyof Parameters)[] = [
Expand All @@ -14,6 +22,14 @@ export const PARAM_KEYS: (keyof Parameters)[] = [
"crossingPenaltySq",
"ripCost",
"greedyMultiplier",
// New tunable parameters
"hopDistanceMultiplier",
"ripCountExponent",
"crossingExponent",
"sameNetCrossingPenalty",
"congestionRadius",
"congestionPenaltyMultiplier",
"connectionOrderWeight",
]

export interface SampleConfig {
Expand All @@ -38,6 +54,13 @@ export function formatParams(params: Parameters): string {
`crossingPenaltySq=${params.crossingPenaltySq.toFixed(3)}`,
`ripCost=${params.ripCost.toFixed(3)}`,
`greedyMultiplier=${params.greedyMultiplier.toFixed(3)}`,
`hopDistMult=${params.hopDistanceMultiplier.toFixed(3)}`,
`ripCountExp=${params.ripCountExponent.toFixed(3)}`,
`crossingExp=${params.crossingExponent.toFixed(3)}`,
`sameNetCross=${params.sameNetCrossingPenalty.toFixed(3)}`,
`congRadius=${params.congestionRadius.toFixed(3)}`,
`congMult=${params.congestionPenaltyMultiplier.toFixed(3)}`,
`connOrderWt=${params.connectionOrderWeight.toFixed(3)}`,
].join(", ")
}

Expand All @@ -58,5 +81,12 @@ export function createZeroParams(): Parameters {
crossingPenaltySq: 0,
ripCost: 0,
greedyMultiplier: 0,
hopDistanceMultiplier: 0,
ripCountExponent: 0,
crossingExponent: 0,
sameNetCrossingPenalty: 0,
congestionRadius: 0,
congestionPenaltyMultiplier: 0,
connectionOrderWeight: 0,
}
}
2 changes: 1 addition & 1 deletion scripts/run-benchmark-2x2-1206x4-both-orientations.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { createProblemFromBaseGraph } from "../lib/JumperGraphSolver/jumper-grap
import { JumperGraphSolver } from "../lib/JumperGraphSolver/JumperGraphSolver"
import { calculateGraphBounds } from "../lib/JumperGraphSolver/jumper-graph-generator/calculateGraphBounds"

const SAMPLES_PER_CROSSING_COUNT = 100
const SAMPLES_PER_CROSSING_COUNT = 2000
const MIN_CROSSINGS = 2
const MAX_CROSSINGS = 20

Expand Down
Loading