Skip to content

Commit 79252cb

Browse files
committed
More useful refs
1 parent 07ee6b7 commit 79252cb

File tree

11 files changed

+365
-48
lines changed

11 files changed

+365
-48
lines changed

apps/typegpu-docs/src/examples/simulation/gravity/compute.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ export const computeCollisionsShader = tgpu['~unstable'].computeFn({
5959
if (isSmaller(currentId, otherId)) {
6060
const dir = std.normalize(current.position.sub(other.position));
6161
current.position = other.position.add(
62-
dir.mul(radiusOf(current) + radiusOf(other)),
62+
dir.mul(radiusOf(d.ref(current)) + radiusOf(d.ref(other))),
6363
);
6464
}
6565

@@ -118,7 +118,7 @@ export const computeGravityShader = tgpu['~unstable'].computeFn({
118118
}
119119

120120
const dist = std.max(
121-
radiusOf(current) + radiusOf(other),
121+
radiusOf(d.ref(current)) + radiusOf(d.ref(other)),
122122
std.distance(current.position, other.position),
123123
);
124124
const gravityForce = (current.mass * other.mass) / dist / dist;

apps/typegpu-docs/src/examples/simulation/gravity/render.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ export const mainVertex = tgpu['~unstable'].vertexFn({
5656
const currentBody = renderLayout.$.celestialBodies[input.instanceIndex];
5757

5858
const worldPosition = currentBody.position.add(
59-
input.position.xyz.mul(radiusOf(currentBody)),
59+
input.position.xyz.mul(radiusOf(d.ref(currentBody))),
6060
);
6161

6262
const camera = cameraAccess.$;

packages/typegpu/src/core/function/shelllessImpl.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@ export function createShelllessImpl(
4444
},
4545

4646
toString(): string {
47-
return `fn*:${getName(core) ?? '<unnamed>'}`;
47+
return `fn*:${getName(core) ?? '<unnamed>'}(${
48+
argTypes.map((t) => t.toString()).join(', ')
49+
})`;
4850
},
4951
};
5052
}

packages/typegpu/src/data/ptr.ts

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import { $internal } from '../shared/symbols.ts';
2+
import { Origin, OriginToPtrParams, originToPtrParams } from './snippet.ts';
23
import type { Access, AddressSpace, Ptr, StorableData } from './wgslTypes.ts';
34

45
export function ptrFn<T extends StorableData>(
@@ -58,3 +59,20 @@ export function INTERNAL_createPtr<
5859
toString: () => `ptr<${addressSpace}, ${inner}, ${access}>`,
5960
} as Ptr<TAddressSpace, TInner, TAccess>;
6061
}
62+
63+
export function createPtrFromOrigin(
64+
origin: Origin,
65+
innerDataType: StorableData,
66+
): Ptr | undefined {
67+
const ptrParams = originToPtrParams[origin as keyof OriginToPtrParams];
68+
69+
if (ptrParams) {
70+
return INTERNAL_createPtr(
71+
ptrParams.space,
72+
innerDataType,
73+
ptrParams.access,
74+
);
75+
}
76+
77+
return undefined;
78+
}

packages/typegpu/src/data/ref.ts

Lines changed: 55 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,13 @@
11
import { stitch } from '../core/resolve/stitch.ts';
2-
import { invariant, WgslTypeError } from '../errors.ts';
2+
import { invariant } from '../errors.ts';
33
import { inCodegenMode } from '../execMode.ts';
44
import { setName } from '../shared/meta.ts';
55
import { $internal, $isRef, $ownSnippet, $resolve } from '../shared/symbols.ts';
66
import type { ResolutionCtx, SelfResolvable } from '../types.ts';
77
import { UnknownData } from './dataTypes.ts';
88
import type { DualFn } from './dualFn.ts';
9-
import { INTERNAL_createPtr } from './ptr.ts';
10-
import {
11-
isEphemeralSnippet,
12-
type OriginToPtrParams,
13-
originToPtrParams,
14-
type ResolvedSnippet,
15-
snip,
16-
type Snippet,
17-
} from './snippet.ts';
9+
import { createPtrFromOrigin } from './ptr.ts';
10+
import { type ResolvedSnippet, snip, type Snippet } from './snippet.ts';
1811
import {
1912
isNaturallyEphemeral,
2013
isPtr,
@@ -29,17 +22,25 @@ import {
2922
export interface ref<T> {
3023
readonly [$internal]: unknown;
3124
readonly [$isRef]: true;
25+
26+
/**
27+
* Derefences the reference, and gives access to the underlying value.
28+
*
29+
* @example ```ts
30+
* const boid = Boid({ pos: d.vec3f(3, 2, 1) });
31+
* const posRef = d.ref(boid.pos);
32+
*
33+
* // Actually updates `boid.pos`
34+
* posRef.$ = d.vec3f(1, 2, 3);
35+
* console.log(boid.pos); // Output: vec3f(1, 2, 3)
36+
* ```
37+
*/
3238
$: T;
3339
}
3440

3541
// TODO: Restrict calls to this function only from within TypeGPU functions
3642
export const ref: DualFn<<T>(value: T) => ref<T>> = (() => {
3743
const gpuImpl = (value: Snippet) => {
38-
if (!isEphemeralSnippet(value)) {
39-
throw new WgslTypeError(
40-
`Can't create refs from references. Copy the value first by wrapping it in its schema.`,
41-
);
42-
}
4344
return snip(new RefOnGPU(value), UnknownData, /* origin */ 'runtime');
4445
};
4546

@@ -71,11 +72,11 @@ export const ref: DualFn<<T>(value: T) => ref<T>> = (() => {
7172
// --------------
7273

7374
class refImpl<T> implements ref<T> {
74-
readonly #value: T | string;
75+
#value: T;
7576
readonly [$internal]: true;
7677
readonly [$isRef]: true;
7778

78-
constructor(value: T | string) {
79+
constructor(value: T) {
7980
this.#value = value;
8081
this[$internal] = true;
8182
this[$isRef] = true;
@@ -84,20 +85,49 @@ class refImpl<T> implements ref<T> {
8485
get $(): T {
8586
return this.#value as T;
8687
}
88+
89+
set $(value: T) {
90+
if (value && typeof value === 'object') {
91+
// Setting an object means updating the properties of the original object.
92+
// e.g.: foo.$ = Boid();
93+
for (const key of Object.keys(value) as (keyof T)[]) {
94+
this.#value[key] = value[key];
95+
}
96+
} else {
97+
this.#value = value;
98+
}
99+
}
87100
}
88101

89102
export class RefOnGPU {
90-
readonly snippet: Snippet;
91103
readonly [$internal]: true;
92104

105+
readonly snippet: Snippet;
106+
/**
107+
* Pointer params only exist if the ref was created from a reference (buttery-butter).
108+
*/
109+
readonly ptrType: Ptr | undefined;
110+
93111
constructor(snippet: Snippet) {
94-
this.snippet = snippet;
95112
this[$internal] = true;
113+
this.snippet = snippet;
114+
this.ptrType = createPtrFromOrigin(
115+
snippet.origin,
116+
snippet.dataType as StorableData,
117+
);
96118
}
97119

98120
toString(): string {
99121
return `ref:${this.snippet.value}`;
100122
}
123+
124+
[$resolve](ctx: ResolutionCtx): ResolvedSnippet {
125+
invariant(
126+
!!this.ptrType,
127+
'RefOnGPU must have a pointer type when resolved',
128+
);
129+
return snip(stitch`(&${this.snippet})`, this.ptrType, this.snippet.origin);
130+
}
101131
}
102132

103133
export class RefOperator implements SelfResolvable {
@@ -109,20 +139,18 @@ export class RefOperator implements SelfResolvable {
109139
this[$internal] = true;
110140
this.snippet = snippet;
111141

112-
const ptrParams =
113-
originToPtrParams[this.snippet.origin as keyof OriginToPtrParams];
142+
const ptrType = createPtrFromOrigin(
143+
snippet.origin,
144+
snippet.dataType as StorableData,
145+
);
114146

115-
if (!ptrParams) {
147+
if (!ptrType) {
116148
throw new Error(
117149
`Cannot take a reference of a value with origin ${this.snippet.origin}`,
118150
);
119151
}
120152

121-
this.#ptrType = INTERNAL_createPtr(
122-
ptrParams.space,
123-
this.snippet.dataType as StorableData,
124-
ptrParams.access,
125-
);
153+
this.#ptrType = ptrType;
126154
}
127155

128156
get [$ownSnippet](): Snippet {

packages/typegpu/src/tgsl/shellless.ts

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@ import {
33
type ShelllessImpl,
44
} from '../core/function/shelllessImpl.ts';
55
import type { AnyData } from '../data/dataTypes.ts';
6+
import { RefOnGPU } from '../data/ref.ts';
67
import type { Snippet } from '../data/snippet.ts';
78
import { isPtr } from '../data/wgslTypes.ts';
9+
import { WgslTypeError } from '../errors.ts';
810
import { getResolutionCtx } from '../execMode.ts';
911
import { getMetaData, getName } from '../shared/meta.ts';
1012
import { concretize } from './generationHelpers.ts';
@@ -47,7 +49,22 @@ export class ShelllessRepository {
4749
);
4850
}
4951

50-
const argTypes = (argSnippets ?? []).map((s) => {
52+
const argTypes = (argSnippets ?? []).map((s, index) => {
53+
if (s.value instanceof RefOnGPU) {
54+
if (!s.value.ptrType) {
55+
throw new WgslTypeError(
56+
`d.ref() created with primitive types must be stored in a variable before use`,
57+
);
58+
}
59+
return s.value.ptrType;
60+
}
61+
62+
if (s.dataType.type === 'unknown') {
63+
throw new Error(
64+
`Passed illegal value ${s.value} as the #${index} argument to ${meta.name}(...)`,
65+
);
66+
}
67+
5168
let type = concretize(s.dataType as AnyData);
5269

5370
if (s.origin === 'constant-ref') {

packages/typegpu/src/tgsl/wgslGenerator.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -782,6 +782,11 @@ ${this.ctx.pre}else ${alternate}`;
782782

783783
if (eq.value instanceof RefOnGPU) {
784784
// We're assigning a newly created `d.ref()`
785+
if (eq.value.ptrType) {
786+
throw new WgslTypeError(
787+
`Cannot store d.ref() in a variable if it references another value. Copy the value passed into d.ref() instead.`,
788+
);
789+
}
785790
const refSnippet = eq.value.snippet;
786791
const varName = this.refVariable(
787792
rawId,

packages/typegpu/tests/constant.test.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ describe('tgpu.const', () => {
6464
[Error: Resolution of the following tree failed:
6565
- <root>
6666
- fn*:fn2
67-
- fn*:fn2: Cannot pass constant references as function arguments. Explicitly copy them by wrapping them in a schema: 'vec3f(...)']
67+
- fn*:fn2(): Cannot pass constant references as function arguments. Explicitly copy them by wrapping them in a schema: 'vec3f(...)']
6868
`);
6969
});
7070

@@ -84,7 +84,7 @@ describe('tgpu.const', () => {
8484
[Error: Resolution of the following tree failed:
8585
- <root>
8686
- fn*:fn
87-
- fn*:fn: 'boid.pos = vec3f()' is invalid, because boid.pos is a constant.]
87+
- fn*:fn(): 'boid.pos = vec3f()' is invalid, because boid.pos is a constant.]
8888
`);
8989

9090
// Since we freeze the object, we cannot mutate when running the function in JS either

0 commit comments

Comments
 (0)