diff --git a/packages/typegpu/src/core/function/tgpuFn.ts b/packages/typegpu/src/core/function/tgpuFn.ts index 636efeaf46..ccf254d14f 100644 --- a/packages/typegpu/src/core/function/tgpuFn.ts +++ b/packages/typegpu/src/core/function/tgpuFn.ts @@ -41,6 +41,7 @@ import type { InheritArgNames, } from './fnTypes.ts'; import { stripTemplate } from './templateUtils.ts'; +import type { TgpuBufferShorthand } from '../buffer/bufferShorthand.ts'; // ---------- // Public API @@ -108,6 +109,29 @@ export type TgpuFn any> = }; }; +/** + * A function wrapper that allows providing slot and accessor overrides for shellless functions + */ +export interface TgpuGenericFn { + readonly [$internal]: { + inner: T; + }; + readonly [$providing]?: Providing | undefined; + readonly resourceType: 'generic-function'; + + with(slot: TgpuSlot, value: Eventual): TgpuGenericFn; + with( + accessor: TgpuAccessor, + value: + | TgpuFn<() => S> + | TgpuBufferUsage + | TgpuBufferShorthand + | Infer, + ): TgpuGenericFn; + + (...args: Parameters): ReturnType; +} + export function fn< Args extends AnyData[] | [], >(argTypes: Args, returnType?: undefined): TgpuFnShell; @@ -117,10 +141,20 @@ export function fn< Return extends AnyData, >(argTypes: Args, returnType: Return): TgpuFnShell; +export function fn(inner: T): TgpuGenericFn; + export function fn< Args extends AnyData[] | [], Return extends AnyData = Void, ->(argTypes: Args, returnType?: Return | undefined): TgpuFnShell { +>( + argTypesOrCallback: Args | AnyFn, + returnType?: Return | undefined, +): TgpuFnShell | TgpuGenericFn { + if (typeof argTypesOrCallback === 'function') { + return createGenericFn(argTypesOrCallback, []); + } + + const argTypes = argTypesOrCallback; const shell: TgpuFnShellHeader = { [$internal]: true, argTypes, @@ -147,6 +181,13 @@ export function isTgpuFn( (value as TgpuFn<(...args: Args) => Return>)?.resourceType === 'function'; } +export function isGenericFn( + value: unknown | TgpuGenericFn, +): value is TgpuGenericFn { + return isMarkedInternal(value) && + (value as TgpuGenericFn)?.resourceType === 'generic-function'; +} + // -------------- // Implementation // -------------- @@ -328,3 +369,44 @@ function createBoundFunction( return fn; } + +function createGenericFn( + inner: T, + pairs: SlotValuePair[], +): TgpuGenericFn { + type This = TgpuGenericFn; + + const fnBase = { + [$internal]: { inner }, + resourceType: 'generic-function' as const, + [$providing]: pairs.length > 0 ? { inner, pairs } : undefined, + + with( + slot: TgpuSlot | TgpuAccessor, + value: unknown, + ): TgpuGenericFn { + return createGenericFn(inner, [ + ...pairs, + [isAccessor(slot) ? slot.slot : slot, value], + ]); + }, + }; + + const call = ((...args: Parameters): ReturnType => { + return inner(...args) as ReturnType; + }) as T; + + const genericFn = Object.assign(call, fnBase) as unknown as This; + + Object.defineProperty(genericFn, 'toString', { + value() { + const fnLabel = getName(inner) ?? ''; + if (pairs.length > 0) { + return `fn*:${fnLabel}[${pairs.map(stringifyPair).join(', ')}]`; + } + return `fn*:${fnLabel}`; + }, + }); + + return genericFn; +} diff --git a/packages/typegpu/src/index.ts b/packages/typegpu/src/index.ts index 92a2b15d7f..98ebe9f3a1 100644 --- a/packages/typegpu/src/index.ts +++ b/packages/typegpu/src/index.ts @@ -174,7 +174,11 @@ export type { TgpuLayoutTexture, TgpuLayoutUniform, } from './tgpuBindGroupLayout.ts'; -export type { TgpuFn, TgpuFnShell } from './core/function/tgpuFn.ts'; +export type { + TgpuFn, + TgpuFnShell, + TgpuGenericFn, +} from './core/function/tgpuFn.ts'; export type { TgpuComptime } from './core/function/comptime.ts'; export type { TgpuVertexFn, diff --git a/packages/typegpu/src/resolutionCtx.ts b/packages/typegpu/src/resolutionCtx.ts index 4b13aa8afb..5887bdc73d 100644 --- a/packages/typegpu/src/resolutionCtx.ts +++ b/packages/typegpu/src/resolutionCtx.ts @@ -607,6 +607,10 @@ export class ResolutionCtxImpl implements ResolutionCtx { } withSlots(pairs: SlotValuePair[], callback: () => T): T { + if (pairs.length === 0) { + return callback(); + } + this._itemStateStack.pushSlotBindings(pairs); try { diff --git a/packages/typegpu/src/tgsl/wgslGenerator.ts b/packages/typegpu/src/tgsl/wgslGenerator.ts index 2c95b96d9e..fd2cd86a7f 100644 --- a/packages/typegpu/src/tgsl/wgslGenerator.ts +++ b/packages/typegpu/src/tgsl/wgslGenerator.ts @@ -23,7 +23,7 @@ import { invariant, ResolutionError, WgslTypeError } from '../errors.ts'; import { getName } from '../shared/meta.ts'; import { isMarkedInternal } from '../shared/symbols.ts'; import { safeStringify } from '../shared/stringify.ts'; -import { $internal } from '../shared/symbols.ts'; +import { $internal, $providing } from '../shared/symbols.ts'; import { pow } from '../std/numeric.ts'; import { add, div, mul, neg, sub } from '../std/operators.ts'; import { type FnArgsConversionHint, isKnownAtComptime } from '../types.ts'; @@ -44,6 +44,8 @@ import type { DualFn } from '../data/dualFn.ts'; import { createPtrFromOrigin, implicitFrom, ptrFn } from '../data/ptr.ts'; import { RefOperator } from '../data/ref.ts'; import { constant } from '../core/constant/tgpuConstant.ts'; +import { isGenericFn } from '../core/function/tgpuFn.ts'; +import type { SlotValuePair } from '../core/slot/slotTypes.ts'; const { NodeTypeCatalog: NODE } = tinyest; @@ -549,26 +551,44 @@ ${this.ctx.pre}}`; return callee.value.operator(callee.value.lhs, rhs); } - if (!isMarkedInternal(callee.value)) { - const args = argNodes.map((arg) => this.expression(arg)); - const shellless = this.ctx.shelllessRepo.get( - callee.value as (...args: never[]) => unknown, - args, + if (!isMarkedInternal(callee.value) || isGenericFn(callee.value)) { + const slotPairs: SlotValuePair[] = isGenericFn(callee.value) + ? (callee.value[$providing]?.pairs ?? []) + : []; + const callback = isGenericFn(callee.value) + ? callee.value[$internal].inner + : (callee.value as (...args: never[]) => unknown); + + const shelllessCall = this.ctx.withSlots( + slotPairs, + (): Snippet | undefined => { + const args = argNodes.map((arg) => this.expression(arg)); + const shellless = this.ctx.shelllessRepo.get( + callback, + args, + ); + if (!shellless) { + return undefined; + } + + const converted = args.map((s, idx) => { + const argType = shellless.argTypes[idx] as AnyData; + return tryConvertSnippet(s, argType, /* verbose */ false); + }); + + return this.ctx.withResetIndentLevel(() => { + const snippet = this.ctx.resolve(shellless); + return snip( + stitch`${snippet.value}(${converted})`, + snippet.dataType, + /* origin */ 'runtime', + ); + }); + }, ); - if (shellless) { - const converted = args.map((s, idx) => { - const argType = shellless.argTypes[idx] as AnyData; - return tryConvertSnippet(s, argType, /* verbose */ false); - }); - return this.ctx.withResetIndentLevel(() => { - const snippet = this.ctx.resolve(shellless); - return snip( - stitch`${snippet.value}(${converted})`, - snippet.dataType, - /* origin */ 'runtime', - ); - }); + if (shelllessCall) { + return shelllessCall; } throw new Error( diff --git a/packages/typegpu/tests/tgpuGenericFn.test.ts b/packages/typegpu/tests/tgpuGenericFn.test.ts new file mode 100644 index 0000000000..05b1fa821d --- /dev/null +++ b/packages/typegpu/tests/tgpuGenericFn.test.ts @@ -0,0 +1,148 @@ +import { describe, expect } from 'vitest'; +import * as d from '../src/data/index.ts'; +import tgpu from '../src/index.ts'; +import { it } from './utils/extendedIt.ts'; + +describe('TgpuGenericFn - shellless callback wrapper', () => { + it('generates only one definition when both original and wrapped function are used', () => { + const countAccess = tgpu['~unstable'].accessor(d.f32, 2); + + const getDouble = () => { + 'use gpu'; + return countAccess.$ * 2; + }; + + const getDouble4 = tgpu.fn(getDouble); + + const main = () => { + 'use gpu'; + const original = getDouble(); + const wrapped = getDouble4(); + return original + wrapped; + }; + + expect(tgpu.resolve([main])).toMatchInlineSnapshot(` + "fn getDouble() -> f32 { + return 4f; + } + + fn main() -> f32 { + let original = getDouble(); + let wrapped = getDouble(); + return (original + wrapped); + }" + `); + }); + + it('works when only the wrapped function is used', () => { + const countAccess = tgpu['~unstable'].accessor(d.f32, 0); + + const getDouble = () => { + 'use gpu'; + return countAccess.$ * 2; + }; + + const getDouble4 = tgpu.fn(getDouble); + + const main = () => { + 'use gpu'; + return getDouble4(); + }; + + const wgsl = tgpu.resolve([main]); + expect(tgpu.resolve([main])).toMatchInlineSnapshot(` + "fn getDouble() -> f32 { + return 0f; + } + + fn main() -> f32 { + return getDouble(); + }" + `); + }); + + it('does not duplicate the same function', () => { + const countAccess = tgpu['~unstable'].accessor(d.f32, 0); + + const getDouble = () => { + 'use gpu'; + return countAccess.$ * 2; + }; + + const getDouble4 = tgpu.fn(getDouble); + + const main = () => { + 'use gpu'; + return getDouble4() + getDouble4(); + }; + + const wgsl = tgpu.resolve([main]); + expect(tgpu.resolve([main])).toMatchInlineSnapshot(` + "fn getDouble() -> f32 { + return 0f; + } + + fn main() -> f32 { + return (getDouble() + getDouble()); + }" + `); + }); + + it('supports .with for slot bindings on generic functions', () => { + const multiplier = tgpu.slot(2).$name('multiplier'); + + const scale = () => { + 'use gpu'; + return d.f32(multiplier.$) * d.f32(2); + }; + + const scaleGeneric = tgpu.fn(scale); + const scaleBy3 = scaleGeneric.with(multiplier, 3); + const scaleBy4 = scaleGeneric.with(multiplier, 4); + + const main = () => { + 'use gpu'; + return scaleBy3() + scaleBy4(); + }; + + expect(tgpu.resolve([main])).toMatchInlineSnapshot(` + "fn scale() -> f32 { + return 6f; + } + + fn scale_1() -> f32 { + return 8f; + } + + fn main() -> f32 { + return (scale() + scale_1()); + }" + `); + }); + + it('supports .with for accessor bindings on generic functions', () => { + const valueAccess = tgpu['~unstable'].accessor(d.f32); + + const getValue = () => { + 'use gpu'; + return valueAccess.$; + }; + + const getValueGeneric = tgpu.fn(getValue).with(valueAccess, 2); + + const main = () => { + 'use gpu'; + return getValueGeneric(); + }; + + expect(tgpu.resolve([main])).toMatchInlineSnapshot(` + "fn getValue() -> f32 { + return 2f; + } + + fn main() -> f32 { + return getValue(); + }" + `); + }); +});