Skip to content

Commit 5e98dfd

Browse files
author
jballe
committed
Creates fingerprint op.
PiperOrigin-RevId: 242570134
1 parent 0a0ec2f commit 5e98dfd

File tree

3 files changed

+178
-9
lines changed

3 files changed

+178
-9
lines changed

cc/kernels/range_coding_helper_kernels.cc

Lines changed: 64 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ limitations under the License.
2828
#include "tensorflow/core/lib/core/errors.h"
2929
#include "tensorflow/core/lib/core/threadpool.h"
3030
#include "tensorflow/core/lib/gtl/array_slice.h"
31+
#include "tensorflow/core/platform/fingerprint.h"
3132
#include "tensorflow/core/platform/logging.h"
3233
#include "tensorflow/core/platform/macros.h"
3334
#include "tensorflow/core/platform/types.h"
@@ -37,19 +38,20 @@ namespace {
3738
namespace gtl = tensorflow::gtl;
3839
namespace thread = tensorflow::thread;
3940
using tensorflow::DEVICE_CPU;
41+
using tensorflow::Fingerprint64;
42+
using tensorflow::int32;
43+
using tensorflow::int64;
4044
using tensorflow::OpKernel;
4145
using tensorflow::OpKernelConstruction;
4246
using tensorflow::OpKernelContext;
47+
using tensorflow::string;
4348
using tensorflow::Tensor;
4449
using tensorflow::TensorShape;
4550
using tensorflow::TensorShapeUtils;
46-
using tensorflow::errors::InvalidArgument;
47-
using tensorflow::int32;
48-
using tensorflow::int64;
49-
using tensorflow::string;
50-
using tensorflow::uint8;
5151
using tensorflow::uint32;
5252
using tensorflow::uint64;
53+
using tensorflow::uint8;
54+
using tensorflow::errors::InvalidArgument;
5355

5456
class PmfToCdfOp : public OpKernel {
5557
public:
@@ -208,5 +210,62 @@ class PmfToCdfOp : public OpKernel {
208210
REGISTER_KERNEL_BUILDER(Name("PmfToQuantizedCdf").Device(DEVICE_CPU),
209211
PmfToCdfOp);
210212

213+
class ArrayFingerprintOp : public tensorflow::OpKernel {
214+
public:
215+
using OpKernel::OpKernel;
216+
217+
void Compute(tensorflow::OpKernelContext* context) override {
218+
const Tensor& input = context->input(0);
219+
OP_REQUIRES(context, tensorflow::DataTypeCanUseMemcpy(input.dtype()),
220+
InvalidArgument("Data type not supported: ",
221+
tensorflow::DataTypeString(input.dtype())));
222+
223+
const int64 size =
224+
input.shape().num_elements() * tensorflow::DataTypeSize(input.dtype());
225+
auto data = input.bit_casted_shaped<char, 1>({size});
226+
227+
Tensor* output;
228+
OP_REQUIRES_OK(context,
229+
context->allocate_output(0, TensorShape{}, &output));
230+
231+
output->scalar<int64>()() =
232+
Fingerprint64({data.data(), static_cast<size_t>(data.size())});
233+
}
234+
};
235+
236+
REGISTER_KERNEL_BUILDER(Name("ArrayFingerprint").Device(tensorflow::DEVICE_CPU),
237+
ArrayFingerprintOp);
238+
239+
class CheckArrayFingerprintOp : public tensorflow::OpKernel {
240+
public:
241+
using OpKernel::OpKernel;
242+
243+
void Compute(tensorflow::OpKernelContext* context) override {
244+
const Tensor& input = context->input(0);
245+
const Tensor& fingerprint = context->input(1);
246+
OP_REQUIRES(context, tensorflow::DataTypeCanUseMemcpy(input.dtype()),
247+
InvalidArgument("Data type not supported: ",
248+
tensorflow::DataTypeString(input.dtype())));
249+
OP_REQUIRES(context, TensorShapeUtils::IsScalar(fingerprint.shape()),
250+
InvalidArgument("`fingerprint` should be a scalar"));
251+
252+
const int64 size =
253+
input.shape().num_elements() * tensorflow::DataTypeSize(input.dtype());
254+
auto data = input.bit_casted_shaped<char, 1>({size});
255+
256+
OP_REQUIRES(
257+
context,
258+
fingerprint.scalar<int64>()() ==
259+
Fingerprint64({data.data(), static_cast<size_t>(data.size())}),
260+
tensorflow::errors::DataLoss("Fingerprint mismatch"));
261+
262+
context->set_output(0, input);
263+
}
264+
};
265+
266+
REGISTER_KERNEL_BUILDER(
267+
Name("CheckArrayFingerprint").Device(tensorflow::DEVICE_CPU),
268+
CheckArrayFingerprintOp);
269+
211270
} // namespace
212271
} // namespace tensorflow_compression

cc/kernels/range_coding_helper_kernels_test.cc

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ using tensorflow::NodeDefBuilder;
4040
using tensorflow::OpsTestBase;
4141
using tensorflow::ShapeInferenceTestOp;
4242
using tensorflow::Tensor;
43+
using tensorflow::TensorShape;
4344
using tensorflow::TTypes;
4445

4546
class PmfToQuantizedCdfOpTest : public OpsTestBase {
@@ -151,6 +152,87 @@ TEST_F(PmfToQuantizedCdfOpTest, ShapeFn) {
151152
INFER_OK(op, "[3,4,5]", "[d0_0,d0_1,6]");
152153
}
153154

155+
class FingerprintOpTest : public tensorflow::OpsTestBase {
156+
protected:
157+
void MakeFingerprintOp(Tensor* tensor) {
158+
TF_ASSERT_OK(tensorflow::NodeDefBuilder("fingerprint", "ArrayFingerprint")
159+
.Input(tensorflow::FakeInput(tensor->dtype()))
160+
.Finalize(node_def()));
161+
TF_ASSERT_OK(InitOp());
162+
163+
inputs_.clear();
164+
inputs_.emplace_back(tensor);
165+
}
166+
167+
void MakeCheckFingerprintOp(Tensor* tensor, Tensor* fingerprint) {
168+
TF_ASSERT_OK(
169+
tensorflow::NodeDefBuilder("check_fingerprint", "CheckArrayFingerprint")
170+
.Input(tensorflow::FakeInput(tensor->dtype()))
171+
.Input(tensorflow::FakeInput(fingerprint->dtype()))
172+
.Finalize(node_def()));
173+
TF_ASSERT_OK(InitOp());
174+
175+
inputs_.clear();
176+
inputs_.emplace_back(tensor);
177+
inputs_.emplace_back(fingerprint);
178+
}
179+
};
180+
181+
TEST_F(FingerprintOpTest, Verify) {
182+
std::random_device rd;
183+
random::PhiloxRandom gen(rd(), rd());
184+
random::SimplePhilox rand(&gen);
185+
for (tensorflow::DataType dtype : tensorflow::kRealNumberTypes) {
186+
const int rank = rand.Uniform(4);
187+
188+
TensorShape shape;
189+
for (int i = 0; i < rank; ++i) {
190+
shape.AddDim(rand.Uniform(9) + 1);
191+
}
192+
193+
Tensor tensor(dtype, shape);
194+
195+
const int64 length = shape.num_elements() * tensorflow::DataTypeSize(dtype);
196+
auto buffer = tensor.bit_casted_shaped<char, 1>({length});
197+
buffer.setRandom();
198+
199+
MakeFingerprintOp(&tensor);
200+
TF_ASSERT_OK(RunOpKernel());
201+
202+
Tensor fingerprint = *GetOutput(0);
203+
204+
MakeCheckFingerprintOp(&tensor, &fingerprint);
205+
TF_ASSERT_OK(RunOpKernel());
206+
207+
// Change one byte in the buffer.
208+
const int64 pos = rand.Uniform(length);
209+
buffer(pos) = ~buffer(pos);
210+
211+
MakeCheckFingerprintOp(&tensor, &fingerprint);
212+
ASSERT_FALSE(RunOpKernel().ok());
213+
}
214+
}
215+
216+
TEST_F(FingerprintOpTest, FingerprintShapeFn) {
217+
tensorflow::ShapeInferenceTestOp op("ArrayFingerprint");
218+
219+
INFER_OK(op, "?", "[]");
220+
INFER_OK(op, "[]", "[]");
221+
INFER_OK(op, "[1]", "[]");
222+
INFER_OK(op, "[1,2]", "[]");
223+
INFER_OK(op, "[1,2,3]", "[]");
224+
}
225+
226+
TEST_F(FingerprintOpTest, CheckFingerprintShapeFn) {
227+
tensorflow::ShapeInferenceTestOp op("CheckArrayFingerprint");
228+
229+
INFER_OK(op, "?;?", "in0");
230+
INFER_OK(op, "[];?", "in0");
231+
INFER_OK(op, "[1,2];?", "in0");
232+
INFER_OK(op, "[1,2,3];?", "in0");
233+
INFER_ERROR("rank 0", op, "?;[1]");
234+
}
235+
154236
} // namespace
155237
} // namespace tensorflow_compression
156238

