Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 83 additions & 1 deletion packages/typegpu/src/core/function/tgpuFn.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import type {
InheritArgNames,
} from './fnTypes.ts';
import { stripTemplate } from './templateUtils.ts';
import type { TgpuBufferShorthand } from '../buffer/bufferShorthand.ts';

// ----------
// Public API
Expand Down Expand Up @@ -108,6 +109,29 @@ export type TgpuFn<ImplSchema extends AnyFn = (...args: any[]) => any> =
};
};

/**
* A function wrapper that allows providing slot and accessor overrides for shellless functions
*/
export interface TgpuGenericFn<T extends AnyFn> {
readonly [$internal]: {
inner: T;
};
readonly [$providing]?: Providing | undefined;
readonly resourceType: 'generic-function';

with<S>(slot: TgpuSlot<S>, value: Eventual<S>): TgpuGenericFn<T>;
with<S extends AnyData>(
accessor: TgpuAccessor<S>,
value:
| TgpuFn<() => S>
| TgpuBufferUsage<S>
| TgpuBufferShorthand<S>
| Infer<S>,
): TgpuGenericFn<T>;

(...args: Parameters<T>): ReturnType<T>;
}

export function fn<
Args extends AnyData[] | [],
>(argTypes: Args, returnType?: undefined): TgpuFnShell<Args, Void>;
Expand All @@ -117,10 +141,20 @@ export function fn<
Return extends AnyData,
>(argTypes: Args, returnType: Return): TgpuFnShell<Args, Return>;

export function fn<T extends AnyFn>(inner: T): TgpuGenericFn<T>;

