Skip to content

Commit b95a07b

Browse files
authored
my issue>>>?????
1 parent 3249ebf commit b95a07b

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

build_rknn.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
import onnx
44

5-
from numpy import cumsum
5+
from numpy import cumsum, max, exp, sum
66
from rknn.api.custom_op import get_node_attr
77

88
class cstCumSum:
@@ -17,6 +17,19 @@ def compute(self, node, inputs):
1717
axis = get_node_attr(node, "axis")
1818
return [cumsum(x, axis=axis)]
1919

20+
class cstSoftmax:
21+
op_type = 'cstSoftmax'
22+
def shape_infer(self, node, in_shapes, in_dtypes):
23+
out_shapes = in_shapes.copy()
24+
out_dtypes = in_dtypes.copy()
25+
return out_shapes, out_dtypes
26+
def compute(self, node, inputs):
27+
x = inputs[0]
28+
axis = get_node_attr(node, 'axis')
29+
x_max = max(x, axis=axis, keepdims=True)
30+
tmp = exp(x - x_max)
31+
s = sum(tmp, axis=axis, keepdims=True)
32+
2033

2134
# Function to modify the ONNX model
2235
def modify_onnx_change_op_type(input_onnx_path, output_onnx_path, old_op_type, new_op_type):
@@ -57,7 +70,7 @@ def ConvertModel(model_path='ViT-B-32__openai/textual/model.onnx', target_platfo
5770
rknn.config(target_platform=target_platform, dynamic_input=dynamic_input, disable_rules=['fuse_matmul_softmax_matmul_to_sdpa'])
5871

5972
modified_onnx_path = model_path.replace('.onnx', '_cstcumsum.onnx')
60-
modified = modify_onnx_change_op_type(model_path, modified_onnx_path, "CumSum", "cstCumSum")
73+
modified = modify_onnx_change_op_type(model_path, modified_onnx_path, "CumSum", "cstSoftmax")
6174
onnx_to_load = modified_onnx_path if modified else model_path
6275
if modified:
6376
ret = rknn.reg_custom_op(cstCumSum())

0 commit comments

Comments
 (0)