Skip to content

Commit 9a2e72e

Browse files
GooglerSung Jin Hwang
authored andcommitted
Internal change
PiperOrigin-RevId: 256796501 Change-Id: Ide683b4eb0bf82e73b9fd2decf89f90993b59c5c
1 parent 504f159 commit 9a2e72e

File tree

7 files changed

+206
-117
lines changed

7 files changed

+206
-117
lines changed

tensorflow_compression/cc/kernels/range_coding_helper_kernels.cc

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -232,37 +232,5 @@ class ArrayFingerprintOp : public tensorflow::OpKernel {
232232

233233
REGISTER_KERNEL_BUILDER(Name("ArrayFingerprint").Device(tensorflow::DEVICE_CPU),
234234
ArrayFingerprintOp);
235-
236-
class CheckArrayFingerprintOp : public tensorflow::OpKernel {
237-
public:
238-
using OpKernel::OpKernel;
239-
240-
void Compute(tensorflow::OpKernelContext* context) override {
241-
const Tensor& input = context->input(0);
242-
const Tensor& fingerprint = context->input(1);
243-
OP_REQUIRES(context, tensorflow::DataTypeCanUseMemcpy(input.dtype()),
244-
InvalidArgument("Data type not supported: ",
245-
tensorflow::DataTypeString(input.dtype())));
246-
OP_REQUIRES(context, TensorShapeUtils::IsScalar(fingerprint.shape()),
247-
InvalidArgument("`fingerprint` should be a scalar"));
248-
249-
const int64 size =
250-
input.shape().num_elements() * tensorflow::DataTypeSize(input.dtype());
251-
auto data = input.bit_casted_shaped<char, 1>({size});
252-
253-
OP_REQUIRES(context,
254-
static_cast<uint64>(fingerprint.scalar<int64>()()) ==
255-
::util::Fingerprint64(data.data(),
256-
static_cast<size_t>(data.size())),
257-
tensorflow::errors::DataLoss("Fingerprint mismatch"));
258-
259-
context->set_output(0, input);
260-
}
261-
};
262-
263-
REGISTER_KERNEL_BUILDER(
264-
Name("CheckArrayFingerprint").Device(tensorflow::DEVICE_CPU),
265-
CheckArrayFingerprintOp);
266-
267235
} // namespace
268236
} // namespace tensorflow_compression

tensorflow_compression/cc/kernels/range_coding_helper_kernels_test.cc

Lines changed: 6 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -161,19 +161,6 @@ class FingerprintOpTest : public tensorflow::OpsTestBase {
161161
inputs_.clear();
162162
inputs_.emplace_back(tensor);
163163
}
164-
165-
void MakeCheckFingerprintOp(Tensor* tensor, Tensor* fingerprint) {
166-
TF_ASSERT_OK(
167-
tensorflow::NodeDefBuilder("check_fingerprint", "CheckArrayFingerprint")
168-
.Input(tensorflow::FakeInput(tensor->dtype()))
169-
.Input(tensorflow::FakeInput(fingerprint->dtype()))
170-
.Finalize(node_def()));
171-
TF_ASSERT_OK(InitOp());
172-
173-
inputs_.clear();
174-
inputs_.emplace_back(tensor);
175-
inputs_.emplace_back(fingerprint);
176-
}
177164
};
178165

179166
TEST_F(FingerprintOpTest, Verify) {
@@ -196,18 +183,17 @@ TEST_F(FingerprintOpTest, Verify) {
196183

197184
MakeFingerprintOp(&tensor);
198185
TF_ASSERT_OK(RunOpKernel());
199-
200-
Tensor fingerprint = *GetOutput(0);
201-
202-
MakeCheckFingerprintOp(&tensor, &fingerprint);
203-
TF_ASSERT_OK(RunOpKernel());
186+
Tensor fingerprint0 = *GetOutput(0);
204187

205188
// Change one byte in the buffer.
206189
const int64 pos = rand.Uniform(length);
207190
buffer(pos) = ~buffer(pos);
208191

209-
MakeCheckFingerprintOp(&tensor, &fingerprint);
210-
ASSERT_FALSE(RunOpKernel().ok());
192+
MakeFingerprintOp(&tensor);
193+
TF_ASSERT_OK(RunOpKernel());
194+
Tensor fingerprint1 = *GetOutput(0);
195+
196+
EXPECT_NE(fingerprint0.tensor_data(), fingerprint1.tensor_data());
211197
}
212198
}
213199

