Skip to content

Commit 37999ac

Browse files
committed
feat: Inspect vector type in shader function
1 parent e0b5cc7 commit 37999ac

File tree

7 files changed

+140
-85
lines changed

7 files changed

+140
-85
lines changed

packages/tinyest-for-wgsl/src/parsers.ts

Lines changed: 18 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -24,73 +24,6 @@ function isDeclared(ctx: Context, name: string) {
2424
return ctx.stack.some((scope) => scope.declaredNames.includes(name));
2525
}
2626

27-
const BINARY_OP_MAP = {
28-
'==': '==',
29-
'!=': '!=',
30-
'===': '==',
31-
'!==': '!=',
32-
'<': '<',
33-
'<=': '<=',
34-
'>': '>',
35-
'>=': '>=',
36-
'<<': '<<',
37-
'>>': '>>',
38-
get '>>>'(): never {
39-
throw new Error('The `>>>` operator is unsupported in TGSL.');
40-
},
41-
'+': '+',
42-
'-': '-',
43-
'*': '*',
44-
'/': '/',
45-
'%': '%',
46-
'|': '|',
47-
'^': '^',
48-
'&': '&',
49-
get in(): never {
50-
throw new Error('The `in` operator is unsupported in TGSL.');
51-
},
52-
get instanceof(): never {
53-
throw new Error('The `instanceof` operator is unsupported in TGSL.');
54-
},
55-
'**': '**',
56-
get '|>'(): never {
57-
throw new Error('The `|>` operator is unsupported in TGSL.');
58-
},
59-
} as const;
60-
61-
const LOGICAL_OP_MAP = {
62-
'||': '||',
63-
'&&': '&&',
64-
get '??'(): never {
65-
throw new Error('The `??` operator is unsupported in TGSL.');
66-
},
67-
} as const;
68-
69-
const ASSIGNMENT_OP_MAP = {
70-
'=': '=',
71-
'+=': '+=',
72-
'-=': '-=',
73-
'*=': '*=',
74-
'/=': '/=',
75-
'%=': '%=',
76-
'<<=': '<<=',
77-
'>>=': '>>=',
78-
get '>>>='(): never {
79-
throw new Error('The `>>>=` operator is unsupported in TGSL.');
80-
},
81-
'|=': '|=',
82-
'^=': '^=',
83-
'&=': '&=',
84-
get '**='(): never {
85-
throw new Error('The `**=` operator is unsupported in TGSL.');
86-
},
87-
'||=': '||=',
88-
'&&=': '&&=',
89-
get '??='(): never {
90-
throw new Error('The `??=` operator is unsupported in TGSL.');
91-
},
92-
} as const;
93-
9427
const Transpilers: Partial<
9528
{
9629
[Type in JsNode['type']]: (
@@ -144,24 +77,36 @@ const Transpilers: Partial<
14477
},
14578

14679
BinaryExpression(ctx, node) {
147-
const wgslOp = BINARY_OP_MAP[node.operator];
14880
const left = transpile(ctx, node.left) as tinyest.Expression;
14981
const right = transpile(ctx, node.right) as tinyest.Expression;
150-
return [NODE.binaryExpr, left, wgslOp, right];
82+
return [
83+
NODE.binaryExpr,
84+
left,
85+
node.operator as tinyest.BinaryOperator,
86+
right,
87+
];
15188
},
15289

15390
LogicalExpression(ctx, node) {
154-
const wgslOp = LOGICAL_OP_MAP[node.operator];
15591
const left = transpile(ctx, node.left) as tinyest.Expression;
15692
const right = transpile(ctx, node.right) as tinyest.Expression;
157-
return [NODE.logicalExpr, left, wgslOp, right];
93+
return [
94+
NODE.logicalExpr,
95+
left,
96+
node.operator as tinyest.LogicalOperator,
97+
right,
98+
];
15899
},
159100

160101
AssignmentExpression(ctx, node) {
161-
const wgslOp = ASSIGNMENT_OP_MAP[node.operator as acorn.AssignmentOperator];
162102
const left = transpile(ctx, node.left) as tinyest.Expression;
163103
const right = transpile(ctx, node.right) as tinyest.Expression;
164-
return [NODE.assignmentExpr, left, wgslOp, right];
104+
return [
105+
NODE.assignmentExpr,
106+
left,
107+
node.operator as tinyest.AssignmentOperator,
108+
right,
109+
];
165110
},
166111

167112
UnaryExpression(ctx, node) {

packages/tinyest/src/nodes.ts

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,12 +118,15 @@ export type Statement =
118118
export type BinaryOperator =
119119
| '=='
120120
| '!='
121+
| '==='
122+
| '!=='
121123
| '<'
122124
| '<='
123125
| '>'
124126
| '>='
125127
| '<<'
126128
| '>>'
129+
| '>>>'
127130
| '+'
128131
| '-'
129132
| '*'
@@ -132,6 +135,8 @@ export type BinaryOperator =
132135
| '|'
133136
| '^'
134137
| '&'
138+
| 'in'
139+
| 'instanceof'
135140
| '**';
136141

137142
export type BinaryExpression = readonly [
@@ -164,7 +169,7 @@ export type AssignmentExpression = readonly [
164169
rhs: Expression,
165170
];
166171

167-
export type LogicalOperator = '&&' | '||';
172+
export type LogicalOperator = '&&' | '||' | '??';
168173

169174
export type LogicalExpression = readonly [
170175
type: NodeTypeCatalog['logicalExpr'],

packages/typegpu/src/tgsl/generationHelpers.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,12 @@ export function accessProp(
229229
return accessProp(derefed, propName);
230230
}
231231

232+
if (isVec(target.dataType)) {
233+
if (propName === 'kind') {
234+
return snip(target.dataType.type, UnknownData, 'constant');
235+
}
236+
}
237+
232238
const propLength = propName.length;
233239
if (
234240
isVec(target.dataType) &&

packages/typegpu/src/tgsl/wgslGenerator.ts

Lines changed: 65 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import { safeStringify } from '../shared/stringify.ts';
2626
import { $internal } from '../shared/symbols.ts';
2727
import { pow } from '../std/numeric.ts';
2828
import { add, div, mul, neg, sub } from '../std/operators.ts';
29-
import type { FnArgsConversionHint } from '../types.ts';
29+
import { type FnArgsConversionHint, isKnownAtComptime } from '../types.ts';
3030
import {
3131
convertStructValues,
3232
convertToCommonType,
@@ -50,6 +50,8 @@ const { NodeTypeCatalog: NODE } = tinyest;
5050
const parenthesizedOps = [
5151
'==',
5252
'!=',
53+
'===',
54+
'!==',
5355
'<',
5456
'<=',
5557
'>',
@@ -68,7 +70,48 @@ const parenthesizedOps = [
6870
'||',
6971
];
7072

71-
const binaryLogicalOps = ['&&', '||', '==', '!=', '<', '<=', '>', '>='];
73+
const binaryLogicalOps = [
74+
'&&',
75+
'||',
76+
'==',
77+
'!=',
78+
'===',
79+
'!==',
80+
'<',
81+
'<=',
82+
'>',
83+
'>=',
84+
];
85+
86+
const OP_MAP = {
87+
//
88+
// binary
89+
//
90+
'===': '==',
91+
'!==': '!=',
92+
get '>>>'(): never {
93+
throw new Error('The `>>>` operator is unsupported in TypeGPU functions.');
94+
},
95+
get in(): never {
96+
throw new Error('The `in` operator is unsupported in TypeGPU functions.');
97+
},
98+
get instanceof(): never {
99+
throw new Error(
100+
'The `instanceof` operator is unsupported in TypeGPU functions.',
101+
);
102+
},
103+
get '|>'(): never {
104+
throw new Error('The `|>` operator is unsupported in TypeGPU functions.');
105+
},
106+
//
107+
// logical
108+
//
109+
'||': '||',
110+
'&&': '&&',
111+
get '??'(): never {
112+
throw new Error('The `??` operator is unsupported in TypeGPU functions.');
113+
},
114+
} as Record<string, string>;
72115

73116
type Operator =
74117
| tinyest.BinaryOperator
@@ -250,8 +293,24 @@ ${this.ctx.pre}}`;
250293
);
251294
}
252295

296+
if (op === '==') {
297+
throw new Error('Please use the === operator instead of ==');
298+
}
299+
300+
if (
301+
op === '===' && isKnownAtComptime(lhsExpr) && isKnownAtComptime(rhsExpr)
302+
) {
303+
return snip(
304+
lhsExpr.value === rhsExpr.value,
305+
bool,
306+
/* ref */ 'constant',
307+
);
308+
}
309+
253310
if (lhsExpr.dataType.type === 'unknown') {
254-
throw new WgslTypeError(`Left-hand side of '${op}' is of unknown type`);
311+
throw new WgslTypeError(
312+
`Left-hand side of '${op}' is of unknown type`,
313+
);
255314
}
256315

257316
if (rhsExpr.dataType.type === 'unknown') {
@@ -315,8 +374,8 @@ ${this.ctx.pre}}`;
315374

316375
return snip(
317376
parenthesizedOps.includes(op)
318-
? `(${lhsStr} ${op} ${rhsStr})`
319-
: `${lhsStr} ${op} ${rhsStr}`,
377+
? `(${lhsStr} ${OP_MAP[op] ?? op} ${rhsStr})`
378+
: `${lhsStr} ${OP_MAP[op] ?? op} ${rhsStr}`,
320379
type,
321380
// Result of an operation, so not a reference to anything
322381
/* ref */ 'runtime',
@@ -670,7 +729,7 @@ ${this.ctx.pre}}`;
670729
}
671730

672731
if (expression[0] === NODE.stringLiteral) {
673-
return snip(expression[1], UnknownData, /* ref */ 'runtime'); // arbitrary ref
732+
return snip(expression[1], UnknownData, /* origin */ 'constant');
674733
}
675734

676735
if (expression[0] === NODE.preUpdate) {

packages/typegpu/src/types.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,8 @@ export function getOwnSnippet(value: unknown): Snippet | undefined {
318318
}
319319

320320
export function isKnownAtComptime(snippet: Snippet): boolean {
321-
return typeof snippet.value !== 'string' &&
321+
return (typeof snippet.value !== 'string' ||
322+
snippet.dataType.type === 'unknown') &&
322323
getOwnSnippet(snippet.value) === undefined;
323324
}
324325

packages/typegpu/tests/examples/individual/tgsl-parsing-test.test.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,10 @@ describe('tgsl parsing test example', () => {
3636
3737
fn logicalExpressionTests_1() -> bool {
3838
var s = true;
39-
s = (s && (true == true));
40-
s = (s && (true == true));
41-
s = (s && (true == true));
42-
s = (s && (false == false));
39+
s = (s && true);
40+
s = (s && true);
41+
s = (s && true);
42+
s = (s && true);
4343
s = (s && true);
4444
s = (s && !false);
4545
s = (s && true);

packages/typegpu/tests/vector.test.ts

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import { readData, writeData } from '../src/data/dataIO.ts';
44
import * as d from '../src/data/index.ts';
55
import { sizeOf } from '../src/data/sizeOf.ts';
66
import tgpu from '../src/index.ts';
7+
import * as std from '../src/std/index.ts';
78
import { asWgsl } from './utils/parseResolved.ts';
89

910
describe('constructors', () => {
@@ -972,3 +973,41 @@ describe('v4b', () => {
972973
});
973974
});
974975
});
976+
977+
describe('type predicates', () => {
978+
it('prunes branches', () => {
979+
const ceil = (input: d.v3f | d.v3i): d.v3i => {
980+
'use gpu';
981+
if (input.kind === 'vec3f') {
982+
return d.vec3i(std.ceil(input));
983+
} else {
984+
return input;
985+
}
986+
};
987+
988+
const main = () => {
989+
'use gpu';
990+
const foo = ceil(d.vec3f(1, 2, 3));
991+
const bar = ceil(d.vec3i(1, 2, 3));
992+
};
993+
994+
expect(asWgsl(main)).toMatchInlineSnapshot(`
995+
"fn ceil(input: vec3f) -> vec3i {
996+
{
997+
return vec3i(ceil(input));
998+
}
999+
}
1000+
1001+
fn ceil_1(input: vec3i) -> vec3i {
1002+
{
1003+
return input;
1004+
}
1005+
}
1006+
1007+
fn main() {
1008+
var foo = ceil(vec3f(1, 2, 3));
1009+
var bar = ceil_1(vec3i(1, 2, 3));
1010+
}"
1011+
`);
1012+
});
1013+
});

0 commit comments

Comments
 (0)