Skip to content

Commit 171af79

Browse files
committed
More test coverage for argument origin tracking
1 parent 685479d commit 171af79

File tree

8 files changed

+203
-27
lines changed

8 files changed

+203
-27
lines changed

packages/typegpu/src/core/function/fnCore.ts

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import {
77
type Snippet,
88
} from '../../data/snippet.ts';
99
import {
10+
isNaturallyEphemeral,
1011
isPtr,
1112
isWgslData,
1213
isWgslStruct,
@@ -196,39 +197,49 @@ export function createFnCore(
196197
// of the argument based on the argument's referentiality.
197198
// In other words, if we pass a reference to a function, it's typed as a pointer,
198199
// otherwise it's typed as a value.
199-
const ref = isPtr(argType)
200+
const origin = isPtr(argType)
200201
? argType.addressSpace === 'storage'
201202
? argType.access === 'read' ? 'readonly' : 'mutable'
202203
: argType.addressSpace
204+
: isNaturallyEphemeral(argType)
205+
? 'runtime'
203206
: 'argument';
204207

205208
switch (astParam?.type) {
206209
case FuncParameterType.identifier: {
207210
const rawName = astParam.name;
208-
const snippet = snip(ctx.makeNameValid(rawName), argType, ref);
211+
const snippet = snip(ctx.makeNameValid(rawName), argType, origin);
209212
args.push(snippet);
210213
if (snippet.value !== rawName) {
211214
argAliases.push([rawName, snippet]);
212215
}
213216
break;
214217
}
215218
case FuncParameterType.destructuredObject: {
216-
args.push(snip(`_arg_${i}`, argType, ref));
217-
argAliases.push(...astParam.props.map(({ name, alias }) =>
218-
[
219+
args.push(snip(`_arg_${i}`, argType, origin));
220+
argAliases.push(...astParam.props.map(({ name, alias }) => {
221+
// Undecorating, as the struct type can contain builtins
222+
const destrType = undecorate(
223+
(argTypes[i] as WgslStruct).propTypes[name],
224+
);
225+
226+
const destrOrigin = isPtr(destrType)
227+
? destrType.addressSpace === 'storage'
228+
? destrType.access === 'read' ? 'readonly' : 'mutable'
229+
: destrType.addressSpace
230+
: isNaturallyEphemeral(destrType)
231+
? 'runtime'
232+
: 'argument';
233+
234+
return [
219235
alias,
220-
snip(
221-
`_arg_${i}.${name}`,
222-
(argTypes[i] as WgslStruct)
223-
.propTypes[name],
224-
ref,
225-
),
226-
] as [string, Snippet]
227-
));
236+
snip(`_arg_${i}.${name}`, destrType, destrOrigin),
237+
] as [string, Snippet];
238+
}));
228239
break;
229240
}
230241
case undefined:
231-
args.push(snip(`_arg_${i}`, argType, ref));
242+
args.push(snip(`_arg_${i}`, argType, origin));
232243
}
233244
}
234245

packages/typegpu/src/core/variable/tgpuVariable.ts

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import type { AnyData } from '../../data/dataTypes.ts';
2-
import type { ref } from '../../data/ref.ts';
32
import { type ResolvedSnippet, snip } from '../../data/snippet.ts';
43
import { isNaturallyEphemeral } from '../../data/wgslTypes.ts';
54
import { IllegalVarAccessError } from '../../errors.ts';
@@ -26,7 +25,7 @@ export type VariableScope = 'private' | 'workgroup';
2625
export interface TgpuVar<
2726
TScope extends VariableScope = VariableScope,
2827
TDataType extends AnyData = AnyData,
29-
> extends TgpuNamable, ref<InferGPU<TDataType>> {
28+
> extends TgpuNamable {
3029
readonly [$gpuValueOf]: InferGPU<TDataType>;
3130
value: InferGPU<TDataType>;
3231
$: InferGPU<TDataType>;

packages/typegpu/src/data/ref.ts

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,10 @@ export function derefSnippet(snippet: Snippet): Snippet {
161161
invariant(isPtr(snippet.dataType), 'Only pointers can be dereferenced');
162162

163163
const innerType = snippet.dataType.inner;
164-
// Dereferencing a pointer does not return a copy of the value, it's still a reference.
165-
const origin = isNaturallyEphemeral(innerType) ? 'runtime' : snippet.origin;
164+
const origin =
165+
isNaturallyEphemeral(innerType) && snippet.origin !== 'argument'
166+
? 'runtime'
167+
: snippet.origin;
166168

167169
if (snippet.value instanceof RefOperator) {
168170
return snip(stitch`${snippet.value.snippet}`, innerType, origin);

packages/typegpu/src/tgsl/wgslGenerator.ts

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,8 @@ ${this.ctx.pre}}`;
174174
// Even types that aren't naturally referential (like vectors or structs) should
175175
// be treated as constant references when assigned to a const.
176176
varOrigin = 'constant-ref';
177+
} else if (origin === 'argument' && !wgsl.isNaturallyEphemeral(dataType)) {
178+
varOrigin = 'argument';
177179
} else if (!wgsl.isNaturallyEphemeral(dataType)) {
178180
varOrigin = isEphemeralOrigin(origin) ? 'function' : origin;
179181
} else if (origin === 'constant' && varType === 'const') {
@@ -291,11 +293,22 @@ ${this.ctx.pre}}`;
291293
);
292294
}
293295

296+
if (
297+
rhsExpr.origin === 'argument' &&
298+
!wgsl.isNaturallyEphemeral(rhsExpr.dataType)
299+
) {
300+
throw new WgslTypeError(
301+
`'${lhsStr} = ${rhsStr}' is invalid, because argument references cannot be assigned.\n-----\nTry '${lhsStr} = ${
302+
this.ctx.resolve(rhsExpr.dataType).value
303+
}(${rhsStr})' to copy the value instead.\n-----`,
304+
);
305+
}
306+
294307
if (!isEphemeralSnippet(rhsExpr)) {
295308
throw new WgslTypeError(
296309
`'${lhsStr} = ${rhsStr}' is invalid, because references cannot be assigned.\n-----\nTry '${lhsStr} = ${
297310
this.ctx.resolve(rhsExpr.dataType).value
298-
}(${rhsStr})' instead.\n-----`,
311+
}(${rhsStr})' to copy the value instead.\n-----`,
299312
);
300313
}
301314
}
@@ -812,7 +825,10 @@ ${this.ctx.pre}else ${alternate}`;
812825

813826
// Assigning a reference to a `const` variable means we store the pointer
814827
// of the rhs.
815-
if (!isEphemeralSnippet(eq)) {
828+
if (
829+
!isEphemeralSnippet(eq) ||
830+
(eq.origin === 'argument' && !wgsl.isNaturallyEphemeral(dataType))
831+
) {
816832
// Referential
817833
if (stmtType === NODE.let) {
818834
const rhsStr = this.ctx.resolve(eq.value).value;

packages/typegpu/tests/examples/individual/gravity.test.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,9 +270,9 @@ describe('gravity example', () => {
270270
var lightColor = vec3f(1, 0.8999999761581421, 0.8999999761581421);
271271
var textureColor = textureSample(celestialBodyTextures_9, sampler_10, input.uv, input.sphereTextureIndex).xyz;
272272
var ambient = ((textureColor * lightColor) * input.ambientLightFactor);
273-
var normal = input.normals;
273+
let normal = (&input.normals);
274274
var lightDirection = normalize((lightSource_11 - input.worldPosition));
275-
let cosTheta = dot(normal, lightDirection);
275+
let cosTheta = dot((*normal), lightDirection);
276276
var diffuse = ((textureColor * lightColor) * max(0f, cosTheta));
277277
var litColor = (ambient + diffuse);
278278
return vec4f(litColor.xyz, 1f);

packages/typegpu/tests/ref.test.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import { describe, expect } from 'vitest';
33
import { it } from './utils/extendedIt.ts';
44
import { asWgsl } from './utils/parseResolved.ts';
55

6-
describe('ref', () => {
6+
describe('d.ref', () => {
77
it('fails when using a ref as an external', () => {
88
const sup = d.ref(0);
99

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import { describe, expect } from 'vitest';
2+
import * as d from '../../src/data/index.ts';
3+
import { it } from '../utils/extendedIt';
4+
import { asWgsl } from '../utils/parseResolved';
5+
6+
describe('function argument origin tracking', () => {
7+
it('should allow mutation of primitive arguments', () => {
8+
const foo = (a: number) => {
9+
'use gpu';
10+
a += 1;
11+
};
12+
13+
const main = () => {
14+
'use gpu';
15+
foo(1);
16+
};
17+
18+
expect(asWgsl(main)).toMatchInlineSnapshot(`
19+
"fn foo(a: i32) {
20+
a += 1i;
21+
}
22+
23+
fn main() {
24+
foo(1i);
25+
}"
26+
`);
27+
});
28+
29+
it('should allow mutation of destructured primitive arguments', () => {
30+
const Foo = d.struct({ a: d.f32 });
31+
32+
const foo = ({ a }: { a: number }) => {
33+
'use gpu';
34+
a += 1;
35+
};
36+
37+
const main = () => {
38+
'use gpu';
39+
foo(Foo({ a: 1 }));
40+
};
41+
42+
expect(asWgsl(main)).toMatchInlineSnapshot(`
43+
"struct Foo {
44+
a: f32,
45+
}
46+
47+
fn foo(_arg_0: Foo) {
48+
_arg_0.a += 1f;
49+
}
50+
51+
fn main() {
52+
foo(Foo(1f));
53+
}"
54+
`);
55+
});
56+
57+
it('should fail on mutation of non-primitive arguments', () => {
58+
const foo = (a: d.v3f) => {
59+
'use gpu';
60+
a.x += 1;
61+
};
62+
63+
const main = () => {
64+
'use gpu';
65+
foo(d.vec3f(1, 2, 3));
66+
};
67+
68+
expect(() => asWgsl(main)).toThrowErrorMatchingInlineSnapshot(`
69+
[Error: Resolution of the following tree failed:
70+
- <root>
71+
- fn*:main
72+
- fn*:main()
73+
- fn*:foo(vec3f): 'a.x += 1f' is invalid, because non-pointer arguments cannot be mutated.]
74+
`);
75+
});
76+
77+
it('should fail on transitive mutation of non-primitive arguments', () => {
78+
const foo = (a: d.v3f) => {
79+
'use gpu';
80+
const b = a;
81+
b.x += 1;
82+
};
83+
84+
const main = () => {
85+
'use gpu';
86+
foo(d.vec3f(1, 2, 3));
87+
};
88+
89+
expect(() => asWgsl(main)).toThrowErrorMatchingInlineSnapshot(`
90+
[Error: Resolution of the following tree failed:
91+
- <root>
92+
- fn*:main
93+
- fn*:main()
94+
- fn*:foo(vec3f): '(*b).x += 1f' is invalid, because non-pointer arguments cannot be mutated.]
95+
`);
96+
});
97+
98+
it('should fail on create a let variable from an argument reference', () => {
99+
const foo = (a: d.v3f) => {
100+
'use gpu';
101+
let b = a;
102+
b = d.vec3f();
103+
return b;
104+
};
105+
106+
const main = () => {
107+
'use gpu';
108+
foo(d.vec3f(1, 2, 3));
109+
};
110+
111+
expect(() => asWgsl(main)).toThrowErrorMatchingInlineSnapshot(`
112+
[Error: Resolution of the following tree failed:
113+
- <root>
114+
- fn*:main
115+
- fn*:main()
116+
- fn*:foo(vec3f): 'let b = a' is invalid, because references cannot be assigned to 'let' variable declarations.
117+
-----
118+
- Try 'let b = vec3f(a)' if you need to reassign 'b' later
119+
- Try 'const b = a' if you won't reassign 'b' later.
120+
-----]
121+
`);
122+
});
123+
124+
it('should fail on assigning an argument reference to a variable', () => {
125+
const foo = (a: d.v3f) => {
126+
'use gpu';
127+
let b = d.vec3f();
128+
b = a;
129+
return b;
130+
};
131+
132+
const main = () => {
133+
'use gpu';
134+
foo(d.vec3f(1, 2, 3));
135+
};
136+
137+
expect(() => asWgsl(main)).toThrowErrorMatchingInlineSnapshot(`
138+
[Error: Resolution of the following tree failed:
139+
- <root>
140+
- fn*:main
141+
- fn*:main()
142+
- fn*:foo(vec3f): 'b = a' is invalid, because argument references cannot be assigned.
143+
-----
144+
Try 'b = vec3f(a)' to copy the value instead.
145+
-----]
146+
`);
147+
});
148+
});

packages/typegpu/tests/tgslFn.test.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,8 @@ describe('TGSL tgpu.fn function', () => {
164164
@vertex fn vertex_fn(input: vertex_fn_Input) -> vertex_fn_Output {
165165
let vi = f32(input.vi);
166166
let ii = f32(input.ii);
167-
var color = input.color;
168-
return vertex_fn_Output(vec4f(color.w, ii, vi, 1f), vec2f(color.w, vi));
167+
let color = (&input.color);
168+
return vertex_fn_Output(vec4f((*color).w, ii, vi, 1f), vec2f((*color).w, vi));
169169
}"
170170
`);
171171
});
@@ -391,9 +391,9 @@ describe('TGSL tgpu.fn function', () => {
391391
}
392392
393393
@fragment fn fragmentFn(input: fragmentFn_Input) -> fragmentFn_Output {
394-
var pos = input.pos;
394+
let pos = (&input.pos);
395395
var sampleMask = 0;
396-
if (((input.sampleMask > 0u) && (pos.x > 0f))) {
396+
if (((input.sampleMask > 0u) && ((*pos).x > 0f))) {
397397
sampleMask = 1i;
398398
}
399399
return fragmentFn_Output(u32(sampleMask), 1f, vec4f());

0 commit comments

Comments
 (0)