cc/ops/range_coding_ops.cc

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@ encoded: A range-coded scalar string.
8787
precision: The number of bits for probability quantization. Must be <= 16.
8888
)doc");
8989

90-
9190
REGISTER_OP("RangeDecode")
9291
.Input("encoded: string")
9392
.Input("shape: int32")
@@ -120,7 +119,6 @@ precision: The number of bits for probability quantization. Must be <= 16, and
120119
must match the precision used by RangeEncode that produced `encoded`.
121120
)doc");
122121

123-
124122
REGISTER_OP("UnboundedIndexRangeEncode")
125123
.Input("data: int32")
126124
.Input("index: int32")
@@ -198,7 +196,6 @@ overflow_width: The bit width of the variable-length overflow code. Must be <=
198196
precision.
199197
)doc");
200198

201-
202199
REGISTER_OP("UnboundedIndexRangeDecode")
203200
.Input("encoded: string")
204201
.Input("index: int32")
@@ -239,7 +236,6 @@ overflow_width: The bit width of the variable-length overflow code. Must be <=
239236
produced `encoded`.
240237
)doc");
241238

242-
243239
REGISTER_OP("PmfToQuantizedCdf")
244240
.Input("pmf: float")
245241
.Output("cdf: int32")
@@ -268,6 +264,38 @@ Note that the input PMF is pre-quantization. The input PMF is not normalized
268264
by this op prior to quantization. Therefore the user is responsible for
269265
normalizing PMF if necessary.
270266
)doc");
267+
268+
REGISTER_OP("ArrayFingerprint")
269+
.Input("input: T")
270+
.Output("fingerprint: int64")
271+
.Attr("T: realnumbertype")
272+
.SetShapeFn(tensorflow::shape_inference::ScalarShape)
273+
.Doc(R"doc(
274+
Produces fingerprint of the input data.
275+
276+
input: Tensor to be fingerprinted.
277+
fingerprint: Fingerprint value of input.
278+
)doc");
279+
280+
REGISTER_OP("CheckArrayFingerprint")
281+
.Input("input: T")
282+
.Input("fingerprint: int64")
283+
.Output("output: T")
284+
.Attr("T: realnumbertype")
285+
.SetShapeFn([](InferenceContext* c) {
286+
ShapeHandle unused;
287+
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
288+
c->set_output(0, c->input(0));
289+
return tensorflow::Status::OK();
290+
})
291+
.Doc(R"doc(
292+
Computes the fingerprint of `input` and checks the computed value against
293+
`fingerprint`. If the check fails, then this op returns an error status.
294+
295+
input: Tensor to be fingerprinted and checked.
296+
fingerprint: Fingerprint value to be checked against.
297+
output: The same as input.
298+
)doc");
271299
// clang-format on
272300

273301
} // namespace

0 commit comments

Comments
 (0)