Skip to content

Commit 6977c78

Browse files
committed
Add correctness test for pad
1 parent e88e19f commit 6977c78

File tree

1 file changed

+30
-1
lines changed

1 file changed

+30
-1
lines changed

test/test_transforms_v2.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4661,7 +4661,36 @@ def test_bounding_boxes_correctness(self, padding, format, dtype, device, fn):
46614661

46624662
assert_equal(actual, expected)
46634663

4664-
# TODOKP need keypoint correctness tests
4664+
def _reference_pad_keypoints(self, keypoints, *, padding):
4665+
if isinstance(padding, int):
4666+
padding = [padding]
4667+
left, top, right, bottom = padding * (4 // len(padding))
4668+
4669+
affine_matrix = np.array(
4670+
[
4671+
[1, 0, left],
4672+
[0, 1, top],
4673+
],
4674+
)
4675+
4676+
height = keypoints.canvas_size[0] + top + bottom
4677+
width = keypoints.canvas_size[1] + left + right
4678+
4679+
return reference_affine_keypoints_helper(
4680+
keypoints, affine_matrix=affine_matrix, new_canvas_size=(height, width)
4681+
)
4682+
4683+
@pytest.mark.parametrize("padding", CORRECTNESS_PADDINGS)
4684+
@pytest.mark.parametrize("dtype", [torch.int64, torch.float32])
4685+
@pytest.mark.parametrize("device", cpu_and_cuda())
4686+
@pytest.mark.parametrize("fn", [F.pad, transform_cls_to_functional(transforms.Pad)])
4687+
def test_keypoints_correctness(self, padding, dtype, device, fn):
4688+
keypoints = make_keypoints(dtype=dtype, device=device)
4689+
4690+
actual = fn(keypoints, padding=padding)
4691+
expected = self._reference_pad_keypoints(keypoints, padding=padding)
4692+
4693+
assert_equal(actual, expected)
46654694

46664695

46674696
class TestCenterCrop:

0 commit comments

Comments
 (0)