Skip to content

Commit 8415c37

Browse files
SiegeLordExtensorflower-gardener
authored andcommitted
Inference Gym: Add a smoothing option to PlasmaSpectroscopy problem.
I didn't add a new dataset, as I wasn't able to produce a good example of multimodality with this setting on. PiperOrigin-RevId: 380846216
1 parent b351dbe commit 8415c37

File tree

3 files changed

+24
-2
lines changed

3 files changed

+24
-2
lines changed

spinoffs/inference_gym/inference_gym/targets/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,7 @@ py_test(
393393
srcs_version = "PY3",
394394
deps = [
395395
":plasma_spectroscopy",
396+
# absl/testing:parameterized dep,
396397
# numpy dep,
397398
# tensorflow dep,
398399
# tensorflow_probability/python/internal:test_util dep,

spinoffs/inference_gym/inference_gym/targets/plasma_spectroscopy.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def __init__(
9292
velocity_scale=0.5,
9393
absolute_noise_scale=0.5,
9494
relative_noise_scale=0.05,
95+
use_bump_function=False,
9596
name='plasma_spectroscopy',
9697
pretty_name='Plasma Spectroscopy',
9798
):
@@ -124,6 +125,8 @@ def __init__(
124125
observation noise.
125126
relative_noise_scale: Float `Tensor` scalar. Absolute noise scale of the
126127
observation noise.
128+
use_bump_function: Python `bool`. If True, use a bump function to smoothly
129+
decay the plasma blob to 0 at the edges.
127130
name: Python `str` name prefixed to Ops created by this class.
128131
pretty_name: A Python `str`. The pretty name of this model.
129132
@@ -151,6 +154,7 @@ def __init__(
151154
self._velocity_scale = velocity_scale
152155
self._absolute_noise_scale = absolute_noise_scale
153156
self._relative_noise_scale = relative_noise_scale
157+
self._use_bump_function = use_bump_function
154158

155159
@tfd.JointDistributionCoroutine
156160
def prior():
@@ -262,6 +266,10 @@ def absolute_noise_scale(self):
262266
def relative_noise_scale(self):
263267
return self._relative_noise_scale
264268

269+
@property
270+
def use_bump_function(self):
271+
return self._use_bump_function
272+
265273
def forward_model(self, sample):
266274
"""The forward model.
267275
@@ -314,6 +322,10 @@ def forward_model(self, sample):
314322
-(wavelengths[:, tf.newaxis] - doppler_shifted_center_wavelength)**2
315323
/ (2 * bandwidth**2))
316324

325+
if self.use_bump_function:
326+
emissivity *= tfp.math.round_exponential_bump_function(
327+
tf.linspace(-1., 1., self.num_shells))
328+
317329
x = tf.linspace(-self.outer_shell_radius, self.outer_shell_radius,
318330
self.num_integration_points)
319331
y = tf.linspace(-self.sensor_span, self.sensor_span, self.num_sensors)

spinoffs/inference_gym/inference_gym/targets/plasma_spectroscopy_test.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# ============================================================================
1616
"""Tests for inference_gym.targets.plasma_spectroscopy."""
1717

18+
from absl.testing import parameterized
1819
import numpy as np
1920
import tensorflow.compat.v2 as tf
2021

@@ -36,13 +37,21 @@ def _test_dataset():
3637
@test_util.multi_backend_test(globals(), 'targets.plasma_spectroscopy_test')
3738
class PlasmaSpectroscopyTest(test_util.InferenceGymTestCase):
3839

39-
def testBasic(self):
40+
@parameterized.named_parameters(
41+
('Smooth', True),
42+
('NotSmooth', False),
43+
)
44+
def testBasic(self, use_bump_function):
4045
"""Checks that you get finite values given unconstrained samples.
4146
4247
We check `unnormalized_log_prob` as well as the values of the sample
4348
transformations.
49+
50+
Args:
51+
use_bump_function: Whether to use the bump function.
4452
"""
45-
model = plasma_spectroscopy.PlasmaSpectroscopy(**_test_dataset())
53+
model = plasma_spectroscopy.PlasmaSpectroscopy(
54+
**_test_dataset(), use_bump_function=use_bump_function)
4655
self.validate_log_prob_and_transforms(
4756
model,
4857
sample_transformation_shapes=dict(

0 commit comments

Comments
 (0)