Skip to content

Commit 6918555

Browse files
docs: Binary image segmentation example (#1795)
1 parent 6e8a4dc commit 6918555

File tree

14 files changed

+778
-169
lines changed

14 files changed

+778
-169
lines changed

apps/typegpu-docs/astro.config.mjs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,10 @@ export default defineConfig({
4646
'process.env.NODE_DEBUG_NATIVE': '""',
4747
},
4848
optimizeDeps: {
49-
exclude: ['@rolldown/browser'],
49+
exclude: [
50+
'@rolldown/browser',
51+
'onnxruntime-web',
52+
],
5053
},
5154
// Allowing query params, for invalidation
5255
plugins: [
@@ -65,6 +68,7 @@ export default defineConfig({
6568
noExternal: [
6669
'wgsl-wasm-transpiler-bundler',
6770
'@rolldown/browser',
71+
'onnxruntime-web',
6872
],
6973
},
7074
},

apps/typegpu-docs/package.json

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,13 @@
3838
"fuse.js": "catalog:frontend",
3939
"jotai": "^2.15.0",
4040
"jotai-location": "^0.6.2",
41+
"lodash": "^4.17.21",
4142
"lucide-react": "^0.536.0",
4243
"lz-string": "^1.5.0",
4344
"monaco-editor": "^0.53.0",
4445
"morphcharts": "^1.3.2",
4546
"motion": "^12.23.24",
47+
"onnxruntime-web": "1.23.0-dev.20250917-21fbad8a65",
4648
"pathe": "^2.0.3",
4749
"react": "^19.1.0",
4850
"react-dom": "^19.1.0",
@@ -66,8 +68,8 @@
6668
"@types/babel__standalone": "^7.1.9",
6769
"@types/babel__template": "^7.4.4",
6870
"@types/babel__traverse": "^7.20.7",
69-
"@types/node": "^24.7.0",
7071
"@types/dom-mediacapture-record": "^1.0.22",
72+
"@types/node": "^24.7.0",
7173
"@vitejs/plugin-basic-ssl": "^2.1.0",
7274
"@webgpu/types": "catalog:types",
7375
"astro-vtbot": "^2.1.6",
Binary file not shown.
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
<canvas></canvas>
2+
<video
3+
autoplay
4+
playsinline
5+
class="absolute top-0 left-0 w-px h-px opacity-0 pointer-events-none"
6+
>
7+
</video>
Lines changed: 331 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,331 @@
1+
import tgpu, {
2+
type RenderFlag,
3+
type SampledFlag,
4+
type StorageFlag,
5+
type TgpuBindGroup,
6+
type TgpuTexture,
7+
} from 'typegpu';
8+
import { fullScreenTriangle } from 'typegpu/common';
9+
import * as d from 'typegpu/data';
10+
import { MODEL_HEIGHT, MODEL_WIDTH, prepareSession } from './model.ts';
11+
import {
12+
blockDim,
13+
blurLayout,
14+
drawWithMaskLayout,
15+
generateMaskLayout,
16+
prepareModelInputLayout,
17+
sampleBiasSlot,
18+
useGaussianSlot,
19+
} from './schemas.ts';
20+
import {
21+
computeFn,
22+
drawWithMaskFragment,
23+
generateMaskFromOutput,
24+
prepareModelInput,
25+
} from './shaders.ts';
26+
27+
// Background segmentation uses the u2netp model (https://github.com/xuebinqin/U-2-Net)
28+
// by Xuebin Qin et al., licensed under the Apache License 2.0 (https://www.apache.org/licenses/LICENSE-2.0)
29+
30+
// We need to wait for issue to close and release: https://github.com/microsoft/onnxruntime/issues/26480
31+
if (/^((?!chrome|android).)*safari/i.test(navigator.userAgent)) {
32+
throw new Error('Unfortunately, ONNX does not work on Safari or iOS yet.');
33+
}
34+
35+
// setup
36+
37+
const canvas = document.querySelector('canvas') as HTMLCanvasElement;
38+
const video = document.querySelector('video') as HTMLVideoElement;
39+
40+
if (navigator.mediaDevices.getUserMedia) {
41+
video.srcObject = await navigator.mediaDevices.getUserMedia({
42+
video: {
43+
facingMode: 'user',
44+
width: { ideal: 1280 },
45+
height: { ideal: 720 },
46+
frameRate: { ideal: 60 },
47+
},
48+
});
49+
} else {
50+
throw new Error('getUserMedia not supported');
51+
}
52+
53+
const adapter = await navigator.gpu?.requestAdapter();
54+
const device = await adapter?.requestDevice() as GPUDevice;
55+
56+
if (!device || !adapter) {
57+
throw new Error('Failed to initialize device.');
58+
}
59+
60+
// monkey patching ONNX: https://github.com/microsoft/onnxruntime/issues/26107
61+
const oldRequestAdapter = navigator.gpu.requestAdapter;
62+
const oldRequestDevice = adapter.requestDevice;
63+
navigator.gpu.requestAdapter = async () => adapter;
64+
adapter.requestDevice = async () => device;
65+
const root = await tgpu.initFromDevice({ device });
66+
const context = canvas.getContext('webgpu') as GPUCanvasContext;
67+
const presentationFormat = navigator.gpu.getPreferredCanvasFormat();
68+
69+
context.configure({
70+
device,
71+
format: presentationFormat,
72+
alphaMode: 'premultiplied',
73+
});
74+
75+
// resources
76+
77+
let blurStrength = 5;
78+
let useGaussianBlur = false;
79+
80+
const zeroBuffer = root.createBuffer(d.u32, 0).$usage('uniform');
81+
const oneBuffer = root.createBuffer(d.u32, 1).$usage('uniform');
82+
const useGaussianUniform = root.createUniform(d.u32, 0);
83+
const sampleBiasUniform = root.createUniform(d.f32, 0);
84+
85+
const sampler = root['~unstable'].createSampler({
86+
magFilter: 'linear',
87+
minFilter: 'linear',
88+
mipmapFilter: 'linear',
89+
});
90+
91+
const maskTexture = root['~unstable'].createTexture({
92+
size: [MODEL_WIDTH, MODEL_HEIGHT],
93+
format: 'rgba8unorm',
94+
dimension: '2d',
95+
}).$usage('sampled', 'render', 'storage');
96+
97+
const modelInputBuffer = root
98+
.createBuffer(d.arrayOf(d.f32, 3 * MODEL_WIDTH * MODEL_HEIGHT))
99+
.$usage('storage');
100+
101+
const modelOutputBuffer = root
102+
.createBuffer(d.arrayOf(d.f32, 1 * MODEL_WIDTH * MODEL_HEIGHT))
103+
.$usage('storage');
104+
105+
let blurredTextures: (
106+
& TgpuTexture<{
107+
size: [number, number];
108+
format: 'rgba8unorm';
109+
mipLevelCount: 10;
110+
}>
111+
& StorageFlag
112+
& SampledFlag
113+
& RenderFlag
114+
)[];
115+
116+
const generateMaskBindGroup = root.createBindGroup(generateMaskLayout, {
117+
maskTexture,
118+
outputBuffer: modelOutputBuffer,
119+
});
120+
121+
let blurBindGroups: TgpuBindGroup<typeof blurLayout.entries>[];
122+
123+
// pipelines
124+
125+
const prepareModelInputPipeline = root['~unstable']
126+
.createGuardedComputePipeline(
127+
prepareModelInput,
128+
);
129+
130+
const session = await prepareSession(
131+
root.unwrap(modelInputBuffer),
132+
root.unwrap(modelOutputBuffer),
133+
);
134+
135+
const generateMaskFromOutputPipeline = root['~unstable']
136+
.createGuardedComputePipeline(
137+
generateMaskFromOutput,
138+
);
139+
140+
const blurPipeline = root['~unstable']
141+
.withCompute(computeFn)
142+
.createPipeline();
143+
144+
const drawWithMaskPipeline = root['~unstable']
145+
.with(useGaussianSlot, useGaussianUniform)
146+
.with(sampleBiasSlot, sampleBiasUniform)
147+
.withVertex(fullScreenTriangle, {})
148+
.withFragment(drawWithMaskFragment, { format: presentationFormat })
149+
.createPipeline();
150+
151+
// recalculating mask
152+
153+
let calculateMaskCallbackId: number | undefined;
154+
155+
async function processCalculateMask() {
156+
if (video.readyState < 2) {
157+
calculateMaskCallbackId = video.requestVideoFrameCallback(
158+
processCalculateMask,
159+
);
160+
return;
161+
}
162+
163+
prepareModelInputPipeline
164+
.with(root.createBindGroup(prepareModelInputLayout, {
165+
inputTexture: device.importExternalTexture({ source: video }),
166+
outputBuffer: modelInputBuffer,
167+
sampler,
168+
}))
169+
.dispatchThreads(MODEL_WIDTH, MODEL_HEIGHT);
170+
171+
await session.run();
172+
173+
generateMaskFromOutputPipeline
174+
.with(generateMaskBindGroup)
175+
.dispatchThreads(MODEL_WIDTH, MODEL_HEIGHT);
176+
177+
calculateMaskCallbackId = video.requestVideoFrameCallback(
178+
processCalculateMask,
179+
);
180+
}
181+
calculateMaskCallbackId = video.requestVideoFrameCallback(processCalculateMask);
182+
183+
// frame
184+
185+
function onVideoChange(size: { width: number; height: number }) {
186+
const aspectRatio = size.width / size.height;
187+
video.style.height = `${video.clientWidth / aspectRatio}px`;
188+
if (canvas.parentElement) {
189+
canvas.parentElement.style.aspectRatio = `${aspectRatio}`;
190+
canvas.parentElement.style.height =
191+
`min(100cqh, calc(100cqw/(${aspectRatio})))`;
192+
}
193+
blurredTextures = [0, 1].map(() =>
194+
root['~unstable'].createTexture({
195+
size: [size.width, size.height],
196+
format: 'rgba8unorm',
197+
dimension: '2d',
198+
mipLevelCount: 10,
199+
}).$usage('sampled', 'render', 'storage')
200+
);
201+
blurBindGroups = [
202+
root.createBindGroup(blurLayout, {
203+
flip: zeroBuffer,
204+
inTexture: blurredTextures[0],
205+
outTexture: blurredTextures[1].createView(
206+
d.textureStorage2d('rgba8unorm', 'read-only'),
207+
{ mipLevelCount: 1 },
208+
),
209+
sampler,
210+
}),
211+
root.createBindGroup(blurLayout, {
212+
flip: oneBuffer,
213+
inTexture: blurredTextures[1],
214+
outTexture: blurredTextures[0].createView(
215+
d.textureStorage2d('rgba8unorm', 'read-only'),
216+
{ mipLevelCount: 1 },
217+
),
218+
sampler,
219+
}),
220+
];
221+
}
222+
223+
let videoFrameCallbackId: number | undefined;
224+
let lastFrameSize: { width: number; height: number } | undefined;
225+
226+
async function processVideoFrame(
227+
_: number,
228+
metadata: VideoFrameCallbackMetadata,
229+
) {
230+
if (video.readyState < 2) {
231+
videoFrameCallbackId = video.requestVideoFrameCallback(processVideoFrame);
232+
return;
233+
}
234+
235+
const frameWidth = metadata.width;
236+
const frameHeight = metadata.height;
237+
238+
if (
239+
!lastFrameSize ||
240+
lastFrameSize.width !== frameWidth ||
241+
lastFrameSize.height !== frameHeight
242+
) {
243+
lastFrameSize = { width: frameWidth, height: frameHeight };
244+
onVideoChange(lastFrameSize);
245+
}
246+
247+
blurredTextures[0].write(video);
248+
249+
if (useGaussianBlur) {
250+
for (const _ of Array(blurStrength * 2)) {
251+
blurPipeline
252+
.with(blurBindGroups[0])
253+
.dispatchWorkgroups(
254+
Math.ceil(frameWidth / blockDim),
255+
Math.ceil(frameHeight / 4),
256+
);
257+
blurPipeline
258+
.with(blurBindGroups[1])
259+
.dispatchWorkgroups(
260+
Math.ceil(frameHeight / blockDim),
261+
Math.ceil(frameWidth / 4),
262+
);
263+
}
264+
} else {
265+
blurredTextures[0].generateMipmaps();
266+
}
267+
268+
drawWithMaskPipeline
269+
.withColorAttachment({
270+
view: context.getCurrentTexture().createView(),
271+
clearValue: [1, 1, 1, 1],
272+
loadOp: 'clear',
273+
storeOp: 'store',
274+
})
275+
.with(root.createBindGroup(drawWithMaskLayout, {
276+
inputTexture: device.importExternalTexture({ source: video }),
277+
inputBlurredTexture: blurredTextures[0],
278+
maskTexture,
279+
sampler,
280+
}))
281+
.draw(3);
282+
283+
videoFrameCallbackId = video.requestVideoFrameCallback(processVideoFrame);
284+
}
285+
videoFrameCallbackId = video.requestVideoFrameCallback(processVideoFrame);
286+
287+
// #region Example controls & Cleanup
288+
289+
export const controls = {
290+
'blur type': {
291+
initial: 'mipmaps',
292+
options: ['mipmaps', 'gaussian'],
293+
async onSelectChange(value: string) {
294+
useGaussianBlur = value === 'gaussian';
295+
useGaussianUniform.write(useGaussianBlur ? 1 : 0);
296+
},
297+
},
298+
'blur strength': {
299+
initial: blurStrength,
300+
min: 0,
301+
max: 10,
302+
step: 1,
303+
onSliderChange(newValue: number) {
304+
blurStrength = newValue;
305+
sampleBiasUniform.write(blurStrength);
306+
},
307+
},
308+
};
309+
310+
export function onCleanup() {
311+
if (videoFrameCallbackId !== undefined) {
312+
video.cancelVideoFrameCallback(videoFrameCallbackId);
313+
}
314+
if (calculateMaskCallbackId !== undefined) {
315+
video.cancelVideoFrameCallback(calculateMaskCallbackId);
316+
}
317+
session.release();
318+
if (video.srcObject) {
319+
for (const track of (video.srcObject as MediaStream).getTracks()) {
320+
track.stop();
321+
}
322+
}
323+
navigator.gpu.requestAdapter = oldRequestAdapter;
324+
if (adapter) {
325+
adapter.requestDevice = oldRequestDevice;
326+
}
327+
328+
root.destroy();
329+
}
330+
331+
// #endregion
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{
2+
"title": "Background Segmentation",
3+
"category": "image-processing",
4+
"tags": ["experimental", "camera", "onnx"]
5+
}

0 commit comments

Comments
 (0)