Skip to content

Commit 1711e97

Browse files
author
Jan Buethe
committed
fixed enable_binary_blob option for CWriter
1 parent 2056881 commit 1711e97

File tree

3 files changed

+40
-41
lines changed

3 files changed

+40
-41
lines changed

dnn/torch/lossgen/export_lossgen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def c_export(args, model):
5252

5353
message = f"Auto generated from checkpoint {os.path.basename(args.checkpoint)}"
5454

55-
writer = CWriter(os.path.join(args.output_dir, "lossgen_data"), message=message, model_struct_name='LossGen')
55+
writer = CWriter(os.path.join(args.output_dir, "lossgen_data"), message=message, model_struct_name='LossGen', enable_binary_blob=False)
5656
writer.header.write(
5757
f"""
5858
#include "opus_types.h"

dnn/torch/weight-exchange/wexchange/c_export/c_writer.py

Lines changed: 38 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -120,50 +120,49 @@ def __init__(self,
120120
def _finalize_header(self):
121121

122122
# create model type
123-
if self.enable_binary_blob:
124-
if self.add_typedef:
125-
self.header.write(f"\ntypedef struct {{")
126-
else:
127-
self.header.write(f"\nstruct {self.model_struct_name} {{")
128-
for name, data in self.layer_dict.items():
129-
layer_type = data[0]
130-
self.header.write(f"\n {layer_type} {name};")
131-
if self.add_typedef:
132-
self.header.write(f"\n}} {self.model_struct_name};\n")
133-
else:
134-
self.header.write(f"\n}};\n")
135-
136-
init_prototype = f"int init_{self.model_struct_name.lower()}({self.model_struct_name} *model, const WeightArray *arrays)"
137-
self.header.write(f"\n{init_prototype};\n")
123+
if self.add_typedef:
124+
self.header.write(f"\ntypedef struct {{")
125+
else:
126+
self.header.write(f"\nstruct {self.model_struct_name} {{")
127+
for name, data in self.layer_dict.items():
128+
layer_type = data[0]
129+
self.header.write(f"\n {layer_type} {name};")
130+
if self.add_typedef:
131+
self.header.write(f"\n}} {self.model_struct_name};\n")
132+
else:
133+
self.header.write(f"\n}};\n")
134+
135+
init_prototype = f"int init_{self.model_struct_name.lower()}({self.model_struct_name} *model, const WeightArray *arrays)"
136+
self.header.write(f"\n{init_prototype};\n")
138137

139138
self.header.write(f"\n#endif /* {self.header_guard} */\n")
140139

141140
def _finalize_source(self):
142141

143-
if self.enable_binary_blob:
144-
# create weight array
145-
if len(set(self.weight_arrays)) != len(self.weight_arrays):
146-
raise ValueError("error: detected duplicates in weight arrays")
147-
self.source.write("\n#ifndef USE_WEIGHTS_FILE\n")
148-
self.source.write(f"const WeightArray {self.model_struct_name.lower()}_arrays[] = {{\n")
149-
for name in self.weight_arrays:
150-
self.source.write(f"#ifdef WEIGHTS_{name}_DEFINED\n")
151-
self.source.write(f' {{"{name}", WEIGHTS_{name}_TYPE, sizeof({name}), {name}}},\n')
152-
self.source.write(f"#endif\n")
153-
self.source.write(" {NULL, 0, 0, NULL}\n")
154-
self.source.write("};\n")
155-
156-
self.source.write("#endif /* USE_WEIGHTS_FILE */\n")
157-
158-
# create init function definition
159-
init_prototype = f"int init_{self.model_struct_name.lower()}({self.model_struct_name} *model, const WeightArray *arrays)"
160-
self.source.write("\n#ifndef DUMP_BINARY_WEIGHTS\n")
161-
self.source.write(f"{init_prototype} {{\n")
162-
for name, data in self.layer_dict.items():
163-
self.source.write(f" if ({data[1]}) return 1;\n")
164-
self.source.write(" return 0;\n")
165-
self.source.write("}\n")
166-
self.source.write("#endif /* DUMP_BINARY_WEIGHTS */\n")
142+
143+
# create weight array
144+
if len(set(self.weight_arrays)) != len(self.weight_arrays):
145+
raise ValueError("error: detected duplicates in weight arrays")
146+
if self.enable_binary_blob: self.source.write("\n#ifndef USE_WEIGHTS_FILE\n")
147+
self.source.write(f"const WeightArray {self.model_struct_name.lower()}_arrays[] = {{\n")
148+
for name in self.weight_arrays:
149+
self.source.write(f"#ifdef WEIGHTS_{name}_DEFINED\n")
150+
self.source.write(f' {{"{name}", WEIGHTS_{name}_TYPE, sizeof({name}), {name}}},\n')
151+
self.source.write(f"#endif\n")
152+
self.source.write(" {NULL, 0, 0, NULL}\n")
153+
self.source.write("};\n")
154+
155+
if self.enable_binary_blob: self.source.write("#endif /* USE_WEIGHTS_FILE */\n")
156+
157+
# create init function definition
158+
init_prototype = f"int init_{self.model_struct_name.lower()}({self.model_struct_name} *model, const WeightArray *arrays)"
159+
if self.enable_binary_blob: self.source.write("\n#ifndef DUMP_BINARY_WEIGHTS\n")
160+
self.source.write(f"{init_prototype} {{\n")
161+
for name, data in self.layer_dict.items():
162+
self.source.write(f" if ({data[1]}) return 1;\n")
163+
self.source.write(" return 0;\n")
164+
self.source.write("}\n")
165+
if self.enable_binary_blob:self.source.write("#endif /* DUMP_BINARY_WEIGHTS */\n")
167166

168167

169168
def close(self):

dnn/torch/weight-exchange/wexchange/c_export/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def print_vector(writer, vector, name, dtype='float', reshape_8x4=False, static=
5454
#ifndef USE_WEIGHTS_FILE
5555
'''
5656
)
57-
writer.weight_arrays.append(name)
57+
writer.weight_arrays.append(name)
5858

5959
if reshape_8x4:
6060
vector = vector.reshape((vector.shape[0]//4, 4, vector.shape[1]//8, 8))

0 commit comments

Comments
 (0)