Skip to content

Commit 166a088

Browse files
authored
feat: More predictable dual-impl behavior (#2085)
1 parent 53f4ff8 commit 166a088

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+685
-661
lines changed

packages/typegpu/src/core/function/comptime.ts

Lines changed: 28 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,22 @@
1-
import type { DualFn } from '../../data/dualFn.ts';
2-
import type { MapValueToSnippet } from '../../data/snippet.ts';
31
import { WgslTypeError } from '../../errors.ts';
4-
import { getResolutionCtx } from '../../execMode.ts';
52
import { setName, type TgpuNamable } from '../../shared/meta.ts';
6-
import { $getNameForward, $internal } from '../../shared/symbols.ts';
3+
import {
4+
$getNameForward,
5+
$gpuCallable,
6+
$internal,
7+
} from '../../shared/symbols.ts';
78
import { coerceToSnippet } from '../../tgsl/generationHelpers.ts';
8-
import { isKnownAtComptime, NormalState } from '../../types.ts';
9+
import { type DualFn, isKnownAtComptime } from '../../types.ts';
910

10-
export type TgpuComptime<
11-
T extends (...args: never[]) => unknown = (...args: never[]) => unknown,
12-
> =
11+
type AnyFn = (...args: never[]) => unknown;
12+
13+
export type TgpuComptime<T extends AnyFn = AnyFn> =
1314
& DualFn<T>
1415
& TgpuNamable
15-
& { [$getNameForward]: unknown; [$internal]: { isComptime: true } };
16+
& {
17+
[$getNameForward]: unknown;
18+
[$internal]: { isComptime: true };
19+
};
1620

1721
export function isComptimeFn(value: unknown): value is TgpuComptime {
1822
return !!(value as TgpuComptime)?.[$internal]?.isComptime;
@@ -45,48 +49,32 @@ export function isComptimeFn(value: unknown): value is TgpuComptime {
4549
export function comptime<T extends (...args: never[]) => unknown>(
4650
func: T,
4751
): TgpuComptime<T> {
48-
const gpuImpl = (...args: MapValueToSnippet<Parameters<T>>) => {
49-
const argSnippets = args as MapValueToSnippet<Parameters<T>>;
50-
51-
if (!argSnippets.every((s) => isKnownAtComptime(s))) {
52-
throw new WgslTypeError(
53-
`Called comptime function with runtime-known values: ${
54-
argSnippets.filter((s) => !isKnownAtComptime(s)).map((s) =>
55-
`'${s.value}'`
56-
).join(', ')
57-
}`,
58-
);
59-
}
60-
61-
return coerceToSnippet(func(...argSnippets.map((s) => s.value) as never[]));
62-
};
63-
6452
const impl = ((...args: Parameters<T>) => {
65-
const ctx = getResolutionCtx();
66-
if (ctx?.mode.type === 'codegen') {
67-
ctx.pushMode(new NormalState());
68-
try {
69-
return gpuImpl(...args as MapValueToSnippet<Parameters<T>>);
70-
} finally {
71-
ctx.popMode('normal');
72-
}
73-
}
7453
return func(...args);
7554
}) as TgpuComptime<T>;
7655

7756
impl.toString = () => 'comptime';
7857
impl[$getNameForward] = func;
58+
impl[$gpuCallable] = {
59+
call(_ctx, args) {
60+
if (!args.every((s) => isKnownAtComptime(s))) {
61+
throw new WgslTypeError(
62+
`Called comptime function with runtime-known values: ${
63+
args.filter((s) => !isKnownAtComptime(s)).map((s) => `'${s.value}'`)
64+
.join(', ')
65+
}`,
66+
);
67+
}
68+
69+
return coerceToSnippet(func(...args.map((s) => s.value) as never[]));
70+
},
71+
};
7972
impl.$name = (label: string) => {
8073
setName(func, label);
8174
return impl;
8275
};
8376
Object.defineProperty(impl, $internal, {
84-
value: {
85-
isComptime: true,
86-
jsImpl: func,
87-
gpuImpl,
88-
argConversionHint: 'keep',
89-
},
77+
value: { isComptime: true },
9078
});
9179

9280
return impl as TgpuComptime<T>;
Lines changed: 81 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,25 @@
1-
import type { DualFn } from '../../data/dualFn.ts';
1+
import type { AnyData } from '../../data/dataTypes.ts';
22
import { type MapValueToSnippet, snip } from '../../data/snippet.ts';
3-
import { getResolutionCtx, inCodegenMode } from '../../execMode.ts';
4-
import { isKnownAtComptime, NormalState } from '../../types.ts';
53
import { setName } from '../../shared/meta.ts';
6-
import { $internal } from '../../shared/symbols.ts';
4+
import { $gpuCallable } from '../../shared/symbols.ts';
75
import { tryConvertSnippet } from '../../tgsl/conversion.ts';
8-
import type { AnyData } from '../../data/dataTypes.ts';
6+
import {
7+
type DualFn,
8+
isKnownAtComptime,
9+
NormalState,
10+
type ResolutionCtx,
11+
} from '../../types.ts';
912

1013
type MapValueToDataType<T> = { [K in keyof T]: AnyData };
14+
type AnyFn = (...args: never[]) => unknown;
1115

12-
interface DualImplOptions<T extends (...args: never[]) => unknown> {
16+
interface DualImplOptions<T extends AnyFn> {
1317
readonly name: string | undefined;
1418
readonly normalImpl: T | string;
15-
readonly codegenImpl: (...args: MapValueToSnippet<Parameters<T>>) => string;
19+
readonly codegenImpl: (
20+
ctx: ResolutionCtx,
21+
args: MapValueToSnippet<Parameters<T>>,
22+
) => string;
1623
readonly signature:
1724
| { argTypes: AnyData[]; returnType: AnyData }
1825
| ((
@@ -34,90 +41,83 @@ export class MissingCpuImplError extends Error {
3441
}
3542
}
3643

37-
export function dualImpl<T extends (...args: never[]) => unknown>(
44+
export function dualImpl<T extends AnyFn>(
3845
options: DualImplOptions<T>,
3946
): DualFn<T> {
40-
const gpuImpl = (...args: MapValueToSnippet<Parameters<T>>) => {
41-
// biome-ignore lint/style/noNonNullAssertion: it's there
42-
const ctx = getResolutionCtx()!;
43-
const { argTypes, returnType } = typeof options.signature === 'function'
44-
? options.signature(
45-
...args.map((s) => {
46-
// Dereference implicit pointers
47-
if (s.dataType.type === 'ptr' && s.dataType.implicit) {
48-
return s.dataType.inner;
49-
}
50-
return s.dataType;
51-
}) as MapValueToDataType<Parameters<T>>,
52-
)
53-
: options.signature;
54-
55-
const argSnippets = args as MapValueToSnippet<Parameters<T>>;
56-
const converted = argSnippets.map((s, idx) => {
57-
const argType = argTypes[idx];
58-
if (!argType) {
59-
throw new Error('Function called with invalid arguments');
60-
}
61-
return tryConvertSnippet(s, argType, !options.ignoreImplicitCastWarning);
62-
}) as MapValueToSnippet<Parameters<T>>;
63-
64-
if (
65-
!options.noComptime &&
66-
converted.every((s) => isKnownAtComptime(s)) &&
67-
typeof options.normalImpl === 'function'
68-
) {
69-
ctx.pushMode(new NormalState());
70-
try {
71-
return snip(
72-
options.normalImpl(...converted.map((s) => s.value) as never[]),
73-
returnType,
74-
// Functions give up ownership of their return value
75-
/* origin */ 'constant',
76-
);
77-
} catch (e) {
78-
// cpuImpl may in some cases be present but implemented only partially.
79-
// In that case, if the MissingCpuImplError is thrown, we fallback to codegenImpl.
80-
// If it is any other error, we just rethrow.
81-
if (!(e instanceof MissingCpuImplError)) {
82-
throw e;
83-
}
84-
} finally {
85-
ctx.popMode('normal');
86-
}
87-
}
88-
89-
return snip(
90-
options.codegenImpl(...converted),
91-
returnType,
92-
// Functions give up ownership of their return value
93-
/* origin */ 'runtime',
94-
);
95-
};
96-
9747
const impl = ((...args: Parameters<T>) => {
98-
if (inCodegenMode()) {
99-
return gpuImpl(...args as MapValueToSnippet<Parameters<T>>);
100-
}
10148
if (typeof options.normalImpl === 'string') {
10249
throw new MissingCpuImplError(options.normalImpl);
10350
}
10451
return options.normalImpl(...args);
105-
}) as T;
52+
}) as DualFn<T>;
10653

10754
setName(impl, options.name);
10855
impl.toString = () => options.name ?? '<unknown>';
109-
Object.defineProperty(impl, $internal, {
110-
value: {
111-
jsImpl: options.normalImpl,
112-
gpuImpl,
113-
get strictSignature() {
114-
return typeof options.signature !== 'function'
115-
? options.signature
116-
: undefined;
117-
},
118-
argConversionHint: 'keep',
56+
impl[$gpuCallable] = {
57+
get strictSignature() {
58+
return typeof options.signature !== 'function'
59+
? options.signature
60+
: undefined;
61+
},
62+
call(ctx, args) {
63+
const { argTypes, returnType } = typeof options.signature === 'function'
64+
? options.signature(
65+
...args.map((s) => {
66+
// Dereference implicit pointers
67+
if (s.dataType.type === 'ptr' && s.dataType.implicit) {
68+
return s.dataType.inner;
69+
}
70+
return s.dataType;
71+
}) as MapValueToDataType<Parameters<T>>,
72+
)
73+
: options.signature;
74+
75+
const converted = args.map((s, idx) => {
76+
const argType = argTypes[idx];
77+
if (!argType) {
78+
throw new Error('Function called with invalid arguments');
79+
}
80+
return tryConvertSnippet(
81+
ctx,
82+
s,
83+
argType,
84+
!options.ignoreImplicitCastWarning,
85+
);
86+
}) as MapValueToSnippet<Parameters<T>>;
87+
88+
if (
89+
!options.noComptime &&
90+
converted.every((s) => isKnownAtComptime(s)) &&
91+
typeof options.normalImpl === 'function'
92+
) {
93+
ctx.pushMode(new NormalState());
94+
try {
95+
return snip(
96+
options.normalImpl(...converted.map((s) => s.value) as never[]),
97+
returnType,
98+
// Functions give up ownership of their return value
99+
/* origin */ 'constant',
100+
);
101+
} catch (e) {
102+
// cpuImpl may in some cases be present but implemented only partially.
103+
// In that case, if the MissingCpuImplError is thrown, we fallback to codegenImpl.
104+
// If it is any other error, we just rethrow.
105+
if (!(e instanceof MissingCpuImplError)) {
106+
throw e;
107+
}
108+
} finally {
109+
ctx.popMode('normal');
110+
}
111+
}
112+
113+
return snip(
114+
options.codegenImpl(ctx, converted),
115+
returnType,
116+
// Functions give up ownership of their return value
117+
/* origin */ 'runtime',
118+
);
119119
},
120-
});
120+
};
121121

122-
return impl as DualFn<T>;
122+
return impl;
123123
}

0 commit comments

Comments
 (0)