Skip to content

Commit 5e1610c

Browse files
authored
back to where we start
1 parent b95a07b commit 5e1610c

File tree

1 file changed

+5
-45
lines changed

1 file changed

+5
-45
lines changed

build_rknn.py

Lines changed: 5 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
from numpy import cumsum, max, exp, sum
66
from rknn.api.custom_op import get_node_attr
77

8-
class cstCumSum:
8+
class CumSum:
99
# Just CumSum with a different name so it wont conflict
10-
op_type = "cstCumSum"
10+
op_type = "CumSum"
1111

1212
def shape_infer(self, node, in_shapes, in_dtypes):
1313
return in_shapes.copy(), in_dtypes.copy()
@@ -17,44 +17,6 @@ def compute(self, node, inputs):
1717
axis = get_node_attr(node, "axis")
1818
return [cumsum(x, axis=axis)]
1919

20-
class cstSoftmax:
21-
op_type = 'cstSoftmax'
22-
def shape_infer(self, node, in_shapes, in_dtypes):
23-
out_shapes = in_shapes.copy()
24-
out_dtypes = in_dtypes.copy()
25-
return out_shapes, out_dtypes
26-
def compute(self, node, inputs):
27-
x = inputs[0]
28-
axis = get_node_attr(node, 'axis')
29-
x_max = max(x, axis=axis, keepdims=True)
30-
tmp = exp(x - x_max)
31-
s = sum(tmp, axis=axis, keepdims=True)
32-
33-
34-
# Function to modify the ONNX model
35-
def modify_onnx_change_op_type(input_onnx_path, output_onnx_path, old_op_type, new_op_type):
36-
print(f"Loading ONNX model: {input_onnx_path}")
37-
model = onnx.load(input_onnx_path)
38-
nodes_changed = 0
39-
40-
for node in model.graph.node:
41-
if node.op_type == old_op_type:
42-
print(f" Found node '{node.name}' with op_type '{old_op_type}'. Changing to '{new_op_type}'.")
43-
# Create a new node with the new op_type, keeping everything else
44-
node.op_type = new_op_type
45-
nodes_changed += 1
46-
47-
if nodes_changed > 0:
48-
print(f"Saving modified ONNX model to: {output_onnx_path}")
49-
onnx.save(model, output_onnx_path)
50-
del model
51-
modified_model = onnx.load(output_onnx_path) # idk if it loads
52-
print(f"Successfully changed {nodes_changed} nodes from '{old_op_type}' to '{new_op_type}'.")
53-
return True # Indicate modification happened
54-
else:
55-
print(f"No nodes with op_type '{old_op_type}' found. No modifications made.")
56-
return False # Indicate no modification happened
57-
5820

5921
parser = argparse.ArgumentParser("RKNN model converting")
6022
parser.add_argument("model", help="Directory of the model that will be exported to RKNN ex:ViT-B-32__openai.", type=str)
@@ -69,11 +31,9 @@ def ConvertModel(model_path='ViT-B-32__openai/textual/model.onnx', target_platfo
6931

7032
rknn.config(target_platform=target_platform, dynamic_input=dynamic_input, disable_rules=['fuse_matmul_softmax_matmul_to_sdpa'])
7133

72-
modified_onnx_path = model_path.replace('.onnx', '_cstcumsum.onnx')
73-
modified = modify_onnx_change_op_type(model_path, modified_onnx_path, "CumSum", "cstSoftmax")
74-
onnx_to_load = modified_onnx_path if modified else model_path
75-
if modified:
76-
ret = rknn.reg_custom_op(cstCumSum())
34+
onnx_to_load = model_path
35+
if 1:
36+
ret = rknn.reg_custom_op(CumSum())
7737

7838
if ret != 0:
7939
raise RuntimeError("Register Custom OP failed!")

0 commit comments

Comments
 (0)