Skip to content

Commit e816859

Browse files
SiegeLordExtensorflower-gardener
authored andcommitted
FunMC: Add AIS/SMC.
PiperOrigin-RevId: 492849049
1 parent 6fc9a9e commit e816859

File tree

5 files changed

+489
-40
lines changed

5 files changed

+489
-40
lines changed

spinoffs/fun_mc/fun_mc/dynamic/backend_jax/tf_on_jax.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,10 @@ def _get_static_value(value):
161161

162162
tf.newaxis = None
163163

164+
_impl_np()(jnp.cumsum)
164165
_impl_np()(jnp.exp)
165166
_impl_np()(jnp.einsum)
167+
_impl_np()(jnp.floor)
166168
_impl_np()(jnp.float32)
167169
_impl_np()(jnp.float64)
168170
_impl_np()(jnp.int32)
@@ -179,20 +181,22 @@ def _get_static_value(value):
179181
_impl_np()(jnp.zeros_like)
180182
_impl_np()(jnp.transpose)
181183
_impl_np(name='fill')(jnp.full)
184+
_impl_np(['nn'])(jax.nn.softmax)
182185
_impl_np(['math'])(jnp.ceil)
183186
_impl_np(['math'])(jnp.log)
184187
_impl_np(['math'], name='mod')(jnp.mod)
185188
_impl_np(['math'])(jnp.sqrt)
186189
_impl_np(['math'], name='is_finite')(jnp.isfinite)
190+
_impl_np(['math'], name='is_nan')(jnp.isnan)
187191
_impl_np(['math'], name='pow')(jnp.power)
192+
_impl_np(['math'], name='reduce_all')(jnp.all)
188193
_impl_np(['math'], name='reduce_prod')(jnp.prod)
189194
_impl_np(['math'], name='reduce_variance')(jnp.var)
190195
_impl_np(name='abs')(jnp.abs)
191196
_impl_np(name='Tensor')(jnp.ndarray)
192197
_impl_np(name='concat')(jnp.concatenate)
193198
_impl_np(name='constant')(jnp.array)
194199
_impl_np(name='expand_dims')(jnp.expand_dims)
195-
_impl_np(['math'], name='reduce_all')(jnp.all)
196200
_impl_np(name='reduce_max')(jnp.max)
197201
_impl_np(name='reduce_mean')(jnp.mean)
198202
_impl_np(name='reduce_sum')(jnp.sum)

spinoffs/fun_mc/fun_mc/dynamic/backend_jax/util.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@
2020
from jax import lax
2121
from jax import random
2222
from jax import tree_util
23-
from jax.example_libraries import stax
2423
import jax.numpy as jnp
2524

