Skip to content
Open
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
b0e36da
TgpuGenericFn
lursz Jan 9, 2026
218f8a1
🦕
lursz Jan 9, 2026
51c8fdf
still not passing...
lursz Jan 9, 2026
3d71e24
🦕
lursz Jan 9, 2026
e717fae
Proposed way to proceed with tgpu.fn (#2030)
iwoplaza Jan 9, 2026
77446c2
isTgpuGenericFn
lursz Jan 9, 2026
5a92b03
isTgpuGenericFn_vol2
lursz Jan 9, 2026
225e920
Merge branch 'main' into feat/tgpufn-callback
lursz Jan 12, 2026
eb735ca
🦕
lursz Jan 12, 2026
3df2489
test corrections
lursz Jan 12, 2026
7618c0c
shelllessCall fix - used to return undefined
lursz Jan 12, 2026
c906a4c
Merge branch 'main' into feat/tgpufn-callback
lursz Jan 12, 2026
c58e141
tests tests tests
lursz Jan 12, 2026
67b21c6
Merge remote-tracking branch 'refs/remotes/origin/feat/tgpufn-callbac…
lursz Jan 12, 2026
92625c3
Merge branch 'main' into feat/tgpufn-callback
lursz Jan 12, 2026
34f8c42
more tests
lursz Jan 12, 2026
a741df7
Merge remote-tracking branch 'refs/remotes/origin/feat/tgpufn-callbac…
lursz Jan 12, 2026
d0a9e21
🦕🦕🦕🦕🦕🦕
lursz Jan 12, 2026
f5610d3
🦕
lursz Jan 12, 2026
836d7f6
pr fixes, mostly generics
lursz Jan 15, 2026
87eacef
more tests
lursz Jan 15, 2026
e635728
Merge branch 'main' into feat/tgpufn-callback
lursz Jan 15, 2026
3e6be56
Merge branch 'main' into feat/tgpufn-callback
lursz Jan 18, 2026
d5f77c5
Merge branch 'main' into feat/tgpufn-callback
lursz Jan 22, 2026
c93ae97
merge main
lursz Jan 23, 2026
73fe652
🦕🦕
lursz Jan 23, 2026
869e57e
Merge branch 'main' into feat/tgpufn-callback
lursz Jan 23, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 90 additions & 1 deletion packages/typegpu/src/core/function/tgpuFn.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import type {
InheritArgNames,
} from './fnTypes.ts';
import { stripTemplate } from './templateUtils.ts';
import type { TgpuBufferShorthand } from '../buffer/bufferShorthand.ts';

// ----------
// Public API
Expand Down Expand Up @@ -108,6 +109,28 @@ export type TgpuFn<ImplSchema extends AnyFn = (...args: any[]) => any> =
};
};

/**
* A function wrapper that allows providing slot and accessor overrides for shellless functions
*/
export interface TgpuGenericFn<Callback extends AnyFn> {
readonly [$internal]: true;
readonly [$providing]?: Providing | undefined;
readonly resourceType: 'generic-function';
readonly callback: Callback;

with<T>(slot: TgpuSlot<T>, value: Eventual<T>): TgpuGenericFn<Callback>;
with<T extends AnyData>(
accessor: TgpuAccessor<T>,
value:
| TgpuFn<() => T>
| TgpuBufferUsage<T>
| TgpuBufferShorthand<T>
| Infer<T>,
): TgpuGenericFn<Callback>;

(...args: Parameters<Callback>): ReturnType<Callback>;
}

export function fn<
Args extends AnyData[] | [],
>(argTypes: Args, returnType?: undefined): TgpuFnShell<Args, Void>;
Expand All @@ -117,10 +140,22 @@ export function fn<
Return extends AnyData,
>(argTypes: Args, returnType: Return): TgpuFnShell<Args, Return>;

export function fn<Callback extends AnyFn>(
callback: Callback,
): TgpuGenericFn<Callback>;

