|
23 | 23 | from inference_gym.internal.datasets import sp500_closing_prices as sp500_closing_prices_lib
|
24 | 24 | from inference_gym.internal.datasets import synthetic_item_response_theory as synthetic_item_response_theory_lib
|
25 | 25 | from inference_gym.internal.datasets import synthetic_log_gaussian_cox_process as synthetic_log_gaussian_cox_process_lib
|
| 26 | +from inference_gym.internal.datasets import synthetic_plasma_spectroscopy as synthetic_plasma_spectroscopy_lib |
26 | 27 |
|
27 | 28 | __all__ = [
|
28 | 29 | 'brownian_motion_missing_middle_observations',
|
|
32 | 33 | 'sp500_closing_prices',
|
33 | 34 | 'synthetic_item_response_theory',
|
34 | 35 | 'synthetic_log_gaussian_cox_process',
|
| 36 | + 'synthetic_plasma_spectroscopy', |
35 | 37 | ]
|
36 | 38 |
|
37 | 39 |
|
@@ -556,3 +558,22 @@ def synthetic_log_gaussian_cox_process(
|
556 | 558 | train_extents=extents,
|
557 | 559 | train_counts=counts,
|
558 | 560 | )
|
| 561 | + |
| 562 | + |
| 563 | +def synthetic_plasma_spectroscopy(): |
| 564 | + """Synthetic dataset sampled from the PlasmaSpectroscopy model. |
| 565 | +
|
| 566 | + Returns: |
| 567 | + dataset: A Dict with the following keys: |
| 568 | + measurements: Float `Tensor` with shape [num_wavelengths, num_sensors]. |
| 569 | + The spectrometer measurements. |
| 570 | + wavelengths: Float `Tensor` with shape [num_wavelengths]. Wavelengths |
| 571 | + measured by the spectrometers. |
| 572 | + center_wavelength: Float `Tensor` scalar. The center wavelength of the |
| 573 | + target emission line. |
| 574 | + """ |
| 575 | + return dict( |
| 576 | + measurements=synthetic_plasma_spectroscopy_lib.MEASUREMENTS, |
| 577 | + wavelengths=synthetic_plasma_spectroscopy_lib.WAVELENGTHS, |
| 578 | + center_wavelength=synthetic_plasma_spectroscopy_lib.CENTER_WAVELENGTH, |
| 579 | + ) |
0 commit comments