Skip to content

Commit 99864d3

Browse files
Johannes Ballécopybara-github
authored andcommitted
Adds op for stochastic rounding.
PiperOrigin-RevId: 492303397 Change-Id: I3d98126948c75e62ca8aed2b85947f530ae03953
1 parent f1afde6 commit 99864d3

File tree

5 files changed

+243
-0
lines changed

5 files changed

+243
-0
lines changed
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
/* Copyright 2022 Google LLC. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
#include <cmath>
16+
#include <cstdint>
17+
#include <random>
18+
19+
#include "tensorflow/core/framework/op_kernel.h"
20+
#include "tensorflow/core/framework/tensor.h"
21+
#include "tensorflow/core/framework/types.h"
22+
#include "tensorflow/core/platform/status.h"
23+
24+
namespace tensorflow_compression {
25+
namespace {
26+
namespace errors = tensorflow::errors;
27+
using tensorflow::DEVICE_CPU;
28+
using tensorflow::OpKernel;
29+
using tensorflow::OpKernelConstruction;
30+
using tensorflow::OpKernelContext;
31+
using tensorflow::Tensor;
32+
33+
// Xoroshiro256+ algorithm, adapted from
34+
// https://prng.di.unimi.it/xoshiro256plus.c
35+
inline uint64_t next_random(uint64_t* state) {
36+
const uint64_t result = state[0] + state[3];
37+
const uint64_t t = state[1] << 17;
38+
state[2] ^= state[0];
39+
state[3] ^= state[1];
40+
state[1] ^= state[2];
41+
state[0] ^= state[3];
42+
state[2] ^= t;
43+
state[3] = (state[3] << 45) | (state[3] >> (64 - 45));
44+
return result;
45+
}
46+
47+
template <typename T>
48+
class StochasticRoundOp : public OpKernel {
49+
public:
50+
explicit StochasticRoundOp(OpKernelConstruction* context)
51+
: OpKernel(context) {}
52+
53+
void Compute(OpKernelContext* context) override {
54+
const Tensor& inputs_tensor = context->input(0);
55+
auto inputs = inputs_tensor.flat<T>();
56+
57+
OP_REQUIRES(context, context->input(1).dims() == 0,
58+
errors::InvalidArgument("step_size must be a scalar."));
59+
const float step_size = context->input(1).scalar<float>()();
60+
61+
auto seed = context->input(2).flat<int32_t>();
62+
63+
Tensor* outputs_tensor;
64+
OP_REQUIRES_OK(context, context->allocate_output(0, inputs_tensor.shape(),
65+
&outputs_tensor));
66+
auto outputs = outputs_tensor->flat<int32_t>();
67+
68+
uint64_t random_state[4];
69+
70+
if (seed.size()) {
71+
std::seed_seq seq(seed.data(), seed.data() + seed.size());
72+
seq.generate(reinterpret_cast<uint32_t*>(random_state),
73+
reinterpret_cast<uint32_t*>(random_state + 4));
74+
} else {
75+
// Seed the random state from system clock, in a best-effort fashion.
76+
uint64_t seed =
77+
std::chrono::high_resolution_clock::now().time_since_epoch().count();
78+
std::seed_seq seq{seed, seed >> 32};
79+
seq.generate(reinterpret_cast<uint32_t*>(random_state),
80+
reinterpret_cast<uint32_t*>(random_state + 4));
81+
}
82+
83+
for (int64_t i = 0; i < inputs.size(); ++i) {
84+
// Promote 16-bit types to 32 bit.
85+
float number = static_cast<float>(inputs(i)) / step_size;
86+
float integral = std::floor(number);
87+
outputs(i) = integral;
88+
// Regardless of T, comparing in float32 is accurate enough here.
89+
float fractional = number - integral;
90+
float random =
91+
(next_random(random_state) >> 40) * 0x1.0p-24f; // from [0, 1)
92+
if (random < fractional) {
93+
++outputs(i);
94+
}
95+
}
96+
}
97+
};
98+
99+
REGISTER_KERNEL_BUILDER(Name("StochasticRound")
100+
.Device(DEVICE_CPU)
101+
.TypeConstraint<tensorflow::bfloat16>("T"),
102+
StochasticRoundOp<tensorflow::bfloat16>);
103+
REGISTER_KERNEL_BUILDER(
104+
Name("StochasticRound").Device(DEVICE_CPU).TypeConstraint<Eigen::half>("T"),
105+
StochasticRoundOp<Eigen::half>);
106+
REGISTER_KERNEL_BUILDER(
107+
Name("StochasticRound").Device(DEVICE_CPU).TypeConstraint<float>("T"),
108+
StochasticRoundOp<float>);
109+
110+
} // namespace
111+
} // namespace tensorflow_compression
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/* Copyright 2022 Google LLC. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
#include "tensorflow/core/framework/common_shape_fns.h"
16+
#include "tensorflow/core/framework/op.h"
17+
18+
namespace tensorflow_compression {
19+
namespace {
20+
21+
REGISTER_OP("StochasticRound")
22+
.Attr("T: {bfloat16, float16, float32}")
23+
.Input("inputs: T")
24+
.Input("step_size: float32")
25+
.Input("seed: int32")
26+
.Output("outputs: int32")
27+
.SetShapeFn(tensorflow::shape_inference::UnchangedShape)
28+
.Doc(R"doc(
29+
Rounds `inputs / step_size` stochastically.
30+
31+
This op computes the elementwise function:
32+
33+
output = {
34+
floor(x) with prob. p = x - floor(x)
35+
floor(x) + 1 with prob. 1 - p
36+
}
37+
where x = input / step_size.
38+
39+
inputs: Floating point tensor to be rounded.
40+
step_size: Scalar tensor. Step size for rounding.
41+
seed: Arbitrary shape tensor. Seed for random number generator. If it has no
42+
elements, seeding is attempted from system time.
43+
outputs: Integer tensor of same shape as `inputs`, containing rounded values.
44+
)doc");
45+
46+
} // namespace
47+
} // namespace tensorflow_compression

tensorflow_compression/python/ops/BUILD

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,12 @@ py_test(
6363
deps = [":round_ops"],
6464
)
6565

66+
py_test(
67+
name = "quantization_ops_test",
68+
srcs = ["quantization_ops_test.py"],
69+
deps = [":gen_ops"],
70+
)
71+
6672
filegroup(
6773
name = "py_src",
6874
srcs = glob(["*.py"]),

tensorflow_compression/python/ops/gen_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,5 +34,6 @@
3434
"pmf_to_quantized_cdf",
3535
"run_length_gamma_decode",
3636
"run_length_gamma_encode",
37+
"stochastic_round",
3738
]
3839
# pylint:enable=undefined-all-variable
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright 2022 Google LLC. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Quantization tests."""
16+
17+
import time
18+
from absl.testing import parameterized
19+
import tensorflow as tf
20+
from tensorflow_compression.python.ops import gen_ops
21+
22+
23+
class QuantizationOpsTest(tf.test.TestCase, parameterized.TestCase):
24+
"""Python test for quantization ops."""
25+
26+
@parameterized.parameters(tf.bfloat16, tf.float16, tf.float32)
27+
def test_difference_is_at_most_one(self, dtype):
28+
values = tf.random.uniform((100,), -100., 100., dtype=dtype)
29+
rounded = gen_ops.stochastic_round(values, 1., ())
30+
self.assertEqual(rounded.dtype, tf.int32)
31+
self.assertAllClose(values, rounded, atol=1, rtol=0)
32+
33+
def test_identical_seed_yields_identical_output(self):
34+
values = tf.random.uniform((100,), -100., 100., dtype=tf.float32)
35+
rounded1 = gen_ops.stochastic_round(values, 1., (123, 456))
36+
self.assertEqual(rounded1.dtype, tf.int32)
37+
rounded2 = gen_ops.stochastic_round(values, 1., (123, 456))
38+
self.assertEqual(rounded2.dtype, tf.int32)
39+
rounded3 = gen_ops.stochastic_round(values, 1., (456, 789))
40+
self.assertEqual(rounded3.dtype, tf.int32)
41+
self.assertAllEqual(rounded1, rounded2)
42+
self.assertNotAllEqual(rounded1, rounded3)
43+
44+
def test_clock_seed_yields_different_output(self):
45+
values = tf.random.uniform((100,), -100., 100., dtype=tf.float32)
46+
rounded1 = gen_ops.stochastic_round(values, 1., ())
47+
self.assertEqual(rounded1.dtype, tf.int32)
48+
time.sleep(1.) # Ensure even on a low-resolution clock, we change seed.
49+
rounded2 = gen_ops.stochastic_round(values, 1., ())
50+
self.assertEqual(rounded2.dtype, tf.int32)
51+
self.assertNotAllEqual(rounded1, rounded2)
52+
53+
@parameterized.parameters(1., .75, 1e-4)
54+
def test_rounding_is_deterministic_at_integers(self, step_size):
55+
values = tf.random.uniform((100,), -100, 100, dtype=tf.int32)
56+
rounded = gen_ops.stochastic_round(
57+
step_size * tf.cast(values, tf.float32), step_size, ())
58+
self.assertEqual(rounded.dtype, tf.int32)
59+
self.assertAllEqual(values, rounded)
60+
61+
@parameterized.parameters(1., .75, 1e-4)
62+
def test_difference_at_half_integers_is_at_most_one_half(self, step_size):
63+
values = tf.range(-10, 10, dtype=tf.float32) + .5
64+
rounded = gen_ops.stochastic_round(step_size * values, step_size, ())
65+
self.assertEqual(rounded.dtype, tf.int32)
66+
self.assertAllClose(values, rounded, atol=.5, rtol=0)
67+
68+
def test_rounding_is_unbiased(self):
69+
values = tf.random.uniform((20,), -100., 100., dtype=tf.float32)
70+
replicated = tf.broadcast_to(values, (100000, 20))
71+
rounded = gen_ops.stochastic_round(replicated, 1., ())
72+
self.assertEqual(rounded.dtype, tf.int32)
73+
averaged = tf.reduce_mean(tf.cast(rounded, tf.float32), axis=0)
74+
self.assertAllClose(values, averaged, atol=5e-3, rtol=0)
75+
76+
77+
if __name__ == "__main__":
78+
tf.test.main()

0 commit comments

Comments
 (0)