Skip to content

Commit 6260fb1

Browse files
sarahtranfbfacebook-github-bot
authored andcommitted
Clean up enable_cross_tensor_attribution from FeaturePermutation (meta-pytorch#1649)
Summary: Defaulted to true everywhere in D81948483 stack Differential Revision: D83107514
1 parent 17b8d8a commit 6260fb1

File tree

2 files changed

+131
-239
lines changed

2 files changed

+131
-239
lines changed

captum/attr/_core/feature_permutation.py

Lines changed: 5 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,9 @@ class FeaturePermutation(FeatureAblation):
5555
of examples to compute attributions and cannot be performed on a single example.
5656
5757
By default, each scalar value within
58-
each input tensor is taken as a feature and shuffled independently, *unless*
59-
attribute() is called with enable_cross_tensor_attribution=True. Passing
60-
a feature mask, allows grouping features to be shuffled together.
58+
each input tensor is taken as a feature and shuffled independently. Passing
59+
a feature mask allows grouping features to be shuffled together (including
60+
features defined across different input tensors).
6161
Each input scalar in the group will be given the same attribution value
6262
equal to the change in target as a result of shuffling the entire feature
6363
group.
@@ -92,12 +92,6 @@ def __init__(
9292
"""
9393
FeatureAblation.__init__(self, forward_func=forward_func)
9494
self.perm_func = perm_func
95-
# Minimum number of elements needed in each input tensor, when
96-
# `enable_cross_tensor_attribution` is False, otherwise the
97-
# attribution for the tensor will be skipped. Set to 1 to throw if any
98-
# input tensors only have one example
99-
self._min_examples_per_batch = 2
100-
# Similar to above, when `enable_cross_tensor_attribution` is True.
10195
# Considering the case when we permute multiple input tensors at once
10296
# through `feature_mask`, we disregard the feature group if the 0th
10397
# dim of *any* input tensor in the group is less than
@@ -115,7 +109,6 @@ def attribute( # type: ignore
115109
feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None,
116110
perturbations_per_eval: int = 1,
117111
show_progress: bool = False,
118-
enable_cross_tensor_attribution: bool = True,
119112
**kwargs: Any,
120113
) -> TensorOrTupleOfTensorsGeneric:
121114
r"""
@@ -187,18 +180,12 @@ def attribute( # type: ignore
187180
input tensor. Each tensor should contain integers in
188181
the range 0 to num_features - 1, and indices
189182
corresponding to the same feature should have the
190-
same value. Note that features within each input
191-
tensor are ablated independently (not across
192-
tensors), unless enable_cross_tensor_attribution is
193-
True.
194-
183+
same value.
195184
The first dimension of each mask must be 1, as we require
196185
to have the same group of features for each input sample.
197186
198187
If None, then a feature mask is constructed which assigns
199-
each scalar within a tensor as a separate feature, which
200-
is permuted independently, unless
201-
enable_cross_tensor_attribution is True.
188+
each scalar within a tensor as a separate feature.
202189
Default: None
203190
perturbations_per_eval (int, optional): Allows permutations
204191
of multiple features to be processed simultaneously
@@ -217,10 +204,6 @@ def attribute( # type: ignore
217204
(e.g. time estimation). Otherwise, it will fallback to
218205
a simple output of progress.
219206
Default: False
220-
enable_cross_tensor_attribution (bool, optional): If True, then
221-
features can be grouped across input tensors depending on
222-
the values in the feature mask.
223-
Default: False
224207
**kwargs (Any, optional): Any additional arguments used by child
225208
classes of :class:`.FeatureAblation` (such as
226209
:class:`.Occlusion`) to construct ablations. These
@@ -292,7 +275,6 @@ def attribute( # type: ignore
292275
feature_mask=feature_mask,
293276
perturbations_per_eval=perturbations_per_eval,
294277
show_progress=show_progress,
295-
enable_cross_tensor_attribution=enable_cross_tensor_attribution,
296278
**kwargs,
297279
)
298280

@@ -304,7 +286,6 @@ def attribute_future(
304286
feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None,
305287
perturbations_per_eval: int = 1,
306288
show_progress: bool = False,
307-
enable_cross_tensor_attribution: bool = True,
308289
**kwargs: Any,
309290
) -> Future[TensorOrTupleOfTensorsGeneric]:
310291
"""
@@ -321,54 +302,9 @@ def attribute_future(
321302
feature_mask=feature_mask,
322303
perturbations_per_eval=perturbations_per_eval,
323304
show_progress=show_progress,
324-
enable_cross_tensor_attribution=enable_cross_tensor_attribution,
325305
**kwargs,
326306
)
327307

328-
def _construct_ablated_input(
329-
self,
330-
expanded_input: Tensor,
331-
input_mask: Union[None, Tensor, Tuple[Tensor, ...]],
332-
baseline: Union[None, float, Tensor],
333-
start_feature: int,
334-
end_feature: int,
335-
**kwargs: Any,
336-
) -> Tuple[Tensor, Tensor]:
337-
r"""
338-
This function permutes the features of `expanded_input` with a given
339-
feature mask and feature range. Permutation occurs via calling
340-
`self.perm_func` across each batch within `expanded_input`. As with
341-
`FeatureAblation._construct_ablated_input`:
342-
- `expanded_input.shape = (num_features, num_examples, ...)`
343-
- `num_features = end_feature - start_feature` (i.e. start and end is a
344-
half-closed interval)
345-
- `input_mask` is a tensor of the same shape as one input, which
346-
describes the locations of each feature via their "index"
347-
348-
Since `baselines` is set to None for `FeatureAblation.attribute, this
349-
will be the zero tensor, however, it is not used.
350-
"""
351-
assert (
352-
input_mask is not None
353-
and not isinstance(input_mask, tuple)
354-
and input_mask.shape[0] == 1
355-
), (
356-
"input_mask.shape[0] != 1: pass in one mask in order to permute"
357-
"the same features for each input"
358-
)
359-
current_mask = torch.stack(
360-
[input_mask == j for j in range(start_feature, end_feature)], dim=0
361-
).bool()
362-
current_mask = current_mask.to(expanded_input.device)
363-
364-
output = torch.stack(
365-
[
366-
self.perm_func(x, mask.squeeze(0))
367-
for x, mask in zip(expanded_input, current_mask)
368-
]
369-
)
370-
return output, current_mask
371-
372308
def _construct_ablated_input_across_tensors(
373309
self,
374310
inputs: Tuple[Tensor, ...],

0 commit comments

Comments
 (0)