Skip to content

Commit df08050

Browse files
authored
Make tf flow functions usable for onnx (#431)
1 parent ea4e491 commit df08050

File tree

1 file changed

+130
-45
lines changed

1 file changed

+130
-45
lines changed

returnn/flow.py

Lines changed: 130 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,55 @@
11
from sisyphus import tk
22
from sisyphus.delayed_ops import DelayedJoin
3-
from typing import Optional, List, Union
3+
from typing import Optional, List, Union, Dict
44

55
from i6_core import rasr
66
from i6_core.returnn.training import Checkpoint
77

88

9+
def make_precomputed_hybrid_feature_flow(
10+
backend: str,
11+
rasr_config: rasr.RasrConfig,
12+
rasr_post_config: Optional[rasr.RasrConfig] = None,
13+
fwd_input_name: str = "fwd-input",
14+
) -> rasr.FlowNetwork:
15+
"""
16+
Create the feature flow for a simple TF/ONNX network that predicts frame-wise outputs, to be used
17+
in combination with the `nn-precomputed-hybrid` feature-scorer setting in RASR.
18+
19+
The resulting flow is a trivial (for ONNX, the "tf" is replaced by "onnx"):
20+
21+
<link from="<tf_fwd_input_name>" to="tf-fwd:input"/>
22+
<node name="tf-fwd" id="$(id)" filter="tensorflow-forward"/>
23+
<link from="tf-fwd:log-posteriors" to="network:features"/>
24+
25+
:param backend: "tf" or "onnx"
26+
:param rasr_config: rasr config for the forward node
27+
:param rasr_post_config: rasr post config (not hashed) for the forward node
28+
:param fwd_input_name: naming for the tf network input, usually no need to be changed
29+
:return: tensorflow-/onnx-forward node flow with output link and related config
30+
"""
31+
32+
# flow (model scoring done in tf/onnx flow node)
33+
flow = rasr.FlowNetwork()
34+
flow.add_input(fwd_input_name)
35+
flow.add_output("features")
36+
flow.add_param("id")
37+
38+
assert backend in ("tf", "onnx"), f"backend not supported: {backend}"
39+
node_filter = {"tf": "tensorflow-forward", "onnx": "onnx-forward"}[backend]
40+
fwd_node = flow.add_node(node_filter, f"{backend}-fwd", {"id": "$(id)"})
41+
flow.link(f"network:{fwd_input_name}", fwd_node + ":input")
42+
flow.link(fwd_node + ":log-posteriors", "network:features")
43+
44+
flow.config = rasr.RasrConfig()
45+
flow.config[fwd_node] = rasr_config
46+
if rasr_post_config is not None:
47+
flow.post_config = rasr.RasrConfig()
48+
flow.post_config[fwd_node] = rasr_post_config
49+
50+
return flow
51+
52+
953
def make_precomputed_hybrid_tf_feature_flow(
1054
tf_graph: tk.Path,
1155
tf_checkpoint: Checkpoint,
@@ -15,14 +59,8 @@ def make_precomputed_hybrid_tf_feature_flow(
1559
tf_fwd_input_name: str = "tf-fwd-input",
1660
) -> rasr.FlowNetwork:
1761
"""
18-
Create the feature flow for a simple TF network that predicts frame-wise outputs, to be used
19-
in combination with the `nn-precomputed-hybrid` feature-scorer setting in RASR.
20-
21-
The resulting flow is a trivial:
22-
23-
<link from="<tf_fwd_input_name>" to="tf-fwd:input"/>
24-
<node name="tf-fwd" id="$(id)" filter="tensorflow-forward"/>
25-
<link from="tf-fwd:log-posteriors" to="network:features"/>
62+
Create the feature flow for a simple TF network that predicts frame-wise outputs,
63+
see make_precomputed_hybrid_feature_flow.
2664
2765
With the config settings:
2866
@@ -41,7 +79,6 @@ def make_precomputed_hybrid_tf_feature_flow(
4179
param-name = <output_type>
4280
tensor-name = <output_tensor_name>/output_batch_major
4381
44-
4582
:param tf_graph: usually the output of a CompileTFGraphJob
4683
:param tf_checkpoint: the checkpoint to load the model from, e.g. from a ReturnnTrainingJob or similar
4784
:param extern_data_name: name of the extern data entry to feed the features to
@@ -52,66 +89,114 @@ def make_precomputed_hybrid_tf_feature_flow(
5289
:return: tensorflow-forward node flow with output link and related config
5390
"""
5491

55-
# tf flow (model scoring done in tf flow node) #
56-
tf_flow = rasr.FlowNetwork()
57-
tf_flow.add_input(tf_fwd_input_name)
58-
tf_flow.add_output("features")
59-
tf_flow.add_param("id")
60-
61-
tf_fwd = tf_flow.add_node("tensorflow-forward", "tf-fwd", {"id": "$(id)"})
62-
tf_flow.link(f"network:{tf_fwd_input_name}", tf_fwd + ":input")
63-
tf_flow.link(tf_fwd + ":log-posteriors", "network:features")
64-
65-
tf_flow.config = rasr.RasrConfig()
66-
tf_flow.config[tf_fwd].input_map.info_0.param_name = "input"
67-
tf_flow.config[
68-
tf_fwd
69-
].input_map.info_0.tensor_name = f"extern_data/placeholders/{extern_data_name}/{extern_data_name}"
70-
tf_flow.config[tf_fwd].input_map.info_0.seq_length_tensor_name = (
92+
rasr_config = rasr.RasrConfig()
93+
rasr_config.input_map.info_0.param_name = "input"
94+
rasr_config.input_map.info_0.tensor_name = f"extern_data/placeholders/{extern_data_name}/{extern_data_name}"
95+
rasr_config.input_map.info_0.seq_length_tensor_name = (
7196
f"extern_data/placeholders/" f"{extern_data_name}/{extern_data_name}_dim0_size"
7297
)
7398

74-
tf_flow.config[tf_fwd].output_map.info_0.param_name = "log-posteriors"
75-
tf_flow.config[tf_fwd].output_map.info_0.tensor_name = f"{output_layer_name}/output_batch_major"
99+
rasr_config.output_map.info_0.param_name = "log-posteriors"
100+
rasr_config.output_map.info_0.tensor_name = f"{output_layer_name}/output_batch_major"
76101

77-
tf_flow.config[tf_fwd].loader.type = "meta"
78-
tf_flow.config[tf_fwd].loader.meta_graph_file = tf_graph
79-
tf_flow.config[tf_fwd].loader.saved_model_file = tf_checkpoint
102+
rasr_config.loader.type = "meta"
103+
rasr_config.loader.meta_graph_file = tf_graph
104+
rasr_config.loader.saved_model_file = tf_checkpoint
80105
if native_ops is not None:
81106
if isinstance(native_ops, list):
82-
tf_flow.config[tf_fwd].loader.required_libraries = DelayedJoin(native_ops, ";")
107+
rasr_config.loader.required_libraries = DelayedJoin(native_ops, ";")
83108
else:
84-
tf_flow.config[tf_fwd].loader.required_libraries = native_ops
109+
rasr_config.loader.required_libraries = native_ops
110+
return make_precomputed_hybrid_feature_flow(
111+
backend="tf",
112+
rasr_config=rasr_config,
113+
fwd_input_name=tf_fwd_input_name,
114+
)
85115

86-
return tf_flow
87116

117+
def make_precomputed_hybrid_onnx_feature_flow(
118+
onnx_model: tk.Path,
119+
io_map: Dict[str, str],
120+
onnx_fwd_input_name: str = "fwd-input",
121+
cpu: int = 1,
122+
) -> rasr.FlowNetwork:
123+
"""
124+
Create the feature flow for a simple ONNX network that predicts frame-wise outputs,
125+
see make_precomputed_hybrid_feature_flow.
88126
89-
def add_tf_flow_to_base_flow(
127+
With the config settings:
128+
129+
[flf-lattice-tool.network.recognizer.feature-extraction.onnx-fwd.io-map]
130+
features = data
131+
features-size = data_len
132+
output = classes
133+
134+
[flf-lattice-tool.network.recognizer.feature-extraction.onnx-fwd.session]
135+
file = <onnx_file>
136+
inter-op-num-threads = 2
137+
intra-op-num-threads = 2
138+
139+
:param onnx_model: usually the output of a OnnxExportJob
140+
:param io_map: e.g. {"features": "data", "output": "classes"}
141+
:param onnx_fwd_input_name: naming for the onnx network input, usually no need to be changed
142+
:param cpu: number of CPUs to use
143+
:return: onnx-forward node flow with output link and related config
144+
"""
145+
146+
rasr_config = rasr.RasrConfig()
147+
rasr_post_config = rasr.RasrConfig()
148+
for k, v in io_map.items():
149+
rasr_config.io_map[k] = v
150+
151+
rasr_config.session.file = onnx_model
152+
rasr_post_config.session.inter_op_num_threads = cpu
153+
rasr_post_config.session.intra_op_num_threads = cpu
154+
155+
return make_precomputed_hybrid_feature_flow(
156+
backend="onnx",
157+
rasr_config=rasr_config,
158+
rasr_post_config=rasr_post_config,
159+
fwd_input_name=onnx_fwd_input_name,
160+
)
161+
162+
163+
def add_fwd_flow_to_base_flow(
90164
base_flow: rasr.FlowNetwork,
91-
tf_flow: rasr.FlowNetwork,
92-
tf_fwd_input_name: str = "tf-fwd-input",
165+
fwd_flow: rasr.FlowNetwork,
166+
fwd_input_name: str = "fwd-input",
93167
) -> rasr.FlowNetwork:
94168
"""
95-
Integrate tf-fwd node into the regular flow network, passing the features into the input of the tf-flow net.
169+
Integrate tf- or onnx-fwd node into a regular flow network, passing the features to the input of the forwarding net.
96170
97171
:param FlowNetwork base_flow:
98-
:param FlowNetwork tf_flow:
99-
:param str tf_fwd_input_name: see: get_tf_flow()
172+
:param FlowNetwork fwd_flow:
173+
:param str fwd_input_name: see: make_precomputed_hybrid_feature_flow()
100174
:rtype: Combined FlowNetwork
101175
"""
102-
assert len(base_flow.outputs) == 1, "Not implemented otherwise" # see hard coded tf-fwd input
176+
assert len(base_flow.outputs) == 1, "Not implemented otherwise" # see hard coded fwd input
103177
base_output = list(base_flow.outputs)[0]
104178

105-
input_name = tf_fwd_input_name
179+
input_name = fwd_input_name
106180

107181
feature_flow = rasr.FlowNetwork()
108182
base_mapping = feature_flow.add_net(base_flow)
109-
tf_mapping = feature_flow.add_net(tf_flow)
183+
fwd_mapping = feature_flow.add_net(fwd_flow)
110184
feature_flow.interconnect_inputs(base_flow, base_mapping)
111-
feature_flow.interconnect(base_flow, base_mapping, tf_flow, tf_mapping, {base_output: input_name})
112-
feature_flow.interconnect_outputs(tf_flow, tf_mapping)
185+
feature_flow.interconnect(base_flow, base_mapping, fwd_flow, fwd_mapping, {base_output: input_name})
186+
feature_flow.interconnect_outputs(fwd_flow, fwd_mapping)
113187

114188
# ensure cache_mode as base feature net
115189
feature_flow.add_flags(base_flow.flags)
116190

117191
return feature_flow
192+
193+
194+
def add_tf_flow_to_base_flow(
195+
base_flow: rasr.FlowNetwork,
196+
tf_flow: rasr.FlowNetwork,
197+
tf_fwd_input_name: str = "tf-fwd-input",
198+
) -> rasr.FlowNetwork:
199+
"""
200+
Keep old name to avoid breaking setups
201+
"""
202+
return add_fwd_flow_to_base_flow(base_flow, tf_flow, tf_fwd_input_name)

0 commit comments

Comments
 (0)