Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion packages/typegpu/src/tgsl/generationHelpers.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import { $internal, $resolve } from '../../src/shared/symbols.ts';
import { type AnyData, UnknownData } from '../data/dataTypes.ts';
import { abstractFloat, abstractInt, bool, f32, i32 } from '../data/numeric.ts';
import { isRef } from '../data/ref.ts';
import { isSnippet, snip, type Snippet } from '../data/snippet.ts';
import {
isSnippet,
type ResolvedSnippet,
snip,
type Snippet,
} from '../data/snippet.ts';
import {
type AnyWgslData,
type F32,
Expand All @@ -14,8 +20,10 @@ import {
type FunctionScopeLayer,
getOwnSnippet,
type ResolutionCtx,
type SelfResolvable,
} from '../types.ts';
import type { ShelllessRepository } from './shellless.ts';
import { stitch } from '../../src/core/resolve/stitch.ts';

export function numericLiteralToSnippet(value: number): Snippet {
if (value >= 2 ** 63 || value < -(2 ** 63)) {
Expand Down Expand Up @@ -127,3 +135,27 @@ export function coerceToSnippet(value: unknown): Snippet {

return snip(value, UnknownData, /* origin */ 'constant');
}

// defers the resolution of array expressions
export class ArrayExpression implements SelfResolvable {
readonly [$internal] = true;

constructor(
public readonly elementType: AnyWgslData,
public readonly type: AnyWgslData,
public readonly elements: Snippet[],
) {
}

[$resolve](ctx: ResolutionCtx): ResolvedSnippet {
const arrayType = `array<${
ctx.resolve(this.elementType).value
}, ${this.elements.length}>`;

return snip(
stitch`${arrayType}(${this.elements})`,
this.type,
/* origin */ 'runtime',
);
}
}
17 changes: 10 additions & 7 deletions packages/typegpu/src/tgsl/wgslGenerator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import {
tryConvertSnippet,
} from './conversion.ts';
import {
ArrayExpression,
concretize,
type GenerationCtx,
numericLiteralToSnippet,
Expand Down Expand Up @@ -756,16 +757,18 @@ ${this.ctx.pre}}`;
elemType = concretize(values[0]?.dataType as wgsl.AnyWgslData);
}

const arrayType = `array<${
this.ctx.resolve(elemType).value
}, ${values.length}>`;
const arrayType = arrayOf[$internal].jsImpl(
elemType as wgsl.AnyWgslData,
values.length,
);

return snip(
stitch`${arrayType}(${values})`,
arrayOf[$internal].jsImpl(
new ArrayExpression(
elemType as wgsl.AnyWgslData,
values.length,
) as wgsl.AnyWgslData,
arrayType,
values,
),
arrayType,
/* origin */ 'runtime',
);
}
Expand Down
80 changes: 80 additions & 0 deletions packages/typegpu/tests/tgsl/wgslGenerator.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import * as std from '../../src/std/index.ts';
import wgslGenerator from '../../src/tgsl/wgslGenerator.ts';
import { CodegenState } from '../../src/types.ts';
import { it } from '../utils/extendedIt.ts';
import { ArrayExpression } from '../../src/tgsl/generationHelpers.ts';

const { NodeTypeCatalog: NODE } = tinyest;

Expand Down Expand Up @@ -575,6 +576,12 @@ describe('wgslGenerator', () => {
expect(d.isWgslArray(res.dataType)).toBe(true);
expect((res.dataType as unknown as WgslArray).elementCount).toBe(3);
expect((res.dataType as unknown as WgslArray).elementType).toBe(d.u32);

// intermediate representation
expect(res.value instanceof ArrayExpression).toBe(true);
expect((res.value as unknown as ArrayExpression).type).toBe(res.dataType);
expect((res.value as unknown as ArrayExpression).elementType)
.toBe((res.dataType as unknown as WgslArray).elementType);
});
});

Expand All @@ -594,6 +601,48 @@ describe('wgslGenerator', () => {
return arr[1i].x;
}"
`);

const astInfo = getMetaData(
testFn[$internal].implementation as (...args: unknown[]) => unknown,
);

if (!astInfo) {
throw new Error('Expected prebuilt AST to be present');
}

expect(JSON.stringify(astInfo.ast?.body)).toMatchInlineSnapshot(
`"[0,[[13,"arr",[100,[[6,[7,"d","vec2u"],[[5,"1"],[5,"2"]]],[6,[7,"d","vec2u"],[[5,"3"],[5,"4"]]],[6,[7,"std","min"],[[6,[7,"d","vec2u"],[[5,"5"],[5,"8"]]],[6,[7,"d","vec2u"],[[5,"7"],[5,"6"]]]]]]]],[10,[7,[8,"arr",[5,"1"]],"x"]]]]"`,
);

provideCtx(ctx, () => {
ctx[$internal].itemStateStack.pushFunctionScope(
'normal',
[],
{},
d.u32,
(astInfo.externals as () => Record<string, unknown>)() ?? {},
);

// Check for: const arr = [1, 2, 3]
// ^ this should be an array<u32, 3>
wgslGenerator.initGenerator(ctx);
const res = wgslGenerator.expression(
// deno-fmt-ignore: it's better that way
(
astInfo.ast?.body[1][0] as tinyest.Const
)[2] as unknown as tinyest.Expression,
);

expect(d.isWgslArray(res.dataType)).toBe(true);
expect((res.dataType as unknown as WgslArray).elementCount).toBe(3);
expect((res.dataType as unknown as WgslArray).elementType).toBe(d.vec2u);

// intermediate representation
expect(res.value instanceof ArrayExpression).toBe(true);
expect((res.value as unknown as ArrayExpression).type).toBe(res.dataType);
expect((res.value as unknown as ArrayExpression).elementType)
.toBe((res.dataType as unknown as WgslArray).elementType);
});
});

