Skip to content

Commit 1990a9a

Browse files
committed
Add SampleSet.as_regions() method.
1 parent 5979fef commit 1990a9a

File tree

2 files changed

+111
-3
lines changed

2 files changed

+111
-3
lines changed

riid/data/sampleset.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ def difficulty_score(self, mean=10.0, std=3.0) -> float:
340340
@property
341341
def ecal(self):
342342
"""Get or set the ecal terms."""
343-
ecal_terms = self.info[list(self.ECAL_INFO_COLUMNS)].values
343+
ecal_terms = self.info[list(self.ECAL_INFO_COLUMNS)].to_numpy(dtype=float)
344344
return ecal_terms
345345

346346
@ecal.setter
@@ -543,6 +543,9 @@ def as_ecal(self, new_offset: float, new_gain: float,
543543
new_cubic: new cubic value, i.e. the 3-th e-cal term
544544
new_low_energy: new low energy term
545545
546+
Returns:
547+
A new `SamleSet` with `spectra` and `info` DataFrames
548+
546549
Raises:
547550
`ValueError` when no argument values are provided
548551
"""
@@ -589,6 +592,48 @@ def as_ecal(self, new_offset: float, new_gain: float,
589592
new_ss.info[ecal_cols] = new_ecal
590593
return new_ss
591594

595+
def as_regions(self, rois: list) -> SampleSet:
596+
"""Obtains a new `SampleSet` where the spectra are limited to specific
597+
regions of interest (ROIs).
598+
599+
Notes:
600+
- If your samples have disparate energy calibration terms, call `as_ecal()` first
601+
to align channel space, then you may call this function. Otherwise, it is possible
602+
to end up with a ragged array of spectra, which we do not support.
603+
- After this call, `spectra` will have columns filled in with energy values for
604+
convenience. As such, in the context of the returned `SampleSet`, the energy
605+
calibration terms in `info` will no longer have any meaning, and any subsequent
606+
calls to methods like `as_ecal()` would not make sense. This method is intended
607+
as a last step to be performed right before analysis of whatever kind.
608+
609+
Args:
610+
rois: a list of 2-tuples where tuple represents (low energy, high energy)
611+
612+
Returns:
613+
A new `SamleSet` with only ROIs remaining in the `spectra` DataFrame
614+
615+
Raises:
616+
`ValueError` when no argument values are provided
617+
"""
618+
if not rois:
619+
raise ValueError("At least one ROI must be provided.")
620+
all_ecals = self.ecal
621+
all_ecals_are_same = np.isclose(all_ecals, all_ecals[0]).all()
622+
if not all_ecals_are_same:
623+
msg = "Spectra have different energy calibrations; consider `as_ecal()` first."
624+
raise ValueError(msg)
625+
626+
energies = self.get_channel_energies(0)
627+
mask = _get_energy_roi_masks(rois, energies)
628+
new_spectra = self.spectra.to_numpy(dtype=float)[:, mask]
629+
new_spectra = new_spectra.reshape((self.n_samples, -1))
630+
mask_energies = energies[mask]
631+
632+
new_ss = self[:]
633+
new_ss.spectra = pd.DataFrame(new_spectra, columns=mask_energies)
634+
new_ss.info.total_counts = new_ss.spectra.sum(axis=1)
635+
return new_ss
636+
592637
def check_seed_health(self, dead_time_threshold=1.0):
593638
"""Checks health of all spectra and info assuming they are seeds.
594639
@@ -1905,6 +1950,14 @@ def _get_distance_df_from_values(distance_values: np.ndarray,
19051950
return distance_df
19061951

19071952

1953+
def _get_energy_roi_masks(rois: list, energies: np.ndarray) -> np.ndarray:
1954+
masks = np.zeros(energies.shape, dtype=bool)
1955+
for (elow, ehigh) in rois:
1956+
roi_mask = (elow <= energies) & (energies < ehigh)
1957+
masks |= roi_mask
1958+
return masks
1959+
1960+
19081961
class InvalidSampleSetFileError(Exception):
19091962
"""Missing or invalid keys in a file being read in as a `SampleSet`."""
19101963
pass

tests/sampleset_tests.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -894,7 +894,7 @@ def test_get_confidences(self):
894894
return_gross=True,
895895
rng=rng
896896
)
897-
_, synthetic_gross_ss = synth.generate(fg_seeds_ss, bg_seeds_ss[0])
897+
_, synthetic_gross_ss = synth.generate(fg_seeds_ss, bg_seeds_ss[0], verbose=False)
898898
synthetic_gross_ss.drop_sources(bg_seeds_ss.sources.columns.levels[2])
899899
synthetic_gross_ss.sources = synthetic_gross_ss.sources[fg_seeds_ss.sources.columns]
900900
synthetic_gross_ss.prediction_probas = pd.DataFrame(
@@ -907,7 +907,8 @@ def test_get_confidences(self):
907907
bg_cps=synth.bg_cps
908908
)
909909

910-
_, synthetic_mixed_gross_ss = synth.generate(mixed_fg_seeds_ss, bg_seeds_ss[0])
910+
_, synthetic_mixed_gross_ss = synth.generate(mixed_fg_seeds_ss, bg_seeds_ss[0],
911+
verbose=False)
911912
synthetic_mixed_gross_ss.drop_sources(bg_seeds_ss.sources.columns.levels[2])
912913
synthetic_mixed_gross_ss.sources = synthetic_mixed_gross_ss.sources[
913914
fg_seeds_ss.sources.columns
@@ -975,6 +976,60 @@ def test_get_confidences(self):
975976
bg_cps=None
976977
)
977978

979+
def test_get_energy_roi_masks(self):
980+
from riid.data.sampleset import _get_energy_roi_masks
981+
ROIS1 = [
982+
(0, 2),
983+
(4, 8),
984+
]
985+
ROIS2 = [
986+
(0, 0.75),
987+
(2.0, 2.5),
988+
(3.0, 5.0),
989+
]
990+
ENERGIES = np.array([
991+
[0, 1, 2, 3, 4, 5, 6],
992+
[0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0]
993+
])
994+
EXPECTED_MASKS1 = np.array([
995+
[True, True, False, False, True, True, True],
996+
[True, True, True, True, False, False, False],
997+
])
998+
EXPECTED_MASKS2 = np.array([
999+
[True, False, True, True, True, False, False],
1000+
[True, True, False, False, True, False, True],
1001+
])
1002+
masks1 = _get_energy_roi_masks(ROIS1, ENERGIES)
1003+
masks2 = _get_energy_roi_masks(ROIS2, ENERGIES)
1004+
1005+
self.assertTrue(np.array_equal(masks1, EXPECTED_MASKS1))
1006+
self.assertTrue(np.array_equal(masks2, EXPECTED_MASKS2))
1007+
1008+
def test_as_regions(self):
1009+
from riid.data.sampleset import _get_energy_roi_masks
1010+
ROIS = [
1011+
(0, 100),
1012+
(500, 550),
1013+
(2400, 2600),
1014+
]
1015+
ss1 = get_dummy_seeds(n_channels=1000)
1016+
ss2 = get_dummy_seeds(n_channels=1000).as_ecal(0, 2500, 0, 0, 0)
1017+
ss3 = get_dummy_seeds(n_channels=500).as_ecal(20, 2000, 0, 0, 0)
1018+
1019+
with self.assertRaises(ValueError):
1020+
ss1.as_regions([])
1021+
1022+
ss_mixed_ecal = SampleSet()
1023+
ss_mixed_ecal.concat([ss1, ss2])
1024+
with self.assertRaises(ValueError):
1025+
ss_mixed_ecal.as_regions(ROIS)
1026+
1027+
for ss in [ss1, ss2, ss3]:
1028+
channel_energies = ss.get_channel_energies(0)
1029+
ss_channels_expected = _get_energy_roi_masks(ROIS, channel_energies).sum()
1030+
rois = ss.as_regions(ROIS)
1031+
self.assertEqual(rois.n_channels, ss_channels_expected)
1032+
9781033
def _assert_row_labels(self, level, actual, expected):
9791034
for i, (a, e) in enumerate(zip(actual, expected)):
9801035
self.assertEqual(

0 commit comments

Comments
 (0)