4
4
from .base_emitter import BaseEmitter
5
5
6
6
_types = {
7
+ TensorProto .DOUBLE : "DOUBLE" ,
7
8
TensorProto .FLOAT : "FLOAT" ,
8
9
TensorProto .FLOAT16 : "FLOAT16" ,
9
10
TensorProto .INT64 : "INT64" ,
10
11
TensorProto .INT32 : "INT32" ,
12
+ TensorProto .INT16 : "INT16" ,
13
+ TensorProto .UINT64 : "UINT64" ,
14
+ TensorProto .UINT32 : "UINT32" ,
15
+ TensorProto .UINT16 : "UINT16" ,
16
+ TensorProto .STRING : "STRING" ,
17
+ TensorProto .BOOL : "BOOL" ,
11
18
}
12
19
13
20
@@ -20,6 +27,10 @@ class BuilderEmitter(BaseEmitter):
20
27
Converts event into proper code.
21
28
"""
22
29
30
+ def __init__ (self , make_model_function : str = "" ):
31
+ super ().__init__ ()
32
+ self .make_model_function = make_model_function
33
+
23
34
def join (self , rows : List [str ], single_line : bool = False ) -> str :
24
35
"Join the rows"
25
36
assert (
@@ -29,6 +40,7 @@ def join(self, rows: List[str], single_line: bool = False) -> str:
29
40
30
41
def _emit_start (self , ** kwargs : Dict [str , Any ]) -> List [str ]:
31
42
self .opsets = kwargs .get ("opsets" , {})
43
+ self .ir_version = kwargs .get ("ir_version" , None )
32
44
return []
33
45
34
46
def _emit_to_onnx_model (self , ** kwargs : Dict [str , Any ]) -> List [str ]:
@@ -43,12 +55,27 @@ def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]:
43
55
)
44
56
rows = [
45
57
"" ,
46
- f"g = GraphBuilder({ self .opsets } )" ,
58
+ (
59
+ f"g = GraphBuilder({ self .opsets } , ir_version={ self .ir_version } )"
60
+ if self .ir_version
61
+ else f"GraphBuilder({ self .opsets } )"
62
+ ),
47
63
* inputs ,
48
64
f"{ self .name } ({ inps } )" ,
49
65
* outputs ,
50
66
"model = g.to_onnx()" ,
51
67
]
68
+ if self .make_model_function :
69
+ rows = [
70
+ "" ,
71
+ "" ,
72
+ f'def { self .make_model_function } () -> "ModelProto":' ,
73
+ * [" " + _ for _ in rows [1 :]],
74
+ " return model" ,
75
+ "" ,
76
+ "" ,
77
+ f"model = { self .make_model_function } ()" ,
78
+ ]
52
79
return rows
53
80
54
81
def _emit_begin_graph (self , ** kwargs : Dict [str , Any ]) -> List [str ]:
@@ -78,13 +105,16 @@ def _emit_input(self, **kwargs: Dict[str, Any]) -> List[str]:
78
105
name = kwargs ["name" ]
79
106
itype = kwargs .get ("elem_type" , 0 )
80
107
shape = kwargs .get ("shape" , None )
108
+ name = self ._clean_result_name (name )
81
109
if itype == 0 :
82
- inp = "X"
110
+ inp = name or "X"
83
111
else :
84
112
if shape is None :
85
- inp = f'X : "{ _itype_to_string (itype )} "'
113
+ inp = f'{ name } : "{ _itype_to_string (itype )} "'
86
114
else :
87
- inp = f'X: "{ _itype_to_string (itype )} [{ ", " .join (map (str , shape ))} ]"'
115
+ inp = (
116
+ f'{ name } : "{ _itype_to_string (itype )} [{ ", " .join (map (str , shape ))} ]"'
117
+ )
88
118
self .inputs_full .append (inp )
89
119
self .inputs .append (name )
90
120
self .inputs_full_ .append ((name , _itype_to_string (itype ), shape ))
@@ -113,6 +143,7 @@ def _emit_end_return(self, **kwargs: Dict[str, Any]) -> List[str]:
113
143
114
144
def _emit_output (self , ** kwargs : Dict [str , Any ]) -> List [str ]:
115
145
name = kwargs ["name" ]
146
+ name = self ._clean_result_name (name )
116
147
itype = kwargs .get ("elem_type" , 0 )
117
148
shape = kwargs .get ("shape" , None )
118
149
self .outputs .append (name )
@@ -126,6 +157,8 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]:
126
157
if kwargs .get ("domain" , "" ) != "" :
127
158
domain = kwargs ["domain" ]
128
159
op_type = f"{ domain } .{ op_type } "
160
+ else :
161
+ domain = ""
129
162
atts = kwargs .get ("atts" , {})
130
163
args = []
131
164
for k , v in atts .items ():
@@ -134,11 +167,22 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]:
134
167
raise NotImplementedError ("Graph attribute not supported yet." )
135
168
args .append (f"{ k } ={ vatt } " )
136
169
137
- outs = ", " .join (outputs )
138
- inps = ", " .join (inputs )
170
+ outs = ", " .join (map (self ._clean_result_name , outputs ))
171
+ inps = ", " .join (map (self ._clean_result_name , inputs ))
172
+ op_type = self ._emit_node_type (op_type , domain )
173
+ sdomain = "" if not domain else f", domain={ domain !r} "
139
174
if args :
140
175
sargs = ", " .join (args )
141
- row = f" { outs } = op.{ op_type } ({ inps } , { sargs } )"
176
+ if inps :
177
+ row = f" { outs } = op.{ op_type } ({ inps } , { sargs } { sdomain } )"
178
+ else :
179
+ row = f" { outs } = op.{ op_type } ({ sargs } { sdomain } )"
142
180
else :
143
- row = f" { outs } = op.{ op_type } ({ inps } )"
181
+ row = f" { outs } = op.{ op_type } ({ inps } { sdomain } )"
144
182
return [row ]
183
+
184
+ def _clean_result_name (self , name ):
185
+ return name
186
+
187
+ def _emit_node_type (self , op_type , domain ):
188
+ return op_type
0 commit comments