Skip to content

Commit 9d4fc05

Browse files
SiegeLordExtensorflower-gardener
authored andcommitted
Inference Gym: Add the Plasma Spectroscopy model.
I also had to fix a whole bunch of things in the interpolation.py, as that was broken in Numpy and JAX. In Numpy, it was doing broadcasting via in-place mutation, which isn't allowed. In JAX, it was violating omnistaging non-static-shape rules. PiperOrigin-RevId: 374929788
1 parent e26f9bb commit 9d4fc05

File tree

10 files changed

+2373
-57
lines changed

10 files changed

+2373
-57
lines changed

spinoffs/inference_gym/inference_gym/internal/data.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from inference_gym.internal.datasets import sp500_closing_prices as sp500_closing_prices_lib
2424
from inference_gym.internal.datasets import synthetic_item_response_theory as synthetic_item_response_theory_lib
2525
from inference_gym.internal.datasets import synthetic_log_gaussian_cox_process as synthetic_log_gaussian_cox_process_lib
26+
from inference_gym.internal.datasets import synthetic_plasma_spectroscopy as synthetic_plasma_spectroscopy_lib
2627

2728
__all__ = [
2829
'brownian_motion_missing_middle_observations',
@@ -32,6 +33,7 @@
3233
'sp500_closing_prices',
3334
'synthetic_item_response_theory',
3435
'synthetic_log_gaussian_cox_process',
36+
'synthetic_plasma_spectroscopy',
3537
]
3638

3739

@@ -556,3 +558,22 @@ def synthetic_log_gaussian_cox_process(
556558
train_extents=extents,
557559
train_counts=counts,
558560
)
561+
562+
563+
def synthetic_plasma_spectroscopy():
564+
"""Synthetic dataset sampled from the PlasmaSpectroscopy model.
565+
566+
Returns:
567+
dataset: A Dict with the following keys:
568+
measurements: Float `Tensor` with shape [num_wavelengths, num_sensors].
569+
The spectrometer measurements.
570+
wavelengths: Float `Tensor` with shape [num_wavelengths]. Wavelengths
571+
measured by the spectrometers.
572+
center_wavelength: Float `Tensor` scalar. The center wavelength of the
573+
target emission line.
574+
"""
575+
return dict(
576+
measurements=synthetic_plasma_spectroscopy_lib.MEASUREMENTS,
577+
wavelengths=synthetic_plasma_spectroscopy_lib.WAVELENGTHS,
578+
center_wavelength=synthetic_plasma_spectroscopy_lib.CENTER_WAVELENGTH,
579+
)

spinoffs/inference_gym/inference_gym/internal/datasets/BUILD

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ py_library(
3737
":sp500_closing_prices",
3838
":synthetic_item_response_theory",
3939
":synthetic_log_gaussian_cox_process",
40+
":synthetic_plasma_spectroscopy",
4041
],
4142
)
4243

@@ -89,3 +90,13 @@ py_library(
8990
# numpy dep,
9091
],
9192
)
93+
94+
# pytype_strict
95+
py_library(
96+
name = "synthetic_plasma_spectroscopy",
97+
srcs = ["synthetic_plasma_spectroscopy.py"],
98+
srcs_version = "PY3",
99+
deps = [
100+
# numpy dep,
101+
],
102+
)

0 commit comments

Comments
 (0)