|
| 1 | +import tgpu, { |
| 2 | + type RenderFlag, |
| 3 | + type SampledFlag, |
| 4 | + type StorageFlag, |
| 5 | + type TgpuBindGroup, |
| 6 | + type TgpuTexture, |
| 7 | +} from 'typegpu'; |
| 8 | +import { fullScreenTriangle } from 'typegpu/common'; |
| 9 | +import * as d from 'typegpu/data'; |
| 10 | +import { MODEL_HEIGHT, MODEL_WIDTH, prepareSession } from './model.ts'; |
| 11 | +import { |
| 12 | + blockDim, |
| 13 | + blurLayout, |
| 14 | + drawWithMaskLayout, |
| 15 | + generateMaskLayout, |
| 16 | + prepareModelInputLayout, |
| 17 | + sampleBiasSlot, |
| 18 | + useGaussianSlot, |
| 19 | +} from './schemas.ts'; |
| 20 | +import { |
| 21 | + computeFn, |
| 22 | + drawWithMaskFragment, |
| 23 | + generateMaskFromOutput, |
| 24 | + prepareModelInput, |
| 25 | +} from './shaders.ts'; |
| 26 | + |
| 27 | +// Background segmentation uses the u2netp model (https://github.com/xuebinqin/U-2-Net) |
| 28 | +// by Xuebin Qin et al., licensed under the Apache License 2.0 (https://www.apache.org/licenses/LICENSE-2.0) |
| 29 | + |
| 30 | +// We need to wait for issue to close and release: https://github.com/microsoft/onnxruntime/issues/26480 |
| 31 | +if (/^((?!chrome|android).)*safari/i.test(navigator.userAgent)) { |
| 32 | + throw new Error('Unfortunately, ONNX does not work on Safari or iOS yet.'); |
| 33 | +} |
| 34 | + |
| 35 | +// setup |
| 36 | + |
| 37 | +const canvas = document.querySelector('canvas') as HTMLCanvasElement; |
| 38 | +const video = document.querySelector('video') as HTMLVideoElement; |
| 39 | + |
| 40 | +if (navigator.mediaDevices.getUserMedia) { |
| 41 | + video.srcObject = await navigator.mediaDevices.getUserMedia({ |
| 42 | + video: { |
| 43 | + facingMode: 'user', |
| 44 | + width: { ideal: 1280 }, |
| 45 | + height: { ideal: 720 }, |
| 46 | + frameRate: { ideal: 60 }, |
| 47 | + }, |
| 48 | + }); |
| 49 | +} else { |
| 50 | + throw new Error('getUserMedia not supported'); |
| 51 | +} |
| 52 | + |
| 53 | +const adapter = await navigator.gpu?.requestAdapter(); |
| 54 | +const device = await adapter?.requestDevice() as GPUDevice; |
| 55 | + |
| 56 | +if (!device || !adapter) { |
| 57 | + throw new Error('Failed to initialize device.'); |
| 58 | +} |
| 59 | + |
| 60 | +// monkey patching ONNX: https://github.com/microsoft/onnxruntime/issues/26107 |
| 61 | +const oldRequestAdapter = navigator.gpu.requestAdapter; |
| 62 | +const oldRequestDevice = adapter.requestDevice; |
| 63 | +navigator.gpu.requestAdapter = async () => adapter; |
| 64 | +adapter.requestDevice = async () => device; |
| 65 | +const root = await tgpu.initFromDevice({ device }); |
| 66 | +const context = canvas.getContext('webgpu') as GPUCanvasContext; |
| 67 | +const presentationFormat = navigator.gpu.getPreferredCanvasFormat(); |
| 68 | + |
| 69 | +context.configure({ |
| 70 | + device, |
| 71 | + format: presentationFormat, |
| 72 | + alphaMode: 'premultiplied', |
| 73 | +}); |
| 74 | + |
| 75 | +// resources |
| 76 | + |
| 77 | +let blurStrength = 5; |
| 78 | +let useGaussianBlur = false; |
| 79 | + |
| 80 | +const zeroBuffer = root.createBuffer(d.u32, 0).$usage('uniform'); |
| 81 | +const oneBuffer = root.createBuffer(d.u32, 1).$usage('uniform'); |
| 82 | +const useGaussianUniform = root.createUniform(d.u32, 0); |
| 83 | +const sampleBiasUniform = root.createUniform(d.f32, 0); |
| 84 | + |
| 85 | +const sampler = root['~unstable'].createSampler({ |
| 86 | + magFilter: 'linear', |
| 87 | + minFilter: 'linear', |
| 88 | + mipmapFilter: 'linear', |
| 89 | +}); |
| 90 | + |
| 91 | +const maskTexture = root['~unstable'].createTexture({ |
| 92 | + size: [MODEL_WIDTH, MODEL_HEIGHT], |
| 93 | + format: 'rgba8unorm', |
| 94 | + dimension: '2d', |
| 95 | +}).$usage('sampled', 'render', 'storage'); |
| 96 | + |
| 97 | +const modelInputBuffer = root |
| 98 | + .createBuffer(d.arrayOf(d.f32, 3 * MODEL_WIDTH * MODEL_HEIGHT)) |
| 99 | + .$usage('storage'); |
| 100 | + |
| 101 | +const modelOutputBuffer = root |
| 102 | + .createBuffer(d.arrayOf(d.f32, 1 * MODEL_WIDTH * MODEL_HEIGHT)) |
| 103 | + .$usage('storage'); |
| 104 | + |
| 105 | +let blurredTextures: ( |
| 106 | + & TgpuTexture<{ |
| 107 | + size: [number, number]; |
| 108 | + format: 'rgba8unorm'; |
| 109 | + mipLevelCount: 10; |
| 110 | + }> |
| 111 | + & StorageFlag |
| 112 | + & SampledFlag |
| 113 | + & RenderFlag |
| 114 | +)[]; |
| 115 | + |
| 116 | +const generateMaskBindGroup = root.createBindGroup(generateMaskLayout, { |
| 117 | + maskTexture, |
| 118 | + outputBuffer: modelOutputBuffer, |
| 119 | +}); |
| 120 | + |
| 121 | +let blurBindGroups: TgpuBindGroup<typeof blurLayout.entries>[]; |
| 122 | + |
| 123 | +// pipelines |
| 124 | + |
| 125 | +const prepareModelInputPipeline = root['~unstable'] |
| 126 | + .createGuardedComputePipeline( |
| 127 | + prepareModelInput, |
| 128 | + ); |
| 129 | + |
| 130 | +const session = await prepareSession( |
| 131 | + root.unwrap(modelInputBuffer), |
| 132 | + root.unwrap(modelOutputBuffer), |
| 133 | +); |
| 134 | + |
| 135 | +const generateMaskFromOutputPipeline = root['~unstable'] |
| 136 | + .createGuardedComputePipeline( |
| 137 | + generateMaskFromOutput, |
| 138 | + ); |
| 139 | + |
| 140 | +const blurPipeline = root['~unstable'] |
| 141 | + .withCompute(computeFn) |
| 142 | + .createPipeline(); |
| 143 | + |
| 144 | +const drawWithMaskPipeline = root['~unstable'] |
| 145 | + .with(useGaussianSlot, useGaussianUniform) |
| 146 | + .with(sampleBiasSlot, sampleBiasUniform) |
| 147 | + .withVertex(fullScreenTriangle, {}) |
| 148 | + .withFragment(drawWithMaskFragment, { format: presentationFormat }) |
| 149 | + .createPipeline(); |
| 150 | + |
| 151 | +// recalculating mask |
| 152 | + |
| 153 | +let calculateMaskCallbackId: number | undefined; |
| 154 | + |
| 155 | +async function processCalculateMask() { |
| 156 | + if (video.readyState < 2) { |
| 157 | + calculateMaskCallbackId = video.requestVideoFrameCallback( |
| 158 | + processCalculateMask, |
| 159 | + ); |
| 160 | + return; |
| 161 | + } |
| 162 | + |
| 163 | + prepareModelInputPipeline |
| 164 | + .with(root.createBindGroup(prepareModelInputLayout, { |
| 165 | + inputTexture: device.importExternalTexture({ source: video }), |
| 166 | + outputBuffer: modelInputBuffer, |
| 167 | + sampler, |
| 168 | + })) |
| 169 | + .dispatchThreads(MODEL_WIDTH, MODEL_HEIGHT); |
| 170 | + |
| 171 | + await session.run(); |
| 172 | + |
| 173 | + generateMaskFromOutputPipeline |
| 174 | + .with(generateMaskBindGroup) |
| 175 | + .dispatchThreads(MODEL_WIDTH, MODEL_HEIGHT); |
| 176 | + |
| 177 | + calculateMaskCallbackId = video.requestVideoFrameCallback( |
| 178 | + processCalculateMask, |
| 179 | + ); |
| 180 | +} |
| 181 | +calculateMaskCallbackId = video.requestVideoFrameCallback(processCalculateMask); |
| 182 | + |
| 183 | +// frame |
| 184 | + |
| 185 | +function onVideoChange(size: { width: number; height: number }) { |
| 186 | + const aspectRatio = size.width / size.height; |
| 187 | + video.style.height = `${video.clientWidth / aspectRatio}px`; |
| 188 | + if (canvas.parentElement) { |
| 189 | + canvas.parentElement.style.aspectRatio = `${aspectRatio}`; |
| 190 | + canvas.parentElement.style.height = |
| 191 | + `min(100cqh, calc(100cqw/(${aspectRatio})))`; |
| 192 | + } |
| 193 | + blurredTextures = [0, 1].map(() => |
| 194 | + root['~unstable'].createTexture({ |
| 195 | + size: [size.width, size.height], |
| 196 | + format: 'rgba8unorm', |
| 197 | + dimension: '2d', |
| 198 | + mipLevelCount: 10, |
| 199 | + }).$usage('sampled', 'render', 'storage') |
| 200 | + ); |
| 201 | + blurBindGroups = [ |
| 202 | + root.createBindGroup(blurLayout, { |
| 203 | + flip: zeroBuffer, |
| 204 | + inTexture: blurredTextures[0], |
| 205 | + outTexture: blurredTextures[1].createView( |
| 206 | + d.textureStorage2d('rgba8unorm', 'read-only'), |
| 207 | + { mipLevelCount: 1 }, |
| 208 | + ), |
| 209 | + sampler, |
| 210 | + }), |
| 211 | + root.createBindGroup(blurLayout, { |
| 212 | + flip: oneBuffer, |
| 213 | + inTexture: blurredTextures[1], |
| 214 | + outTexture: blurredTextures[0].createView( |
| 215 | + d.textureStorage2d('rgba8unorm', 'read-only'), |
| 216 | + { mipLevelCount: 1 }, |
| 217 | + ), |
| 218 | + sampler, |
| 219 | + }), |
| 220 | + ]; |
| 221 | +} |
| 222 | + |
| 223 | +let videoFrameCallbackId: number | undefined; |
| 224 | +let lastFrameSize: { width: number; height: number } | undefined; |
| 225 | + |
| 226 | +async function processVideoFrame( |
| 227 | + _: number, |
| 228 | + metadata: VideoFrameCallbackMetadata, |
| 229 | +) { |
| 230 | + if (video.readyState < 2) { |
| 231 | + videoFrameCallbackId = video.requestVideoFrameCallback(processVideoFrame); |
| 232 | + return; |
| 233 | + } |
| 234 | + |
| 235 | + const frameWidth = metadata.width; |
| 236 | + const frameHeight = metadata.height; |
| 237 | + |
| 238 | + if ( |
| 239 | + !lastFrameSize || |
| 240 | + lastFrameSize.width !== frameWidth || |
| 241 | + lastFrameSize.height !== frameHeight |
| 242 | + ) { |
| 243 | + lastFrameSize = { width: frameWidth, height: frameHeight }; |
| 244 | + onVideoChange(lastFrameSize); |
| 245 | + } |
| 246 | + |
| 247 | + blurredTextures[0].write(video); |
| 248 | + |
| 249 | + if (useGaussianBlur) { |
| 250 | + for (const _ of Array(blurStrength * 2)) { |
| 251 | + blurPipeline |
| 252 | + .with(blurBindGroups[0]) |
| 253 | + .dispatchWorkgroups( |
| 254 | + Math.ceil(frameWidth / blockDim), |
| 255 | + Math.ceil(frameHeight / 4), |
| 256 | + ); |
| 257 | + blurPipeline |
| 258 | + .with(blurBindGroups[1]) |
| 259 | + .dispatchWorkgroups( |
| 260 | + Math.ceil(frameHeight / blockDim), |
| 261 | + Math.ceil(frameWidth / 4), |
| 262 | + ); |
| 263 | + } |
| 264 | + } else { |
| 265 | + blurredTextures[0].generateMipmaps(); |
| 266 | + } |
| 267 | + |
| 268 | + drawWithMaskPipeline |
| 269 | + .withColorAttachment({ |
| 270 | + view: context.getCurrentTexture().createView(), |
| 271 | + clearValue: [1, 1, 1, 1], |
| 272 | + loadOp: 'clear', |
| 273 | + storeOp: 'store', |
| 274 | + }) |
| 275 | + .with(root.createBindGroup(drawWithMaskLayout, { |
| 276 | + inputTexture: device.importExternalTexture({ source: video }), |
| 277 | + inputBlurredTexture: blurredTextures[0], |
| 278 | + maskTexture, |
| 279 | + sampler, |
| 280 | + })) |
| 281 | + .draw(3); |
| 282 | + |
| 283 | + videoFrameCallbackId = video.requestVideoFrameCallback(processVideoFrame); |
| 284 | +} |
| 285 | +videoFrameCallbackId = video.requestVideoFrameCallback(processVideoFrame); |
| 286 | + |
| 287 | +// #region Example controls & Cleanup |
| 288 | + |
| 289 | +export const controls = { |
| 290 | + 'blur type': { |
| 291 | + initial: 'mipmaps', |
| 292 | + options: ['mipmaps', 'gaussian'], |
| 293 | + async onSelectChange(value: string) { |
| 294 | + useGaussianBlur = value === 'gaussian'; |
| 295 | + useGaussianUniform.write(useGaussianBlur ? 1 : 0); |
| 296 | + }, |
| 297 | + }, |
| 298 | + 'blur strength': { |
| 299 | + initial: blurStrength, |
| 300 | + min: 0, |
| 301 | + max: 10, |
| 302 | + step: 1, |
| 303 | + onSliderChange(newValue: number) { |
| 304 | + blurStrength = newValue; |
| 305 | + sampleBiasUniform.write(blurStrength); |
| 306 | + }, |
| 307 | + }, |
| 308 | +}; |
| 309 | + |
| 310 | +export function onCleanup() { |
| 311 | + if (videoFrameCallbackId !== undefined) { |
| 312 | + video.cancelVideoFrameCallback(videoFrameCallbackId); |
| 313 | + } |
| 314 | + if (calculateMaskCallbackId !== undefined) { |
| 315 | + video.cancelVideoFrameCallback(calculateMaskCallbackId); |
| 316 | + } |
| 317 | + session.release(); |
| 318 | + if (video.srcObject) { |
| 319 | + for (const track of (video.srcObject as MediaStream).getTracks()) { |
| 320 | + track.stop(); |
| 321 | + } |
| 322 | + } |
| 323 | + navigator.gpu.requestAdapter = oldRequestAdapter; |
| 324 | + if (adapter) { |
| 325 | + adapter.requestDevice = oldRequestDevice; |
| 326 | + } |
| 327 | + |
| 328 | + root.destroy(); |
| 329 | +} |
| 330 | + |
| 331 | +// #endregion |
0 commit comments