it('does not autocast lhs of an assignment', () => {
Expand Down Expand Up @@ -670,6 +719,12 @@ describe('wgslGenerator', () => {
expect(d.isWgslArray(res.dataType)).toBe(true);
expect((res.dataType as unknown as WgslArray).elementCount).toBe(2);
expect((res.dataType as unknown as WgslArray).elementType).toBe(TestStruct);

// intermediate representation
expect(res.value instanceof ArrayExpression).toBe(true);
expect((res.value as unknown as ArrayExpression).type).toBe(res.dataType);
expect((res.value as unknown as ArrayExpression).elementType)
.toBe((res.dataType as unknown as WgslArray).elementType);
});

it('generates correct code for array expressions with derived elements', () => {
Expand All @@ -696,6 +751,31 @@ describe('wgslGenerator', () => {
expect(JSON.stringify(astInfo.ast?.body)).toMatchInlineSnapshot(
`"[0,[[13,"arr",[100,[[7,"derivedV2f","$"],[6,[7,"std","mul"],[[7,"derivedV2f","$"],[6,[7,"d","vec2f"],[[5,"2"],[5,"2"]]]]]]]],[10,[7,[8,"arr",[5,"1"]],"y"]]]]"`,
);

const res = provideCtx(ctx, () => {
ctx[$internal].itemStateStack.pushFunctionScope(
'normal',
[],
{},
d.f32,
(astInfo.externals as () => Record<string, unknown>)() ?? {},
);

wgslGenerator.initGenerator(ctx);
return wgslGenerator.expression(
(astInfo.ast?.body[1][0] as tinyest.Const)[2] as tinyest.Expression,
);
});

expect(d.isWgslArray(res.dataType)).toBe(true);
expect((res.dataType as unknown as WgslArray).elementCount).toBe(2);
expect((res.dataType as unknown as WgslArray).elementType).toBe(d.vec2f);

// intermediate representation
expect(res.value instanceof ArrayExpression).toBe(true);
expect((res.value as unknown as ArrayExpression).type).toBe(res.dataType);
expect((res.value as unknown as ArrayExpression).elementType)
.toBe((res.dataType as unknown as WgslArray).elementType);
});

it('allows for member access on values returned from function calls', () => {
Expand Down