Skip to content

Commit 2ed5955

Browse files
committed
Update Gravity code
1 parent 66b89cf commit 2ed5955

File tree

6 files changed

+97
-137
lines changed

6 files changed

+97
-137
lines changed

apps/typegpu-docs/src/examples/simulation/gravity/compute.ts

Lines changed: 39 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,17 @@ const { none, bounce, merge } = collisionBehaviors;
99

1010
// tiebreaker function for merges and bounces
1111
const isSmaller = tgpu.fn([d.u32, d.u32], d.bool)((currentId, otherId) => {
12-
if (
13-
computeLayout.$.inState[currentId].mass <
14-
computeLayout.$.inState[otherId].mass
15-
) {
12+
const current = computeLayout.$.inState[currentId];
13+
const other = computeLayout.$.inState[otherId];
14+
15+
if (current.mass < other.mass) {
1616
return true;
1717
}
18-
if (
19-
computeLayout.$.inState[currentId].mass ===
20-
computeLayout.$.inState[otherId].mass
21-
) {
18+
19+
if (current.mass === other.mass) {
2220
return currentId < otherId;
2321
}
22+
2423
return false;
2524
});
2625

@@ -29,35 +28,18 @@ export const computeCollisionsShader = tgpu['~unstable'].computeFn({
2928
workgroupSize: [1],
3029
})((input) => {
3130
const currentId = input.gid.x;
32-
// TODO: replace it with struct copy when Chromium is fixed
33-
const current = CelestialBody({
34-
position: computeLayout.$.inState[currentId].position,
35-
velocity: computeLayout.$.inState[currentId].velocity,
36-
mass: computeLayout.$.inState[currentId].mass,
37-
collisionBehavior: computeLayout.$.inState[currentId].collisionBehavior,
38-
textureIndex: computeLayout.$.inState[currentId].textureIndex,
39-
radiusMultiplier: computeLayout.$.inState[currentId].radiusMultiplier,
40-
ambientLightFactor: computeLayout.$.inState[currentId].ambientLightFactor,
41-
destroyed: computeLayout.$.inState[currentId].destroyed,
42-
});
31+
const current = CelestialBody(computeLayout.$.inState[currentId]);
4332

4433
if (current.destroyed === 0) {
45-
for (let i = 0; i < computeLayout.$.celestialBodiesCount; i++) {
46-
const otherId = d.u32(i);
47-
// TODO: replace it with struct copy when Chromium is fixed
48-
const other = CelestialBody({
49-
position: computeLayout.$.inState[otherId].position,
50-
velocity: computeLayout.$.inState[otherId].velocity,
51-
mass: computeLayout.$.inState[otherId].mass,
52-
collisionBehavior: computeLayout.$.inState[otherId].collisionBehavior,
53-
textureIndex: computeLayout.$.inState[otherId].textureIndex,
54-
radiusMultiplier: computeLayout.$.inState[otherId].radiusMultiplier,
55-
ambientLightFactor: computeLayout.$.inState[otherId].ambientLightFactor,
56-
destroyed: computeLayout.$.inState[otherId].destroyed,
57-
});
34+
for (
35+
let otherId = d.u32(0);
36+
otherId < d.u32(computeLayout.$.celestialBodiesCount);
37+
otherId++
38+
) {
39+
const other = computeLayout.$.inState[otherId];
5840
// no collision occurs...
5941
if (
60-
d.u32(i) === input.gid.x || // ...with itself
42+
otherId === currentId || // ...with itself
6143
other.destroyed === 1 || // ...when other is destroyed
6244
current.collisionBehavior === none || // ...when current behavior is none
6345
other.collisionBehavior === none || // ...when other behavior is none
@@ -75,30 +57,20 @@ export const computeCollisionsShader = tgpu['~unstable'].computeFn({
7557
// bounce occurs
7658
// push the smaller object outside
7759
if (isSmaller(currentId, otherId)) {
78-
current.position = std.add(
79-
other.position,
80-
std.mul(
81-
radiusOf(current) + radiusOf(other),
82-
std.normalize(std.sub(current.position, other.position)),
83-
),
84-
);
60+
const dir = std.normalize(current.position.sub(other.position));
61+
current.position = other.position
62+
.add(dir.mul(radiusOf(current) + radiusOf(other)));
8563
}
64+
8665
// bounce with tiny damping
87-
current.velocity = std.mul(
88-
0.99,
89-
std.sub(
90-
current.velocity,
91-
std.mul(
92-
(((2 * other.mass) / (current.mass + other.mass)) *
93-
std.dot(
94-
std.sub(current.velocity, other.velocity),
95-
std.sub(current.position, other.position),
96-
)) /
97-
std.pow(std.distance(current.position, other.position), 2),
98-
std.sub(current.position, other.position),
99-
),
100-
),
101-
);
66+
const posDiff = current.position.sub(other.position);
67+
const velDiff = current.velocity.sub(other.velocity);
68+
const posDiffFactor =
69+
(((2 * other.mass) / (current.mass + other.mass)) *
70+
std.dot(velDiff, posDiff)) / std.dot(posDiff, posDiff);
71+
72+
current.velocity = current.velocity
73+
.sub(posDiff.mul(posDiffFactor)).mul(0.99);
10274
} else {
10375
// merge occurs
10476
const isCurrentAbsorbed = current.collisionBehavior === bounce ||
@@ -112,16 +84,16 @@ export const computeCollisionsShader = tgpu['~unstable'].computeFn({
11284
const m1 = current.mass;
11385
const m2 = other.mass;
11486
current.velocity = std.add(
115-
std.mul(m1 / (m1 + m2), current.velocity),
116-
std.mul(m2 / (m1 + m2), other.velocity),
87+
current.velocity.mul(m1 / (m1 + m2)),
88+
other.velocity.mul(m2 / (m1 + m2)),
11789
);
11890
current.mass = m1 + m2;
11991
}
12092
}
12193
}
12294
}
12395

124-
computeLayout.$.outState[input.gid.x] = CelestialBody(current);
96+
computeLayout.$.outState[currentId] = CelestialBody(current);
12597
});
12698

12799
export const computeGravityShader = tgpu['~unstable'].computeFn({
@@ -130,23 +102,17 @@ export const computeGravityShader = tgpu['~unstable'].computeFn({
130102
})((input) => {
131103
const dt = timeAccess.$.passed * timeAccess.$.multiplier;
132104
const currentId = input.gid.x;
133-
// TODO: replace it with struct copy when Chromium is fixed
134-
const current = CelestialBody({
135-
position: computeLayout.$.inState[currentId].position,
136-
velocity: computeLayout.$.inState[currentId].velocity,
137-
mass: computeLayout.$.inState[currentId].mass,
138-
collisionBehavior: computeLayout.$.inState[currentId].collisionBehavior,
139-
textureIndex: computeLayout.$.inState[currentId].textureIndex,
140-
radiusMultiplier: computeLayout.$.inState[currentId].radiusMultiplier,
141-
ambientLightFactor: computeLayout.$.inState[currentId].ambientLightFactor,
142-
destroyed: computeLayout.$.inState[currentId].destroyed,
143-
});
105+
const current = CelestialBody(computeLayout.$.inState[currentId]);
144106

145107
if (current.destroyed === 0) {
146-
for (let i = 0; i < computeLayout.$.celestialBodiesCount; i++) {
147-
const other = computeLayout.$.inState[i];
108+
for (
109+
let otherId = d.u32(0);
110+
otherId < d.u32(computeLayout.$.celestialBodiesCount);
111+
otherId++
112+
) {
113+
const other = computeLayout.$.inState[otherId];
148114

149-
if (d.u32(i) === input.gid.x || other.destroyed === 1) {
115+
if (otherId === currentId || other.destroyed === 1) {
150116
continue;
151117
}
152118

@@ -165,5 +131,5 @@ export const computeGravityShader = tgpu['~unstable'].computeFn({
165131
current.position = current.position.add(current.velocity.mul(dt));
166132
}
167133

168-
computeLayout.$.outState[input.gid.x] = CelestialBody(current);
134+
computeLayout.$.outState[currentId] = CelestialBody(current);
169135
});

apps/typegpu-docs/src/examples/simulation/gravity/helpers.ts

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import { load } from '@loaders.gl/core';
22
import { OBJLoader } from '@loaders.gl/obj';
3-
import { tgpu, type TgpuRoot } from 'typegpu';
3+
import type { TgpuRoot } from 'typegpu';
44
import * as d from 'typegpu/data';
5-
import * as std from 'typegpu/std';
65
import { sphereTextureNames } from './enums.ts';
7-
import { CelestialBody, renderVertexLayout, SkyBoxVertex } from './schemas.ts';
6+
import {
7+
type CelestialBody,
8+
renderVertexLayout,
9+
SkyBoxVertex,
10+
} from './schemas.ts';
811

912
function vert(
1013
position: [number, number, number],
@@ -163,6 +166,7 @@ export async function loadSphereTextures(root: TgpuRoot) {
163166
return texture;
164167
}
165168

166-
export const radiusOf = tgpu.fn([CelestialBody], d.f32)((body) =>
167-
std.pow((body.mass * 0.75) / Math.PI, 0.333) * body.radiusMultiplier
168-
);
169+
export const radiusOf = (body: CelestialBody): number => {
170+
'use gpu';
171+
return (((body.mass * 0.75) / Math.PI) ** 0.333) * body.radiusMultiplier;
172+
};

apps/typegpu-docs/src/examples/simulation/gravity/render.ts

Lines changed: 11 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import * as std from 'typegpu/std';
44
import { radiusOf } from './helpers.ts';
55
import {
66
cameraAccess,
7-
CelestialBody,
87
filteringSamplerSlot,
98
lightSourceAccess,
109
renderBindGroupLayout as renderLayout,
@@ -54,32 +53,17 @@ export const mainVertex = tgpu['~unstable'].vertexFn({
5453
},
5554
out: VertexOutput,
5655
})((input) => {
57-
// TODO: replace it with struct copy when Chromium is fixed
58-
const currentBody = CelestialBody({
59-
position: renderLayout.$.celestialBodies[input.instanceIndex].position,
60-
velocity: renderLayout.$.celestialBodies[input.instanceIndex].velocity,
61-
mass: renderLayout.$.celestialBodies[input.instanceIndex].mass,
62-
collisionBehavior:
63-
renderLayout.$.celestialBodies[input.instanceIndex].collisionBehavior,
64-
textureIndex:
65-
renderLayout.$.celestialBodies[input.instanceIndex].textureIndex,
66-
radiusMultiplier:
67-
renderLayout.$.celestialBodies[input.instanceIndex].radiusMultiplier,
68-
ambientLightFactor:
69-
renderLayout.$.celestialBodies[input.instanceIndex].ambientLightFactor,
70-
destroyed: renderLayout.$.celestialBodies[input.instanceIndex].destroyed,
71-
});
56+
const currentBody = renderLayout.$.celestialBodies[input.instanceIndex];
7257

73-
const worldPosition = std.add(
74-
std.mul(radiusOf(currentBody), input.position.xyz),
75-
currentBody.position,
58+
const worldPosition = currentBody.position.add(
59+
input.position.xyz.mul(radiusOf(currentBody)),
7660
);
7761

7862
const camera = cameraAccess.$;
79-
const positionOnCanvas = std.mul(
80-
camera.projection,
81-
std.mul(camera.view, d.vec4f(worldPosition, 1)),
82-
);
63+
const positionOnCanvas = camera.projection
64+
.mul(camera.view)
65+
.mul(d.vec4f(worldPosition, 1));
66+
8367
return {
8468
position: positionOnCanvas,
8569
uv: input.uv,
@@ -107,22 +91,16 @@ export const mainFragment = tgpu['~unstable'].fragmentFn({
10791
input.sphereTextureIndex,
10892
).xyz;
10993

110-
const ambient = std.mul(
111-
input.ambientLightFactor,
112-
std.mul(textureColor, lightColor),
113-
);
94+
const ambient = textureColor.mul(lightColor).mul(input.ambientLightFactor);
11495

11596
const normal = input.normals;
11697
const lightDirection = std.normalize(
117-
std.sub(lightSourceAccess.$, input.worldPosition),
98+
lightSourceAccess.$.sub(input.worldPosition),
11899
);
119100
const cosTheta = std.dot(normal, lightDirection);
120-
const diffuse = std.mul(
121-
std.max(0, cosTheta),
122-
std.mul(textureColor, lightColor),
123-
);
101+
const diffuse = textureColor.mul(lightColor).mul(std.max(0, cosTheta));
124102

125-
const litColor = std.add(ambient, diffuse);
103+
const litColor = ambient.add(diffuse);
126104

127105
return d.vec4f(litColor.xyz, 1);
128106
});

apps/typegpu-docs/src/examples/simulation/gravity/schemas.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ import tgpu, { type TgpuSampler } from 'typegpu';
22
import * as d from 'typegpu/data';
33
import { Camera } from './setup-orbit-camera.ts';
44

5+
export type CelestialBody = d.Infer<typeof CelestialBody>;
56
export const CelestialBody = d.struct({
67
destroyed: d.u32, // boolean
78
position: d.vec3f,

packages/typegpu/src/data/ref.ts

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { stitch } from '../core/resolve/stitch.ts';
2-
import { invariant } from '../errors.ts';
2+
import { invariant, WgslTypeError } from '../errors.ts';
33
import { inCodegenMode } from '../execMode.ts';
44
import { setName } from '../shared/meta.ts';
55
import { $internal, $isRef, $ownSnippet, $resolve } from '../shared/symbols.ts';
@@ -8,6 +8,7 @@ import { UnknownData } from './dataTypes.ts';
88
import type { DualFn } from './dualFn.ts';
99
import { INTERNAL_createPtr } from './ptr.ts';
1010
import {
11+
isEphemeralSnippet,
1112
type OriginToPtrParams,
1213
originToPtrParams,
1314
type ResolvedSnippet,
@@ -34,6 +35,11 @@ export interface ref<T> {
3435
// TODO: Restrict calls to this function only from within TypeGPU functions
3536
export const ref: DualFn<<T>(value: T) => ref<T>> = (() => {
3637
const gpuImpl = (value: Snippet) => {
38+
if (!isEphemeralSnippet(value)) {
39+
throw new WgslTypeError(
40+
`Can't create refs from references. Copy the value first by wrapping it in its schema.`,
41+
);
42+
}
3743
return snip(new RefOnGPU(value), UnknownData, /* origin */ 'runtime');
3844
};
3945

0 commit comments

Comments
 (0)