Skip to content

Commit a65daf1

Browse files
generalltrean
andauthored
use matrix API to bootstrap graph view (#216)
* use matrix API to bootstrap graph view * add score to links (but do not use it yet) * spanning tree (#225) * Fix graph (#226) * tests for getMinimalSpanningTree * fixes * canvas resize fix * Update src/lib/graph-visualization-helpers.js --------- Co-authored-by: trean <[email protected]>
1 parent ccfbc4f commit a65daf1

File tree

5 files changed

+239
-14
lines changed

5 files changed

+239
-14
lines changed

src/components/GraphVisualisation/GraphVisualisation.jsx

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@ import { deduplicatePoints, getSimilarPoints, initGraph } from '../../lib/graph-
44
import ForceGraph from 'force-graph';
55
import { useClient } from '../../context/client-context';
66
import { useSnackbar } from 'notistack';
7+
import { debounce } from 'lodash';
8+
import { resizeObserverWithCallback } from '../../lib/common-helpers';
79

8-
const GraphVisualisation = ({ initNode, options, onDataDisplay, wrapperRef }) => {
10+
const GraphVisualisation = ({ initNode, options, onDataDisplay, wrapperRef, sampleLinks }) => {
911
const graphRef = useRef(null);
1012
const { client: qdrantClient } = useClient();
1113
const { enqueueSnackbar } = useSnackbar();
@@ -53,7 +55,9 @@ const GraphVisualisation = ({ initNode, options, onDataDisplay, wrapperRef }) =>
5355
onDataDisplay(node);
5456
})
5557
.autoPauseRedraw(false)
56-
.nodeCanvasObjectMode((node) => (node?.id === highlightedNode?.id ? 'before' : undefined))
58+
.nodeCanvasObjectMode((node) => {
59+
return node?.id === highlightedNode?.id ? 'before' : undefined;
60+
})
5761
.nodeCanvasObject((node, ctx) => {
5862
if (!node) return;
5963
// add ring for last hovered nodes
@@ -62,18 +66,33 @@ const GraphVisualisation = ({ initNode, options, onDataDisplay, wrapperRef }) =>
6266
ctx.fillStyle = node.id === highlightedNode?.id ? '#817' : 'transparent';
6367
ctx.fill();
6468
})
69+
.linkLabel('score')
6570
.linkColor(() => '#a6a6a6');
71+
72+
graphRef.current.d3Force('charge').strength(-10);
6673
}, [initNode, options]);
6774

6875
useEffect(() => {
76+
if (!wrapperRef) return;
77+
78+
const debouncedResizeCallback = debounce((width, height) => {
79+
graphRef.current.width(width).height(height);
80+
}, 500);
81+
6982
graphRef.current.width(wrapperRef?.clientWidth).height(wrapperRef?.clientHeight);
83+
resizeObserverWithCallback(debouncedResizeCallback).observe(wrapperRef);
84+
85+
return () => {
86+
resizeObserverWithCallback(debouncedResizeCallback).unobserve(wrapperRef);
87+
};
7088
}, [wrapperRef, initNode, options]);
7189

7290
useEffect(() => {
7391
const initNewGraph = async () => {
7492
const graphData = await initGraph(qdrantClient, {
7593
...options,
7694
initNode,
95+
sampleLinks,
7796
});
7897
if (graphRef.current && options) {
7998
const initialActiveNode = graphData.nodes[0];
@@ -83,9 +102,14 @@ const GraphVisualisation = ({ initNode, options, onDataDisplay, wrapperRef }) =>
83102
}
84103
};
85104
initNewGraph().catch((e) => {
86-
enqueueSnackbar(JSON.stringify(e.getActualType()), { variant: 'error' });
105+
console.error(e);
106+
if (e.getActualType) {
107+
enqueueSnackbar(JSON.stringify(e.getActualType()), { variant: 'error' });
108+
} else {
109+
enqueueSnackbar(e.message, { variant: 'error' });
110+
}
87111
});
88-
}, [initNode, options]);
112+
}, [initNode, options, sampleLinks]);
89113

90114
return <div id="graph"></div>;
91115
};
@@ -95,6 +119,7 @@ GraphVisualisation.propTypes = {
95119
options: PropTypes.object.isRequired,
96120
onDataDisplay: PropTypes.func.isRequired,
97121
wrapperRef: PropTypes.object,
122+
sampleLinks: PropTypes.array,
98123
};
99124

100125
export default GraphVisualisation;

src/lib/common-helpers.js

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
export const resizeObserverWithCallback = (callback) => {
2+
return new ResizeObserver((entries) => {
3+
for (const entry of entries) {
4+
const { target } = entry;
5+
const { width, height } = target.getBoundingClientRect();
6+
if (typeof callback === 'function') callback(width, height);
7+
}
8+
});
9+
};

src/lib/graph-visualization-helpers.js

