22import os
33import onnx
44
5- from numpy import cumsum
5+ from numpy import cumsum , max , exp , sum
66from rknn .api .custom_op import get_node_attr
77
88class cstCumSum :
@@ -17,6 +17,19 @@ def compute(self, node, inputs):
1717 axis = get_node_attr (node , "axis" )
1818 return [cumsum (x , axis = axis )]
1919
20+ class cstSoftmax :
21+ op_type = 'cstSoftmax'
22+ def shape_infer (self , node , in_shapes , in_dtypes ):
23+ out_shapes = in_shapes .copy ()
24+ out_dtypes = in_dtypes .copy ()
25+ return out_shapes , out_dtypes
26+ def compute (self , node , inputs ):
27+ x = inputs [0 ]
28+ axis = get_node_attr (node , 'axis' )
29+ x_max = max (x , axis = axis , keepdims = True )
30+ tmp = exp (x - x_max )
31+ s = sum (tmp , axis = axis , keepdims = True )
32+
2033
2134# Function to modify the ONNX model
2235def modify_onnx_change_op_type (input_onnx_path , output_onnx_path , old_op_type , new_op_type ):
@@ -57,7 +70,7 @@ def ConvertModel(model_path='ViT-B-32__openai/textual/model.onnx', target_platfo
5770 rknn .config (target_platform = target_platform , dynamic_input = dynamic_input , disable_rules = ['fuse_matmul_softmax_matmul_to_sdpa' ])
5871
5972 modified_onnx_path = model_path .replace ('.onnx' , '_cstcumsum.onnx' )
60- modified = modify_onnx_change_op_type (model_path , modified_onnx_path , "CumSum" , "cstCumSum " )
73+ modified = modify_onnx_change_op_type (model_path , modified_onnx_path , "CumSum" , "cstSoftmax " )
6174 onnx_to_load = modified_onnx_path if modified else model_path
6275 if modified :
6376 ret = rknn .reg_custom_op (cstCumSum ())
0 commit comments