@@ -591,16 +591,40 @@ def _prune_and_rename(
591591 ),
592592 )
593593 for modifier in sample ['modifiers' ]
594- if modifier ['name' ] not in prune_modifiers
595- and modifier ['type' ] not in prune_modifier_types
594+ if (
595+ channel ['name' ] not in prune_channels
596+ and prune_channels != []
597+ ) # want to remove only if channel is in prune_channels or if prune_channels is empty, i.e. we want to prune this modifier for every channel
598+ or (
599+ sample ['name' ] not in prune_samples
600+ and prune_samples != []
601+ ) # want to remove only if sample is in prune_samples or if prune_samples is empty, i.e. we want to prune this modifier for every sample
602+ or (
603+ modifier ['name' ] not in prune_modifiers
604+ and modifier ['type' ] not in prune_modifier_types
605+ )
606+ or prune_measurements
607+ != [] # need to keep the modifier in case it is used in another measurement
596608 ],
597609 }
598610 for sample in channel ['samples' ]
599- if sample ['name' ] not in prune_samples
611+ if (
612+ channel ['name' ] not in prune_channels
613+ and prune_channels != []
614+ ) # want to remove only if channel is in prune_channels or if prune_channels is empty, i.e. we want to prune this sample for every channel
615+ or sample ['name' ] not in prune_samples
616+ or prune_modifiers
617+ != [] # we only want to remove this sample if we did not specify modifiers to prune
618+ or prune_modifier_types != []
600619 ],
601620 }
602621 for channel in self ['channels' ]
603622 if channel ['name' ] not in prune_channels
623+ or ( # we only want to remove this channel if we did not specify any samples or modifiers to prune
624+ prune_samples != []
625+ or prune_modifiers != []
626+ or prune_modifier_types != []
627+ )
604628 ],
605629 'measurements' : [
606630 {
@@ -615,8 +639,14 @@ def _prune_and_rename(
615639 parameter ['name' ], parameter ['name' ]
616640 ),
617641 )
618- for parameter in measurement ['config' ]['parameters' ]
619- if parameter ['name' ] not in prune_modifiers
642+ for parameter in measurement ['config' ][
643+ 'parameters'
644+ ] # we only want to remove this parameter if measurement is in prune_measurements or if prune_measurements is empty
645+ if (
646+ measurement ['name' ] not in prune_measurements
647+ and prune_measurements != []
648+ )
649+ or parameter ['name' ] not in prune_modifiers
620650 ],
621651 'poi' : rename_modifiers .get (
622652 measurement ['config' ]['poi' ],
@@ -626,6 +656,8 @@ def _prune_and_rename(
626656 }
627657 for measurement in self ['measurements' ]
628658 if measurement ['name' ] not in prune_measurements
659+ or prune_modifiers
660+ != [] # we only want to remove this measurement if we did not specify parameters to remove
629661 ],
630662 'observations' : [
631663 dict (
@@ -634,6 +666,11 @@ def _prune_and_rename(
634666 )
635667 for observation in self ['observations' ]
636668 if observation ['name' ] not in prune_channels
669+ or ( # we only want to remove this channel if we did not specify any samples or modifiers to prune
670+ prune_samples != []
671+ or prune_modifiers != []
672+ or prune_modifier_types != []
673+ )
637674 ],
638675 'version' : self ['version' ],
639676 }
@@ -646,6 +683,7 @@ def prune(
646683 samples = None ,
647684 channels = None ,
648685 measurements = None ,
686+ mode = "logical_or" ,
649687 ):
650688 """
651689 Return a new, pruned workspace specification. This will not modify the original workspace.
@@ -658,6 +696,7 @@ def prune(
658696 samples: A :obj:`list` of samples to prune.
659697 channels: A :obj:`list` of channels to prune.
660698 measurements: A :obj:`list` of measurements to prune.
699+ mode (:obj: string): `logical_or` or `logical_and` to chain pruning with a logical OR or a logical AND, respectively. Default: `logical_or`.
661700
662701 Returns:
663702 ~pyhf.workspace.Workspace: A new workspace object with the specified components removed
@@ -666,19 +705,41 @@ def prune(
666705 ~pyhf.exceptions.InvalidWorkspaceOperation: An item name to prune does not exist in the workspace.
667706
668707 """
708+
709+ if mode not in ["logical_and" , "logical_or" ]:
710+ raise ValueError (
711+ "Pruning mode must be either `logical_and` or `logical_or`."
712+ )
713+
669714 # avoid mutable defaults
670715 modifiers = [] if modifiers is None else modifiers
671716 modifier_types = [] if modifier_types is None else modifier_types
672717 samples = [] if samples is None else samples
673718 channels = [] if channels is None else channels
674719 measurements = [] if measurements is None else measurements
675720
676- return self ._prune_and_rename (
677- prune_modifiers = modifiers ,
678- prune_modifier_types = modifier_types ,
679- prune_samples = samples ,
680- prune_channels = channels ,
681- prune_measurements = measurements ,
721+ if mode == "logical_and" :
722+ if samples != [] and measurements != []:
723+ raise ValueError (
724+ "Pruning of measurements and samples cannot be run with mode `logical_and`."
725+ )
726+ if modifier_types != [] and measurements != []:
727+ raise ValueError (
728+ "Pruning of measurements and modifier_types cannot be run with mode `logical_and`."
729+ )
730+ return self ._prune_and_rename (
731+ prune_modifiers = modifiers ,
732+ prune_modifier_types = modifier_types ,
733+ prune_samples = samples ,
734+ prune_channels = channels ,
735+ prune_measurements = measurements ,
736+ )
737+ return (
738+ self ._prune_and_rename (prune_modifiers = modifiers )
739+ ._prune_and_rename (prune_modifier_types = modifier_types )
740+ ._prune_and_rename (prune_samples = samples )
741+ ._prune_and_rename (prune_channels = channels )
742+ ._prune_and_rename (prune_measurements = measurements )
682743 )
683744
684745 def rename (self , modifiers = None , samples = None , channels = None , measurements = None ):
0 commit comments