Skip to content

Commit 57e8d3f

Browse files
authored
Auto-use all struct types in the shell of WGSL-implemented functions (#752)
1 parent c164dd8 commit 57e8d3f

File tree

8 files changed

+235
-38
lines changed

8 files changed

+235
-38
lines changed

apps/typegpu-docs/src/content/examples/rendering/box-raytracing/index.ts

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,6 @@ const getBoxIntersection = tgpu
178178
return output;
179179
}
180180
`)
181-
.$uses({ RayStruct, IntersectionStruct })
182181
.$name('box_intersection');
183182

184183
const vertexFunction = tgpu

packages/typegpu/src/core/function/fnCore.ts

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ import { MissingLinksError } from '../../errors';
33
import type { ResolutionCtx, Resource } from '../../types';
44
import {
55
type ExternalMap,
6+
addArgTypesToExternals,
7+
addReturnTypeToExternals,
68
applyExternals,
79
replaceExternalsInWgsl,
810
} from '../resolve/externals';
@@ -32,6 +34,15 @@ export function createFnCore(
3234
*/
3335
const externalsToApply: ExternalMap[] = [];
3436

37+
if (typeof implementation === 'string') {
38+
addArgTypesToExternals(implementation, shell.argTypes, (externals) =>
39+
externalsToApply.push(externals),
40+
);
41+
addReturnTypeToExternals(implementation, shell.returnType, (externals) =>
42+
externalsToApply.push(externals),
43+
);
44+
}
45+
3546
return {
3647
label: undefined as string | undefined,
3748

packages/typegpu/src/core/function/ioOutputType.ts

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@ import {
88
} from '../../data/attributes';
99
import { getCustomLocation, isData } from '../../data/dataTypes';
1010
import type { BaseWgslData, Location } from '../../data/wgslTypes';
11-
import type { FnCore } from './fnCore';
12-
import type { IOData, IOLayout, IORecord, Implementation } from './fnTypes';
11+
import type { IOData, IOLayout, IORecord } from './fnTypes';
1312

1413
export type WithLocations<T extends IORecord> = {
1514
[Key in keyof T]: IsBuiltin<T[Key]> extends true
@@ -52,28 +51,10 @@ export function withLocations<T extends IOData>(
5251
);
5352
}
5453

55-
export function createOutputType<T extends IOData>(
56-
core: FnCore,
57-
implementation: Implementation,
58-
returnType: IOLayout<T>,
59-
) {
60-
const Output: IOLayoutToOutputSchema<IOLayout<T>> = (
54+
export function createOutputType<T extends IOData>(returnType: IOLayout<T>) {
55+
return (
6156
isData(returnType)
6257
? location(0, returnType)
6358
: struct(withLocations(returnType) as Record<string, T>)
6459
) as IOLayoutToOutputSchema<IOLayout<T>>;
65-
66-
if (typeof implementation === 'string') {
67-
const outputName = implementation
68-
.match(/->(?<output>.*?){/s)
69-
?.groups?.output?.trim();
70-
71-
if (outputName && !/\s/g.test(outputName)) {
72-
core.applyExternals({
73-
[outputName]: Output,
74-
});
75-
}
76-
}
77-
78-
return Output;
7960
}

packages/typegpu/src/core/function/tgpuFragmentFn.ts

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ import type { OmitBuiltins } from '../../builtin';
22
import { type Vec4f, isWgslStruct } from '../../data/wgslTypes';
33
import type { TgpuNamable } from '../../namable';
44
import type { ResolutionCtx, TgpuResolvable } from '../../types';
5+
import { addReturnTypeToExternals } from '../resolve/externals';
56
import { createFnCore } from './fnCore';
67
import type {
78
ExoticIO,
@@ -98,7 +99,12 @@ function createFragmentFn(
9899
type This = TgpuFragmentFn;
99100

100101
const core = createFnCore(shell, implementation);
101-
const outputType = createOutputType(core, implementation, shell.returnType);
102+
const outputType = createOutputType(shell.returnType);
103+
if (typeof implementation === 'string') {
104+
addReturnTypeToExternals(implementation, outputType, (externals) =>
105+
core.applyExternals(externals),
106+
);
107+
}
102108

103109
return {
104110
shell,

packages/typegpu/src/core/function/tgpuVertexFn.ts

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ import type { OmitBuiltins } from '../../builtin';
22
import { isWgslStruct } from '../../data/wgslTypes';
33
import type { TgpuNamable } from '../../namable';
44
import type { ResolutionCtx, TgpuResolvable } from '../../types';
5+
import { addReturnTypeToExternals } from '../resolve/externals';
56
import { createFnCore } from './fnCore';
67
import type {
78
ExoticIO,
@@ -97,7 +98,12 @@ function createVertexFn(
9798
type This = TgpuVertexFn<IOLayout, IOLayout>;
9899

99100
const core = createFnCore(shell, implementation);
100-
const outputType = createOutputType(core, implementation, shell.returnType);
101+
const outputType = createOutputType(shell.returnType);
102+
if (typeof implementation === 'string') {
103+
addReturnTypeToExternals(implementation, outputType, (externals) =>
104+
core.applyExternals(externals),
105+
);
106+
}
101107

102108
return {
103109
shell,

packages/typegpu/src/core/resolve/externals.ts

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { isWgslData } from '../../data/wgslTypes';
1+
import { isWgslData, isWgslStruct } from '../../data/wgslTypes';
22
import { isNamable } from '../../namable';
33
import { type ResolutionCtx, type Wgsl, isResolvable } from '../../types';
44
import { isSlot } from '../slot/slotTypes';
@@ -32,6 +32,41 @@ export function applyExternals(
3232
}
3333
}
3434

35+
export function addArgTypesToExternals(
36+
implementation: string,
37+
argTypes: unknown[],
38+
applyExternals: (externals: ExternalMap) => void,
39+
) {
40+
const argTypeNames = [
41+
...implementation.matchAll(/:\s*(?<arg>.*?)\s*[,)]/g),
42+
].map((found) => found.groups?.arg);
43+
44+
applyExternals(
45+
Object.fromEntries(
46+
argTypes.flatMap((argType, i) => {
47+
const argTypeName = argTypeNames ? argTypeNames[i] : undefined;
48+
return isWgslStruct(argType) && argTypeName !== undefined
49+
? [[argTypeName, argType]]
50+
: [];
51+
}),
52+
),
53+
);
54+
}
55+
56+
export function addReturnTypeToExternals(
57+
implementation: string,
58+
returnType: unknown,
59+
applyExternals: (externals: ExternalMap) => void,
60+
) {
61+
const outputName = implementation
62+
.match(/->(?<output>.*?){/s)
63+
?.groups?.output?.trim();
64+
65+
if (isWgslStruct(returnType) && outputName && !/\s/g.test(outputName)) {
66+
applyExternals({ [outputName]: returnType });
67+
}
68+
}
69+
3570
/**
3671
* Replaces all occurrences of external names in WGSL code with their resolved values.
3772
* It adds all necessary definitions to the resolution context.
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import { describe, expect, it } from 'vitest';
2+
import {
3+
type ExternalMap,
4+
addArgTypesToExternals,
5+
} from '../src/core/resolve/externals';
6+
import * as d from '../src/data';
7+
8+
describe('addArgTypesToExternals', () => {
9+
const Particle = d.struct({
10+
position: d.vec3f,
11+
color: d.vec4f,
12+
});
13+
14+
const Light = d.struct({
15+
ambient: d.vec4f,
16+
intensity: d.f32,
17+
});
18+
19+
it('extracts struct argument types with their names', () => {
20+
const externals: ExternalMap[] = [];
21+
addArgTypesToExternals(
22+
'(a: vec4f, b: Particle, c: Light) {}',
23+
[d.vec4f, Particle, Light],
24+
(result) => externals.push(result),
25+
);
26+
expect(externals).toEqual([{ Particle, Light }]);
27+
});
28+
29+
it('gets the names from argument list in WGSL implementation', () => {
30+
const externals: ExternalMap[] = [];
31+
addArgTypesToExternals(
32+
'(b: P, a: vec4f, c: L) -> L {}',
33+
[Particle, d.vec4f, Light],
34+
(result) => externals.push(result),
35+
);
36+
expect(externals).toEqual([{ P: Particle, L: Light }]);
37+
});
38+
39+
it('works when builtins are present', () => {
40+
const externals: ExternalMap[] = [];
41+
addArgTypesToExternals(
42+
'(@builtin(workgroup_id) WorkGroupID : vec3u, a: vec4f, b: Particle, c: Light) {}',
43+
[d.vec3u, d.vec4f, Particle, Light],
44+
(result) => externals.push(result),
45+
);
46+
expect(externals).toEqual([{ Particle, Light }]);
47+
});
48+
49+
it('works with unusual whitespace', () => {
50+
const externals: ExternalMap[] = [];
51+
addArgTypesToExternals(
52+
` WorkGroupID : vec3u
53+
,
54+
a : A ,
55+
(@builtin(workgroup_id) b
56+
57+
: B,
58+
59+
c: C
60+
) -> vec4f {}`,
61+
[d.vec3u, Particle, Particle, Particle],
62+
(result) => externals.push(result),
63+
);
64+
expect(externals).toEqual([{ A: Particle, B: Particle, C: Particle }]);
65+
});
66+
});

0 commit comments

Comments
 (0)