Skip to content

Commit 74d1291

Browse files
committed
Fix for referencing implicit pointers
1 parent 6003bc3 commit 74d1291

File tree

10 files changed

+115
-80
lines changed

10 files changed

+115
-80
lines changed

apps/typegpu-docs/src/examples/simulation/gravity/compute.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ export const computeCollisionsShader = tgpu['~unstable'].computeFn({
4343
current.collisionBehavior === none || // ...when current behavior is none
4444
other.collisionBehavior === none || // ...when other behavior is none
4545
std.distance(current.position, other.position) >=
46-
radiusOf(d.ref(current)) + radiusOf(d.ref(other)) // ...when other is too far away
46+
radiusOf(current) + radiusOf(other) // ...when other is too far away
4747
) {
4848
// no collision occurs...
4949
continue;
@@ -59,7 +59,7 @@ export const computeCollisionsShader = tgpu['~unstable'].computeFn({
5959
if (isSmaller(currentId, otherId)) {
6060
const dir = std.normalize(current.position.sub(other.position));
6161
current.position = other.position.add(
62-
dir.mul(radiusOf(d.ref(current)) + radiusOf(d.ref(other))),
62+
dir.mul(radiusOf(current) + radiusOf(other)),
6363
);
6464
}
6565

@@ -118,7 +118,7 @@ export const computeGravityShader = tgpu['~unstable'].computeFn({
118118
}
119119

120120
const dist = std.max(
121-
radiusOf(d.ref(current)) + radiusOf(d.ref(other)),
121+
radiusOf(current) + radiusOf(other),
122122
std.distance(current.position, other.position),
123123
);
124124
const gravityForce = (current.mass * other.mass) / dist / dist;

apps/typegpu-docs/src/examples/simulation/gravity/helpers.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ export async function loadSphereTextures(root: TgpuRoot) {
166166
return texture;
167167
}
168168

169-
export const radiusOf = (body: d.ref<CelestialBody>): number => {
169+
export const radiusOf = (body: CelestialBody): number => {
170170
'use gpu';
171-
return (((body.$.mass * 0.75) / Math.PI) ** 0.333) * body.$.radiusMultiplier;
171+
return (((body.mass * 0.75) / Math.PI) ** 0.333) * body.radiusMultiplier;
172172
};

apps/typegpu-docs/src/examples/simulation/gravity/render.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ export const mainVertex = tgpu['~unstable'].vertexFn({
5656
const currentBody = renderLayout.$.celestialBodies[input.instanceIndex];
5757

5858
const worldPosition = currentBody.position.add(
59-
input.position.xyz.mul(radiusOf(d.ref(currentBody))),
59+
input.position.xyz.mul(radiusOf(currentBody)),
6060
);
6161

6262
const camera = cameraAccess.$;

apps/typegpu-docs/src/examples/tests/tgsl-parsing-test/pointers.ts

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,6 @@ const modifyStructFn = tgpu.fn([d.ptrFn(SimpleStruct)])((ptr) => {
1616
ptr.$.vec.x += 1;
1717
});
1818

19-
const privateNum = tgpu.privateVar(d.u32);
20-
const modifyNumPrivate = tgpu.fn([d.ptrPrivate(d.u32)])((ptr) => {
21-
ptr.$ += 1;
22-
});
23-
2419
const privateVec = tgpu.privateVar(d.vec2f);
2520
const modifyVecPrivate = tgpu.fn([d.ptrPrivate(d.vec2f)])((ptr) => {
2621
ptr.$.x += 1;
@@ -49,9 +44,6 @@ export const pointersTest = tgpu.fn([], d.bool)(() => {
4944
s = s && std.allEq(myStruct.$.vec, d.vec2f(1, 0));
5045

5146
// private pointers
52-
modifyNumPrivate(d.ref(privateNum.$));
53-
s = s && (privateNum.$ === 1);
54-
5547
modifyVecPrivate(d.ref(privateVec.$));
5648
s = s && std.allEq(privateVec.$, d.vec2f(1, 0));
5749

packages/typegpu/src/data/ptr.ts

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,21 @@ export function createPtrFromOrigin(
8080

8181
return undefined;
8282
}
83+
84+
export function implicitFrom(ptr: Ptr): Ptr {
85+
return INTERNAL_createPtr(
86+
ptr.addressSpace,
87+
ptr.inner,
88+
ptr.access,
89+
/* implicit */ true,
90+
);
91+
}
92+
93+
export function explicitFrom(ptr: Ptr): Ptr {
94+
return INTERNAL_createPtr(
95+
ptr.addressSpace,
96+
ptr.inner,
97+
ptr.access,
98+
/* implicit */ false,
99+
);
100+
}

packages/typegpu/src/data/ref.ts

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import { $internal, $ownSnippet, $resolve } from '../shared/symbols.ts';
66
import type { ResolutionCtx, SelfResolvable } from '../types.ts';
77
import { UnknownData } from './dataTypes.ts';
88
import type { DualFn } from './dualFn.ts';
9-
import { createPtrFromOrigin } from './ptr.ts';
9+
import { createPtrFromOrigin, explicitFrom } from './ptr.ts';
1010
import { type ResolvedSnippet, snip, type Snippet } from './snippet.ts';
1111
import {
1212
isNaturallyEphemeral,
@@ -52,6 +52,12 @@ export const ref: DualFn<<T>(value: T) => ref<T>> = (() => {
5252
);
5353
}
5454

55+
if (value.dataType.type === 'ptr') {
56+
// This can happen if we take a reference of an *implicit* pointer, one
57+
// made by assigning a reference to a `const`.
58+
return snip(value.value, explicitFrom(value.dataType), value.origin);
59+
}
60+
5561
/**
5662
* Pointer type only exists if the ref was created from a reference (buttery-butter).
5763
*

packages/typegpu/src/tgsl/wgslGenerator.ts

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ import {
4141
} from './generationHelpers.ts';
4242
import type { ShaderGenerator } from './shaderGenerator.ts';
4343
import type { DualFn } from '../data/dualFn.ts';
44-
import { INTERNAL_createPtr, ptrFn } from '../data/ptr.ts';
44+
import { createPtrFromOrigin, implicitFrom, ptrFn } from '../data/ptr.ts';
4545
import { RefOperator } from '../data/ref.ts';
4646
import { constant } from '../core/constant/tgpuConstant.ts';
4747

@@ -869,18 +869,21 @@ ${this.ctx.pre}else ${alternate}`;
869869
} else {
870870
varType = 'let';
871871
if (!wgsl.isPtr(dataType)) {
872-
dataType = ptrFn(concretize(dataType) as wgsl.StorableData);
872+
const ptrType = createPtrFromOrigin(
873+
eq.origin,
874+
concretize(dataType) as wgsl.StorableData,
875+
);
876+
invariant(
877+
ptrType !== undefined,
878+
`Creating pointer type from origin ${eq.origin}`,
879+
);
880+
dataType = ptrType;
873881
}
874882

875883
if (!(eq.value instanceof RefOperator)) {
876884
// If what we're assigning is something preceded by `&`, then it's a value
877885
// created using `d.ref()`. Otherwise, it's an implicit pointer
878-
dataType = INTERNAL_createPtr(
879-
dataType.addressSpace,
880-
dataType.inner,
881-
dataType.access,
882-
/* implicit */ true,
883-
);
886+
dataType = implicitFrom(dataType);
884887
}
885888
}
886889
} else {
@@ -932,13 +935,11 @@ ${this.ctx.pre}else ${alternate}`;
932935
const [_, init, condition, update, body] = statement;
933936

934937
const [initStatement, conditionExpr, updateStatement] = this.ctx
935-
.withResetIndentLevel(
936-
() => [
937-
init ? this.statement(init) : undefined,
938-
condition ? this.typedExpression(condition, bool) : undefined,
939-
update ? this.statement(update) : undefined,
940-
],
941-
);
938+
.withResetIndentLevel(() => [
939+
init ? this.statement(init) : undefined,
940+
condition ? this.typedExpression(condition, bool) : undefined,
941+
update ? this.statement(update) : undefined,
942+
]);
942943

943944
const initStr = initStatement ? initStatement.slice(0, -1) : '';
944945
const updateStr = updateStatement ? updateStatement.slice(0, -1) : '';

packages/typegpu/tests/examples/individual/gravity.test.ts

Lines changed: 21 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,11 @@ describe('gravity example', () => {
4242
4343
@group(0) @binding(0) var<uniform> celestialBodiesCount_3: i32;
4444
45-
fn radiusOf_4(body: ptr<function, CelestialBody_2>) -> f32 {
46-
return (pow((((*body).mass * 0.75f) / 3.141592653589793f), 0.333f) * (*body).radiusMultiplier);
45+
fn radiusOf_4(body: CelestialBody_2) -> f32 {
46+
return (pow(((body.mass * 0.75f) / 3.141592653589793f), 0.333f) * body.radiusMultiplier);
4747
}
4848
49-
fn radiusOf_5(body: ptr<storage, ptr<function, CelestialBody_2>, read>) -> f32 {
50-
return (pow((((*(*body)).mass * 0.75f) / 3.141592653589793f), 0.333f) * (*(*body)).radiusMultiplier);
51-
}
52-
53-
fn isSmaller_6(currentId: u32, otherId: u32) -> bool {
49+
fn isSmaller_5(currentId: u32, otherId: u32) -> bool {
5450
let current = (&inState_1[currentId]);
5551
let other = (&inState_1[otherId]);
5652
if (((*current).mass < (*other).mass)) {
@@ -62,33 +58,33 @@ describe('gravity example', () => {
6258
return false;
6359
}
6460
65-
@group(0) @binding(2) var<storage, read_write> outState_7: array<CelestialBody_2>;
61+
@group(0) @binding(2) var<storage, read_write> outState_6: array<CelestialBody_2>;
6662
67-
struct computeCollisionsShader_Input_8 {
63+
struct computeCollisionsShader_Input_7 {
6864
@builtin(global_invocation_id) gid: vec3u,
6965
}
7066
71-
@compute @workgroup_size(1) fn computeCollisionsShader_0(input: computeCollisionsShader_Input_8) {
67+
@compute @workgroup_size(1) fn computeCollisionsShader_0(input: computeCollisionsShader_Input_7) {
7268
let currentId = input.gid.x;
7369
var current = inState_1[currentId];
7470
if ((current.destroyed == 0u)) {
7571
for (var otherId = 0u; (otherId < u32(celestialBodiesCount_3)); otherId++) {
7672
let other = (&inState_1[otherId]);
77-
if ((((((otherId == currentId) || ((*other).destroyed == 1u)) || (current.collisionBehavior == 0u)) || ((*other).collisionBehavior == 0u)) || (distance(current.position, (*other).position) >= (radiusOf_4((&current)) + radiusOf_5((&other)))))) {
73+
if ((((((otherId == currentId) || ((*other).destroyed == 1u)) || (current.collisionBehavior == 0u)) || ((*other).collisionBehavior == 0u)) || (distance(current.position, (*other).position) >= (radiusOf_4(current) + radiusOf_4((*other)))))) {
7874
continue;
7975
}
8076
if (((current.collisionBehavior == 1u) && ((*other).collisionBehavior == 1u))) {
81-
if (isSmaller_6(currentId, otherId)) {
77+
if (isSmaller_5(currentId, otherId)) {
8278
var dir = normalize((current.position - (*other).position));
83-
current.position = ((*other).position + (dir * (radiusOf_4((&current)) + radiusOf_5((&other)))));
79+
current.position = ((*other).position + (dir * (radiusOf_4(current) + radiusOf_4((*other)))));
8480
}
8581
var posDiff = (current.position - (*other).position);
8682
var velDiff = (current.velocity - (*other).velocity);
8783
let posDiffFactor = ((((2f * (*other).mass) / (current.mass + (*other).mass)) * dot(velDiff, posDiff)) / dot(posDiff, posDiff));
8884
current.velocity = ((current.velocity - (posDiff * posDiffFactor)) * 0.99);
8985
}
9086
else {
91-
let isCurrentAbsorbed = ((current.collisionBehavior == 1u) || ((current.collisionBehavior == 2u) && isSmaller_6(currentId, otherId)));
87+
let isCurrentAbsorbed = ((current.collisionBehavior == 1u) || ((current.collisionBehavior == 2u) && isSmaller_5(currentId, otherId)));
9288
if (isCurrentAbsorbed) {
9389
current.destroyed = 1u;
9490
}
@@ -101,7 +97,7 @@ describe('gravity example', () => {
10197
}
10298
}
10399
}
104-
outState_7[currentId] = current;
100+
outState_6[currentId] = current;
105101
}
106102
107103
struct Time_2 {
@@ -126,21 +122,17 @@ describe('gravity example', () => {
126122
127123
@group(1) @binding(0) var<uniform> celestialBodiesCount_5: i32;
128124
129-
fn radiusOf_6(body: ptr<function, CelestialBody_4>) -> f32 {
130-
return (pow((((*body).mass * 0.75f) / 3.141592653589793f), 0.333f) * (*body).radiusMultiplier);
125+
fn radiusOf_6(body: CelestialBody_4) -> f32 {
126+
return (pow(((body.mass * 0.75f) / 3.141592653589793f), 0.333f) * body.radiusMultiplier);
131127
}
132128
133-
fn radiusOf_7(body: ptr<storage, ptr<function, CelestialBody_4>, read>) -> f32 {
134-
return (pow((((*(*body)).mass * 0.75f) / 3.141592653589793f), 0.333f) * (*(*body)).radiusMultiplier);
135-
}
136-
137-
@group(1) @binding(2) var<storage, read_write> outState_8: array<CelestialBody_4>;
129+
@group(1) @binding(2) var<storage, read_write> outState_7: array<CelestialBody_4>;
138130
139-
struct computeGravityShader_Input_9 {
131+
struct computeGravityShader_Input_8 {
140132
@builtin(global_invocation_id) gid: vec3u,
141133
}
142134
143-
@compute @workgroup_size(1) fn computeGravityShader_0(input: computeGravityShader_Input_9) {
135+
@compute @workgroup_size(1) fn computeGravityShader_0(input: computeGravityShader_Input_8) {
144136
let dt = (time_1.passed * time_1.multiplier);
145137
let currentId = input.gid.x;
146138
var current = inState_3[currentId];
@@ -150,14 +142,14 @@ describe('gravity example', () => {
150142
if (((otherId == currentId) || ((*other).destroyed == 1u))) {
151143
continue;
152144
}
153-
let dist = max((radiusOf_6((&current)) + radiusOf_7((&other))), distance(current.position, (*other).position));
145+
let dist = max((radiusOf_6(current) + radiusOf_6((*other))), distance(current.position, (*other).position));
154146
let gravityForce = (((current.mass * (*other).mass) / dist) / dist);
155147
var direction = normalize(((*other).position - current.position));
156148
current.velocity = (current.velocity + (direction * ((gravityForce / current.mass) * dt)));
157149
}
158150
current.position = (current.position + (current.velocity * dt));
159151
}
160-
outState_8[currentId] = current;
152+
outState_7[currentId] = current;
161153
}
162154
163155
struct Camera_2 {
@@ -209,8 +201,8 @@ describe('gravity example', () => {
209201
210202
@group(1) @binding(1) var<storage, read> celestialBodies_1: array<CelestialBody_2>;
211203
212-
fn radiusOf_3(body: ptr<storage, ptr<function, CelestialBody_2>, read>) -> f32 {
213-
return (pow((((*(*body)).mass * 0.75f) / 3.141592653589793f), 0.333f) * (*(*body)).radiusMultiplier);
204+
fn radiusOf_3(body: CelestialBody_2) -> f32 {
205+
return (pow(((body.mass * 0.75f) / 3.141592653589793f), 0.333f) * body.radiusMultiplier);
214206
}
215207
216208
struct Camera_5 {
@@ -241,7 +233,7 @@ describe('gravity example', () => {
241233
242234
@vertex fn mainVertex_0(input: mainVertex_Input_7) -> mainVertex_Output_6 {
243235
let currentBody = (&celestialBodies_1[input.instanceIndex]);
244-
var worldPosition = ((*currentBody).position + (input.position.xyz * radiusOf_3((&currentBody))));
236+
var worldPosition = ((*currentBody).position + (input.position.xyz * radiusOf_3((*currentBody))));
245237
let camera = (&camera_4);
246238
var positionOnCanvas = (((*camera).projection * (*camera).view) * vec4f(worldPosition, 1f));
247239
return mainVertex_Output_6(positionOnCanvas, input.uv, input.normal, worldPosition, (*currentBody).textureIndex, (*currentBody).destroyed, (*currentBody).ambientLightFactor);

packages/typegpu/tests/examples/individual/tgsl-parsing-test.test.ts

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -178,24 +178,18 @@ describe('tgsl parsing test example', () => {
178178
(*ptr).vec.x += 1f;
179179
}
180180
181-
var<private> privateNum_16: u32;
182-
183-
fn modifyNumPrivate_17(ptr: ptr<private, u32>) {
184-
(*ptr) += 1u;
185-
}
186-
187-
var<private> privateVec_18: vec2f;
188-
189-
fn modifyVecPrivate_19(ptr: ptr<private, vec2f>) {
181+
fn modifyVecPrivate_16(ptr: ptr<private, vec2f>) {
190182
(*ptr).x += 1f;
191183
}
192184
193-
var<private> privateStruct_20: SimpleStruct_14;
185+
var<private> privateVec_17: vec2f;
194186
195-
fn modifyStructPrivate_21(ptr: ptr<private, SimpleStruct_14>) {
187+
fn modifyStructPrivate_18(ptr: ptr<private, SimpleStruct_14>) {
196188
(*ptr).vec.x += 1f;
197189
}
198190
191+
var<private> privateStruct_19: SimpleStruct_14;
192+
199193
fn pointersTest_11() -> bool {
200194
var s = true;
201195
var num = 0u;
@@ -207,16 +201,14 @@ describe('tgsl parsing test example', () => {
207201
var myStruct = SimpleStruct_14();
208202
modifyStructFn_15((&myStruct));
209203
s = (s && all(myStruct.vec == vec2f(1, 0)));
210-
modifyNumPrivate_17(privateNum_16);
211-
s = (s && (privateNum_16 == 1u));
212-
modifyVecPrivate_19(privateVec_18);
213-
s = (s && all(privateVec_18 == vec2f(1, 0)));
214-
modifyStructPrivate_21(privateStruct_20);
215-
s = (s && all(privateStruct_20.vec == vec2f(1, 0)));
204+
modifyVecPrivate_16((&privateVec_17));
205+
s = (s && all(privateVec_17 == vec2f(1, 0)));
206+
modifyStructPrivate_18((&privateStruct_19));
207+
s = (s && all(privateStruct_19.vec == vec2f(1, 0)));
216208
return s;
217209
}
218210
219-
@group(0) @binding(0) var<storage, read_write> result_22: i32;
211+
@group(0) @binding(0) var<storage, read_write> result_20: i32;
220212
221213
@compute @workgroup_size(1) fn computeRunTests_0() {
222214
var s = true;
@@ -226,10 +218,10 @@ describe('tgsl parsing test example', () => {
226218
s = (s && arrayAndStructConstructorsTest_8());
227219
s = (s && pointersTest_11());
228220
if (s) {
229-
result_22 = 1i;
221+
result_20 = 1i;
230222
}
231223
else {
232-
result_22 = 0i;
224+
result_20 = 0i;
233225
}
234226
}"
235227
`);

0 commit comments

Comments
 (0)