@@ -289,7 +289,7 @@ def plot_spectra(ss: SampleSet, in_energy: bool = False,
289289 figsize : tuple = (12.8 , 7.2 ), xscale : str = "linear" , yscale : str = "log" ,
290290 xlim : tuple = (None , None ), ylim : tuple = (None , None ),
291291 ylabel : str = None , title : str = None , legend_loc : str = None ,
292- target_level = "Isotope" ) -> tuple :
292+ target_level = "Isotope" , labels = None ) -> tuple :
293293 """Plot spectra in a `SampleSet`.
294294
295295 Args:
@@ -304,6 +304,7 @@ def plot_spectra(ss: SampleSet, in_energy: bool = False,
304304 title: plot title
305305 legend_loc: location in which to place the legend
306306 target_level: `SampleSet.sources` column level to use in legend
307+ labels: custom list of labels
307308
308309 Returns:
309310 Tuple (Figure, Axes) of matplotlib objects
@@ -314,10 +315,11 @@ def plot_spectra(ss: SampleSet, in_energy: bool = False,
314315 - `ValueError` when `limit` is not None and less than 1
315316 """
316317 fig , ax = plt .subplots (figsize = figsize )
317- if ss .sources .empty :
318- labels = list (range (ss .n_samples ))
319- else :
320- labels = ss .get_labels (target_level = target_level )
318+ if not labels :
319+ if ss .sources .empty :
320+ labels = list (range (ss .n_samples ))
321+ else :
322+ labels = ss .get_labels (target_level = target_level )
321323
322324 for i in range (ss .n_samples ):
323325 label = labels [i ]
0 commit comments