Skip to content

Commit 57ff399

Browse files
Johannes Ballécopybara-github
authored andcommitted
Updates toy sources to use stateless random number generator.
PiperOrigin-RevId: 464610112 Change-Id: I686969628aea8e68973f35689270d3976c5d6b9d
1 parent 5811c40 commit 57ff399

File tree

4 files changed

+7
-6
lines changed

4 files changed

+7
-6
lines changed

models/toy_sources/ramp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def _event_shape(self):
7676
def _sample_n(self, n, seed=None):
7777
ind = self.index_points
7878
if self.phase is None:
79-
phase = tf.random.uniform((n, 1), seed=seed, dtype=self.dtype)
79+
phase = tf.random.stateless_uniform((n, 1), seed=seed, dtype=self.dtype)
8080
else:
8181
phase = tf.fill((n, 1), tf.constant(self.phase, dtype=self.dtype))
8282
return (ind + phase) % 1 - .5

models/toy_sources/sawbridge.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,15 +96,15 @@ def _event_shape(self):
9696

9797
def _sample_n(self, n, seed=None):
9898
if self.drop is None:
99-
uniform = tf.random.uniform(
99+
uniform = tf.random.stateless_uniform(
100100
(self.order, n, 1), seed=seed, dtype=self.dtype)
101101
else:
102102
uniform = tf.fill(
103103
(self.order, n, 1), tf.constant(self.drop, dtype=self.dtype))
104104
ind = self.index_points
105105
if self.stationary:
106106
if self.phase is None:
107-
phase = tf.random.uniform((n, 1), seed=seed, dtype=self.dtype)
107+
phase = tf.random.stateless_uniform((n, 1), seed=seed, dtype=self.dtype)
108108
else:
109109
phase = tf.constant(self.phase, dtype=self.dtype)
110110
ind += phase

models/toy_sources/sinusoid.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def _event_shape(self):
8080
def _sample_n(self, n, seed=None):
8181
ind = self.index_points
8282
if self.phase is None:
83-
phase = tf.random.uniform((n, 1), seed=seed, dtype=self.dtype)
83+
phase = tf.random.stateless_uniform((n, 1), seed=seed, dtype=self.dtype)
8484
else:
8585
phase = tf.fill((n, 1), tf.constant(self.phase, dtype=self.dtype))
8686
return tf.sin((2 * np.pi) * (ind + phase))

models/toy_sources/sphere.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,11 @@ def _event_shape(self):
7373
return tf.TensorShape([self.order])
7474

7575
def _sample_n(self, n, seed=None):
76-
samples = tf.random.normal((n, self.order), seed=seed, dtype=self.dtype)
76+
samples = tf.random.stateless_normal(
77+
(n, self.order), seed=seed, dtype=self.dtype)
7778
radius = tf.math.sqrt(tf.reduce_sum(tf.square(samples), -1, keepdims=True))
7879
if self.width:
79-
radius *= tf.random.uniform(
80+
radius *= tf.random.stateless_uniform(
8081
(n, 1), minval=1. - self.width / 2., maxval=1. + self.width / 2.,
8182
seed=seed, dtype=self.dtype)
8283
return samples / radius

0 commit comments

Comments
 (0)