2625
__all__ = [
2726
'assert_same_shallow_tree',
2827
'block_until_ready',
28+
'diff',
2929
'flatten_tree',
3030
'get_shallow_tree',
3131
'inverse_fn',
@@ -38,6 +38,7 @@
3838
'random_integer',
3939
'random_normal',
4040
'random_uniform',
41+
'repeat',
4142
'split_seed',
4243
'trace',
4344
'value_and_grad',
@@ -143,7 +144,7 @@ def body(state):
143144

144145
def random_categorical(logits, num_samples, seed):
145146
"""Returns a sample from a categorical distribution. `logits` must be 2D."""
146-
probs = stax.softmax(logits)
147+
probs = jax.nn.softmax(logits)
147148
cum_sum = jnp.cumsum(probs, axis=-1)
148149

149150
eta = random.uniform(
@@ -211,6 +212,7 @@ def wrapper(i, state_untraced_traced):
211212
state, untraced, traced = fn(state)
212213
trace_arrays = map_tree(lambda a, e: a.at[i].set(e), trace_arrays, traced)
213214
return (state, untraced, trace_arrays)
215+
214216
state, untraced, traced = lax.fori_loop(
215217
jnp.asarray(0, num_steps.dtype),
216218
num_steps,
@@ -250,7 +252,6 @@ def scale_by_two(x):
250252
assert y_extra == 3
251253
assert y_ldj == jnp.log(2)
252254
```
253-
254255
"""
255256
value, (extra, ldj) = fn(args)
256257
return value, (extra, ldj), ldj
@@ -307,11 +308,13 @@ def block_until_ready(tensors):
307308
Returns:
308309
tensors: Tensors that are are guaranteed to be ready to materialize.
309310
"""
311+
310312
def _block_until_ready(tensor):
311313
if hasattr(tensor, 'block_until_ready'):
312314
return tensor.block_until_ready()
313315
else:
314316
return tensor
317+
315318
return map_tree(_block_until_ready, tensors)
316319

317320

@@ -326,3 +329,13 @@ def named_call(f=None, name=None):
326329
return functools.partial(named_call, name=name)
327330

328331
return jax.named_call(f, name=name)
332+
333+
334+
def diff(x, prepend=None):
335+
"""Like jnp.diff."""
336+
return jnp.diff(x, prepend=prepend)
337+
338+
339+
def repeat(x, repeats, total_repeat_length=None):
340+
"""Like jnp.repeat."""
341+
return jnp.repeat(x, repeats, total_repeat_length=total_repeat_length)

spinoffs/fun_mc/fun_mc/dynamic/backend_tensorflow/util.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
__all__ = [
2727
'assert_same_shallow_tree',
2828
'block_until_ready',
29+
'diff',
2930
'flatten_tree',
3031
'get_shallow_tree',
3132
'inverse_fn',
@@ -38,6 +39,7 @@
3839
'random_integer',
3940
'random_normal',
4041
'random_uniform',
42+
'repeat',
4143
'split_seed',
4244
'trace',
4345
'value_and_ldj',
@@ -177,7 +179,9 @@ def trace(state, fn, num_steps, unroll, parallel_iterations=10):
177179
state, first_untraced, first_traced = fn(state)
178180
arrays = tf.nest.map_structure(
179181
lambda v: tf.TensorArray( # pylint: disable=g-long-lambda
180-
v.dtype, size=num_steps, element_shape=v.shape).write(0, v),
182+
v.dtype,
183+
size=num_steps,
184+
element_shape=v.shape).write(0, v),
181185
first_traced)
182186
start_idx = 1
183187
else:
@@ -189,7 +193,10 @@ def trace(state, fn, num_steps, unroll, parallel_iterations=10):
189193

190194
arrays = tf.nest.map_structure(
191195
lambda spec: tf.TensorArray( # pylint: disable=g-long-lambda
192-
spec.dtype, size=num_steps, element_shape=spec.shape), traced_spec)
196+
spec.dtype,
197+
size=num_steps,
198+
element_shape=spec.shape),
199+
traced_spec)
193200
first_untraced = tf.nest.map_structure(
194201
lambda spec: tf.zeros(spec.shape, spec.dtype), untraced_spec)
195202
start_idx = 0
@@ -266,7 +273,6 @@ def scale_by_two(x):
266273
assert y_extra == 3
267274
assert y_ldj == np.log(2)
268275
```
269-
270276
"""
271277
value, (extra, ldj) = fn(args)
272278
return value, (extra, ldj), ldj
@@ -348,3 +354,19 @@ def wrapped(*args, **kwargs):
348354
return f(*args, **kwargs)
349355

350356
return wrapped
357+
358+
359+
def diff(x, prepend=None):
360+
"""Like jnp.diff."""
361+
if prepend is not None:
362+
x = tf.concat([tf.convert_to_tensor(prepend, dtype=x.dtype)[tf.newaxis], x],
363+
0)
364+
return x[1:] - x[:-1]
365+
366+
367+
def repeat(x, repeats, total_repeat_length=None):
368+
"""Like jnp.repeat."""
369+
res = tf.repeat(x, repeats)
370+
if total_repeat_length is not None:
371+
res.set_shape([total_repeat_length] + [None] * (len(res.shape) - 1))
372+
return res

0 commit comments

Comments
 (0)