Skip to content

Commit a2f8f8e

Browse files
authored
port tests for F.erase and transforms.RandomErasing (#7902)
1 parent a06df0d commit a2f8f8e

5 files changed

+141
-103
lines changed

test/test_transforms_v2.py

Lines changed: 0 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -540,53 +540,6 @@ def test__get_params(self):
540540
assert (-alpha / h <= displacement[0, ..., 1]).all() and (displacement[0, ..., 1] <= alpha / h).all()
541541

542542

543-
class TestRandomErasing:
544-
def test_assertions(self):
545-
with pytest.raises(TypeError, match="Argument value should be either a number or str or a sequence"):
546-
transforms.RandomErasing(value={})
547-
548-
with pytest.raises(ValueError, match="If value is str, it should be 'random'"):
549-
transforms.RandomErasing(value="abc")
550-
551-
with pytest.raises(TypeError, match="Scale should be a sequence"):
552-
transforms.RandomErasing(scale=123)
553-
554-
with pytest.raises(TypeError, match="Ratio should be a sequence"):
555-
transforms.RandomErasing(ratio=123)
556-
557-
with pytest.raises(ValueError, match="Scale should be between 0 and 1"):
558-
transforms.RandomErasing(scale=[-1, 2])
559-
560-
image = make_image((24, 32))
561-
562-
transform = transforms.RandomErasing(value=[1, 2, 3, 4])
563-
564-
with pytest.raises(ValueError, match="If value is a sequence, it should have either a single value"):
565-
transform._get_params([image])
566-
567-
@pytest.mark.parametrize("value", [5.0, [1, 2, 3], "random"])
568-
def test__get_params(self, value):
569-
image = make_image((24, 32))
570-
num_channels, height, width = F.get_dimensions(image)
571-
572-
transform = transforms.RandomErasing(value=value)
573-
params = transform._get_params([image])
574-
575-
v = params["v"]
576-
h, w = params["h"], params["w"]
577-
i, j = params["i"], params["j"]
578-
assert isinstance(v, torch.Tensor)
579-
if value == "random":
580-
assert v.shape == (num_channels, h, w)
581-
elif isinstance(value, (int, float)):
582-
assert v.shape == (1, 1, 1)
583-
elif isinstance(value, (list, tuple)):
584-
assert v.shape == (num_channels, 1, 1)
585-
586-
assert 0 <= i <= height - h
587-
assert 0 <= j <= width - w
588-
589-
590543
class TestTransform:
591544
@pytest.mark.parametrize(
592545
"inpt_type",

test/test_transforms_v2_consistency.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -276,20 +276,6 @@ def __init__(
276276
],
277277
closeness_kwargs=dict(rtol=0, atol=21),
278278
),
279-
ConsistencyConfig(
280-
v2_transforms.RandomErasing,
281-
legacy_transforms.RandomErasing,
282-
[
283-
ArgsKwargs(p=0),
284-
ArgsKwargs(p=1),
285-
ArgsKwargs(p=1, scale=(0.3, 0.7)),
286-
ArgsKwargs(p=1, ratio=(0.5, 1.5)),
287-
ArgsKwargs(p=1, value=1),
288-
ArgsKwargs(p=1, value=(1, 2, 3)),
289-
ArgsKwargs(p=1, value="random"),
290-
],
291-
supports_pil=False,
292-
),
293279
ConsistencyConfig(
294280
v2_transforms.ColorJitter,
295281
legacy_transforms.ColorJitter,
@@ -550,7 +536,6 @@ def test_call_consistency(config, args_kwargs):
550536
)
551537
for transform_cls, get_params_args_kwargs in [
552538
(v2_transforms.RandomResizedCrop, ArgsKwargs(make_image(), scale=[0.3, 0.7], ratio=[0.5, 1.5])),
553-
(v2_transforms.RandomErasing, ArgsKwargs(make_image(), scale=(0.3, 0.7), ratio=(0.5, 1.5))),
554539
(v2_transforms.ColorJitter, ArgsKwargs(brightness=None, contrast=None, saturation=None, hue=None)),
555540
(v2_transforms.GaussianBlur, ArgsKwargs(0.3, 1.4)),
556541
(v2_transforms.RandomPerspective, ArgsKwargs(23, 17, 0.5)),

test/test_transforms_v2_refactored.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2718,3 +2718,144 @@ def test_errors(self):
27182718

27192719
with pytest.raises(ValueError, match="Padding mode should be either"):
27202720
transforms.RandomCrop([10, 12], padding=1, padding_mode="abc")
2721+
2722+
2723+
class TestErase:
2724+
INPUT_SIZE = (17, 11)
2725+
FUNCTIONAL_KWARGS = dict(
2726+
zip("ijhwv", [2, 2, 10, 8, torch.tensor(0.0, dtype=torch.float32, device="cpu").reshape(-1, 1, 1)])
2727+
)
2728+
2729+
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
2730+
@pytest.mark.parametrize("device", cpu_and_cuda())
2731+
def test_kernel_image(self, dtype, device):
2732+
check_kernel(F.erase_image, make_image(self.INPUT_SIZE, dtype=dtype, device=device), **self.FUNCTIONAL_KWARGS)
2733+
2734+
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
2735+
@pytest.mark.parametrize("device", cpu_and_cuda())
2736+
def test_kernel_image_inplace(self, dtype, device):
2737+
input = make_image(self.INPUT_SIZE, dtype=dtype, device=device)
2738+
input_version = input._version
2739+
2740+
output_out_of_place = F.erase_image(input, **self.FUNCTIONAL_KWARGS)
2741+
assert output_out_of_place.data_ptr() != input.data_ptr()
2742+
assert output_out_of_place is not input
2743+
2744+
output_inplace = F.erase_image(input, **self.FUNCTIONAL_KWARGS, inplace=True)
2745+
assert output_inplace.data_ptr() == input.data_ptr()
2746+
assert output_inplace._version > input_version
2747+
assert output_inplace is input
2748+
2749+
assert_equal(output_inplace, output_out_of_place)
2750+
2751+
def test_kernel_video(self):
2752+
check_kernel(F.erase_video, make_video(self.INPUT_SIZE), **self.FUNCTIONAL_KWARGS)
2753+
2754+
@pytest.mark.parametrize(
2755+
"make_input",
2756+
[make_image_tensor, make_image_pil, make_image, make_video],
2757+
)
2758+
def test_functional(self, make_input):
2759+
check_functional(F.erase, make_input(), **self.FUNCTIONAL_KWARGS)
2760+
2761+
@pytest.mark.parametrize(
2762+
("kernel", "input_type"),
2763+
[
2764+
(F.erase_image, torch.Tensor),
2765+
(F._erase_image_pil, PIL.Image.Image),
2766+
(F.erase_image, tv_tensors.Image),
2767+
(F.erase_video, tv_tensors.Video),
2768+
],
2769+
)
2770+
def test_functional_signature(self, kernel, input_type):
2771+
check_functional_kernel_signature_match(F.erase, kernel=kernel, input_type=input_type)
2772+
2773+
@pytest.mark.parametrize(
2774+
"make_input",
2775+
[make_image_tensor, make_image_pil, make_image, make_video],
2776+
)
2777+
@pytest.mark.parametrize("device", cpu_and_cuda())
2778+
def test_transform(self, make_input, device):
2779+
check_transform(transforms.RandomErasing(p=1), make_input(device=device))
2780+
2781+
def _reference_erase_image(self, image, *, i, j, h, w, v):
2782+
mask = torch.zeros_like(image, dtype=torch.bool)
2783+
mask[..., i : i + h, j : j + w] = True
2784+
2785+
# The broadcasting and type casting logic is handled automagically in the kernel through indexing
2786+
value = torch.broadcast_to(v, (*image.shape[:-2], h, w)).to(image)
2787+
2788+
erased_image = torch.empty_like(image)
2789+
erased_image[mask] = value.flatten()
2790+
erased_image[~mask] = image[~mask]
2791+
2792+
return erased_image
2793+
2794+
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
2795+
@pytest.mark.parametrize("device", cpu_and_cuda())
2796+
def test_functional_image_correctness(self, dtype, device):
2797+
image = make_image(dtype=dtype, device=device)
2798+
2799+
actual = F.erase(image, **self.FUNCTIONAL_KWARGS)
2800+
expected = self._reference_erase_image(image, **self.FUNCTIONAL_KWARGS)
2801+
2802+
assert_equal(actual, expected)
2803+
2804+
@param_value_parametrization(
2805+
scale=[(0.1, 0.2), [0.0, 1.0]],
2806+
ratio=[(0.3, 0.7), [0.1, 5.0]],
2807+
value=[0, 0.5, (0, 1, 0), [-0.2, 0.0, 1.3], "random"],
2808+
)
2809+
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
2810+
@pytest.mark.parametrize("device", cpu_and_cuda())
2811+
@pytest.mark.parametrize("seed", list(range(5)))
2812+
def test_transform_image_correctness(self, param, value, dtype, device, seed):
2813+
transform = transforms.RandomErasing(**{param: value}, p=1)
2814+
2815+
image = make_image(dtype=dtype, device=device)
2816+
2817+
with freeze_rng_state():
2818+
torch.manual_seed(seed)
2819+
# This emulates the random apply check that happens before _get_params is called
2820+
torch.rand(1)
2821+
params = transform._get_params([image])
2822+
2823+
torch.manual_seed(seed)
2824+
actual = transform(image)
2825+
2826+
expected = self._reference_erase_image(image, **params)
2827+
2828+
assert_equal(actual, expected)
2829+
2830+
def test_transform_errors(self):
2831+
with pytest.raises(TypeError, match="Argument value should be either a number or str or a sequence"):
2832+
transforms.RandomErasing(value={})
2833+
2834+
with pytest.raises(ValueError, match="If value is str, it should be 'random'"):
2835+
transforms.RandomErasing(value="abc")
2836+
2837+
with pytest.raises(TypeError, match="Scale should be a sequence"):
2838+
transforms.RandomErasing(scale=123)
2839+
2840+
with pytest.raises(TypeError, match="Ratio should be a sequence"):
2841+
transforms.RandomErasing(ratio=123)
2842+
2843+
with pytest.raises(ValueError, match="Scale should be between 0 and 1"):
2844+
transforms.RandomErasing(scale=[-1, 2])
2845+
2846+
transform = transforms.RandomErasing(value=[1, 2, 3, 4])
2847+
2848+
with pytest.raises(ValueError, match="If value is a sequence, it should have either a single value"):
2849+
transform._get_params([make_image()])
2850+
2851+
@pytest.mark.parametrize("make_input", [make_bounding_boxes, make_detection_mask])
2852+
def test_transform_passthrough(self, make_input):
2853+
transform = transforms.RandomErasing(p=1)
2854+
2855+
input = make_input(self.INPUT_SIZE)
2856+
2857+
with pytest.warns(UserWarning, match="currently passing through inputs of type"):
2858+
# RandomErasing requires an image or video to be present
2859+
_, output = transform(make_image(self.INPUT_SIZE), input)
2860+
2861+
assert output is input

test/transforms_v2_dispatcher_infos.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -269,17 +269,6 @@ def fill_sequence_needs_broadcast(args_kwargs):
269269
},
270270
pil_kernel_info=PILKernelInfo(F._adjust_sharpness_image_pil, kernel_name="adjust_sharpness_image_pil"),
271271
),
272-
DispatcherInfo(
273-
F.erase,
274-
kernels={
275-
tv_tensors.Image: F.erase_image,
276-
tv_tensors.Video: F.erase_video,
277-
},
278-
pil_kernel_info=PILKernelInfo(F._erase_image_pil),
279-
test_marks=[
280-
skip_dispatch_tv_tensor,
281-
],
282-
),
283272
DispatcherInfo(
284273
F.adjust_contrast,
285274
kernels={

test/transforms_v2_kernel_infos.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1123,36 +1123,6 @@ def sample_inputs_adjust_sharpness_video():
11231123
)
11241124

11251125

1126-
def sample_inputs_erase_image_tensor():
1127-
for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE]):
1128-
# FIXME: make the parameters more diverse
1129-
h, w = 6, 7
1130-
v = torch.rand(image_loader.num_channels, h, w)
1131-
yield ArgsKwargs(image_loader, i=1, j=2, h=h, w=w, v=v)
1132-
1133-
1134-
def sample_inputs_erase_video():
1135-
for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
1136-
# FIXME: make the parameters more diverse
1137-
h, w = 6, 7
1138-
v = torch.rand(video_loader.num_channels, h, w)
1139-
yield ArgsKwargs(video_loader, i=1, j=2, h=h, w=w, v=v)
1140-
1141-
1142-
KERNEL_INFOS.extend(
1143-
[
1144-
KernelInfo(
1145-
F.erase_image,
1146-
kernel_name="erase_image_tensor",
1147-
sample_inputs_fn=sample_inputs_erase_image_tensor,
1148-
),
1149-
KernelInfo(
1150-
F.erase_video,
1151-
sample_inputs_fn=sample_inputs_erase_video,
1152-
),
1153-
]
1154-
)
1155-
11561126
_ADJUST_CONTRAST_FACTORS = [0.1, 0.5]
11571127

11581128

0 commit comments

Comments
 (0)