Skip to content

Commit b1f2064

Browse files
committed
Expose _get_params() method as make_params()
1 parent 6279faa commit b1f2064

File tree

9 files changed

+63
-55
lines changed

9 files changed

+63
-55
lines changed

references/segmentation/v2_extras.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def __init__(self, size, fill=0):
1010
self.size = size
1111
self.fill = v2._utils._setup_fill_arg(fill)
1212

13-
def _get_params(self, sample):
13+
def make_params(self, sample):
1414
_, height, width = v2._utils.query_chw(sample)
1515
padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)]
1616
needs_padding = any(padding)

test/test_prototype_transforms.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def test__copy_paste(self, label_type):
159159

160160

161161
class TestFixedSizeCrop:
162-
def test__get_params(self, mocker):
162+
def test_make_params(self, mocker):
163163
crop_size = (7, 7)
164164
batch_shape = (10,)
165165
canvas_size = (11, 5)
@@ -170,7 +170,7 @@ def test__get_params(self, mocker):
170170
make_image(size=canvas_size, color_space="RGB"),
171171
make_bounding_boxes(format=BoundingBoxFormat.XYXY, canvas_size=canvas_size, num_boxes=batch_shape[0]),
172172
]
173-
params = transform._get_params(flat_inputs)
173+
params = transform.make_params(flat_inputs)
174174

175175
assert params["needs_crop"]
176176
assert params["height"] <= crop_size[0]
@@ -191,7 +191,7 @@ def test__transform_culling(self, mocker):
191191

