@@ -414,18 +414,56 @@ fn get_device_override() -> Option<WgpuDevice> {
414414 } )
415415}
416416
417+ #[ cfg( all( target_arch = "wasm32" , target_feature = "atomics" ) ) ]
417418#[ derive( Debug , Clone ) ]
418419pub struct ThreadLocalChannel {
419420 device : WgpuDevice ,
420421}
421422
423+ #[ cfg( all( target_arch = "wasm32" , target_feature = "atomics" ) ) ]
424+ impl ThreadLocalChannel {
425+ fn make_server ( device : & WgpuDevice ) -> Rc < RefCell < Server > > {
426+ let setup = future:: block_on ( create_setup_for_device :: < AutoGraphicsApi , WgslCompiler > (
427+ device,
428+ ) ) ;
429+
430+ let limits = setup. device . limits ( ) ;
431+ let mem_props = MemoryDeviceProperties {
432+ max_page_size : limits. max_storage_buffer_binding_size as u64 ,
433+ alignment : WgpuStorage :: ALIGNMENT
434+ . max ( limits. min_storage_buffer_offset_alignment as u64 ) ,
435+ } ;
436+
437+ let options = RuntimeOptions :: default ( ) ;
438+ let memory_management = {
439+ let mem_props = mem_props. clone ( ) ;
440+ let config = options. memory_config ;
441+ let storage = WgpuStorage :: new ( setup. device . clone ( ) ) ;
442+ MemoryManagement :: from_configuration ( storage, mem_props, config)
443+ } ;
444+ let server = crate :: compute:: WgpuServer :: new (
445+ memory_management,
446+ setup. device . clone ( ) ,
447+ setup. queue ,
448+ options. tasks_max ,
449+ ) ;
450+
451+ Rc :: new ( RefCell :: new ( server) )
452+ }
453+ }
454+
455+ #[ cfg( all( target_arch = "wasm32" , target_feature = "atomics" ) ) ]
422456impl ComputeChannel < Server > for ThreadLocalChannel {
423457 fn read (
424458 & self ,
425459 binding : cubecl_core:: server:: Binding ,
426460 ) -> impl std:: future:: Future < Output = Vec < u8 > > {
427461 LOCAL_RUNTIME . with ( |runtime| {
428- let server = runtime. borrow ( ) [ & self . device ] . clone ( ) ;
462+ let server = runtime
463+ . borrow_mut ( )
464+ . entry ( self . device . clone ( ) )
465+ . or_insert_with ( || Self :: make_server ( & self . device ) )
466+ . clone ( ) ;
429467 async move { server. borrow_mut ( ) . read ( binding) . await }
430468 } )
431469 }
@@ -435,18 +473,35 @@ impl ComputeChannel<Server> for ThreadLocalChannel {
435473 binding : cubecl_core:: server:: Binding ,
436474 ) -> cubecl_runtime:: storage:: BindingResource < Server > {
437475 LOCAL_RUNTIME . with ( |runtime| {
438- runtime. borrow ( ) [ & self . device ]
476+ runtime
477+ . borrow_mut ( )
478+ . entry ( self . device . clone ( ) )
479+ . or_insert_with ( || Self :: make_server ( & self . device ) )
439480 . borrow_mut ( )
440481 . get_resource ( binding)
441482 } )
442483 }
443484
444485 fn create ( & self , data : & [ u8 ] ) -> cubecl_core:: server:: Handle {
445- LOCAL_RUNTIME . with ( |runtime| runtime. borrow ( ) [ & self . device ] . borrow_mut ( ) . create ( data) )
486+ LOCAL_RUNTIME . with ( |runtime| {
487+ runtime
488+ . borrow_mut ( )
489+ . entry ( self . device . clone ( ) )
490+ . or_insert_with ( || Self :: make_server ( & self . device ) )
491+ . borrow_mut ( )
492+ . create ( data)
493+ } )
446494 }
447495
448496 fn empty ( & self , size : usize ) -> cubecl_core:: server:: Handle {
449- LOCAL_RUNTIME . with ( |runtime| runtime. borrow ( ) [ & self . device ] . borrow_mut ( ) . empty ( size) )
497+ LOCAL_RUNTIME . with ( |runtime| {
498+ runtime
499+ . borrow_mut ( )
500+ . entry ( self . device . clone ( ) )
501+ . or_insert_with ( || Self :: make_server ( & self . device ) )
502+ . borrow_mut ( )
503+ . empty ( size)
504+ } )
450505 }
451506
452507 unsafe fn execute (
@@ -457,45 +512,76 @@ impl ComputeChannel<Server> for ThreadLocalChannel {
457512 mode : cubecl_core:: ExecutionMode ,
458513 ) {
459514 LOCAL_RUNTIME . with ( |runtime| {
460- let runtime = runtime. borrow ( ) ;
461- let mut server = runtime[ & self . device ] . borrow_mut ( ) ;
515+ let mut runtime = runtime. borrow_mut ( ) ;
516+ let mut server = runtime
517+ . entry ( self . device . clone ( ) )
518+ . or_insert_with ( || Self :: make_server ( & self . device ) )
519+ . borrow_mut ( ) ;
462520 unsafe { server. execute ( kernel, count, bindings, mode) }
463521 } )
464522 }
465523
466524 fn flush ( & self ) {
467- LOCAL_RUNTIME . with ( |runtime| runtime. borrow ( ) [ & self . device ] . borrow_mut ( ) . flush ( ) )
525+ LOCAL_RUNTIME . with ( |runtime| {
526+ runtime
527+ . borrow_mut ( )
528+ . entry ( self . device . clone ( ) )
529+ . or_insert_with ( || Self :: make_server ( & self . device ) )
530+ . borrow_mut ( )
531+ . flush ( )
532+ } )
468533 }
469534
470535 fn sync ( & self ) -> impl std:: future:: Future < Output = ( ) > {
471536 LOCAL_RUNTIME . with ( |runtime| {
472- let server = runtime. borrow ( ) [ & self . device ] . clone ( ) ;
537+ let server = runtime
538+ . borrow_mut ( )
539+ . entry ( self . device . clone ( ) )
540+ . or_insert_with ( || Self :: make_server ( & self . device ) )
541+ . clone ( ) ;
473542 async move { server. borrow_mut ( ) . sync ( ) . await }
474543 } )
475544 }
476545
477546 fn sync_elapsed ( & self ) -> impl std:: future:: Future < Output = cubecl_runtime:: TimestampsResult > {
478547 LOCAL_RUNTIME . with ( |runtime| {
479- let server = runtime. borrow ( ) [ & self . device ] . clone ( ) ;
548+ let server = runtime
549+ . borrow_mut ( )
550+ . entry ( self . device . clone ( ) )
551+ . or_insert_with ( || Self :: make_server ( & self . device ) )
552+ . clone ( ) ;
480553 async move { server. borrow_mut ( ) . sync_elapsed ( ) . await }
481554 } )
482555 }
483556
484557 fn memory_usage ( & self ) -> cubecl_runtime:: memory_management:: MemoryUsage {
485- LOCAL_RUNTIME . with ( |runtime| runtime. borrow ( ) [ & self . device ] . borrow_mut ( ) . memory_usage ( ) )
558+ LOCAL_RUNTIME . with ( |runtime| {
559+ runtime
560+ . borrow_mut ( )
561+ . entry ( self . device . clone ( ) )
562+ . or_insert_with ( || Self :: make_server ( & self . device ) )
563+ . borrow_mut ( )
564+ . memory_usage ( )
565+ } )
486566 }
487567
488568 fn enable_timestamps ( & self ) {
489569 LOCAL_RUNTIME . with ( |runtime| {
490- runtime. borrow ( ) [ & self . device ]
570+ runtime
571+ . borrow_mut ( )
572+ . entry ( self . device . clone ( ) )
573+ . or_insert_with ( || Self :: make_server ( & self . device ) )
491574 . borrow_mut ( )
492575 . enable_timestamps ( )
493576 } )
494577 }
495578
496579 fn disable_timestamps ( & self ) {
497580 LOCAL_RUNTIME . with ( |runtime| {
498- runtime. borrow ( ) [ & self . device ]
581+ runtime
582+ . borrow_mut ( )
583+ . entry ( self . device . clone ( ) )
584+ . or_insert_with ( || Self :: make_server ( & self . device ) )
499585 . borrow_mut ( )
500586 . disable_timestamps ( )
501587 } )
0 commit comments