Skip to content

Commit 61602f0

Browse files
Johannes Ballécopybara-github
authored andcommitted
Extends TFCI format to models that support multiple RD tradeoffs with one transform.
PiperOrigin-RevId: 426021580 Change-Id: Ia7bff1d294ac417109215d2450e5c2c903fb2add
1 parent 0646fb1 commit 61602f0

File tree

2 files changed

+61
-11
lines changed

2 files changed

+61
-11
lines changed

models/tfci.py

Lines changed: 57 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -96,17 +96,34 @@ def instantiate_model_signature(model, signature, inputs=None, outputs=None):
9696
return wrapped_import.prune(inputs, outputs)
9797

9898

99-
def compress_image(model, input_image):
99+
def compress_image(model, input_image, rd_parameter=None):
100100
"""Compresses an image tensor into a bitstring."""
101101
sender = instantiate_model_signature(model, "sender")
102-
tensors = sender(input_image)
102+
if len(sender.inputs) == 1:
103+
if rd_parameter is not None:
104+
raise ValueError("This model doesn't expect an RD parameter.")
105+
tensors = sender(input_image)
106+
elif len(sender.inputs) == 2:
107+
if rd_parameter is None:
108+
raise ValueError("This model expects an RD parameter.")
109+
rd_parameter = tf.constant(rd_parameter, dtype=sender.inputs[1].dtype)
110+
tensors = sender(input_image, rd_parameter)
111+
# Find RD parameter and expand it to a 1D tensor so it fits into the
112+
# PackedTensors format.
113+
for i, t in enumerate(tensors):
114+
if t.dtype.is_floating and t.shape.rank == 0:
115+
tensors[i] = tf.expand_dims(t, 0)
116+
else:
117+
raise RuntimeError("Unexpected model signature.")
103118
packed = tfc.PackedTensors()
104119
packed.model = model
105120
packed.pack(tensors)
106121
return packed.string
107122

108123

109-
def compress(model, input_file, output_file, target_bpp=None, bpp_strict=False):
124+
def compress(model, input_file, output_file,
125+
rd_parameter=None, rd_parameter_tolerance=None,
126+
target_bpp=None, bpp_strict=False):
110127
"""Compresses a PNG file to a TFCI file."""
111128
if not output_file:
112129
output_file = input_file + ".tfci"
@@ -117,21 +134,35 @@ def compress(model, input_file, output_file, target_bpp=None, bpp_strict=False):
117134

118135
if not target_bpp:
119136
# Just compress with a specific model.
120-
bitstring = compress_image(model, input_image)
137+
bitstring = compress_image(model, input_image, rd_parameter=rd_parameter)
121138
else:
122139
# Get model list.
123140
models = load_cached(model + ".models")
124141
models = models.decode("ascii").split()
125142

126-
# Do a binary search over all RD points.
127-
lower = -1
128-
upper = len(models)
143+
try:
144+
lower, upper = [float(m) for m in models]
145+
use_rd_parameter = True
146+
except ValueError:
147+
lower = -1
148+
upper = len(models)
149+
use_rd_parameter = False
150+
151+
# Do a binary search over RD points.
129152
bpp = None
130153
best_bitstring = None
131154
best_bpp = None
132-
while bpp != target_bpp and upper - lower > 1:
133-
i = (upper + lower) // 2
134-
bitstring = compress_image(models[i], input_image)
155+
while bpp != target_bpp:
156+
if use_rd_parameter:
157+
if upper - lower <= rd_parameter_tolerance:
158+
break
159+
i = (upper + lower) / 2
160+
bitstring = compress_image(model, input_image, rd_parameter=i)
161+
else:
162+
if upper - lower < 2:
163+
break
164+
i = (upper + lower) // 2
165+
bitstring = compress_image(models[i], input_image)
135166
bpp = 8 * len(bitstring) / num_pixels
136167
is_admissible = bpp <= target_bpp or not bpp_strict
137168
is_better = (best_bpp is None or
@@ -162,6 +193,10 @@ def decompress(input_file, output_file):
162193
packed = tfc.PackedTensors(f.read())
163194
receiver = instantiate_model_signature(packed.model, "receiver")
164195
tensors = packed.unpack([t.dtype for t in receiver.inputs])
196+
# Find potential RD parameter and turn it back into a scalar.
197+
for i, t in enumerate(tensors):
198+
if t.dtype.is_floating and t.shape == (1,):
199+
tensors[i] = tf.squeeze(t, 0)
165200
output_image, = receiver(*tensors)
166201
write_png(output_file, output_image)
167202

@@ -247,7 +282,17 @@ def parse_args(argv):
247282
"'target_bpp' is provided, don't specify the index at the end of "
248283
"the model identifier.")
249284
compress_cmd.add_argument(
250-
"--target_bpp", type=float,
285+
"--rd_parameter", "-r", type=float,
286+
help="Rate-distortion parameter (for some models). Ignored if "
287+
"'target_bpp' is set.")
288+
compress_cmd.add_argument(
289+
"--rd_parameter_tolerance", type=float,
290+
default=2 ** -4,
291+
help="Tolerance for rate-distortion parameter. Only used if 'target_bpp' "
292+
"is set for some models, to determine when to stop the binary "
293+
"search.")
294+
compress_cmd.add_argument(
295+
"--target_bpp", "-b", type=float,
251296
help="Target bits per pixel. If provided, a binary search is used to try "
252297
"to match the given bpp as close as possible. In this case, don't "
253298
"specify the index at the end of the model identifier. It will be "
@@ -323,6 +368,7 @@ def main(args):
323368
# Invoke subcommand.
324369
if args.command == "compress":
325370
compress(args.model, args.input_file, args.output_file,
371+
args.rd_parameter, args.rd_parameter_tolerance,
326372
args.target_bpp, args.bpp_strict)
327373
elif args.command == "decompress":
328374
decompress(args.input_file, args.output_file)

tensorflow_compression/python/util/packed_tensors.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ def pack(self, tensors):
7171
raise RuntimeError(f"Unexpected tensor rank: {tensor.shape.rank}.")
7272
if tensor.dtype.is_integer:
7373
feature.int64_list.value[:] = tensor.numpy()
74+
elif tensor.dtype.is_floating:
75+
feature.float_list.value[:] = tensor.numpy()
7476
elif tensor.dtype == tf.string:
7577
feature.bytes_list.value[:] = tensor.numpy()
7678
else:
@@ -89,6 +91,8 @@ def unpack(self, dtypes):
8991
feature = self._example.features.feature[chr(i + 1)]
9092
if dtype.is_integer:
9193
tensors.append(tf.constant(feature.int64_list.value, dtype=dtype))
94+
elif dtype.is_floating:
95+
tensors.append(tf.constant(feature.float_list.value, dtype=dtype))
9296
elif dtype == tf.string:
9397
tensors.append(tf.constant(feature.bytes_list.value, dtype=dtype))
9498
else:

0 commit comments

Comments
 (0)