Skip to content

Commit 5979fef

Browse files
committed
Add support for custom labels in plot_spectra().
1 parent 5426826 commit 5979fef

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

riid/visualize.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ def plot_spectra(ss: SampleSet, in_energy: bool = False,
289289
figsize: tuple = (12.8, 7.2), xscale: str = "linear", yscale: str = "log",
290290
xlim: tuple = (None, None), ylim: tuple = (None, None),
291291
ylabel: str = None, title: str = None, legend_loc: str = None,
292-
target_level="Isotope") -> tuple:
292+
target_level="Isotope", labels=None) -> tuple:
293293
"""Plot spectra in a `SampleSet`.
294294
295295
Args:
@@ -304,6 +304,7 @@ def plot_spectra(ss: SampleSet, in_energy: bool = False,
304304
title: plot title
305305
legend_loc: location in which to place the legend
306306
target_level: `SampleSet.sources` column level to use in legend
307+
labels: custom list of labels
307308
308309
Returns:
309310
Tuple (Figure, Axes) of matplotlib objects
@@ -314,10 +315,11 @@ def plot_spectra(ss: SampleSet, in_energy: bool = False,
314315
- `ValueError` when `limit` is not None and less than 1
315316
"""
316317
fig, ax = plt.subplots(figsize=figsize)
317-
if ss.sources.empty:
318-
labels = list(range(ss.n_samples))
319-
else:
320-
labels = ss.get_labels(target_level=target_level)
318+
if not labels:
319+
if ss.sources.empty:
320+
labels = list(range(ss.n_samples))
321+
else:
322+
labels = ss.get_labels(target_level=target_level)
321323

322324
for i in range(ss.n_samples):
323325
label = labels[i]

0 commit comments

Comments
 (0)