Skip to content

Commit 1a9b176

Browse files
Johannes Ballécopybara-github
authored andcommitted
Updates ops to TF2 API and removes remaining PY2 idioms.
PiperOrigin-RevId: 354751327 Change-Id: Ieb1198645d56a663316a99950d2e230d6c20f28a
1 parent 575df27 commit 1a9b176

File tree

10 files changed

+149
-273
lines changed

10 files changed

+149
-273
lines changed

tensorflow_compression/python/ops/BUILD

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -10,30 +10,14 @@ py_library(
1010
srcs_version = "PY3",
1111
)
1212

13-
py_library(
14-
name = "padding_ops",
15-
srcs = ["padding_ops.py"],
16-
srcs_version = "PY3",
17-
)
18-
19-
py_library(
20-
name = "range_coding_ops",
21-
srcs = ["range_coding_ops.py"],
22-
data = ["//tensorflow_compression/cc:libtensorflow_compression.so"],
23-
srcs_version = "PY3",
24-
deps = [":namespace_helper"],
25-
)
26-
27-
py_library(
28-
name = "soft_round_ops",
29-
srcs = ["soft_round_ops.py"],
30-
srcs_version = "PY3",
31-
)
32-
33-
py_library(
34-
name = "spectral_ops",
35-
srcs = ["spectral_ops.py"],
36-
srcs_version = "PY3",
13+
py_test(
14+
name = "math_ops_test",
15+
srcs = ["math_ops_test.py"],
16+
python_version = "PY3",
17+
deps = [
18+
":math_ops",
19+
":soft_round_ops",
20+
],
3721
)
3822

3923
py_library(
@@ -43,14 +27,10 @@ py_library(
4327
visibility = ["//visibility:private"],
4428
)
4529

46-
py_test(
47-
name = "math_ops_test",
48-
srcs = ["math_ops_test.py"],
49-
python_version = "PY3",
50-
deps = [
51-
":math_ops",
52-
":soft_round_ops",
53-
],
30+
py_library(
31+
name = "padding_ops",
32+
srcs = ["padding_ops.py"],
33+
srcs_version = "PY3",
5434
)
5535

5636
py_test(
@@ -60,18 +40,25 @@ py_test(
6040
deps = [":padding_ops"],
6141
)
6242

43+
py_library(
44+
name = "range_coding_ops",
45+
srcs = ["range_coding_ops.py"],
46+
data = ["//tensorflow_compression/cc:libtensorflow_compression.so"],
47+
srcs_version = "PY3",
48+
deps = [":namespace_helper"],
49+
)
50+
6351
py_test(
6452
name = "range_coding_ops_test",
6553
srcs = ["range_coding_ops_test.py"],
6654
python_version = "PY3",
6755
deps = [":range_coding_ops"],
6856
)
6957

70-
py_test(
71-
name = "spectral_ops_test",
72-
srcs = ["spectral_ops_test.py"],
73-
python_version = "PY3",
74-
deps = [":spectral_ops"],
58+
py_library(
59+
name = "soft_round_ops",
60+
srcs = ["soft_round_ops.py"],
61+
srcs_version = "PY3",
7562
)
7663

7764
py_test(
@@ -81,6 +68,19 @@ py_test(
8168
deps = [":soft_round_ops"],
8269
)
8370

71+
py_library(
72+
name = "spectral_ops",
73+
srcs = ["spectral_ops.py"],
74+
srcs_version = "PY3",
75+
)
76+
77+
py_test(
78+
name = "spectral_ops_test",
79+
srcs = ["spectral_ops_test.py"],
80+
python_version = "PY3",
81+
deps = [":spectral_ops"],
82+
)
83+
8484
filegroup(
8585
name = "py_src",
8686
srcs = glob(["*.py"]),

tensorflow_compression/python/ops/math_ops.py

Lines changed: 57 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,7 @@
1414
# ==============================================================================
1515
"""Math operations."""
1616

17-
from __future__ import absolute_import
18-
from __future__ import division
19-
from __future__ import print_function
20-
21-
import tensorflow.compat.v1 as tf
17+
import tensorflow as tf
2218

2319

2420
__all__ = [
@@ -28,54 +24,8 @@
2824
]
2925

3026

31-
@tf.RegisterGradient("IdentityFirstOfTwoInputs")
32-
def _identity_first_of_two_inputs_grad(op, grad):
33-
"""Gradient for `lower_bound` or `upper_bound` if `gradient == 'identity'`.
34-
35-
Args:
36-
op: The op for which to calculate a gradient.
37-
grad: Gradient with respect to the output of the op.
38-
39-
Returns:
40-
Gradient with respect to the inputs of the op.
41-
"""
42-
del op # unused
43-
return [grad, None]
44-
45-
46-
@tf.RegisterGradient("UpperBound")
47-
def _upper_bound_grad(op, grad):
48-
"""Gradient for `upper_bound` if `gradient == 'identity_if_towards'`.
49-
50-
Args:
51-
op: The op for which to calculate a gradient.
52-
grad: Gradient with respect to the output of the op.
53-
54-
Returns:
55-
Gradient with respect to the inputs of the op.
56-
"""
57-
inputs, bound = op.inputs
58-
pass_through_if = tf.logical_or(inputs <= bound, grad > 0)
59-
return [tf.cast(pass_through_if, grad.dtype) * grad, None]
60-
61-
62-
@tf.RegisterGradient("LowerBound")
63-
def _lower_bound_grad(op, grad):
64-
"""Gradient for `lower_bound` if `gradient == 'identity_if_towards'`.
65-
66-
Args:
67-
op: The op for which to calculate a gradient.
68-
grad: Gradient with respect to the output of the op.
69-
70-
Returns:
71-
Gradient with respect to the inputs of the op.
72-
"""
73-
inputs, bound = op.inputs
74-
pass_through_if = tf.logical_or(inputs >= bound, grad < 0)
75-
return [tf.cast(pass_through_if, grad.dtype) * grad, None]
76-
77-
78-
def upper_bound(inputs, bound, gradient="identity_if_towards", name=None):
27+
def upper_bound(inputs, bound, gradient="identity_if_towards",
28+
name="upper_bound"):
7929
"""Same as `tf.minimum`, but with helpful gradient for `inputs > bound`.
8030
8131
This function behaves just like `tf.minimum`, but the behavior of the gradient
@@ -110,27 +60,37 @@ def upper_bound(inputs, bound, gradient="identity_if_towards", name=None):
11060
Raises:
11161
ValueError: for invalid value of `gradient`.
11262
"""
113-
try:
114-
gradient = {
115-
"identity_if_towards": "UpperBound",
116-
"identity": "IdentityFirstOfTwoInputs",
117-
"disconnected": None,
118-
}[gradient]
119-
except KeyError:
120-
raise ValueError("Invalid value for `gradient`: '{}'.".format(gradient))
121-
122-
with tf.name_scope(name, "UpperBound", [inputs, bound]) as scope:
63+
with tf.name_scope(name) as scope:
12364
inputs = tf.convert_to_tensor(inputs, name="inputs")
124-
bound = tf.convert_to_tensor(
125-
bound, name="bound", dtype=inputs.dtype)
126-
if gradient:
127-
with tf.get_default_graph().gradient_override_map({"Minimum": gradient}):
128-
return tf.minimum(inputs, bound, name=scope)
129-
else:
130-
return tf.minimum(inputs, bound, name=scope)
65+
bound = tf.convert_to_tensor(bound, name="bound", dtype=inputs.dtype)
66+
67+
def identity_if_towards_grad(grad):
68+
"""Gradient if gradient == 'identity_if_towards'."""
69+
pass_through_if = tf.logical_or(inputs <= bound, grad > 0)
70+
return (tf.cast(pass_through_if, grad.dtype) * grad, None)
71+
72+
def disconnected_grad(grad):
73+
"""Gradient if gradient == 'disconnected'."""
74+
return (tf.cast(inputs <= bound, grad.dtype) * grad, None)
75+
76+
try:
77+
gradient = {
78+
"identity_if_towards": identity_if_towards_grad,
79+
"identity": lambda grad: (grad, None),
80+
"disconnected": disconnected_grad,
81+
}[gradient]
82+
except KeyError:
83+
raise ValueError("Invalid value for `gradient`: '{}'.".format(gradient))
84+
85+
@tf.custom_gradient
86+
def _upper_bound(inputs, bound):
87+
return tf.minimum(inputs, bound, name=scope), gradient
88+
89+
return _upper_bound(inputs, bound)
13190

13291

133-
def lower_bound(inputs, bound, gradient="identity_if_towards", name=None):
92+
def lower_bound(inputs, bound, gradient="identity_if_towards",
93+
name="lower_bound"):
13494
"""Same as `tf.maximum`, but with helpful gradient for `inputs < bound`.
13595
13696
This function behaves just like `tf.maximum`, but the behavior of the gradient
@@ -165,24 +125,33 @@ def lower_bound(inputs, bound, gradient="identity_if_towards", name=None):
165125
Raises:
166126
ValueError: for invalid value of `gradient`.
167127
"""
168-
try:
169-
gradient = {
170-
"identity_if_towards": "LowerBound",
171-
"identity": "IdentityFirstOfTwoInputs",
172-
"disconnected": None,
173-
}[gradient]
174-
except KeyError:
175-
raise ValueError("Invalid value for `gradient`: '{}'.".format(gradient))
176-
177-
with tf.name_scope(name, "LowerBound", [inputs, bound]) as scope:
128+
with tf.name_scope(name) as scope:
178129
inputs = tf.convert_to_tensor(inputs, name="inputs")
179-
bound = tf.convert_to_tensor(
180-
bound, name="bound", dtype=inputs.dtype)
181-
if gradient:
182-
with tf.get_default_graph().gradient_override_map({"Maximum": gradient}):
183-
return tf.maximum(inputs, bound, name=scope)
184-
else:
185-
return tf.maximum(inputs, bound, name=scope)
130+
bound = tf.convert_to_tensor(bound, name="bound", dtype=inputs.dtype)
131+
132+
def identity_if_towards_grad(grad):
133+
"""Gradient if gradient == 'identity_if_towards'."""
134+
pass_through_if = tf.logical_or(inputs >= bound, grad < 0)
135+
return (tf.cast(pass_through_if, grad.dtype) * grad, None)
136+
137+
def disconnected_grad(grad):
138+
"""Gradient if gradient == 'disconnected'."""
139+
return (tf.cast(inputs >= bound, grad.dtype) * grad, None)
140+
141+
try:
142+
gradient = {
143+
"identity_if_towards": identity_if_towards_grad,
144+
"identity": lambda grad: (grad, None),
145+
"disconnected": disconnected_grad,
146+
}[gradient]
147+
except KeyError:
148+
raise ValueError("Invalid value for `gradient`: '{}'.".format(gradient))
149+
150+
@tf.custom_gradient
151+
def _lower_bound(inputs, bound):
152+
return tf.maximum(inputs, bound, name=scope), gradient
153+
154+
return _lower_bound(inputs, bound)
186155

187156

188157
def perturb_and_apply(f, x, *args, u=None, x_plus_u=None, expected_grads=True):

0 commit comments

Comments
 (0)