File tree Expand file tree Collapse file tree 1 file changed +9
-9
lines changed Expand file tree Collapse file tree 1 file changed +9
-9
lines changed Original file line number Diff line number Diff line change 2424if  is_torch_npu_available ():
2525    import  torch_npu 
2626
27- ACTIVATION_FUNCTIONS  =  {
28-     "swish" : nn .SiLU () ,
29-     "silu" : nn .SiLU () ,
30-     "mish" : nn .Mish () ,
31-     "gelu" : nn .GELU () ,
32-     "relu" : nn .ReLU () ,
27+ ACT2CLS  =  {
28+     "swish" : nn .SiLU ,
29+     "silu" : nn .SiLU ,
30+     "mish" : nn .Mish ,
31+     "gelu" : nn .GELU ,
32+     "relu" : nn .ReLU ,
3333}
3434
3535
@@ -44,10 +44,10 @@ def get_activation(act_fn: str) -> nn.Module:
4444    """ 
4545
4646    act_fn  =  act_fn .lower ()
47-     if  act_fn  in  ACTIVATION_FUNCTIONS :
48-         return  ACTIVATION_FUNCTIONS [act_fn ]
47+     if  act_fn  in  ACT2CLS :
48+         return  ACT2CLS [act_fn ]() 
4949    else :
50-         raise  ValueError (f"Unsupported  activation function:  { act_fn }  )
50+         raise  ValueError (f"activation function { act_fn }  not found in ACT2FN mapping  { list ( ACT2CLS . keys ()) }  )
5151
5252
5353class  FP32SiLU (nn .Module ):
 
 
   
 
     
   
   
          
    
    
     
    
      
     
     
    You can’t perform that action at this time.
  
 
    
  
    
      
        
     
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments