@@ -93,18 +93,31 @@ export const initEp = async(env: Env, epName: string): Promise<void> => {
9393 if ( typeof navigator === 'undefined' || ! navigator . gpu ) {
9494 throw new Error ( 'WebGPU is not supported in current environment' ) ;
9595 }
96- const powerPreference = env . webgpu ?. powerPreference ;
97- if ( powerPreference !== undefined && powerPreference !== 'low-power' && powerPreference !== 'high-performance' ) {
98- throw new Error ( `Invalid powerPreference setting: "${ powerPreference } "` ) ;
99- }
100- const forceFallbackAdapter = env . webgpu ?. forceFallbackAdapter ;
101- if ( forceFallbackAdapter !== undefined && typeof forceFallbackAdapter !== 'boolean' ) {
102- throw new Error ( `Invalid forceFallbackAdapter setting: "${ forceFallbackAdapter } "` ) ;
103- }
104- const adapter = await navigator . gpu . requestAdapter ( { powerPreference, forceFallbackAdapter} ) ;
96+
97+ let adapter = env . webgpu . adapter as GPUAdapter | null ;
10598 if ( ! adapter ) {
106- throw new Error (
107- 'Failed to get GPU adapter. You may need to enable flag "--enable-unsafe-webgpu" if you are using Chrome.' ) ;
99+ // if adapter is not set, request a new adapter.
100+ const powerPreference = env . webgpu . powerPreference ;
101+ if ( powerPreference !== undefined && powerPreference !== 'low-power' &&
102+ powerPreference !== 'high-performance' ) {
103+ throw new Error ( `Invalid powerPreference setting: "${ powerPreference } "` ) ;
104+ }
105+ const forceFallbackAdapter = env . webgpu . forceFallbackAdapter ;
106+ if ( forceFallbackAdapter !== undefined && typeof forceFallbackAdapter !== 'boolean' ) {
107+ throw new Error ( `Invalid forceFallbackAdapter setting: "${ forceFallbackAdapter } "` ) ;
108+ }
109+ adapter = await navigator . gpu . requestAdapter ( { powerPreference, forceFallbackAdapter} ) ;
110+ if ( ! adapter ) {
111+ throw new Error (
112+ 'Failed to get GPU adapter. ' +
113+ 'You may need to enable flag "--enable-unsafe-webgpu" if you are using Chrome.' ) ;
114+ }
115+ } else {
116+ // if adapter is set, validate it.
117+ if ( typeof adapter . limits !== 'object' || typeof adapter . features !== 'object' ||
118+ typeof adapter . requestDevice !== 'function' ) {
119+ throw new Error ( 'Invalid GPU adapter set in `env.webgpu.adapter`. It must be a GPUAdapter object.' ) ;
120+ }
108121 }
109122
110123 if ( ! env . wasm . simd ) {
0 commit comments