Skip to content

Commit aa4a0bb

Browse files
committed
add mode logical_and to pruning
1 parent 5e55207 commit aa4a0bb

File tree

1 file changed

+72
-11
lines changed

1 file changed

+72
-11
lines changed

src/pyhf/workspace.py

Lines changed: 72 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -591,16 +591,40 @@ def _prune_and_rename(
591591
),
592592
)
593593
for modifier in sample['modifiers']
594-
if modifier['name'] not in prune_modifiers
595-
and modifier['type'] not in prune_modifier_types
594+
if (
595+
channel['name'] not in prune_channels
596+
and prune_channels != []
597+
) # want to remove only if channel is in prune_channels or if prune_channels is empty, i.e. we want to prune this modifier for every channel
598+
or (
599+
sample['name'] not in prune_samples
600+
and prune_samples != []
601+
) # want to remove only if sample is in prune_samples or if prune_samples is empty, i.e. we want to prune this modifier for every sample
602+
or (
603+
modifier['name'] not in prune_modifiers
604+
and modifier['type'] not in prune_modifier_types
605+
)
606+
or prune_measurements
607+
!= [] # need to keep the modifier in case it is used in another measurement
596608
],
597609
}
598610
for sample in channel['samples']
599-
if sample['name'] not in prune_samples
611+
if (
612+
channel['name'] not in prune_channels
613+
and prune_channels != []
614+
) # want to remove only if channel is in prune_channels or if prune_channels is empty, i.e. we want to prune this sample for every channel
615+
or sample['name'] not in prune_samples
616+
or prune_modifiers
617+
!= [] # we only want to remove this sample if we did not specify modifiers to prune
618+
or prune_modifier_types != []
600619
],
601620
}
602621
for channel in self['channels']
603622
if channel['name'] not in prune_channels
623+
or ( # we only want to remove this channel if we did not specify any samples or modifiers to prune
624+
prune_samples != []
625+
or prune_modifiers != []
626+
or prune_modifier_types != []
627+
)
604628
],
605629
'measurements': [
606630
{
@@ -615,8 +639,14 @@ def _prune_and_rename(
615639
parameter['name'], parameter['name']
616640
),
617641
)
618-
for parameter in measurement['config']['parameters']
619-
if parameter['name'] not in prune_modifiers
642+
for parameter in measurement['config'][
643+
'parameters'
644+
] # we only want to remove this parameter if measurement is in prune_measurements or if prune_measurements is empty
645+
if (
646+
measurement['name'] not in prune_measurements
647+
and prune_measurements != []
648+
)
649+
or parameter['name'] not in prune_modifiers
620650
],
621651
'poi': rename_modifiers.get(
622652
measurement['config']['poi'],
@@ -626,6 +656,8 @@ def _prune_and_rename(
626656
}
627657
for measurement in self['measurements']
628658
if measurement['name'] not in prune_measurements
659+
or prune_modifiers
660+
!= [] # we only want to remove this measurement if we did not specify parameters to remove
629661
],
630662
'observations': [
631663
dict(
@@ -634,6 +666,11 @@ def _prune_and_rename(
634666
)
635667
for observation in self['observations']
636668
if observation['name'] not in prune_channels
669+
or ( # we only want to remove this channel if we did not specify any samples or modifiers to prune
670+
prune_samples != []
671+
or prune_modifiers != []
672+
or prune_modifier_types != []
673+
)
637674
],
638675
'version': self['version'],
639676
}
@@ -646,6 +683,7 @@ def prune(
646683
samples=None,
647684
channels=None,
648685
measurements=None,
686+
mode="logical_or",
649687
):
650688
"""
651689
Return a new, pruned workspace specification. This will not modify the original workspace.
@@ -658,6 +696,7 @@ def prune(
658696
samples: A :obj:`list` of samples to prune.
659697
channels: A :obj:`list` of channels to prune.
660698
measurements: A :obj:`list` of measurements to prune.
699+
mode (:obj: string): `logical_or` or `logical_and` to chain pruning with a logical OR or a logical AND, respectively. Default: `logical_or`.
661700
662701
Returns:
663702
~pyhf.workspace.Workspace: A new workspace object with the specified components removed
@@ -666,19 +705,41 @@ def prune(
666705
~pyhf.exceptions.InvalidWorkspaceOperation: An item name to prune does not exist in the workspace.
667706
668707
"""
708+
709+
if mode not in ["logical_and", "logical_or"]:
710+
raise ValueError(
711+
"Pruning mode must be either `logical_and` or `logical_or`."
712+
)
713+
669714
# avoid mutable defaults
670715
modifiers = [] if modifiers is None else modifiers
671716
modifier_types = [] if modifier_types is None else modifier_types
672717
samples = [] if samples is None else samples
673718
channels = [] if channels is None else channels
674719
measurements = [] if measurements is None else measurements
675720

676-
return self._prune_and_rename(
677-
prune_modifiers=modifiers,
678-
prune_modifier_types=modifier_types,
679-
prune_samples=samples,
680-
prune_channels=channels,
681-
prune_measurements=measurements,
721+
if mode == "logical_and":
722+
if samples != [] and measurements != []:
723+
raise ValueError(
724+
"Pruning of measurements and samples cannot be run with mode `logical_and`."
725+
)
726+
if modifier_types != [] and measurements != []:
727+
raise ValueError(
728+
"Pruning of measurements and modifier_types cannot be run with mode `logical_and`."
729+
)
730+
return self._prune_and_rename(
731+
prune_modifiers=modifiers,
732+
prune_modifier_types=modifier_types,
733+
prune_samples=samples,
734+
prune_channels=channels,
735+
prune_measurements=measurements,
736+
)
737+
return (
738+
self._prune_and_rename(prune_modifiers=modifiers)
739+
._prune_and_rename(prune_modifier_types=modifier_types)
740+
._prune_and_rename(prune_samples=samples)
741+
._prune_and_rename(prune_channels=channels)
742+
._prune_and_rename(prune_measurements=measurements)
682743
)
683744

684745
def rename(self, modifiers=None, samples=None, channels=None, measurements=None):

0 commit comments

Comments
 (0)