Skip to content

Commit 26919df

Browse files
SiegeLordExjburnim
authored andcommitted
Add tf.nn.conv to the JAX substrate.
Also, remove another line adjusting rewrite from the NumPy rewrite system. PiperOrigin-RevId: 396021511
1 parent 636d2e4 commit 26919df

File tree

3 files changed

+160
-4
lines changed

3 files changed

+160
-4
lines changed

tensorflow_probability/python/internal/backend/jax/rewrite.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def main(argv):
4545
contents = contents.replace('scipy.special', 'jax.scipy.special')
4646
if FLAGS.rewrite_numpy_import:
4747
contents = contents.replace('\nimport numpy as np',
48-
'\nimport numpy as onp\nimport jax.numpy as np')
48+
'\nimport numpy as onp; import jax.numpy as np')
4949
else:
5050
contents = contents.replace('\nimport numpy as np',
5151
'\nimport numpy as np; onp = np')

tensorflow_probability/python/internal/backend/numpy/nn.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from __future__ import division
1919
from __future__ import print_function
2020

21+
import collections
22+
2123
# Dependency imports
2224
import numpy as np
2325

@@ -35,6 +37,7 @@
3537

3638

3739
__all__ = [
40+
'conv2d',
3841
'l2_normalize',
3942
'log_softmax',
4043
'moments',
@@ -48,6 +51,106 @@
4851
]
4952

5053

54+
JAX_MODE = False
55+
56+
57+
if JAX_MODE:
58+
import jax # pylint: disable=g-import-not-at-top
59+
60+
61+
# Borrowed from TensorFlow.
62+
def _get_sequence(value, n, channel_index, name):
63+
"""Formats a value input for gen_nn_ops."""
64+
# Performance is fast-pathed for common cases:
65+
# `None`, `list`, `tuple` and `int`.
66+
if value is None:
67+
return [1] * (n + 2)
68+
69+
# Always convert `value` to a `list`.
70+
if isinstance(value, list):
71+
pass
72+
elif isinstance(value, tuple):
73+
value = list(value)
74+
elif isinstance(value, int):
75+
value = [value]
76+
elif not isinstance(value, collections.abc.Sized):
77+
value = [value]
78+
else:
79+
value = list(value) # Try casting to a list.
80+
81+
len_value = len(value)
82+
83+
# Fully specified, including batch and channel dims.
84+
if len_value == n + 2:
85+
return value
86+
87+
# Apply value to spatial dims only.
88+
if len_value == 1:
89+
value = value * n # Broadcast to spatial dimensions.
90+
elif len_value != n:
91+
raise ValueError('{} should be of length 1, {} or {} but was {}'.format(
92+
name, n, n + 2, len_value))
93+
94+
# Add batch and channel dims (always 1).
95+
if channel_index == 1:
96+
return [1, 1] + value
97+
else:
98+
return [1] + value + [1]
99+
100+
101+
def _conv2d(
102+
input, # pylint: disable=redefined-builtin
103+
filters,
104+
strides,
105+
padding,
106+
data_format='NHWC',
107+
dilations=None,
108+
name=None):
109+
"""tf.nn.conv2d implementation."""
110+
del name
111+
if not JAX_MODE:
112+
raise NotImplementedError('tf.nn.conv2d not implemented in NumPy.')
113+
114+
if dilations is not None:
115+
raise ValueError('Dilations not yet supported')
116+
117+
channel_index = 1 if data_format.startswith('NC') else 3
118+
119+
window_strides = _get_sequence(strides, 2, channel_index, 'strides')
120+
if window_strides[0] != 1:
121+
raise ValueError(
122+
f'Stride != 1 not supported for batch dimension. `strides`: {strides} '
123+
f'`data_format`: {data_format}')
124+
if window_strides[channel_index] != 1:
125+
raise ValueError(
126+
'Stride != 1 not supported for channel dimension. `strides`: '
127+
f'{strides} '
128+
f'`data_format`: {data_format}')
129+
window_strides = [
130+
e for i, e in enumerate(window_strides) if i not in [0, channel_index]
131+
]
132+
133+
if isinstance(padding, (list, tuple)):
134+
if padding[0] != (0, 0):
135+
raise ValueError(
136+
f'Padding not supported for batch dimension. `padding`: {padding} '
137+
f'`data_format`: {data_format}')
138+
if padding[channel_index] != (0, 0):
139+
raise ValueError(
140+
'Padding not supported for channel dimension. `padding`: '
141+
f'{padding} '
142+
f'`data_format`: {data_format}')
143+
padding = [s for i, s in enumerate(padding) if i not in [0, channel_index]]
144+
145+
return jax.lax.conv_general_dilated(
146+
lhs=input,
147+
rhs=filters,
148+
window_strides=window_strides,
149+
padding=padding,
150+
dimension_numbers=(data_format, 'HWIO', data_format),
151+
)
152+
153+
51154
def _sigmoid_cross_entropy_with_logits( # pylint: disable=invalid-name,unused-argument
52155
_sentinel=None,
53156
labels=None,
@@ -96,6 +199,11 @@ def _softmax_cross_entropy_with_logits( # pylint: disable=invalid-name,unused-a
96199

97200
# --- Begin Public Functions --------------------------------------------------
98201

202+
conv2d = utils.copy_docstring(
203+
'tf.nn.conv2d',
204+
_conv2d)
205+
206+
99207
l2_normalize = utils.copy_docstring(
100208
'tf.nn.l2_normalize',
101209
l2_normalize)

tensorflow_probability/python/internal/backend/numpy/numpy_test.py

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -616,6 +616,54 @@ def argsort_params(draw):
616616
True)) # stable sort
617617

