Skip to content

Commit 3090696

Browse files
author
Johannes Ballé
committed
Merge pull request #133 from SourbhBh:master
PiperOrigin-RevId: 445518637 Change-Id: I893642f3097c2fb78a7203724f91a65d15e015be
2 parents aaedec8 + 8cd7152 commit 3090696

File tree

4 files changed

+299
-15
lines changed

4 files changed

+299
-15
lines changed

models/toy_sources/ramp.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright 2022 TensorFlow Compression contributors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Ramp process."""
16+
17+
import tensorflow as tf
18+
import tensorflow_probability as tfp
19+
20+
21+
class Ramp(tfp.distributions.Distribution):
22+
"""The "ramp": Y(t) = (t+V) mod 1 - 0.5, where V is uniform over [0, 1]."""
23+
24+
def __init__(self,
25+
index_points,
26+
phase=None,
27+
dtype=tf.float32,
28+
validate_args=False,
29+
allow_nan_stats=True,
30+
name="ramp"):
31+
"""Initializer.
32+
33+
Args:
34+
index_points: 1-D Tensor representing the locations at which to evaluate
35+
the process.
36+
phase: Float in [0,1]. Specifies a realization of V.
37+
dtype: Data type of the returned realization. Defaults to `tf.float32`.
38+
validate_args: required by `Distribution` class but unused.
39+
allow_nan_stats: required by `Distribution` class but unused.
40+
name: String. Name of the created object.
41+
"""
42+
parameters = dict(locals())
43+
with tf.name_scope(name) as name:
44+
self._index_points = tf.convert_to_tensor(
45+
index_points, dtype_hint=dtype, name="index_points")
46+
self._phase = phase
47+
super().__init__(
48+
dtype=dtype,
49+
reparameterization_type=tfp.distributions.NOT_REPARAMETERIZED,
50+
validate_args=validate_args,
51+
allow_nan_stats=allow_nan_stats,
52+
parameters=parameters,
53+
name=name,
54+
)
55+
56+
@property
57+
def index_points(self):
58+
return self._index_points
59+
60+
@property
61+
def phase(self):
62+
return self._phase
63+
64+
def _batch_shape_tensor(self):
65+
return tf.constant([], dtype=tf.int32)
66+
67+
def _batch_shape(self):
68+
return tf.TensorShape([])
69+
70+
def _event_shape_tensor(self):
71+
return tf.shape(self.index_points)
72+
73+
def _event_shape(self):
74+
return self.index_points.shape
75+
76+
def _sample_n(self, n, seed=None):
77+
ind = self.index_points
78+
if self.phase is None:
79+
phase = tf.random.uniform((n, 1), seed=seed, dtype=self.dtype)
80+
else:
81+
phase = tf.fill((n, 1), tf.constant(self.phase, dtype=self.dtype))
82+
return (ind + phase) % 1 - .5

models/toy_sources/sawbridge.py

Lines changed: 49 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,48 @@
1+
# Copyright 2020 TensorFlow Compression contributors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
115
"""Sawbridge process."""
216

317
import tensorflow as tf
418
import tensorflow_probability as tfp
519

620

721
class Sawbridge(tfp.distributions.Distribution):
8-
"""The "sawbridge": B(t) = t - 1(t > Z), where Z is uniform over [0,1]."""
22+
"""The "sawbridge": B(t) = t - 1(t > Z), where Z is uniform over [0,1].
923
10-
def __init__(self, index_points, stationary=True, order=1,
11-
dtype=tf.float32, validate_args=False, allow_nan_stats=True,
12-
name="sawbridge"):
24+
The stationary sawbridge is B(t + V mod 1), where V is uniform over [0,1] and
25+
Z, V are independent.
26+
"""
27+
28+
def __init__(self, index_points, phase=None, drop=None, stationary=True,
29+
order=1, dtype=tf.float32, validate_args=False,
30+
allow_nan_stats=True, name="sawbridge"):
1331
"""Initializer.
1432
1533
Args:
1634
index_points: 1-D `Tensor` representing the locations at which to evaluate
1735
the process. The intent is that all locations are in [0,1], but the
1836
process has a natural extrapolation outside this range so no error is
1937
thrown.
38+
phase: Float in [0,1] or `None` (default). Specifies realization of V.
39+
drop: Float in [0,1] or `None` (default). Specifies realization of Z.
2040
stationary: Boolean. Whether or not to "scramble" phase.
2141
order: Integer >= 1. The resulting process is a linear combination of
2242
`order` sawbridges.
23-
dtype: Data type of the returned realization at each timestep. Defaults to
24-
tf.float32.
25-
validate_args: required by tensorflow Distribution class but unused.
26-
allow_nan_stats: required by tensorflow Distribution class but unused.
43+
dtype: Data type of the returned realization. Defaults to `tf.float32`.
44+
validate_args: required by `Distribution` class but unused.
45+
allow_nan_stats: required by `Distribution` class but unused.
2746
name: String. Name of the created object.
2847
"""
2948
parameters = dict(locals())
@@ -32,6 +51,8 @@ def __init__(self, index_points, stationary=True, order=1,
3251
index_points, dtype_hint=dtype, name="index_points")
3352
self._stationary = bool(stationary)
3453
self._order = int(order)
54+
self._phase = phase
55+
self._drop = drop
3556
super().__init__(
3657
dtype=dtype,
3758
reparameterization_type=tfp.distributions.NOT_REPARAMETERIZED,
@@ -53,6 +74,14 @@ def stationary(self):
5374
def order(self):
5475
return self._order
5576

77+
@property
78+
def phase(self):
79+
return self._phase
80+
81+
@property
82+
def drop(self):
83+
return self._drop
84+
5685
def _batch_shape_tensor(self):
5786
return tf.constant([], dtype=tf.int32)
5887

@@ -66,16 +95,21 @@ def _event_shape(self):
6695
return self.index_points.shape
6796

6897
def _sample_n(self, n, seed=None):
69-
uniform = tf.random.uniform((self.order, n), seed=seed, dtype=self.dtype)
98+
if self.drop is None:
99+
uniform = tf.random.uniform(
100+
(self.order, n, 1), seed=seed, dtype=self.dtype)
101+
else:
102+
uniform = tf.fill(
103+
(self.order, n, 1), tf.constant(self.drop, dtype=self.dtype))
70104
ind = self.index_points
71-
# ind shape: (time)
72-
uniform = tf.expand_dims(uniform, -1)
73-
# uniform shape: (order, n, 1)
74105
if self.stationary:
75-
ind += tf.random.uniform((n, 1), seed=seed, dtype=self.dtype)
106+
if self.phase is None:
107+
phase = tf.random.uniform((n, 1), seed=seed, dtype=self.dtype)
108+
else:
109+
phase = tf.constant(self.phase, dtype=self.dtype)
110+
ind += phase
76111
ind %= 1.
77-
less = tf.less(uniform, ind)
78-
# less shape: (order, n, time)
112+
less = tf.less(uniform, ind) # shape: (order, n, time)
79113
# Note:
80114
# ind[n] == 1 -> sample[n] == 0 always
81115
# ind[n] == 0 -> sample[n] == 0 always

models/toy_sources/sinusoid.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Copyright 2022 TensorFlow Compression contributors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Sinusoid process."""
16+
17+
import numpy as np
18+
import tensorflow as tf
19+
import tensorflow_probability as tfp
20+
21+
22+
class Sinusoid(tfp.distributions.Distribution):
23+
"""The "sinusoid": P(t) = sin(2pi(t+V)), where V is uniform over [0,1]."""
24+
25+
def __init__(self,
26+
index_points,
27+
phase=None,
28+
dtype=tf.float32,
29+
validate_args=False,
30+
allow_nan_stats=True,
31+
name="sinusoid"):
32+
"""Initializer.
33+
34+
Args:
35+
index_points: 1-D `Tensor` representing the locations at which to
36+
evaluate the process. The intent is that all locations are in [0,1],
37+
but the process has a natural extrapolation outside this range so no
38+
error is thrown.
39+
phase: Float in [0,1] or `None` (default). Specifies realization of V.
40+
dtype: Data type of the returned realization at each timestep. Defaults
41+
to tf.float32.
42+
validate_args: required by tensorflow Distribution class but unused.
43+
allow_nan_stats: required by tensorflow Distribution class but unused.
44+
name: String. Name of the created object.
45+
"""
46+
parameters = dict(locals())
47+
with tf.name_scope(name) as name:
48+
self._index_points = tf.convert_to_tensor(
49+
index_points, dtype_hint=dtype, name="index_points")
50+
self._phase = phase
51+
super().__init__(
52+
dtype=dtype,
53+
reparameterization_type=tfp.distributions.NOT_REPARAMETERIZED,
54+
validate_args=validate_args,
55+
allow_nan_stats=allow_nan_stats,
56+
parameters=parameters,
57+
name=name,
58+
)
59+
60+
@property
61+
def index_points(self):
62+
return self._index_points
63+
64+
@property
65+
def phase(self):
66+
return self._phase
67+
68+
def _batch_shape_tensor(self):
69+
return tf.constant([], dtype=tf.int32)
70+
71+
def _batch_shape(self):
72+
return tf.TensorShape([])
73+
74+
def _event_shape_tensor(self):
75+
return tf.shape(self.index_points)
76+
77+
def _event_shape(self):
78+
return self.index_points.shape
79+
80+
def _sample_n(self, n, seed=None):
81+
ind = self.index_points
82+
if self.phase is None:
83+
phase = tf.random.uniform((n, 1), seed=seed, dtype=self.dtype)
84+
else:
85+
phase = tf.fill((n, 1), tf.constant(self.phase, dtype=self.dtype))
86+
return tf.sin((2 * np.pi) * (ind + phase))

