1- import type { DualFn } from '../../data/dualFn .ts' ;
1+ import type { AnyData } from '../../data/dataTypes .ts' ;
22import { type MapValueToSnippet , snip } from '../../data/snippet.ts' ;
3- import { getResolutionCtx , inCodegenMode } from '../../execMode.ts' ;
4- import { isKnownAtComptime , NormalState } from '../../types.ts' ;
53import { setName } from '../../shared/meta.ts' ;
6- import { $internal } from '../../shared/symbols.ts' ;
4+ import { $gpuCallable } from '../../shared/symbols.ts' ;
75import { tryConvertSnippet } from '../../tgsl/conversion.ts' ;
8- import type { AnyData } from '../../data/dataTypes.ts' ;
6+ import {
7+ type DualFn ,
8+ isKnownAtComptime ,
9+ NormalState ,
10+ type ResolutionCtx ,
11+ } from '../../types.ts' ;
912
1013type MapValueToDataType < T > = { [ K in keyof T ] : AnyData } ;
14+ type AnyFn = ( ...args : never [ ] ) => unknown ;
1115
12- interface DualImplOptions < T extends ( ... args : never [ ] ) => unknown > {
16+ interface DualImplOptions < T extends AnyFn > {
1317 readonly name : string | undefined ;
1418 readonly normalImpl : T | string ;
15- readonly codegenImpl : ( ...args : MapValueToSnippet < Parameters < T > > ) => string ;
19+ readonly codegenImpl : (
20+ ctx : ResolutionCtx ,
21+ args : MapValueToSnippet < Parameters < T > > ,
22+ ) => string ;
1623 readonly signature :
1724 | { argTypes : AnyData [ ] ; returnType : AnyData }
1825 | ( (
@@ -34,90 +41,83 @@ export class MissingCpuImplError extends Error {
3441 }
3542}
3643
37- export function dualImpl < T extends ( ... args : never [ ] ) => unknown > (
44+ export function dualImpl < T extends AnyFn > (
3845 options : DualImplOptions < T > ,
3946) : DualFn < T > {
40- const gpuImpl = ( ...args : MapValueToSnippet < Parameters < T > > ) => {
41- // biome-ignore lint/style/noNonNullAssertion: it's there
42- const ctx = getResolutionCtx ( ) ! ;
43- const { argTypes, returnType } = typeof options . signature === 'function'
44- ? options . signature (
45- ...args . map ( ( s ) => {
46- // Dereference implicit pointers
47- if ( s . dataType . type === 'ptr' && s . dataType . implicit ) {
48- return s . dataType . inner ;
49- }
50- return s . dataType ;
51- } ) as MapValueToDataType < Parameters < T > > ,
52- )
53- : options . signature ;
54-
55- const argSnippets = args as MapValueToSnippet < Parameters < T > > ;
56- const converted = argSnippets . map ( ( s , idx ) => {
57- const argType = argTypes [ idx ] ;
58- if ( ! argType ) {
59- throw new Error ( 'Function called with invalid arguments' ) ;
60- }
61- return tryConvertSnippet ( s , argType , ! options . ignoreImplicitCastWarning ) ;
62- } ) as MapValueToSnippet < Parameters < T > > ;
63-
64- if (
65- ! options . noComptime &&
66- converted . every ( ( s ) => isKnownAtComptime ( s ) ) &&
67- typeof options . normalImpl === 'function'
68- ) {
69- ctx . pushMode ( new NormalState ( ) ) ;
70- try {
71- return snip (
72- options . normalImpl ( ...converted . map ( ( s ) => s . value ) as never [ ] ) ,
73- returnType ,
74- // Functions give up ownership of their return value
75- /* origin */ 'constant' ,
76- ) ;
77- } catch ( e ) {
78- // cpuImpl may in some cases be present but implemented only partially.
79- // In that case, if the MissingCpuImplError is thrown, we fallback to codegenImpl.
80- // If it is any other error, we just rethrow.
81- if ( ! ( e instanceof MissingCpuImplError ) ) {
82- throw e ;
83- }
84- } finally {
85- ctx . popMode ( 'normal' ) ;
86- }
87- }
88-
89- return snip (
90- options . codegenImpl ( ...converted ) ,
91- returnType ,
92- // Functions give up ownership of their return value
93- /* origin */ 'runtime' ,
94- ) ;
95- } ;
96-
9747 const impl = ( ( ...args : Parameters < T > ) => {
98- if ( inCodegenMode ( ) ) {
99- return gpuImpl ( ...args as MapValueToSnippet < Parameters < T > > ) ;
100- }
10148 if ( typeof options . normalImpl === 'string' ) {
10249 throw new MissingCpuImplError ( options . normalImpl ) ;
10350 }
10451 return options . normalImpl ( ...args ) ;
105- } ) as T ;
52+ } ) as DualFn < T > ;
10653
10754 setName ( impl , options . name ) ;
10855 impl . toString = ( ) => options . name ?? '<unknown>' ;
109- Object . defineProperty ( impl , $internal , {
110- value : {
111- jsImpl : options . normalImpl ,
112- gpuImpl,
113- get strictSignature ( ) {
114- return typeof options . signature !== 'function'
115- ? options . signature
116- : undefined ;
117- } ,
118- argConversionHint : 'keep' ,
56+ impl [ $gpuCallable ] = {
57+ get strictSignature ( ) {
58+ return typeof options . signature !== 'function'
59+ ? options . signature
60+ : undefined ;
61+ } ,
62+ call ( ctx , args ) {
63+ const { argTypes, returnType } = typeof options . signature === 'function'
64+ ? options . signature (
65+ ...args . map ( ( s ) => {
66+ // Dereference implicit pointers
67+ if ( s . dataType . type === 'ptr' && s . dataType . implicit ) {
68+ return s . dataType . inner ;
69+ }
70+ return s . dataType ;
71+ } ) as MapValueToDataType < Parameters < T > > ,
72+ )
73+ : options . signature ;
74+
75+ const converted = args . map ( ( s , idx ) => {
76+ const argType = argTypes [ idx ] ;
77+ if ( ! argType ) {
78+ throw new Error ( 'Function called with invalid arguments' ) ;
79+ }
80+ return tryConvertSnippet (
81+ ctx ,
82+ s ,
83+ argType ,
84+ ! options . ignoreImplicitCastWarning ,
85+ ) ;
86+ } ) as MapValueToSnippet < Parameters < T > > ;
87+
88+ if (
89+ ! options . noComptime &&
90+ converted . every ( ( s ) => isKnownAtComptime ( s ) ) &&
91+ typeof options . normalImpl === 'function'
92+ ) {
93+ ctx . pushMode ( new NormalState ( ) ) ;
94+ try {
95+ return snip (
96+ options . normalImpl ( ...converted . map ( ( s ) => s . value ) as never [ ] ) ,
97+ returnType ,
98+ // Functions give up ownership of their return value
99+ /* origin */ 'constant' ,
100+ ) ;
101+ } catch ( e ) {
102+ // cpuImpl may in some cases be present but implemented only partially.
103+ // In that case, if the MissingCpuImplError is thrown, we fallback to codegenImpl.
104+ // If it is any other error, we just rethrow.
105+ if ( ! ( e instanceof MissingCpuImplError ) ) {
106+ throw e ;
107+ }
108+ } finally {
109+ ctx . popMode ( 'normal' ) ;
110+ }
111+ }
112+
113+ return snip (
114+ options . codegenImpl ( ctx , converted ) ,
115+ returnType ,
116+ // Functions give up ownership of their return value
117+ /* origin */ 'runtime' ,
118+ ) ;
119119 } ,
120- } ) ;
120+ } ;
121121
122- return impl as DualFn < T > ;
122+ return impl ;
123123}
0 commit comments