Skip to content

Commit 0e564fd

Browse files
committed
More fixes for the non-blob weight export
1 parent 1711e97 commit 0e564fd

File tree

2 files changed

+3
-4
lines changed

2 files changed

+3
-4
lines changed

dnn/torch/lossgen/export_lossgen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def c_export(args, model):
5252

5353
message = f"Auto generated from checkpoint {os.path.basename(args.checkpoint)}"
5454

55-
writer = CWriter(os.path.join(args.output_dir, "lossgen_data"), message=message, model_struct_name='LossGen', enable_binary_blob=False)
55+
writer = CWriter(os.path.join(args.output_dir, "lossgen_data"), message=message, model_struct_name='LossGen', enable_binary_blob=False, add_typedef=True)
5656
writer.header.write(
5757
f"""
5858
#include "opus_types.h"

dnn/torch/weight-exchange/wexchange/c_export/common.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,7 @@ def print_vector(writer, vector, name, dtype='float', reshape_8x4=False, static=
6464

6565
if debug_float:
6666
f.write('#ifndef DISABLE_DEBUG_FLOAT\n')
67-
if binary_blob:
68-
f.write(
67+
f.write(
6968
f'''
7069
#define WEIGHTS_{name}_DEFINED
7170
#define WEIGHTS_{name}_TYPE WEIGHT_TYPE_{dtype_suffix[dtype]}
@@ -384,4 +383,4 @@ def print_tconv1d_layer(writer : CWriter,
384383
writer.header.write(f"\n#define {name.upper()}_KERNEL_SIZE {kernel_size}\n")
385384
writer.header.write(f"\n#define {name.upper()}_STRIDE {stride}\n")
386385
writer.header.write(f"\n#define {name.upper()}_IN_CHANNELS {in_channels}\n")
387-
writer.header.write(f"\n#define {name.upper()}_OUT_CHANNELS {out_channels}\n")
386+
writer.header.write(f"\n#define {name.upper()}_OUT_CHANNELS {out_channels}\n")

0 commit comments

Comments
 (0)