-
-
Notifications
You must be signed in to change notification settings - Fork 47
feat: tgpu.fn(callback) for providing slots and accessors #2029
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 19 commits
b0e36da
218f8a1
51c8fdf
3d71e24
e717fae
77446c2
5a92b03
225e920
eb735ca
3df2489
7618c0c
c906a4c
c58e141
67b21c6
92625c3
34f8c42
a741df7
d0a9e21
f5610d3
836d7f6
87eacef
e635728
3e6be56
d5f77c5
c93ae97
73fe652
869e57e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -41,6 +41,7 @@ import type { | |
| InheritArgNames, | ||
| } from './fnTypes.ts'; | ||
| import { stripTemplate } from './templateUtils.ts'; | ||
| import type { TgpuBufferShorthand } from '../buffer/bufferShorthand.ts'; | ||
|
|
||
| // ---------- | ||
| // Public API | ||
|
|
@@ -108,6 +109,28 @@ export type TgpuFn<ImplSchema extends AnyFn = (...args: any[]) => any> = | |
| }; | ||
| }; | ||
|
|
||
| /** | ||
| * A function wrapper that allows providing slot and accessor overrides for shellless functions | ||
| */ | ||
| export interface TgpuGenericFn<Callback extends AnyFn> { | ||
| readonly [$internal]: true; | ||
| readonly [$providing]?: Providing | undefined; | ||
| readonly resourceType: 'generic-function'; | ||
| readonly callback: Callback; | ||
lursz marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| with<T>(slot: TgpuSlot<T>, value: Eventual<T>): TgpuGenericFn<Callback>; | ||
| with<T extends AnyData>( | ||
| accessor: TgpuAccessor<T>, | ||
| value: | ||
| | TgpuFn<() => T> | ||
| | TgpuBufferUsage<T> | ||
| | TgpuBufferShorthand<T> | ||
| | Infer<T>, | ||
| ): TgpuGenericFn<Callback>; | ||
|
|
||
| (...args: Parameters<Callback>): ReturnType<Callback>; | ||
| } | ||
|
|
||
| export function fn< | ||
| Args extends AnyData[] | [], | ||
| >(argTypes: Args, returnType?: undefined): TgpuFnShell<Args, Void>; | ||
|
|
@@ -117,10 +140,22 @@ export function fn< | |
| Return extends AnyData, | ||
| >(argTypes: Args, returnType: Return): TgpuFnShell<Args, Return>; | ||
|
|
||
| export function fn<Callback extends AnyFn>( | ||
| callback: Callback, | ||
| ): TgpuGenericFn<Callback>; | ||
lursz marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| export function fn< | ||
| Args extends AnyData[] | [], | ||
| Return extends AnyData = Void, | ||
| >(argTypes: Args, returnType?: Return | undefined): TgpuFnShell<Args, Return> { | ||
| >( | ||
| argTypesOrCallback: Args | AnyFn, | ||
| returnType?: Return | undefined, | ||
| ): TgpuFnShell<Args, Return> | TgpuGenericFn<AnyFn> { | ||
| if (typeof argTypesOrCallback === 'function') { | ||
| return createGenericFn(argTypesOrCallback, []); | ||
| } | ||
|
|
||
| const argTypes = argTypesOrCallback; | ||
| const shell: TgpuFnShellHeader<Args, Return> = { | ||
| [$internal]: true, | ||
| argTypes, | ||
|
|
@@ -147,6 +182,13 @@ export function isTgpuFn<Args extends AnyData[] | [], Return extends AnyData>( | |
| (value as TgpuFn<(...args: Args) => Return>)?.resourceType === 'function'; | ||
| } | ||
|
|
||
| export function isGenericFn<Callback extends AnyFn>( | ||
| value: unknown | TgpuGenericFn<Callback>, | ||
| ): value is TgpuGenericFn<Callback> { | ||
| return isMarkedInternal(value) && | ||
| (value as TgpuGenericFn<Callback>)?.resourceType === 'generic-function'; | ||
| } | ||
|
|
||
| // -------------- | ||
| // Implementation | ||
| // -------------- | ||
|
|
@@ -328,3 +370,50 @@ function createBoundFunction<ImplSchema extends AnyFn>( | |
|
|
||
| return fn; | ||
| } | ||
|
|
||
| function createGenericFn<Callback extends AnyFn>( | ||
| callback: Callback, | ||
| pairs: SlotValuePair[], | ||
| ): TgpuGenericFn<Callback> { | ||
lursz marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| type This = TgpuGenericFn<Callback>; | ||
|
|
||
| const fnBase = { | ||
| [$internal]: true as const, | ||
lursz marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| resourceType: 'generic-function' as const, | ||
| callback, | ||
| [$providing]: pairs.length > 0 | ||
| ? { | ||
| inner: callback, | ||
| pairs, | ||
| } | ||
| : undefined, | ||
lursz marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| with( | ||
| slot: TgpuSlot<unknown> | TgpuAccessor, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Mutable accessor is missing. Also, is there a particular reason, why there isn't |
||
| value: unknown, | ||
| ): TgpuGenericFn<Callback> { | ||
| return createGenericFn(callback, [ | ||
| ...pairs, | ||
| [isAccessor(slot) ? slot.slot : slot, value], | ||
| ]); | ||
| }, | ||
| }; | ||
|
|
||
| const call = ((...args: Parameters<Callback>): ReturnType<Callback> => { | ||
| return callback(...args) as ReturnType<Callback>; | ||
| }) as Callback; | ||
|
|
||
| const genericFn = Object.assign(call, fnBase) as unknown as This; | ||
|
|
||
| Object.defineProperty(genericFn, 'toString', { | ||
| value() { | ||
| const fnLabel = getName(callback) ?? '<unnamed>'; | ||
| if (pairs.length > 0) { | ||
| return `fn*:${fnLabel}[${pairs.map(stringifyPair).join(', ')}]`; | ||
| } | ||
| return `fn*:${fnLabel}`; | ||
| }, | ||
| }); | ||
|
|
||
| return genericFn; | ||
| } | ||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -23,7 +23,7 @@ import { invariant, ResolutionError, WgslTypeError } from '../errors.ts'; | |||||
| import { getName } from '../shared/meta.ts'; | ||||||
| import { isMarkedInternal } from '../shared/symbols.ts'; | ||||||
| import { safeStringify } from '../shared/stringify.ts'; | ||||||
| import { $internal } from '../shared/symbols.ts'; | ||||||
| import { $internal, $providing } from '../shared/symbols.ts'; | ||||||
| import { pow } from '../std/numeric.ts'; | ||||||
| import { add, div, mul, neg, sub } from '../std/operators.ts'; | ||||||
| import { type FnArgsConversionHint, isKnownAtComptime } from '../types.ts'; | ||||||
|
|
@@ -44,6 +44,8 @@ import type { DualFn } from '../data/dualFn.ts'; | |||||
| import { createPtrFromOrigin, implicitFrom, ptrFn } from '../data/ptr.ts'; | ||||||
| import { RefOperator } from '../data/ref.ts'; | ||||||
| import { constant } from '../core/constant/tgpuConstant.ts'; | ||||||
| import { isGenericFn } from '../core/function/tgpuFn.ts'; | ||||||
| import type { SlotValuePair } from '../core/slot/slotTypes.ts'; | ||||||
|
|
||||||
| const { NodeTypeCatalog: NODE } = tinyest; | ||||||
|
|
||||||
|
|
@@ -549,26 +551,44 @@ ${this.ctx.pre}}`; | |||||
| return callee.value.operator(callee.value.lhs, rhs); | ||||||
| } | ||||||
|
|
||||||
| if (!isMarkedInternal(callee.value)) { | ||||||
| const args = argNodes.map((arg) => this.expression(arg)); | ||||||
| const shellless = this.ctx.shelllessRepo.get( | ||||||
| callee.value as (...args: never[]) => unknown, | ||||||
| args, | ||||||
| if (!isMarkedInternal(callee.value) || isGenericFn(callee.value)) { | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you save the result of |
||||||
| const slotPairs: SlotValuePair[] = isGenericFn(callee.value) | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| ? (callee.value[$providing]?.pairs ?? []) | ||||||
| : []; | ||||||
| const callback = isGenericFn(callee.value) | ||||||
| ? callee.value.callback | ||||||
lursz marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||
| : (callee.value as (...args: never[]) => unknown); | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
|
||||||
| const shelllessCall = this.ctx.withSlots( | ||||||
| slotPairs, | ||||||
| (): Snippet | undefined => { | ||||||
| const args = argNodes.map((arg) => this.expression(arg)); | ||||||
| const shellless = this.ctx.shelllessRepo.get( | ||||||
| callback, | ||||||
| args, | ||||||
| ); | ||||||
| if (!shellless) { | ||||||
| return undefined; | ||||||
| } | ||||||
|
|
||||||
| const converted = args.map((s, idx) => { | ||||||
| const argType = shellless.argTypes[idx] as AnyData; | ||||||
| return tryConvertSnippet(s, argType, /* verbose */ false); | ||||||
| }); | ||||||
|
|
||||||
| return this.ctx.withResetIndentLevel(() => { | ||||||
| const snippet = this.ctx.resolve(shellless); | ||||||
| return snip( | ||||||
| stitch`${snippet.value}(${converted})`, | ||||||
| snippet.dataType, | ||||||
| /* origin */ 'runtime', | ||||||
| ); | ||||||
| }); | ||||||
| }, | ||||||
| ); | ||||||
| if (shellless) { | ||||||
| const converted = args.map((s, idx) => { | ||||||
| const argType = shellless.argTypes[idx] as AnyData; | ||||||
| return tryConvertSnippet(s, argType, /* verbose */ false); | ||||||
| }); | ||||||
|
|
||||||
| return this.ctx.withResetIndentLevel(() => { | ||||||
| const snippet = this.ctx.resolve(shellless); | ||||||
| return snip( | ||||||
| stitch`${snippet.value}(${converted})`, | ||||||
| snippet.dataType, | ||||||
| /* origin */ 'runtime', | ||||||
| ); | ||||||
| }); | ||||||
| if (shelllessCall) { | ||||||
| return shelllessCall; | ||||||
| } | ||||||
|
|
||||||
| throw new Error( | ||||||
|
|
||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add more tests to see if the
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this api could use a paragraph in either functions or slots docs |
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
| @@ -0,0 +1,90 @@ | ||||
| import { describe, expect } from 'vitest'; | ||||
| import * as d from '../src/data/index.ts'; | ||||
| import tgpu from '../src/index.ts'; | ||||
| import { it } from './utils/extendedIt.ts'; | ||||
|
|
||||
| describe('TgpuGenericFn - shellless callback wrapper', () => { | ||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it('can be called in js', () => {
const getValue = () => {
'use gpu';
return 2;
};
const getValueGeneric = tgpu.fn(getValue);
expect(getValueGeneric()).toBe(2);
}); |
||||
| it('generates only one definition when both original and wrapped function are used', () => { | ||||
| const countAccess = tgpu['~unstable'].accessor(d.f32, 2); | ||||
|
|
||||
| const getDouble = () => { | ||||
| 'use gpu'; | ||||
| return countAccess.$ * 2; | ||||
| }; | ||||
|
|
||||
| const getDouble4 = tgpu.fn(getDouble); | ||||
|
|
||||
| const main = () => { | ||||
| 'use gpu'; | ||||
| const original = getDouble(); | ||||
| const wrapped = getDouble4(); | ||||
| return original + wrapped; | ||||
| }; | ||||
|
|
||||
| expect(tgpu.resolve([main])).toMatchInlineSnapshot(` | ||||
| "fn getDouble() -> f32 { | ||||
| return 4f; | ||||
| } | ||||
|
|
||||
| fn main() -> f32 { | ||||
| let original = getDouble(); | ||||
| let wrapped = getDouble(); | ||||
| return (original + wrapped); | ||||
| }" | ||||
| `); | ||||
| }); | ||||
|
|
||||
| it('works when only the wrapped function is used', () => { | ||||
| const countAccess = tgpu['~unstable'].accessor(d.f32, 0); | ||||
|
|
||||
| const getDouble = () => { | ||||
| 'use gpu'; | ||||
| return countAccess.$ * 2; | ||||
| }; | ||||
|
|
||||
| const getDouble4 = tgpu.fn(getDouble); | ||||
|
|
||||
| const main = () => { | ||||
| 'use gpu'; | ||||
| return getDouble4(); | ||||
| }; | ||||
|
|
||||
| const wgsl = tgpu.resolve([main]); | ||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||
| expect(tgpu.resolve([main])).toMatchInlineSnapshot(` | ||||
| "fn getDouble() -> f32 { | ||||
| return 0f; | ||||
| } | ||||
|
|
||||
| fn main() -> f32 { | ||||
| return getDouble(); | ||||
| }" | ||||
| `); | ||||
| }); | ||||
|
|
||||
| it('does not duplicate the same function', () => { | ||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think, we've seen that test before |
||||
| const countAccess = tgpu['~unstable'].accessor(d.f32, 0); | ||||
|
|
||||
| const getDouble = () => { | ||||
| 'use gpu'; | ||||
| return countAccess.$ * 2; | ||||
| }; | ||||
|
|
||||
| const getDouble4 = tgpu.fn(getDouble); | ||||
|
|
||||
| const main = () => { | ||||
| 'use gpu'; | ||||
| return getDouble4() + getDouble4(); | ||||
| }; | ||||
|
|
||||
| const wgsl = tgpu.resolve([main]); | ||||
| expect(tgpu.resolve([main])).toMatchInlineSnapshot(` | ||||
| "fn getDouble() -> f32 { | ||||
| return 0f; | ||||
| } | ||||
|
|
||||
| fn main() -> f32 { | ||||
| return (getDouble() + getDouble()); | ||||
| }" | ||||
| `); | ||||
| }); | ||||
| }); | ||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this intended behavior? it('complex', () => {
const valueAccess = tgpu['~unstable'].accessor(d.f32);
const slot = tgpu.slot(5);
const getValue = () => {
'use gpu';
return valueAccess.$ * slot.$;
};
const getValueGeneric = tgpu.fn(getValue).with(valueAccess, 2).with(
valueAccess,
4,
).with(slot, 7).with(slot, 5);
const getValueGenericAgain = tgpu.fn(getValue).with(valueAccess, 4);
const getValueShell = tgpu.fn([], d.f32)(getValue).with(
valueAccess,
4,
);
const getValueShell2 = tgpu.fn([], d.f32)(getValue).with(
valueAccess,
4,
);
const main = () => {
'use gpu';
return getValueShell() + getValueGeneric() + getValueGenericAgain() +
getValueShell2();
};
expect(tgpu.resolve([main])).toMatchInlineSnapshot(`
"fn getValue() -> f32 {
return 20f;
}
fn getValue_1() -> f32 {
return 20f;
}
fn getValue_2() -> f32 {
return 20f;
}
fn main() -> f32 {
return (((getValue() + getValue_1()) + getValue_1()) + getValue_2());
}"
`);
}); |
||||
Uh oh!
There was an error while loading. Please reload this page.