@@ -55,9 +55,9 @@ class FeaturePermutation(FeatureAblation):
55
55
of examples to compute attributions and cannot be performed on a single example.
56
56
57
57
By default, each scalar value within
58
- each input tensor is taken as a feature and shuffled independently, *unless*
59
- attribute() is called with enable_cross_tensor_attribution=True. Passing
60
- a feature mask, allows grouping features to be shuffled together .
58
+ each input tensor is taken as a feature and shuffled independently. Passing
59
+ a feature mask allows grouping features to be shuffled together (including
60
+ features defined across different input tensors) .
61
61
Each input scalar in the group will be given the same attribution value
62
62
equal to the change in target as a result of shuffling the entire feature
63
63
group.
@@ -92,12 +92,6 @@ def __init__(
92
92
"""
93
93
FeatureAblation .__init__ (self , forward_func = forward_func )
94
94
self .perm_func = perm_func
95
- # Minimum number of elements needed in each input tensor, when
96
- # `enable_cross_tensor_attribution` is False, otherwise the
97
- # attribution for the tensor will be skipped. Set to 1 to throw if any
98
- # input tensors only have one example
99
- self ._min_examples_per_batch = 2
100
- # Similar to above, when `enable_cross_tensor_attribution` is True.
101
95
# Considering the case when we permute multiple input tensors at once
102
96
# through `feature_mask`, we disregard the feature group if the 0th
103
97
# dim of *any* input tensor in the group is less than
@@ -115,7 +109,6 @@ def attribute( # type: ignore
115
109
feature_mask : Union [None , TensorOrTupleOfTensorsGeneric ] = None ,
116
110
perturbations_per_eval : int = 1 ,
117
111
show_progress : bool = False ,
118
- enable_cross_tensor_attribution : bool = True ,
119
112
** kwargs : Any ,
120
113
) -> TensorOrTupleOfTensorsGeneric :
121
114
r"""
@@ -187,18 +180,12 @@ def attribute( # type: ignore
187
180
input tensor. Each tensor should contain integers in
188
181
the range 0 to num_features - 1, and indices
189
182
corresponding to the same feature should have the
190
- same value. Note that features within each input
191
- tensor are ablated independently (not across
192
- tensors), unless enable_cross_tensor_attribution is
193
- True.
194
-
183
+ same value.
195
184
The first dimension of each mask must be 1, as we require
196
185
to have the same group of features for each input sample.
197
186
198
187
If None, then a feature mask is constructed which assigns
199
- each scalar within a tensor as a separate feature, which
200
- is permuted independently, unless
201
- enable_cross_tensor_attribution is True.
188
+ each scalar within a tensor as a separate feature.
202
189
Default: None
203
190
perturbations_per_eval (int, optional): Allows permutations
204
191
of multiple features to be processed simultaneously
@@ -217,10 +204,6 @@ def attribute( # type: ignore
217
204
(e.g. time estimation). Otherwise, it will fallback to
218
205
a simple output of progress.
219
206
Default: False
220
- enable_cross_tensor_attribution (bool, optional): If True, then
221
- features can be grouped across input tensors depending on
222
- the values in the feature mask.
223
- Default: False
224
207
**kwargs (Any, optional): Any additional arguments used by child
225
208
classes of :class:`.FeatureAblation` (such as
226
209
:class:`.Occlusion`) to construct ablations. These
@@ -292,7 +275,6 @@ def attribute( # type: ignore
292
275
feature_mask = feature_mask ,
293
276
perturbations_per_eval = perturbations_per_eval ,
294
277
show_progress = show_progress ,
295
- enable_cross_tensor_attribution = enable_cross_tensor_attribution ,
296
278
** kwargs ,
297
279
)
298
280
@@ -304,7 +286,6 @@ def attribute_future(
304
286
feature_mask : Union [None , TensorOrTupleOfTensorsGeneric ] = None ,
305
287
perturbations_per_eval : int = 1 ,
306
288
show_progress : bool = False ,
307
- enable_cross_tensor_attribution : bool = True ,
308
289
** kwargs : Any ,
309
290
) -> Future [TensorOrTupleOfTensorsGeneric ]:
310
291
"""
@@ -321,54 +302,9 @@ def attribute_future(
321
302
feature_mask = feature_mask ,
322
303
perturbations_per_eval = perturbations_per_eval ,
323
304
show_progress = show_progress ,
324
- enable_cross_tensor_attribution = enable_cross_tensor_attribution ,
325
305
** kwargs ,
326
306
)
327
307
328
- def _construct_ablated_input (
329
- self ,
330
- expanded_input : Tensor ,
331
- input_mask : Union [None , Tensor , Tuple [Tensor , ...]],
332
- baseline : Union [None , float , Tensor ],
333
- start_feature : int ,
334
- end_feature : int ,
335
- ** kwargs : Any ,
336
- ) -> Tuple [Tensor , Tensor ]:
337
- r"""
338
- This function permutes the features of `expanded_input` with a given
339
- feature mask and feature range. Permutation occurs via calling
340
- `self.perm_func` across each batch within `expanded_input`. As with
341
- `FeatureAblation._construct_ablated_input`:
342
- - `expanded_input.shape = (num_features, num_examples, ...)`
343
- - `num_features = end_feature - start_feature` (i.e. start and end is a
344
- half-closed interval)
345
- - `input_mask` is a tensor of the same shape as one input, which
346
- describes the locations of each feature via their "index"
347
-
348
- Since `baselines` is set to None for `FeatureAblation.attribute, this
349
- will be the zero tensor, however, it is not used.
350
- """
351
- assert (
352
- input_mask is not None
353
- and not isinstance (input_mask , tuple )
354
- and input_mask .shape [0 ] == 1
355
- ), (
356
- "input_mask.shape[0] != 1: pass in one mask in order to permute"
357
- "the same features for each input"
358
- )
359
- current_mask = torch .stack (
360
- [input_mask == j for j in range (start_feature , end_feature )], dim = 0
361
- ).bool ()
362
- current_mask = current_mask .to (expanded_input .device )
363
-
364
- output = torch .stack (
365
- [
366
- self .perm_func (x , mask .squeeze (0 ))
367
- for x , mask in zip (expanded_input , current_mask )
368
- ]
369
- )
370
- return output , current_mask
371
-
372
308
def _construct_ablated_input_across_tensors (
373
309
self ,
374
310
inputs : Tuple [Tensor , ...],
0 commit comments