192192
is_valid = torch.randint(0, 2, (batch_size,), dtype=torch.bool)
193193
mocker.patch(
194-
"torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params",
194+
"torchvision.prototype.transforms._geometry.FixedSizeCrop.make_params",
195195
return_value=dict(
196196
needs_crop=True,
197197
top=0,
@@ -229,7 +229,7 @@ def test__transform_bounding_boxes_clamping(self, mocker):
229229
canvas_size = (10, 10)
230230

231231
mocker.patch(
232-
"torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params",
232+
"torchvision.prototype.transforms._geometry.FixedSizeCrop.make_params",
233233
return_value=dict(
234234
needs_crop=True,
235235
top=0,

test/test_transforms_v2.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1355,7 +1355,7 @@ def test_transform_bounding_boxes_correctness(self, format, center, seed):
13551355
transform = transforms.RandomAffine(**self._CORRECTNESS_TRANSFORM_AFFINE_RANGES, center=center)
13561356

13571357
torch.manual_seed(seed)
1358-
params = transform._get_params([bounding_boxes])
1358+
params = transform.make_params([bounding_boxes])
13591359

13601360
torch.manual_seed(seed)
13611361
actual = transform(bounding_boxes)
@@ -1369,14 +1369,14 @@ def test_transform_bounding_boxes_correctness(self, format, center, seed):
13691369
@pytest.mark.parametrize("scale", _EXHAUSTIVE_TYPE_TRANSFORM_AFFINE_RANGES["scale"])
13701370
@pytest.mark.parametrize("shear", _EXHAUSTIVE_TYPE_TRANSFORM_AFFINE_RANGES["shear"])
13711371
@pytest.mark.parametrize("seed", list(range(10)))
1372-
def test_transform_get_params_bounds(self, degrees, translate, scale, shear, seed):
1372+
def test_transformmake_params_bounds(self, degrees, translate, scale, shear, seed):
13731373
image = make_image()
13741374
height, width = F.get_size(image)
13751375

13761376
transform = transforms.RandomAffine(degrees=degrees, translate=translate, scale=scale, shear=shear)
13771377

13781378
torch.manual_seed(seed)
1379-
params = transform._get_params([image])
1379+
params = transform.make_params([image])
13801380

13811381
if isinstance(degrees, (int, float)):
13821382
assert -degrees <= params["angle"] <= degrees
@@ -1783,7 +1783,7 @@ def test_transform_bounding_boxes_correctness(self, format, expand, center, seed
17831783
transform = transforms.RandomRotation(**self._CORRECTNESS_TRANSFORM_AFFINE_RANGES, expand=expand, center=center)
17841784

17851785
torch.manual_seed(seed)
1786-
params = transform._get_params([bounding_boxes])
1786+
params = transform.make_params([bounding_boxes])
17871787

17881788
torch.manual_seed(seed)
17891789
actual = transform(bounding_boxes)
@@ -1795,11 +1795,11 @@ def test_transform_bounding_boxes_correctness(self, format, expand, center, seed
17951795

17961796
@pytest.mark.parametrize("degrees", _EXHAUSTIVE_TYPE_TRANSFORM_AFFINE_RANGES["degrees"])
17971797
@pytest.mark.parametrize("seed", list(range(10)))
1798-
def test_transform_get_params_bounds(self, degrees, seed):
1798+
def test_transformmake_params_bounds(self, degrees, seed):
17991799
transform = transforms.RandomRotation(degrees=degrees)
18001800

18011801
torch.manual_seed(seed)
1802-
params = transform._get_params([])
1802+
params = transform.make_params([])
18031803

18041804
if isinstance(degrees, (int, float)):
18051805
assert -degrees <= params["angle"] <= degrees
@@ -2996,7 +2996,7 @@ def test_transform_bounding_boxes_correctness(self, output_size, format, dtype,
29962996

29972997
with freeze_rng_state():
29982998
torch.manual_seed(seed)
2999-
params = transform._get_params([bounding_boxes])
2999+
params = transform.make_params([bounding_boxes])
30003000
assert not params.pop("needs_pad")
30013001
del params["padding"]
30023002
assert params.pop("needs_crop")
@@ -3129,9 +3129,9 @@ def test_transform_image_correctness(self, param, value, dtype, device, seed):
31293129

31303130
with freeze_rng_state():
31313131
torch.manual_seed(seed)
3132-
# This emulates the random apply check that happens before _get_params is called
3132+
# This emulates the random apply check that happens before make_params is called
31333133
torch.rand(1)
3134-
params = transform._get_params([image])
3134+
params = transform.make_params([image])
31353135

31363136
torch.manual_seed(seed)
31373137
actual = transform(image)
@@ -3159,7 +3159,7 @@ def test_transform_errors(self):
31593159
transform = transforms.RandomErasing(value=[1, 2, 3, 4])
31603160

31613161
with pytest.raises(ValueError, match="If value is a sequence, it should have either a single value"):
3162-
transform._get_params([make_image()])
3162+
transform.make_params([make_image()])
31633163

31643164

31653165
class TestGaussianBlur:
@@ -3244,9 +3244,9 @@ def test_assertions(self):
32443244
transforms.GaussianBlur(3, sigma={})
32453245

32463246
@pytest.mark.parametrize("sigma", [10.0, [10.0, 12.0], (10, 12.0), [10]])
3247-
def test__get_params(self, sigma):
3247+
def test_make_params(self, sigma):
32483248
transform = transforms.GaussianBlur(3, sigma=sigma)
3249-
params = transform._get_params([])
3249+
params = transform.make_params([])
32503250

32513251
if isinstance(sigma, float):
32523252
assert params["sigma"][0] == params["sigma"][1] == sigma
@@ -5251,7 +5251,7 @@ def test_transform_params_correctness(self, side_range, make_input, device):
52515251
input = make_input()
52525252
height, width = F.get_size(input)
52535253

5254-
params = transform._get_params([input])
5254+
params = transform.make_params([input])
52555255
assert "padding" in params
52565256

52575257
padding = params["padding"]
@@ -5305,13 +5305,13 @@ def test_transform(self, make_input, device):
53055305

53065306
check_transform(transforms.ScaleJitter(self.TARGET_SIZE), make_input(self.INPUT_SIZE, device=device))
53075307

5308-
def test__get_params(self):
5308+
def test_make_params(self):
53095309
input_size = self.INPUT_SIZE
53105310
target_size = self.TARGET_SIZE
53115311
scale_range = (0.5, 1.5)
53125312

53135313
transform = transforms.ScaleJitter(target_size=target_size, scale_range=scale_range)
5314-
params = transform._get_params([make_image(input_size)])
5314+
params = transform.make_params([make_image(input_size)])
53155315

53165316
assert "size" in params
53175317
size = params["size"]
@@ -5580,7 +5580,7 @@ def was_applied(output, inpt):
55805580
class TestRandomIoUCrop:
55815581
@pytest.mark.parametrize("device", cpu_and_cuda())
55825582
@pytest.mark.parametrize("options", [[0.5, 0.9], [2.0]])
5583-
def test__get_params(self, device, options):
5583+
def test_make_params(self, device, options):
55845584
orig_h, orig_w = size = (24, 32)
55855585
image = make_image(size)
55865586
bboxes = tv_tensors.BoundingBoxes(
@@ -5596,7 +5596,7 @@ def test__get_params(self, device, options):
55965596
n_samples = 5
55975597
for _ in range(n_samples):
55985598

5599-
params = transform._get_params(sample)
5599+
params = transform.make_params(sample)
56005600

56015601
if options == [2.0]:
56025602
assert len(params) == 0
@@ -5622,8 +5622,8 @@ def test__transform_empty_params(self, mocker):
56225622
bboxes = tv_tensors.BoundingBoxes(torch.tensor([[1, 1, 2, 2]]), format="XYXY", canvas_size=(4, 4))
56235623
label = torch.tensor([1])
56245624
sample = [image, bboxes, label]
5625-
# Let's mock transform._get_params to control the output:
5626-
transform._get_params = mocker.MagicMock(return_value={})
5625+
# Let's mock transform.make_params to control the output:
5626+
transform.make_params = mocker.MagicMock(return_value={})
56275627
output = transform(sample)
56285628
torch.testing.assert_close(output, sample)
56295629

@@ -5648,7 +5648,7 @@ def test__transform(self, mocker):
56485648
is_within_crop_area = torch.tensor([0, 1, 0, 1, 0, 1], dtype=torch.bool)
56495649

56505650
params = dict(top=1, left=2, height=12, width=12, is_within_crop_area=is_within_crop_area)
5651-
transform._get_params = mocker.MagicMock(return_value=params)
5651+
transform.make_params = mocker.MagicMock(return_value=params)
56525652
output = transform(sample)
56535653

56545654
# check number of bboxes vs number of labels:
@@ -5662,13 +5662,13 @@ def test__transform(self, mocker):
56625662

56635663
class TestRandomShortestSize:
56645664
@pytest.mark.parametrize("min_size,max_size", [([5, 9], 20), ([5, 9], None)])
5665-
def test__get_params(self, min_size, max_size):
5665+
def test_make_params(self, min_size, max_size):
56665666
canvas_size = (3, 10)
56675667

56685668
transform = transforms.RandomShortestSize(min_size=min_size, max_size=max_size, antialias=True)
56695669

56705670
sample = make_image(canvas_size)
5671-
params = transform._get_params([sample])
5671+
params = transform.make_params([sample])
56725672

56735673
assert "size" in params
56745674
size = params["size"]
@@ -5685,14 +5685,14 @@ def test__get_params(self, min_size, max_size):
56855685

56865686

56875687
class TestRandomResize:
5688-
def test__get_params(self):
5688+
def test_make_params(self):
56895689
min_size = 3
56905690
max_size = 6
56915691

56925692
transform = transforms.RandomResize(min_size=min_size, max_size=max_size, antialias=True)
56935693

56945694
for _ in range(10):
5695-
params = transform._get_params([])
5695+
params = transform.make_params([])
56965696

56975697
assert isinstance(params["size"], list) and len(params["size"]) == 1
56985698
size = params["size"][0]
@@ -6148,12 +6148,12 @@ def test_transform_image_correctness(self, quality, color_space, seed):
61486148

61496149
@pytest.mark.parametrize("quality", [5, (10, 20)])
61506150
@pytest.mark.parametrize("seed", list(range(10)))
6151-
def test_transform_get_params_bounds(self, quality, seed):
6151+
def test_transformmake_params_bounds(self, quality, seed):
61526152
transform = transforms.JPEG(quality=quality)
61536153

61546154
with freeze_rng_state():
61556155
torch.manual_seed(seed)
6156-
params = transform._get_params([])
6156+
params = transform.make_params([])
61576157

61586158
if isinstance(quality, int):
61596159
assert params["quality"] == quality

torchvision/prototype/transforms/_geometry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def _check_inputs(self, flat_inputs: List[Any]) -> None:
5353
f"{type(self).__name__}() also requires it to contain a Label or OneHotLabel."
5454
)
5555

56-
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
56+
def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
5757
height, width = query_size(flat_inputs)
5858
new_height = min(height, self.crop_height)
5959
new_width = min(width, self.crop_width)

torchvision/transforms/v2/_augment.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: An
9696
)
9797
return super()._call_kernel(functional, inpt, *args, **kwargs)
9898

99-
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
99+
def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
100100
img_c, img_h, img_w = query_chw(flat_inputs)
101101

102102
if self.value is not None and not (len(self.value) in (1, img_c)):
@@ -181,7 +181,7 @@ def forward(self, *inputs):
181181
params = {
182182
"labels": labels,
183183
"batch_size": labels.shape[0],
184-
**self._get_params(
184+
**self.make_params(
185185
[inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform]
186186
),
187187
}
@@ -243,7 +243,7 @@ class MixUp(_BaseMixUpCutMix):
243243
It can also be a callable that takes the same input as the transform, and returns the labels.
244244
"""
245245

246-
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
246+
def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
247247
return dict(lam=float(self._dist.sample(()))) # type: ignore[arg-type]
248248

249249
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
@@ -292,7 +292,7 @@ class CutMix(_BaseMixUpCutMix):
292292
It can also be a callable that takes the same input as the transform, and returns the labels.
293293
"""
294294

295-
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
295+
def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
296296
lam = float(self._dist.sample(())) # type: ignore[arg-type]
297297

298298
H, W = query_size(flat_inputs)
@@ -361,7 +361,7 @@ def __init__(self, quality: Union[int, Sequence[int]]):
361361

362362
self.quality = quality
363363

364-
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
364+
def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
365365
quality = torch.randint(self.quality[0], self.quality[1] + 1, ()).item()
366366
return dict(quality=quality)
367367

torchvision/transforms/v2/_color.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class RandomGrayscale(_RandomApplyTransform):
4646
def __init__(self, p: float = 0.1) -> None:
4747
super().__init__(p=p)
4848

49-
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
49+
def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
5050
num_input_channels, *_ = query_chw(flat_inputs)
5151
return dict(num_input_channels=num_input_channels)
5252

@@ -142,7 +142,7 @@ def _check_input(
142142
def _generate_value(left: float, right: float) -> float:
143143
return torch.empty(1).uniform_(left, right).item()
144144

145-
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
145+
def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
146146
fn_idx = torch.randperm(4)
147147

148148
b = None if self.brightness is None else self._generate_value(self.brightness[0], self.brightness[1])
@@ -173,7 +173,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
173173
class RandomChannelPermutation(Transform):
174174
"""Randomly permute the channels of an image or video"""
175175

176-
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
176+
def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
177177
num_channels, *_ = query_chw(flat_inputs)
178178
return dict(permutation=torch.randperm(num_channels))
179179

@@ -220,7 +220,7 @@ def __init__(
220220
self.saturation = saturation
221221
self.p = p
222222

223-
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
223+
def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
224224
num_channels, *_ = query_chw(flat_inputs)
225225
params: Dict[str, Any] = {
226226
key: ColorJitter._generate_value(range[0], range[1]) if torch.rand(1) < self.p else None

0 commit comments

Comments
 (0)