export function fn<
Args extends AnyData[] | [],
Return extends AnyData = Void,
>(argTypes: Args, returnType?: Return | undefined): TgpuFnShell<Args, Return> {
>(
argTypesOrCallback: Args | AnyFn,
returnType?: Return | undefined,
): TgpuFnShell<Args, Return> | TgpuGenericFn<AnyFn> {
if (typeof argTypesOrCallback === 'function') {
return createGenericFn(argTypesOrCallback, []);
}

const argTypes = argTypesOrCallback;
const shell: TgpuFnShellHeader<Args, Return> = {
[$internal]: true,
argTypes,
Expand All @@ -147,6 +181,13 @@ export function isTgpuFn<Args extends AnyData[] | [], Return extends AnyData>(
(value as TgpuFn<(...args: Args) => Return>)?.resourceType === 'function';
}

export function isGenericFn<Callback extends AnyFn>(
value: unknown | TgpuGenericFn<Callback>,
): value is TgpuGenericFn<Callback> {
return isMarkedInternal(value) &&
(value as TgpuGenericFn<Callback>)?.resourceType === 'generic-function';
}

// --------------
// Implementation
// --------------
Expand Down Expand Up @@ -328,3 +369,44 @@ function createBoundFunction<ImplSchema extends AnyFn>(

return fn;
}

function createGenericFn<T extends AnyFn>(
inner: T,
pairs: SlotValuePair[],
): TgpuGenericFn<T> {
type This = TgpuGenericFn<T>;

const fnBase = {
[$internal]: { inner },
resourceType: 'generic-function' as const,
[$providing]: pairs.length > 0 ? { inner, pairs } : undefined,

with(
slot: TgpuSlot<unknown> | TgpuAccessor,
value: unknown,
): TgpuGenericFn<T> {
return createGenericFn(inner, [
...pairs,
[isAccessor(slot) ? slot.slot : slot, value],
]);
},
};

const call = ((...args: Parameters<T>): ReturnType<T> => {
return inner(...args) as ReturnType<T>;
}) as T;

const genericFn = Object.assign(call, fnBase) as unknown as This;

Object.defineProperty(genericFn, 'toString', {
value() {
const fnLabel = getName(inner) ?? '<unnamed>';
if (pairs.length > 0) {
return `fn*:${fnLabel}[${pairs.map(stringifyPair).join(', ')}]`;
}
return `fn*:${fnLabel}`;
},
});

return genericFn;
}
6 changes: 5 additions & 1 deletion packages/typegpu/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,11 @@ export type {
TgpuLayoutTexture,
TgpuLayoutUniform,
} from './tgpuBindGroupLayout.ts';
export type { TgpuFn, TgpuFnShell } from './core/function/tgpuFn.ts';
export type {
TgpuFn,
TgpuFnShell,
TgpuGenericFn,
} from './core/function/tgpuFn.ts';
export type { TgpuComptime } from './core/function/comptime.ts';
export type {
TgpuVertexFn,
Expand Down
4 changes: 4 additions & 0 deletions packages/typegpu/src/resolutionCtx.ts
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,10 @@ export class ResolutionCtxImpl implements ResolutionCtx {
}

withSlots<T>(pairs: SlotValuePair<unknown>[], callback: () => T): T {
if (pairs.length === 0) {
return callback();
}

this._itemStateStack.pushSlotBindings(pairs);

try {
Expand Down
58 changes: 39 additions & 19 deletions packages/typegpu/src/tgsl/wgslGenerator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import { invariant, ResolutionError, WgslTypeError } from '../errors.ts';
import { getName } from '../shared/meta.ts';
import { isMarkedInternal } from '../shared/symbols.ts';
import { safeStringify } from '../shared/stringify.ts';
import { $internal } from '../shared/symbols.ts';
import { $internal, $providing } from '../shared/symbols.ts';
import { pow } from '../std/numeric.ts';
import { add, div, mul, neg, sub } from '../std/operators.ts';
import { type FnArgsConversionHint, isKnownAtComptime } from '../types.ts';
Expand All @@ -44,6 +44,8 @@ import type { DualFn } from '../data/dualFn.ts';
import { createPtrFromOrigin, implicitFrom, ptrFn } from '../data/ptr.ts';
import { RefOperator } from '../data/ref.ts';
import { constant } from '../core/constant/tgpuConstant.ts';
import { isGenericFn } from '../core/function/tgpuFn.ts';
import type { SlotValuePair } from '../core/slot/slotTypes.ts';

const { NodeTypeCatalog: NODE } = tinyest;

Expand Down Expand Up @@ -549,26 +551,44 @@ ${this.ctx.pre}}`;
return callee.value.operator(callee.value.lhs, rhs);
}

if (!isMarkedInternal(callee.value)) {
const args = argNodes.map((arg) => this.expression(arg));
const shellless = this.ctx.shelllessRepo.get(
callee.value as (...args: never[]) => unknown,
args,
if (!isMarkedInternal(callee.value) || isGenericFn(callee.value)) {
const slotPairs: SlotValuePair[] = isGenericFn(callee.value)
? (callee.value[$providing]?.pairs ?? [])
: [];
const callback = isGenericFn(callee.value)
? callee.value[$internal].inner
: (callee.value as (...args: never[]) => unknown);

const shelllessCall = this.ctx.withSlots(
slotPairs,
(): Snippet | undefined => {
const args = argNodes.map((arg) => this.expression(arg));
const shellless = this.ctx.shelllessRepo.get(
callback,
args,
);
if (!shellless) {
return undefined;
}

const converted = args.map((s, idx) => {
const argType = shellless.argTypes[idx] as AnyData;
return tryConvertSnippet(s, argType, /* verbose */ false);
});

return this.ctx.withResetIndentLevel(() => {
const snippet = this.ctx.resolve(shellless);
return snip(
stitch`${snippet.value}(${converted})`,
snippet.dataType,
/* origin */ 'runtime',
);
});
},
);
if (shellless) {
const converted = args.map((s, idx) => {
const argType = shellless.argTypes[idx] as AnyData;
return tryConvertSnippet(s, argType, /* verbose */ false);
});

return this.ctx.withResetIndentLevel(() => {
const snippet = this.ctx.resolve(shellless);
return snip(
stitch`${snippet.value}(${converted})`,
snippet.dataType,
/* origin */ 'runtime',
);
});
if (shelllessCall) {
return shelllessCall;
}

throw new Error(
Expand Down
148 changes: 148 additions & 0 deletions packages/typegpu/tests/tgpuGenericFn.test.ts
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add more tests to see if the .with API works on these generic functions?

Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import { describe, expect } from 'vitest';
import * as d from '../src/data/index.ts';
import tgpu from '../src/index.ts';
import { it } from './utils/extendedIt.ts';

describe('TgpuGenericFn - shellless callback wrapper', () => {
it('generates only one definition when both original and wrapped function are used', () => {
const countAccess = tgpu['~unstable'].accessor(d.f32, 2);

const getDouble = () => {
'use gpu';
return countAccess.$ * 2;
};

const getDouble4 = tgpu.fn(getDouble);

const main = () => {
'use gpu';
const original = getDouble();
const wrapped = getDouble4();
return original + wrapped;
};

expect(tgpu.resolve([main])).toMatchInlineSnapshot(`
"fn getDouble() -> f32 {
return 4f;
}

fn main() -> f32 {
let original = getDouble();
let wrapped = getDouble();
return (original + wrapped);
}"
`);
});

it('works when only the wrapped function is used', () => {
const countAccess = tgpu['~unstable'].accessor(d.f32, 0);

const getDouble = () => {
'use gpu';
return countAccess.$ * 2;
};

const getDouble4 = tgpu.fn(getDouble);

const main = () => {
'use gpu';
return getDouble4();
};

const wgsl = tgpu.resolve([main]);
expect(tgpu.resolve([main])).toMatchInlineSnapshot(`
"fn getDouble() -> f32 {
return 0f;
}

fn main() -> f32 {
return getDouble();
}"
`);
});

it('does not duplicate the same function', () => {
const countAccess = tgpu['~unstable'].accessor(d.f32, 0);

const getDouble = () => {
'use gpu';
return countAccess.$ * 2;
};

const getDouble4 = tgpu.fn(getDouble);

const main = () => {
'use gpu';
return getDouble4() + getDouble4();
};

const wgsl = tgpu.resolve([main]);
expect(tgpu.resolve([main])).toMatchInlineSnapshot(`
"fn getDouble() -> f32 {
return 0f;
}

fn main() -> f32 {
return (getDouble() + getDouble());
}"
`);
});

it('supports .with for slot bindings on generic functions', () => {
const multiplier = tgpu.slot(2).$name('multiplier');

const scale = () => {
'use gpu';
return d.f32(multiplier.$) * d.f32(2);
};

const scaleGeneric = tgpu.fn(scale);
const scaleBy3 = scaleGeneric.with(multiplier, 3);
const scaleBy4 = scaleGeneric.with(multiplier, 4);

const main = () => {
'use gpu';
return scaleBy3() + scaleBy4();
};

expect(tgpu.resolve([main])).toMatchInlineSnapshot(`
"fn scale() -> f32 {
return 6f;
}

fn scale_1() -> f32 {
return 8f;
}

fn main() -> f32 {
return (scale() + scale_1());
}"
`);
});

it('supports .with for accessor bindings on generic functions', () => {
const valueAccess = tgpu['~unstable'].accessor(d.f32);

const getValue = () => {
'use gpu';
return valueAccess.$;
};

const getValueGeneric = tgpu.fn(getValue).with(valueAccess, 2);

const main = () => {
'use gpu';
return getValueGeneric();
};

expect(tgpu.resolve([main])).toMatchInlineSnapshot(`
"fn getValue() -> f32 {
return 2f;
}

fn main() -> f32 {
return getValue();
}"
`);
});
});