Lines changed: 117 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,42 @@
1-
export const initGraph = async (qdrantClient, { collectionName, initNode, limit, filter, using }) => {
2-
if (!initNode) {
1+
import { axiosInstance } from '../common/axios';
2+
3+
export const initGraph = async (
4+
qdrantClient,
5+
{ collectionName, initNode, limit, filter, using, sampleLinks, tree = false }
6+
) => {
7+
let nodes = [];
8+
let links = [];
9+
10+
if (sampleLinks) {
11+
const uniquePoints = new Set();
12+
13+
for (const link of sampleLinks) {
14+
links.push({ source: link.a, target: link.b, score: link.score });
15+
uniquePoints.add(link.a);
16+
uniquePoints.add(link.b);
17+
}
18+
19+
if (tree) {
20+
// ToDo acs should depend on metric type
21+
links = getMinimalSpanningTree(links, true);
22+
}
23+
24+
nodes = await getPointsWithPayload(qdrantClient, { collectionName, pointIds: Array.from(uniquePoints) });
25+
} else if (initNode) {
26+
initNode.clicked = true;
27+
nodes = await getSimilarPoints(qdrantClient, { collectionName, pointId: initNode.id, limit, filter, using });
28+
links = nodes.map((point) => ({ source: initNode.id, target: point.id, score: point.score }));
29+
nodes = [initNode, ...nodes];
30+
} else {
331
return {
432
nodes: [],
533
links: [],
634
};
735
}
8-
initNode.clicked = true;
9-
10-
const points = await getSimilarPoints(qdrantClient, { collectionName, pointId: initNode.id, limit, filter, using });
1136

1237
const graphData = {
13-
nodes: [initNode, ...points],
14-
links: points.map((point) => ({ source: initNode.id, target: point.id })),
38+
nodes,
39+
links,
1540
};
1641
return graphData;
1742
};
@@ -44,9 +69,94 @@ export const getFirstPoint = async (qdrantClient, { collectionName, filter }) =>
4469
return points[0];
4570
};
4671

72+
const getPointsWithPayload = async (qdrantClient, { collectionName, pointIds }) => {
73+
const points = await qdrantClient.retrieve(collectionName, {
74+
ids: pointIds,
75+
with_payload: true,
76+
with_vector: false,
77+
});
78+
79+
return points;
80+
};
81+
82+
export const getSamplePoints = async ({ collectionName, filter, sample, using, limit }) => {
83+
// ToDo: replace it with qdrantClient when it will be implemented
84+
85+
const response = await axiosInstance({
86+
method: 'POST',
87+
url: `collections/${collectionName}/points/search/matrix/pairs`,
88+
data: {
89+
filter,
90+
sample,
91+
using,
92+
limit,
93+
},
94+
});
95+
96+
return response.data.result.pairs;
97+
};
98+
4799
export const deduplicatePoints = (existingPoints, foundPoints) => {
48100
// Returns array of found points that are not in existing points
49101
// deduplication is done by id
50102
const existingIds = new Set(existingPoints.map((point) => point.id));
51103
return foundPoints.filter((point) => !existingIds.has(point.id));
52104
};
105+
106+
export const getMinimalSpanningTree = (links, acs = true) => {
107+
// Sort links by score (assuming each link has a score property)
108+
109+
let sortedLinks = [];
110+
if (acs) {
111+
sortedLinks = links.sort((a, b) => b.score - a.score);
112+
} else {
113+
sortedLinks = links.sort((a, b) => a.score - b.score);
114+
}
115+
// Helper function to find the root of a node
116+
const findRoot = (parent, i) => {
117+
if (parent[i] === i) {
118+
return i;
119+
}
120+
return findRoot(parent, parent[i]);
121+
};
122+
123+
// Helper function to perform union of two sets
124+
const union = (parent, rank, x, y) => {
125+
const rootX = findRoot(parent, x);
126+
const rootY = findRoot(parent, y);
127+
128+
if (rank[rootX] < rank[rootY]) {
129+
parent[rootX] = rootY;
130+
} else if (rank[rootX] > rank[rootY]) {
131+
parent[rootY] = rootX;
132+
} else {
133+
parent[rootY] = rootX;
134+
rank[rootX]++;
135+
}
136+
};
137+
138+
const parent = {};
139+
const rank = {};
140+
const mstLinks = [];
141+
142+
// Initialize parent and rank arrays
143+
links.forEach((link) => {
144+
parent[link.source] = link.source;
145+
parent[link.target] = link.target;
146+
rank[link.source] = 0;
147+
rank[link.target] = 0;
148+
});
149+
150+
// Kruskal's algorithm
151+
sortedLinks.forEach((link) => {
152+
const sourceRoot = findRoot(parent, link.source);
153+
const targetRoot = findRoot(parent, link.target);
154+
155+
if (sourceRoot !== targetRoot) {
156+
mstLinks.push(link);
157+
union(parent, rank, sourceRoot, targetRoot);
158+
}
159+
});
160+
161+
return mstLinks;
162+
};
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import { describe, it, expect } from 'vitest';
2+
import { getMinimalSpanningTree } from '../graph-visualization-helpers';
3+
4+
describe('getMinimalSpanningTree', () => {
5+
it('should return the minimal spanning tree for a given set of links (ascending order)', () => {
6+
const links = [
7+
{ source: 'A', target: 'B', score: 1 },
8+
{ source: 'B', target: 'C', score: 2 },
9+
{ source: 'A', target: 'C', score: 3 },
10+
{ source: 'C', target: 'D', score: 4 },
11+
{ source: 'B', target: 'D', score: 5 },
12+
];
13+
14+
const expectedMST = [
15+
{ source: 'B', target: 'D', score: 5 },
16+
{ source: 'C', target: 'D', score: 4 },
17+
{ source: 'A', target: 'C', score: 3 },
18+
];
19+
20+
const result = getMinimalSpanningTree(links, true);
21+
expect(result).toEqual(expectedMST);
22+
});
23+
24+
it('should return the minimal spanning tree for a given set of links (descending order)', () => {
25+
const links = [
26+
{ source: 'A', target: 'B', score: 1 },
27+
{ source: 'B', target: 'C', score: 2 },
28+
{ source: 'A', target: 'C', score: 3 },
29+
{ source: 'C', target: 'D', score: 4 },
30+
{ source: 'B', target: 'D', score: 5 },
31+
];
32+
33+
const expectedMST = [
34+
{ source: 'A', target: 'B', score: 1 },
35+
{ source: 'B', target: 'C', score: 2 },
36+
{ source: 'C', target: 'D', score: 4 },
37+
];
38+
39+
const result = getMinimalSpanningTree(links, false);
40+
expect(result).toEqual(expectedMST);
41+
});
42+
43+
it('should return an empty array if no links are provided', () => {
44+
const links = [];
45+
const expectedMST = [];
46+
const result = getMinimalSpanningTree(links, true);
47+
expect(result).toEqual(expectedMST);
48+
});
49+
50+
it('should handle a single link correctly', () => {
51+
const links = [{ source: 'A', target: 'B', score: 1 }];
52+
const expectedMST = [{ source: 'A', target: 'B', score: 1 }];
53+
const result = getMinimalSpanningTree(links, true);
54+
expect(result).toEqual(expectedMST);
55+
});
56+
});

src/pages/Graph.jsx

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import { useWindowResize } from '../hooks/windowHooks';
99
import PointPreview from '../components/Common/PointPreview';
1010
import CodeEditorWindow from '../components/FilterEditorWindow';
1111
import { useClient } from '../context/client-context';
12-
import { getFirstPoint } from '../lib/graph-visualization-helpers';
12+
import { getFirstPoint, getSamplePoints } from '../lib/graph-visualization-helpers';
1313
import { useSnackbar } from 'notistack';
1414

1515
const explanation = `
@@ -19,12 +19,15 @@ const explanation = `
1919
// Available parameters:
2020
//
2121
// - 'limit': number of records to use on each step.
22+
// - 'sample': bootstrap graph with sample data from collection.
2223
//
2324
// - 'filter': filter expression to select vectors for visualization.
2425
// See https://qdrant.tech/documentation/concepts/filtering/
2526
//
2627
// - 'using': specify which vector to use for visualization
2728
// if there are multiple.
29+
//
30+
// - 'tree': if true, will use show spanning tree instead of full graph.
2831
2932
`;
3033

@@ -45,6 +48,8 @@ function Graph() {
4548
const location = useLocation();
4649
const { newInitNode, vectorName } = location.state || {};
4750
const [initNode, setInitNode] = useState(null);
51+
const [sampleLinks, setSampleLinks] = useState(null);
52+
4853
const [options, setOptions] = useState({
4954
limit: 5,
5055
filter: null,
@@ -92,8 +97,17 @@ function Graph() {
9297
const handleRunCode = async (data, collectionName) => {
9398
// scroll
9499
try {
95-
const firstPoint = await getFirstPoint(qdrantClient, { collectionName: collectionName, filter: data?.filter });
96-
setInitNode(firstPoint);
100+
if (data.sample) {
101+
const sampleLinks = await getSamplePoints({
102+
collectionName: collectionName,
103+
...data,
104+
});
105+
setSampleLinks(sampleLinks);
106+
setInitNode(null);
107+
} else {
108+
const firstPoint = await getFirstPoint(qdrantClient, { collectionName: collectionName, filter: data?.filter });
109+
setInitNode(firstPoint);
110+
}
97111
setOptions({
98112
collectionName: collectionName,
99113
...data,
@@ -130,6 +144,16 @@ function Graph() {
130144
type: 'string',
131145
enum: vectorNames,
132146
},
147+
sample: {
148+
description: 'Bootstrap graph with sample data from collection',
149+
type: 'integer',
150+
nullable: true,
151+
},
152+
tree: {
153+
description: 'Show spanning tree instead of full graph',
154+
type: 'boolean',
155+
nullable: true,
156+
},
133157
},
134158
});
135159

@@ -170,6 +194,7 @@ function Graph() {
170194
initNode={initNode}
171195
onDataDisplay={handlePointDisplay}
172196
wrapperRef={VisualizeChartWrapper.current}
197+
sampleLinks={sampleLinks}
173198
/>
174199
</Box>
175200
</Box>

0 commit comments

Comments
 (0)