618618

619+
@hps.composite
620+
def conv2d_params(draw):
621+
# NCHW is GPU-only
622+
# data_format = draw(hps.sampled_from(['NHWC', 'NCHW']))
623+
data_format = draw(hps.just('NHWC'))
624+
625+
input_shape = draw(shapes(4, 4, min_side=5, max_side=10))
626+
if data_format.startswith('NC'):
627+
channels = input_shape[1]
628+
else:
629+
channels = input_shape[3]
630+
filter_shape = draw(shapes(3, 3, min_side=1, max_side=4))
631+
filter_shape = filter_shape[:2] + (channels, filter_shape[-1])
632+
633+
input_ = draw(
634+
single_arrays(
635+
batch_shape=(),
636+
shape=hps.just(input_shape),
637+
))
638+
filters = draw(single_arrays(
639+
batch_shape=(),
640+
shape=hps.just(filter_shape),
641+
))
642+
small = hps.integers(0, 5)
643+
small_pos = hps.integers(1, 5)
644+
strides = draw(hps.one_of(small_pos, hps.tuples(small_pos, small_pos)))
645+
if isinstance(strides, tuple) and len(strides) == 2 and draw(hps.booleans()):
646+
if data_format.startswith('NC'):
647+
strides = (1, 1) + strides
648+
else:
649+
strides = (1,) + strides + (1,)
650+
651+
zeros = (0, 0)
652+
explicit_padding = (
653+
draw(hps.tuples(small, small)),
654+
draw(hps.tuples(small, small)),
655+
)
656+
if data_format.startswith('NC'):
657+
explicit_padding = (zeros, zeros) + explicit_padding
658+
else:
659+
explicit_padding = (zeros,) + explicit_padding + (zeros,)
660+
padding = draw(
661+
hps.one_of(
662+
hps.just(explicit_padding), hps.sampled_from(['SAME', 'VALID'])))
663+
664+
return (input_, filters, strides, padding, data_format)
665+
666+
619667
@hps.composite
620668
def sparse_xent_params(draw):
621669
num_classes = draw(hps.integers(1, 6))
@@ -1223,6 +1271,7 @@ def _not_implemented(*args, **kwargs):
12231271
[n_same_shape(n=2, elements=[floats(), positive_floats()])]),
12241272
TestCase('math.xlog1py',
12251273
[n_same_shape(n=2, elements=[floats(), positive_floats()])]),
1274+
TestCase('nn.conv2d', [conv2d_params()], disabled=NUMPY_MODE),
12261275
TestCase(
12271276
'nn.sparse_softmax_cross_entropy_with_logits', [sparse_xent_params()],
12281277
rtol=1e-4,
@@ -1276,9 +1325,8 @@ def _not_implemented(*args, **kwargs):
12761325
xla_disabled=True),
12771326
TestCase('histogram_fixed_width_bins',
12781327
[histogram_fixed_width_bins_params()]),
1279-
TestCase(
1280-
'argsort', [argsort_params()],
1281-
xla_const_args=(1, 2, 3)), # axis, direction, stable-sort
1328+
TestCase('argsort', [argsort_params()],
1329+
xla_const_args=(1, 2, 3)), # axis, direction, stable-sort
12821330
]
12831331

12841332

0 commit comments

Comments
 (0)