models/toy_sources/sphere.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright 2022 TensorFlow Compression contributors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Sphere process."""
16+
17+
import tensorflow as tf
18+
import tensorflow_probability as tfp
19+
20+
21+
class Sphere(tfp.distributions.Distribution):
22+
"""Uniform distribution on the unit hypersphere."""
23+
24+
def __init__(self,
25+
order=2,
26+
width=0.,
27+
dtype=tf.float32,
28+
validate_args=False,
29+
allow_nan_stats=True,
30+
name="sphere"):
31+
"""Initializer.
32+
33+
Arguments:
34+
order: Integer >= 1. The dimensionality of the sphere.
35+
width: Float in [0,1]. Allows for realizations to be approximately
36+
uniformly distributed in a band between radius `1 - width` and
37+
`1 + width` (for `width` << 1).
38+
dtype: Data type of the returned realization. Defaults to `tf.float32`.
39+
validate_args: required by `Distribution` class but unused.
40+
allow_nan_stats: required by `Distribution` class but unused.
41+
name: String. Name of the created object.
42+
"""
43+
parameters = dict(locals())
44+
self._order = int(order)
45+
self._width = float(width)
46+
super().__init__(
47+
dtype=dtype,
48+
reparameterization_type=tfp.distributions.NOT_REPARAMETERIZED,
49+
validate_args=validate_args,
50+
allow_nan_stats=allow_nan_stats,
51+
parameters=parameters,
52+
name=name,
53+
)
54+
55+
@property
56+
def order(self):
57+
return self._order
58+
59+
@property
60+
def width(self):
61+
return self._width
62+
63+
def _batch_shape_tensor(self):
64+
return tf.constant([], dtype=tf.int32)
65+
66+
def _batch_shape(self):
67+
return tf.TensorShape([])
68+
69+
def _event_shape_tensor(self):
70+
return tf.constant([self.order], dtype=tf.int32)
71+
72+
def _event_shape(self):
73+
return tf.TensorShape([self.order])
74+
75+
def _sample_n(self, n, seed=None):
76+
samples = tf.random.normal((n, self.order), seed=seed, dtype=self.dtype)
77+
radius = tf.math.sqrt(tf.reduce_sum(tf.square(samples), -1, keepdims=True))
78+
if self.width:
79+
radius *= tf.random.uniform(
80+
(n, 1), minval=1. - self.width / 2., maxval=1. + self.width / 2.,
81+
seed=seed, dtype=self.dtype)
82+
return samples / radius

0 commit comments

Comments
 (0)