@@ -96,17 +96,34 @@ def instantiate_model_signature(model, signature, inputs=None, outputs=None):
96
96
return wrapped_import .prune (inputs , outputs )
97
97
98
98
99
- def compress_image (model , input_image ):
99
+ def compress_image (model , input_image , rd_parameter = None ):
100
100
"""Compresses an image tensor into a bitstring."""
101
101
sender = instantiate_model_signature (model , "sender" )
102
- tensors = sender (input_image )
102
+ if len (sender .inputs ) == 1 :
103
+ if rd_parameter is not None :
104
+ raise ValueError ("This model doesn't expect an RD parameter." )
105
+ tensors = sender (input_image )
106
+ elif len (sender .inputs ) == 2 :
107
+ if rd_parameter is None :
108
+ raise ValueError ("This model expects an RD parameter." )
109
+ rd_parameter = tf .constant (rd_parameter , dtype = sender .inputs [1 ].dtype )
110
+ tensors = sender (input_image , rd_parameter )
111
+ # Find RD parameter and expand it to a 1D tensor so it fits into the
112
+ # PackedTensors format.
113
+ for i , t in enumerate (tensors ):
114
+ if t .dtype .is_floating and t .shape .rank == 0 :
115
+ tensors [i ] = tf .expand_dims (t , 0 )
116
+ else :
117
+ raise RuntimeError ("Unexpected model signature." )
103
118
packed = tfc .PackedTensors ()
104
119
packed .model = model
105
120
packed .pack (tensors )
106
121
return packed .string
107
122
108
123
109
- def compress (model , input_file , output_file , target_bpp = None , bpp_strict = False ):
124
+ def compress (model , input_file , output_file ,
125
+ rd_parameter = None , rd_parameter_tolerance = None ,
126
+ target_bpp = None , bpp_strict = False ):
110
127
"""Compresses a PNG file to a TFCI file."""
111
128
if not output_file :
112
129
output_file = input_file + ".tfci"
@@ -117,21 +134,35 @@ def compress(model, input_file, output_file, target_bpp=None, bpp_strict=False):
117
134
118
135
if not target_bpp :
119
136
# Just compress with a specific model.
120
- bitstring = compress_image (model , input_image )
137
+ bitstring = compress_image (model , input_image , rd_parameter = rd_parameter )
121
138
else :
122
139
# Get model list.
123
140
models = load_cached (model + ".models" )
124
141
models = models .decode ("ascii" ).split ()
125
142
126
- # Do a binary search over all RD points.
127
- lower = - 1
128
- upper = len (models )
143
+ try :
144
+ lower , upper = [float (m ) for m in models ]
145
+ use_rd_parameter = True
146
+ except ValueError :
147
+ lower = - 1
148
+ upper = len (models )
149
+ use_rd_parameter = False
150
+
151
+ # Do a binary search over RD points.
129
152
bpp = None
130
153
best_bitstring = None
131
154
best_bpp = None
132
- while bpp != target_bpp and upper - lower > 1 :
133
- i = (upper + lower ) // 2
134
- bitstring = compress_image (models [i ], input_image )
155
+ while bpp != target_bpp :
156
+ if use_rd_parameter :
157
+ if upper - lower <= rd_parameter_tolerance :
158
+ break
159
+ i = (upper + lower ) / 2
160
+ bitstring = compress_image (model , input_image , rd_parameter = i )
161
+ else :
162
+ if upper - lower < 2 :
163
+ break
164
+ i = (upper + lower ) // 2
165
+ bitstring = compress_image (models [i ], input_image )
135
166
bpp = 8 * len (bitstring ) / num_pixels
136
167
is_admissible = bpp <= target_bpp or not bpp_strict
137
168
is_better = (best_bpp is None or
@@ -162,6 +193,10 @@ def decompress(input_file, output_file):
162
193
packed = tfc .PackedTensors (f .read ())
163
194
receiver = instantiate_model_signature (packed .model , "receiver" )
164
195
tensors = packed .unpack ([t .dtype for t in receiver .inputs ])
196
+ # Find potential RD parameter and turn it back into a scalar.
197
+ for i , t in enumerate (tensors ):
198
+ if t .dtype .is_floating and t .shape == (1 ,):
199
+ tensors [i ] = tf .squeeze (t , 0 )
165
200
output_image , = receiver (* tensors )
166
201
write_png (output_file , output_image )
167
202
@@ -247,7 +282,17 @@ def parse_args(argv):
247
282
"'target_bpp' is provided, don't specify the index at the end of "
248
283
"the model identifier." )
249
284
compress_cmd .add_argument (
250
- "--target_bpp" , type = float ,
285
+ "--rd_parameter" , "-r" , type = float ,
286
+ help = "Rate-distortion parameter (for some models). Ignored if "
287
+ "'target_bpp' is set." )
288
+ compress_cmd .add_argument (
289
+ "--rd_parameter_tolerance" , type = float ,
290
+ default = 2 ** - 4 ,
291
+ help = "Tolerance for rate-distortion parameter. Only used if 'target_bpp' "
292
+ "is set for some models, to determine when to stop the binary "
293
+ "search." )
294
+ compress_cmd .add_argument (
295
+ "--target_bpp" , "-b" , type = float ,
251
296
help = "Target bits per pixel. If provided, a binary search is used to try "
252
297
"to match the given bpp as close as possible. In this case, don't "
253
298
"specify the index at the end of the model identifier. It will be "
@@ -323,6 +368,7 @@ def main(args):
323
368
# Invoke subcommand.
324
369
if args .command == "compress" :
325
370
compress (args .model , args .input_file , args .output_file ,
371
+ args .rd_parameter , args .rd_parameter_tolerance ,
326
372
args .target_bpp , args .bpp_strict )
327
373
elif args .command == "decompress" :
328
374
decompress (args .input_file , args .output_file )
0 commit comments