@@ -220,17 +206,6 @@ TEST_F(FingerprintOpTest, FingerprintShapeFn) {
220206
INFER_OK(op, "[1,2]", "[]");
221207
INFER_OK(op, "[1,2,3]", "[]");
222208
}
223-
224-
TEST_F(FingerprintOpTest, CheckFingerprintShapeFn) {
225-
tensorflow::ShapeInferenceTestOp op("CheckArrayFingerprint");
226-
227-
INFER_OK(op, "?;?", "in0");
228-
INFER_OK(op, "[];?", "in0");
229-
INFER_OK(op, "[1,2];?", "in0");
230-
INFER_OK(op, "[1,2,3];?", "in0");
231-
INFER_ERROR("rank 0", op, "?;[1]");
232-
}
233-
234209
} // namespace
235210
} // namespace tensorflow_compression
236211

tensorflow_compression/cc/ops/range_coding_ops.cc

Lines changed: 5 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ using tensorflow::shape_inference::DimensionHandle;
2626
using tensorflow::shape_inference::InferenceContext;
2727
using tensorflow::shape_inference::ShapeHandle;
2828

29-
// clang-format off
3029
REGISTER_OP("RangeEncode")
3130
.Input("data: int16")
3231
.Input("cdf: int32")
@@ -96,7 +95,7 @@ REGISTER_OP("RangeDecode")
9695
.Output("decoded: int16")
9796
.Attr("precision: int >= 1")
9897
.Attr("debug_level: int = 1")
99-
.SetShapeFn([] (InferenceContext* c) {
98+
.SetShapeFn([](InferenceContext* c) {
10099
ShapeHandle out;
101100
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &out));
102101
c->set_output(0, out);
@@ -133,7 +132,7 @@ REGISTER_OP("UnboundedIndexRangeEncode")
133132
.Attr("precision: int >= 1")
134133
.Attr("overflow_width: int >= 1")
135134
.Attr("debug_level: int = 1")
136-
.SetShapeFn(tensorflow::shape_inference::ScalarShape)
135+
.SetShapeFn(shape_inference::ScalarShape)
137136
.Doc(R"doc(
138137
Range encodes unbounded integer `data` using an indexed probability table.
139138
@@ -211,7 +210,7 @@ REGISTER_OP("UnboundedIndexRangeDecode")
211210
.Attr("precision: int >= 1")
212211
.Attr("overflow_width: int >= 1")
213212
.Attr("debug_level: int = 1")
214-
.SetShapeFn([] (InferenceContext* c) {
213+
.SetShapeFn([](InferenceContext* c) {
215214
c->set_output(0, c->input(1));
216215
return Status::OK();
217216
})
@@ -246,7 +245,7 @@ REGISTER_OP("PmfToQuantizedCdf")
246245
.Input("pmf: float")
247246
.Output("cdf: int32")
248247
.Attr("precision: int >= 1")
249-
.SetShapeFn([] (InferenceContext* c) {
248+
.SetShapeFn([](InferenceContext* c) {
250249
ShapeHandle in;
251250
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &in));
252251
DimensionHandle last;
@@ -275,34 +274,12 @@ REGISTER_OP("ArrayFingerprint")
275274
.Input("input: T")
276275
.Output("fingerprint: int64")
277276
.Attr("T: realnumbertype")
278-
.SetShapeFn(tensorflow::shape_inference::ScalarShape)
277+
.SetShapeFn(shape_inference::ScalarShape)
279278
.Doc(R"doc(
280279
Produces fingerprint of the input data.
281280
282281
input: Tensor to be fingerprinted.
283282
fingerprint: Fingerprint value of input.
284283
)doc");
285-
286-
REGISTER_OP("CheckArrayFingerprint")
287-
.Input("input: T")
288-
.Input("fingerprint: int64")
289-
.Output("output: T")
290-
.Attr("T: realnumbertype")
291-
.SetShapeFn([](InferenceContext* c) {
292-
ShapeHandle unused;
293-
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
294-
c->set_output(0, c->input(0));
295-
return tensorflow::Status::OK();
296-
})
297-
.Doc(R"doc(
298-
Computes the fingerprint of `input` and checks the computed value against
299-
`fingerprint`. If the check fails, then this op returns an error status.
300-
301-
input: Tensor to be fingerprinted and checked.
302-
fingerprint: Fingerprint value to be checked against.
303-
output: The same as input.
304-
)doc");
305-
// clang-format on
306-
307284
} // namespace
308285
} // namespace tensorflow_compression

tensorflow_compression/python/layers/entropy_models.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -633,7 +633,10 @@ def cdf_initializer(shape, dtype=None, partition_info=None):
633633
# here, or the variable will return the wrong dynamic shape later. A
634634
# placeholder with default gets the trick done (see initializer above).
635635
quantized_cdf = self.add_variable(
636-
"quantized_cdf", shape=None, dtype=tf.int32, trainable=False,
636+
"quantized_cdf",
637+
shape=(channels, None),
638+
dtype=tf.int32,
639+
trainable=False,
637640
initializer=cdf_initializer)
638641
cdf_length = self.add_variable(
639642
"cdf_length", shape=(channels,), dtype=tf.int32, trainable=False,

tensorflow_compression/python/layers/parameterizers.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -49,21 +49,30 @@ class Parameterizer(object):
4949

5050

5151
class StaticParameterizer(Parameterizer):
52-
"""A parameterizer that always returns a constant tensor.
52+
"""A parameterizer that returns a non-variable.
5353
54-
No variables are created, hence the parameter never changes.
54+
No variables are created, and `getter` is ignored. If `value` is a `Tensor`,
55+
the parameter can depend on some other computation. Otherwise, it never
56+
changes.
5557
5658
Args:
57-
initializer: An initializer object which will be called to produce the
58-
static parameter.
59+
value: Either a constant or `Tensor` value, or a callable which returns such
60+
a thing given a shape and dtype argument (for example, an initializer).
5961
"""
6062

61-
def __init__(self, initializer):
62-
self.initializer = initializer
63+
def __init__(self, value):
64+
self.value = value
6365

6466
def __call__(self, getter, name, shape, dtype, initializer, regularizer=None):
65-
del getter, name, initializer, regularizer # unused
66-
return self.initializer(shape, dtype)
67+
del getter, name, initializer # unused
68+
if regularizer is not None:
69+
raise NotImplementedError("Regularizers are not currently supported for "
70+
"static parameterizers.")
71+
if callable(self.value):
72+
# Treat value as initializer.
73+
return self.value(shape, dtype=dtype)
74+
else:
75+
return self.value
6776

6877

6978
class RDFTParameterizer(Parameterizer):
@@ -92,35 +101,41 @@ def __call__(self, getter, name, shape, dtype, initializer, regularizer=None):
92101
size = var_shape[0]
93102
for s in var_shape[1:-2]:
94103
size *= s
95-
irdft_matrix = spectral_ops.irdft_matrix(var_shape[:-2], dtype=var_dtype)
96104
if self.dc:
97105
rdft_shape = (size, var_shape[-2] * var_shape[-1])
98106
else:
99-
irdft_matrix = irdft_matrix[:, 1:]
100107
rdft_shape = (size - 1, var_shape[-2] * var_shape[-1])
101108
rdft_dtype = var_dtype
102109
rdft_name = name + "_rdft"
103110

104111
def rdft_initializer(shape, dtype=None, partition_info=None):
112+
"""Initializer wrapper."""
105113
assert tuple(shape) == rdft_shape, shape
106114
assert dtype == rdft_dtype, dtype
107115
init = initializer(
108116
var_shape, dtype=var_dtype, partition_info=partition_info)
109117
init = tf.reshape(init, (-1, rdft_shape[-1]))
118+
irdft_matrix = spectral_ops.irdft_matrix(var_shape[:-2], dtype=var_dtype)
119+
if not self.dc:
120+
irdft_matrix = irdft_matrix[:, 1:]
110121
init = tf.linalg.matmul(irdft_matrix, init, transpose_a=True)
111122
return init
112123

113124
def reparam(rdft):
125+
irdft_matrix = spectral_ops.irdft_matrix(var_shape[:-2], dtype=var_dtype)
126+
if not self.dc:
127+
irdft_matrix = irdft_matrix[:, 1:]
114128
var = tf.linalg.matmul(irdft_matrix, rdft)
115129
var = tf.reshape(var, var_shape)
116130
return var
117131

132+
reparam_regularizer = None
118133
if regularizer is not None:
119-
regularizer = lambda rdft: regularizer(reparam(rdft))
134+
reparam_regularizer = lambda rdft: regularizer(reparam(rdft))
120135

121136
rdft = getter(
122137
name=rdft_name, shape=rdft_shape, dtype=rdft_dtype,
123-
initializer=rdft_initializer, regularizer=regularizer)
138+
initializer=rdft_initializer, regularizer=reparam_regularizer)
124139
return reparam(rdft)
125140

126141

@@ -165,10 +180,11 @@ def reparam(var):
165180
var = tf.math.square(var) - pedestal
166181
return var
167182

183+
reparam_regularizer = None
168184
if regularizer is not None:
169-
regularizer = lambda var: regularizer(reparam(var))
185+
reparam_regularizer = lambda var: regularizer(reparam(var))
170186

171187
var = getter(
172188
name=reparam_name, shape=shape, dtype=dtype,
173-
initializer=reparam_initializer, regularizer=regularizer)
189+
initializer=reparam_initializer, regularizer=reparam_regularizer)
174190
return reparam(var)

0 commit comments

Comments
 (0)