export function fn<
Args extends AnyData[] | [],
Return extends AnyData = Void,
>(argTypes: Args, returnType?: Return | undefined): TgpuFnShell<Args, Return> {
>(
argTypesOrCallback: Args | AnyFn,
returnType?: Return | undefined,
): TgpuFnShell<Args, Return> | TgpuGenericFn<AnyFn> {
if (typeof argTypesOrCallback === 'function') {
return createGenericFn(argTypesOrCallback, []);
}

const argTypes = argTypesOrCallback;
const shell: TgpuFnShellHeader<Args, Return> = {
[$internal]: true,
argTypes,
Expand All @@ -147,6 +182,13 @@ export function isTgpuFn<Args extends AnyData[] | [], Return extends AnyData>(
(value as TgpuFn<(...args: Args) => Return>)?.resourceType === 'function';
}

export function isGenericFn<Callback extends AnyFn>(
value: unknown | TgpuGenericFn<Callback>,
): value is TgpuGenericFn<Callback> {
return isMarkedInternal(value) &&
(value as TgpuGenericFn<Callback>)?.resourceType === 'generic-function';
}

// --------------
// Implementation
// --------------
Expand Down Expand Up @@ -328,3 +370,50 @@ function createBoundFunction<ImplSchema extends AnyFn>(

return fn;
}

function createGenericFn<Callback extends AnyFn>(
callback: Callback,
pairs: SlotValuePair[],
): TgpuGenericFn<Callback> {
type This = TgpuGenericFn<Callback>;

const fnBase = {
[$internal]: true as const,
resourceType: 'generic-function' as const,
callback,
[$providing]: pairs.length > 0
? {
inner: callback,
pairs,
}
: undefined,

with(
slot: TgpuSlot<unknown> | TgpuAccessor,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mutable accessor is missing. Also, is there a particular reason, why there isn't comptime wrapper like in createBoundFunction?

value: unknown,
): TgpuGenericFn<Callback> {
return createGenericFn(callback, [
...pairs,
[isAccessor(slot) ? slot.slot : slot, value],
]);
},
};

const call = ((...args: Parameters<Callback>): ReturnType<Callback> => {
return callback(...args) as ReturnType<Callback>;
}) as Callback;

const genericFn = Object.assign(call, fnBase) as unknown as This;

Object.defineProperty(genericFn, 'toString', {
value() {
const fnLabel = getName(callback) ?? '<unnamed>';
if (pairs.length > 0) {
return `fn*:${fnLabel}[${pairs.map(stringifyPair).join(', ')}]`;
}
return `fn*:${fnLabel}`;
},
});

return genericFn;
}
6 changes: 5 additions & 1 deletion packages/typegpu/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,11 @@ export type {
TgpuLayoutTexture,
TgpuLayoutUniform,
} from './tgpuBindGroupLayout.ts';
export type { TgpuFn, TgpuFnShell } from './core/function/tgpuFn.ts';
export type {
TgpuFn,
TgpuFnShell,
TgpuGenericFn,
} from './core/function/tgpuFn.ts';
export type { TgpuComptime } from './core/function/comptime.ts';
export type {
TgpuVertexFn,
Expand Down
4 changes: 4 additions & 0 deletions packages/typegpu/src/resolutionCtx.ts
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,10 @@ export class ResolutionCtxImpl implements ResolutionCtx {
}

withSlots<T>(pairs: SlotValuePair<unknown>[], callback: () => T): T {
if (pairs.length === 0) {
return callback();
}

this._itemStateStack.pushSlotBindings(pairs);

try {
Expand Down
58 changes: 39 additions & 19 deletions packages/typegpu/src/tgsl/wgslGenerator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import { invariant, ResolutionError, WgslTypeError } from '../errors.ts';
import { getName } from '../shared/meta.ts';
import { isMarkedInternal } from '../shared/symbols.ts';
import { safeStringify } from '../shared/stringify.ts';
import { $internal } from '../shared/symbols.ts';
import { $internal, $providing } from '../shared/symbols.ts';
import { pow } from '../std/numeric.ts';
import { add, div, mul, neg, sub } from '../std/operators.ts';
import { type FnArgsConversionHint, isKnownAtComptime } from '../types.ts';
Expand All @@ -44,6 +44,8 @@ import type { DualFn } from '../data/dualFn.ts';
import { createPtrFromOrigin, implicitFrom, ptrFn } from '../data/ptr.ts';
import { RefOperator } from '../data/ref.ts';
import { constant } from '../core/constant/tgpuConstant.ts';
import { isGenericFn } from '../core/function/tgpuFn.ts';
import type { SlotValuePair } from '../core/slot/slotTypes.ts';

const { NodeTypeCatalog: NODE } = tinyest;

Expand Down Expand Up @@ -549,26 +551,44 @@ ${this.ctx.pre}}`;
return callee.value.operator(callee.value.lhs, rhs);
}

if (!isMarkedInternal(callee.value)) {
const args = argNodes.map((arg) => this.expression(arg));
const shellless = this.ctx.shelllessRepo.get(
callee.value as (...args: never[]) => unknown,
args,
if (!isMarkedInternal(callee.value) || isGenericFn(callee.value)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you save the result of isGenericFn to some variable? You are calling it 3 times, which means there will be 3 jumps (unless JS caches the calls)

const slotPairs: SlotValuePair[] = isGenericFn(callee.value)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const slotPairs: SlotValuePair[] = isGenericFn(callee.value)
const slotPairs = isGenericFn(callee.value)

? (callee.value[$providing]?.pairs ?? [])
: [];
const callback = isGenericFn(callee.value)
? callee.value.callback
: (callee.value as (...args: never[]) => unknown);
Copy link
Collaborator

@cieplypolar cieplypolar Jan 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
: (callee.value as (...args: never[]) => unknown);
: (callee.value as AnyFn);


const shelllessCall = this.ctx.withSlots(
slotPairs,
(): Snippet | undefined => {
const args = argNodes.map((arg) => this.expression(arg));
const shellless = this.ctx.shelllessRepo.get(
callback,
args,
);
if (!shellless) {
return undefined;
}

const converted = args.map((s, idx) => {
const argType = shellless.argTypes[idx] as AnyData;
return tryConvertSnippet(s, argType, /* verbose */ false);
});

return this.ctx.withResetIndentLevel(() => {
const snippet = this.ctx.resolve(shellless);
return snip(
stitch`${snippet.value}(${converted})`,
snippet.dataType,
/* origin */ 'runtime',
);
});
},
);
if (shellless) {
const converted = args.map((s, idx) => {
const argType = shellless.argTypes[idx] as AnyData;
return tryConvertSnippet(s, argType, /* verbose */ false);
});

return this.ctx.withResetIndentLevel(() => {
const snippet = this.ctx.resolve(shellless);
return snip(
stitch`${snippet.value}(${converted})`,
snippet.dataType,
/* origin */ 'runtime',
);
});
if (shelllessCall) {
return shelllessCall;
}

throw new Error(
Expand Down
90 changes: 90 additions & 0 deletions packages/typegpu/tests/tgpuGenericFn.test.ts
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add more tests to see if the .with API works on these generic functions?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this api could use a paragraph in either functions or slots docs

Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import { describe, expect } from 'vitest';
import * as d from '../src/data/index.ts';
import tgpu from '../src/index.ts';
import { it } from './utils/extendedIt.ts';

describe('TgpuGenericFn - shellless callback wrapper', () => {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it('can be called in js', () => {
  const getValue = () => {
    'use gpu';
    return 2;
  };

  const getValueGeneric = tgpu.fn(getValue);

  expect(getValueGeneric()).toBe(2);
});

it('generates only one definition when both original and wrapped function are used', () => {
const countAccess = tgpu['~unstable'].accessor(d.f32, 2);

const getDouble = () => {
'use gpu';
return countAccess.$ * 2;
};

const getDouble4 = tgpu.fn(getDouble);

const main = () => {
'use gpu';
const original = getDouble();
const wrapped = getDouble4();
return original + wrapped;
};

expect(tgpu.resolve([main])).toMatchInlineSnapshot(`
"fn getDouble() -> f32 {
return 4f;
}

fn main() -> f32 {
let original = getDouble();
let wrapped = getDouble();
return (original + wrapped);
}"
`);
});

it('works when only the wrapped function is used', () => {
const countAccess = tgpu['~unstable'].accessor(d.f32, 0);

const getDouble = () => {
'use gpu';
return countAccess.$ * 2;
};

const getDouble4 = tgpu.fn(getDouble);

const main = () => {
'use gpu';
return getDouble4();
};

const wgsl = tgpu.resolve([main]);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const wgsl = tgpu.resolve([main]);

expect(tgpu.resolve([main])).toMatchInlineSnapshot(`
"fn getDouble() -> f32 {
return 0f;
}

fn main() -> f32 {
return getDouble();
}"
`);
});

it('does not duplicate the same function', () => {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think, we've seen that test before

const countAccess = tgpu['~unstable'].accessor(d.f32, 0);

const getDouble = () => {
'use gpu';
return countAccess.$ * 2;
};

const getDouble4 = tgpu.fn(getDouble);

const main = () => {
'use gpu';
return getDouble4() + getDouble4();
};

const wgsl = tgpu.resolve([main]);
expect(tgpu.resolve([main])).toMatchInlineSnapshot(`
"fn getDouble() -> f32 {
return 0f;
}

fn main() -> f32 {
return (getDouble() + getDouble());
}"
`);
});
});
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this intended behavior?

it('complex', () => {
  const valueAccess = tgpu['~unstable'].accessor(d.f32);
  const slot = tgpu.slot(5);

  const getValue = () => {
    'use gpu';
    return valueAccess.$ * slot.$;
  };

  const getValueGeneric = tgpu.fn(getValue).with(valueAccess, 2).with(
    valueAccess,
    4,
  ).with(slot, 7).with(slot, 5);
  const getValueGenericAgain = tgpu.fn(getValue).with(valueAccess, 4);
  const getValueShell = tgpu.fn([], d.f32)(getValue).with(
    valueAccess,
    4,
  );
  const getValueShell2 = tgpu.fn([], d.f32)(getValue).with(
    valueAccess,
    4,
  );

  const main = () => {
    'use gpu';
    return getValueShell() + getValueGeneric() + getValueGenericAgain() +
      getValueShell2();
  };

  expect(tgpu.resolve([main])).toMatchInlineSnapshot(`
    "fn getValue() -> f32 {
      return 20f;
    }

    fn getValue_1() -> f32 {
      return 20f;
    }

    fn getValue_2() -> f32 {
      return 20f;
    }

    fn main() -> f32 {
      return (((getValue() + getValue_1()) + getValue_1()) + getValue_2());
    }"
  `);
});