Skip to content

Commit 21776b3

Browse files
committed
Test for updating a whole struct, returning refs
1 parent c853f5c commit 21776b3

File tree

7 files changed

+170
-21
lines changed

7 files changed

+170
-21
lines changed

packages/typegpu/src/data/ref.ts

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,15 @@ export interface ref<T> {
4141
// TODO: Restrict calls to this function only from within TypeGPU functions
4242
export const ref: DualFn<<T>(value: T) => ref<T>> = (() => {
4343
const gpuImpl = (value: Snippet) => {
44-
return snip(new RefOperator(value), UnknownData, /* origin */ 'runtime');
44+
const ptrType = createPtrFromOrigin(
45+
value.origin,
46+
value.dataType as StorableData,
47+
);
48+
return snip(
49+
new RefOperator(value, ptrType),
50+
ptrType ?? UnknownData,
51+
/* origin */ 'runtime',
52+
);
4553
};
4654

4755
const jsImpl = <T>(value: T) => new refImpl(value);
@@ -99,6 +107,11 @@ class refImpl<T> implements ref<T> {
99107
}
100108
}
101109

110+
/**
111+
* The result of calling `d.ref(...)`. The code responsible for
112+
* generating shader code can check if the value of a snippet is
113+
* an instance of `RefOperator`, and act accordingly.
114+
*/
102115
export class RefOperator implements SelfResolvable {
103116
readonly [$internal]: true;
104117
readonly snippet: Snippet;
@@ -114,14 +127,10 @@ export class RefOperator implements SelfResolvable {
114127
*/
115128
readonly ptrType: Ptr | undefined;
116129

117-
constructor(snippet: Snippet) {
130+
constructor(snippet: Snippet, ptrType: Ptr | undefined) {
118131
this[$internal] = true;
119132
this.snippet = snippet;
120-
121-
this.ptrType = createPtrFromOrigin(
122-
snippet.origin,
123-
snippet.dataType as StorableData,
124-
);
133+
this.ptrType = ptrType;
125134
}
126135

127136
get [$ownSnippet](): Snippet {

packages/typegpu/src/tgsl/conversion.ts

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import {
1010
type I32,
1111
isMat,
1212
isVec,
13+
Ptr,
1314
type U32,
1415
type WgslStruct,
1516
} from '../data/wgslTypes.ts';
@@ -73,6 +74,8 @@ function getImplicitConversionRank(
7374

7475
if (
7576
trueSrc.type === 'ptr' &&
77+
// Only dereferencing implicit pointers, otherwise we'd have a types mismatch between TS and WGSL
78+
trueSrc.implicit &&
7679
getAutoConversionRank(trueSrc.inner as AnyData, trueDst).rank <
7780
Number.POSITIVE_INFINITY
7881
) {
@@ -240,7 +243,11 @@ function applyActionToSnippet(
240243

241244
switch (action.action) {
242245
case 'ref':
243-
return snip(new RefOperator(snippet), targetType, snippet.origin);
246+
return snip(
247+
new RefOperator(snippet, targetType as Ptr),
248+
targetType,
249+
snippet.origin,
250+
);
244251
case 'deref':
245252
return derefSnippet(snippet);
246253
case 'cast': {

packages/typegpu/src/tgsl/generationHelpers.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -394,11 +394,11 @@ export function concretizeSnippets(args: Snippet[]): Snippet[] {
394394
export type GenerationCtx = ResolutionCtx & {
395395
readonly pre: string;
396396
/**
397-
* Used by `generateTypedExpression` to signal downstream
397+
* Used by `typedExpression` to signal downstream
398398
* expression resolution what type is expected of them.
399399
*
400400
* It is used exclusively for inferring the types of structs and arrays.
401-
* It is modified exclusively by `generateTypedExpression` function.
401+
* It is modified exclusively by `typedExpression` function.
402402
*/
403403
expectedType: AnyData | undefined;
404404

packages/typegpu/src/tgsl/wgslGenerator.ts

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,10 @@ ${this.ctx.pre}}`;
153153
dataType: wgsl.StorableData,
154154
): string {
155155
const varName = this.ctx.makeNameValid(id);
156+
const ptrType = ptrFn(dataType);
156157
const snippet = snip(
157-
new RefOperator(snip(varName, dataType, 'function')),
158-
ptrFn(dataType),
158+
new RefOperator(snip(varName, dataType, 'function'), ptrType),
159+
ptrType,
159160
'function',
160161
);
161162
this.ctx.defineVariable(id, snippet);
@@ -697,6 +698,12 @@ ${this.ctx.pre}}`;
697698
)
698699
: this.expression(returnNode);
699700

701+
if (returnSnippet.value instanceof RefOperator) {
702+
throw new WgslTypeError(
703+
stitch`Cannot return references, returning '${returnSnippet.value.snippet}'`,
704+
);
705+
}
706+
700707
if (
701708
!expectedReturnType &&
702709
!isEphemeralSnippet(returnSnippet) &&

packages/typegpu/tests/ref.test.ts

Lines changed: 115 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ describe('ref', () => {
4444
`);
4545
});
4646

47-
it('fails when creating a ref with a reference', () => {
47+
it('fails when creating a ref with a reference, and assigning it to a variable', () => {
4848
const hello = () => {
4949
'use gpu';
5050
const position = d.vec3f(1, 2, 3);
@@ -58,4 +58,118 @@ describe('ref', () => {
5858
- fn*:hello(): Cannot store d.ref() in a variable if it references another value. Copy the value passed into d.ref() instead.]
5959
`);
6060
});
61+
62+
it('allows updating a whole struct from another function', () => {
63+
type Entity = d.Infer<typeof Entity>;
64+
const Entity = d.struct({ pos: d.vec3f });
65+
66+
const clearEntity = (entity: d.ref<Entity>) => {
67+
'use gpu';
68+
entity.$ = Entity({ pos: d.vec3f(0, 0, 0) });
69+
};
70+
71+
const main = () => {
72+
'use gpu';
73+
const entity = Entity({ pos: d.vec3f(1, 2, 3) });
74+
clearEntity(d.ref(entity));
75+
// entity.pos should be vec3f(0, 0, 0)
76+
return entity;
77+
};
78+
79+
// Works in JS
80+
expect(main().pos).toStrictEqual(d.vec3f(0, 0, 0));
81+
82+
// And on the GPU
83+
expect(asWgsl(main)).toMatchInlineSnapshot(`
84+
"struct Entity {
85+
pos: vec3f,
86+
}
87+
88+
fn clearEntity(entity: ptr<function, Entity>) {
89+
(*entity) = Entity(vec3f());
90+
}
91+
92+
fn main() -> Entity {
93+
var entity = Entity(vec3f(1, 2, 3));
94+
clearEntity((&entity));
95+
return entity;
96+
}"
97+
`);
98+
});
99+
100+
it('allows updating a number from another function', () => {
101+
const increment = (value: d.ref<number>) => {
102+
'use gpu';
103+
value.$ += 1;
104+
};
105+
106+
const main = () => {
107+
'use gpu';
108+
const value = d.ref(0);
109+
increment(value);
110+
return value.$;
111+
};
112+
113+
// Works in JS
114+
expect(main()).toBe(1);
115+
116+
// And on the GPU
117+
expect(asWgsl(main)).toMatchInlineSnapshot(`
118+
"fn increment(value: ptr<function, i32>) {
119+
(*value) += 1i;
120+
}
121+
122+
fn main() -> i32 {
123+
var value = 0;
124+
increment((&value));
125+
return value;
126+
}"
127+
`);
128+
});
129+
130+
it('rejects passing d.ref created from non-refs directly into functions', () => {
131+
const increment = (value: d.ref<number>) => {
132+
'use gpu';
133+
value.$ += 1;
134+
};
135+
136+
const main = () => {
137+
'use gpu';
138+
increment(d.ref(0));
139+
};
140+
141+
expect(() => asWgsl(main)).toThrowErrorMatchingInlineSnapshot(`
142+
[Error: Resolution of the following tree failed:
143+
- <root>
144+
- fn*:main
145+
- fn*:main(): d.ref() created with primitive types must be stored in a variable before use]
146+
`);
147+
});
148+
149+
it('fails when returning a ref', () => {
150+
const foo = () => {
151+
'use gpu';
152+
const value = d.ref(0);
153+
return value;
154+
};
155+
156+
const bar = () => {
157+
'use gpu';
158+
return d.ref(0);
159+
};
160+
161+
expect(() => asWgsl(foo)).toThrowErrorMatchingInlineSnapshot(`
162+
[Error: Resolution of the following tree failed:
163+
- <root>
164+
- fn*:foo
165+
- fn*:foo(): Cannot return references, returning 'value']
166+
`);
167+
168+
expect(() => asWgsl(bar)).toThrowErrorMatchingInlineSnapshot(`
169+
[Error: Resolution of the following tree failed:
170+
- <root>
171+
- fn*:bar
172+
- fn*:bar(): Cannot return references, returning '0']
173+
`);
174+
});
61175
});

packages/typegpu/tests/tgsl/conversion.test.ts

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import { UnknownData } from '../../src/data/dataTypes.ts';
1414
import { ResolutionCtxImpl } from '../../src/resolutionCtx.ts';
1515
import { namespace } from '../../src/core/resolve/namespace.ts';
1616
import wgslGenerator from '../../src/tgsl/wgslGenerator.ts';
17+
import { INTERNAL_createPtr } from '../../src/data/ptr.ts';
1718

1819
const ctx = new ResolutionCtxImpl({
1920
namespace: namespace({ names: 'strict' }),
@@ -30,8 +31,20 @@ afterAll(() => {
3031
});
3132

3233
describe('getBestConversion', () => {
33-
const ptrF32 = d.ptrPrivate(d.f32);
34-
const ptrI32 = d.ptrPrivate(d.i32);
34+
// d.ptrPrivate(d.f32)
35+
const ptrF32 = INTERNAL_createPtr(
36+
'private',
37+
d.f32,
38+
'read-write',
39+
/* implicit */ true,
40+
);
41+
// d.ptrPrivate(d.i32)
42+
const ptrI32 = INTERNAL_createPtr(
43+
'private',
44+
d.i32,
45+
'read-write',
46+
/* implicit */ true,
47+
);
3548

3649
it('returns result for identical types', () => {
3750
const res = getBestConversion([d.f32, d.f32]);
@@ -192,7 +205,7 @@ describe('convertToCommonType', () => {
192205
const snippetAbsInt = snip('1', abstractInt, /* ref */ 'runtime');
193206
const snippetPtrF32 = snip(
194207
'ptr_f32',
195-
d.ptrPrivate(d.f32),
208+
INTERNAL_createPtr('private', d.f32, 'read-write', /* implicit */ true),
196209
/* ref */ 'function',
197210
);
198211
const snippetUnknown = snip('?', UnknownData, /* ref */ 'runtime');

packages/typegpu/tests/tgsl/shellless.test.ts

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ describe('shellless', () => {
186186
`);
187187
});
188188

189-
it('generates uniform pointer params when passing a fixed uniform directly to a function', ({ root }) => {
189+
it('generates uniform pointer params when passing a fixed uniform ref to a function', ({ root }) => {
190190
const posUniform = root.createUniform(d.vec3f);
191191

192192
const sumComponents = (vec: d.ref<d.v3f>) => {
@@ -197,16 +197,15 @@ describe('shellless', () => {
197197
const main = () => {
198198
'use gpu';
199199
sumComponents(d.ref(posUniform.$));
200-
// sumComponents(&posUniform);
201200
};
202201

203202
expect(asWgsl(main)).toMatchInlineSnapshot(`
204-
"@group(0) @binding(0) var<uniform> posUniform: vec3f;
205-
206-
fn sumComponents(vec: ptr<uniform, vec3f>) -> f32 {
203+
"fn sumComponents(vec: ptr<uniform, vec3f>) -> f32 {
207204
return (((*vec).x + (*vec).y) + (*vec).z);
208205
}
209206
207+
@group(0) @binding(0) var<uniform> posUniform: vec3f;
208+
210209
fn main() {
211210
sumComponents((&posUniform));
212211
}"

0 commit comments

Comments
 (0)