Skip to content

Commit 0fdd655

Browse files
committed
Expose _transform method as transform
1 parent 36087d6 commit 0fdd655

File tree

14 files changed

+63
-63
lines changed

14 files changed

+63
-63
lines changed

references/segmentation/v2_extras.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def make_params(self, sample):
1616
needs_padding = any(padding)
1717
return dict(padding=padding, needs_padding=needs_padding)
1818

19-
def _transform(self, inpt, params):
19+
def transform(self, inpt, params):
2020
if not params["needs_padding"]:
2121
return inpt
2222

test/test_transforms_v2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1843,7 +1843,7 @@ def test_functional_image_fast_path_correctness(self, size, angle, expand):
18431843

18441844
class TestContainerTransforms:
18451845
class BuiltinTransform(transforms.Transform):
1846-
def _transform(self, inpt, params):
1846+
def transform(self, inpt, params):
18471847
return inpt
18481848

18491849
class PackedInputTransform(nn.Module):
@@ -5544,7 +5544,7 @@ def split_on_pure_tensor(to_split):
55445544
return pure_tensors[0] if pure_tensors else None, pure_tensors[1:], others
55455545

55465546
class CopyCloneTransform(transforms.Transform):
5547-
def _transform(self, inpt, params):
5547+
def transform(self, inpt, params):
55485548
return inpt.clone() if isinstance(inpt, torch.Tensor) else inpt.copy()
55495549

55505550
@staticmethod

torchvision/prototype/transforms/_geometry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
107107
needs_pad=needs_pad,
108108
)
109109

110-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
110+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
111111
if params["needs_crop"]:
112112
inpt = self._call_kernel(
113113
F.crop,

torchvision/prototype/transforms/_misc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def __init__(self, dims: Union[Sequence[int], Dict[Type, Optional[Sequence[int]]
3939
)
4040
self.dims = dims
4141

42-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor:
42+
def transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor:
4343
dims = self.dims[type(inpt)]
4444
if dims is None:
4545
return inpt.as_subclass(torch.Tensor)
@@ -61,7 +61,7 @@ def __init__(self, dims: Union[Tuple[int, int], Dict[Type, Optional[Tuple[int, i
6161
)
6262
self.dims = dims
6363

64-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor:
64+
def transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor:
6565
dims = self.dims[type(inpt)]
6666
if dims is None:
6767
return inpt.as_subclass(torch.Tensor)

torchvision/prototype/transforms/_type_conversion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def __init__(self, num_categories: int = -1):
1515
super().__init__()
1616
self.num_categories = num_categories
1717

18-
def _transform(self, inpt: proto_tv_tensors.Label, params: Dict[str, Any]) -> proto_tv_tensors.OneHotLabel:
18+
def transform(self, inpt: proto_tv_tensors.Label, params: Dict[str, Any]) -> proto_tv_tensors.OneHotLabel:
1919
num_categories = self.num_categories
2020
if num_categories == -1 and inpt.categories is not None:
2121
num_categories = len(inpt.categories)

torchvision/transforms/v2/_augment.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
134134

135135
return dict(i=i, j=j, h=h, w=w, v=v)
136136

137-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
137+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
138138
if params["v"] is not None:
139139
inpt = self._call_kernel(F.erase, inpt, **params, inplace=self.inplace)
140140

@@ -190,7 +190,7 @@ def forward(self, *inputs):
190190
# after an image or video. However, we need to handle them in _transform, so we make sure to set them to True
191191
needs_transform_list[next(idx for idx, inpt in enumerate(flat_inputs) if inpt is labels)] = True
192192
flat_outputs = [
193-
self._transform(inpt, params) if needs_transform else inpt
193+
self.transform(inpt, params) if needs_transform else inpt
194194
for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list)
195195
]
196196

@@ -246,7 +246,7 @@ class MixUp(_BaseMixUpCutMix):
246246
def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
247247
return dict(lam=float(self._dist.sample(()))) # type: ignore[arg-type]
248248

249-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
249+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
250250
lam = params["lam"]
251251

252252
if inpt is params["labels"]:
@@ -314,7 +314,7 @@ def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
314314

315315
return dict(box=box, lam_adjusted=lam_adjusted)
316316

317-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
317+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
318318
if inpt is params["labels"]:
319319
return self._mixup_label(inpt, lam=params["lam_adjusted"])
320320
elif isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)) or is_pure_tensor(inpt):
@@ -365,5 +365,5 @@ def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
365365
quality = torch.randint(self.quality[0], self.quality[1] + 1, ()).item()
366366
return dict(quality=quality)
367367

368-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
368+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
369369
return self._call_kernel(F.jpeg, inpt, quality=params["quality"])

torchvision/transforms/v2/_color.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def __init__(self, num_output_channels: int = 1):
2525
super().__init__()
2626
self.num_output_channels = num_output_channels
2727

28-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
28+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
2929
return self._call_kernel(F.rgb_to_grayscale, inpt, num_output_channels=self.num_output_channels)
3030

3131

@@ -50,7 +50,7 @@ def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
5050
num_input_channels, *_ = query_chw(flat_inputs)
5151
return dict(num_input_channels=num_input_channels)
5252

53-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
53+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
5454
return self._call_kernel(F.rgb_to_grayscale, inpt, num_output_channels=params["num_input_channels"])
5555

5656

@@ -64,7 +64,7 @@ class RGB(Transform):
6464
def __init__(self):
6565
super().__init__()
6666

67-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
67+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
6868
return self._call_kernel(F.grayscale_to_rgb, inpt)
6969

7070

@@ -152,7 +152,7 @@ def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
152152

153153
return dict(fn_idx=fn_idx, brightness_factor=b, contrast_factor=c, saturation_factor=s, hue_factor=h)
154154

155-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
155+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
156156
output = inpt
157157
brightness_factor = params["brightness_factor"]
158158
contrast_factor = params["contrast_factor"]
@@ -177,7 +177,7 @@ def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
177177
num_channels, *_ = query_chw(flat_inputs)
178178
return dict(permutation=torch.randperm(num_channels))
179179

180-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
180+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
181181
return self._call_kernel(F.permute_channels, inpt, params["permutation"])
182182

183183

@@ -235,7 +235,7 @@ def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
235235
params["channel_permutation"] = torch.randperm(num_channels) if torch.rand(1) < self.p else None
236236
return params
237237

238-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
238+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
239239
if params["brightness_factor"] is not None:
240240
inpt = self._call_kernel(F.adjust_brightness, inpt, brightness_factor=params["brightness_factor"])
241241
if params["contrast_factor"] is not None and params["contrast_before"]:
@@ -264,7 +264,7 @@ class RandomEqualize(_RandomApplyTransform):
264264

265265
_v1_transform_cls = _transforms.RandomEqualize
266266

267-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
267+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
268268
return self._call_kernel(F.equalize, inpt)
269269

270270

@@ -281,7 +281,7 @@ class RandomInvert(_RandomApplyTransform):
281281

282282
_v1_transform_cls = _transforms.RandomInvert
283283

284-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
284+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
285285
return self._call_kernel(F.invert, inpt)
286286

287287

@@ -304,7 +304,7 @@ def __init__(self, bits: int, p: float = 0.5) -> None:
304304
super().__init__(p=p)
305305
self.bits = bits
306306

307-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
307+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
308308
return self._call_kernel(F.posterize, inpt, bits=self.bits)
309309

310310

@@ -332,7 +332,7 @@ def __init__(self, threshold: float, p: float = 0.5) -> None:
332332
super().__init__(p=p)
333333
self.threshold = threshold
334334

335-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
335+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
336336
return self._call_kernel(F.solarize, inpt, threshold=self.threshold)
337337

338338

@@ -349,7 +349,7 @@ class RandomAutocontrast(_RandomApplyTransform):
349349

350350
_v1_transform_cls = _transforms.RandomAutocontrast
351351

352-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
352+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
353353
return self._call_kernel(F.autocontrast, inpt)
354354

355355

@@ -372,5 +372,5 @@ def __init__(self, sharpness_factor: float, p: float = 0.5) -> None:
372372
super().__init__(p=p)
373373
self.sharpness_factor = sharpness_factor
374374

375-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
375+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
376376
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=self.sharpness_factor)

torchvision/transforms/v2/_deprecated.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,5 +46,5 @@ def __init__(self) -> None:
4646
)
4747
super().__init__()
4848

49-
def _transform(self, inpt: Union[PIL.Image.Image, np.ndarray], params: Dict[str, Any]) -> torch.Tensor:
49+
def transform(self, inpt: Union[PIL.Image.Image, np.ndarray], params: Dict[str, Any]) -> torch.Tensor:
5050
return _F.to_tensor(inpt)

torchvision/transforms/v2/_geometry.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class RandomHorizontalFlip(_RandomApplyTransform):
4444

4545
_v1_transform_cls = _transforms.RandomHorizontalFlip
4646

47-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
47+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
4848
return self._call_kernel(F.horizontal_flip, inpt)
4949

5050

@@ -62,7 +62,7 @@ class RandomVerticalFlip(_RandomApplyTransform):
6262

6363
_v1_transform_cls = _transforms.RandomVerticalFlip
6464

65-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
65+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
6666
return self._call_kernel(F.vertical_flip, inpt)
6767

6868

@@ -156,7 +156,7 @@ def __init__(
156156
self.max_size = max_size
157157
self.antialias = antialias
158158

159-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
159+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
160160
return self._call_kernel(
161161
F.resize,
162162
inpt,
@@ -189,7 +189,7 @@ def __init__(self, size: Union[int, Sequence[int]]):
189189
super().__init__()
190190
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
191191

192-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
192+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
193193
return self._call_kernel(F.center_crop, inpt, output_size=self.size)
194194

195195

@@ -306,7 +306,7 @@ def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
306306

307307
return dict(top=i, left=j, height=h, width=w)
308308

309-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
309+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
310310
return self._call_kernel(
311311
F.resized_crop, inpt, **params, size=self.size, interpolation=self.interpolation, antialias=self.antialias
312312
)
@@ -363,7 +363,7 @@ def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: An
363363
)
364364
return super()._call_kernel(functional, inpt, *args, **kwargs)
365365

366-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
366+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
367367
return self._call_kernel(F.five_crop, inpt, self.size)
368368

369369
def check_inputs(self, flat_inputs: List[Any]) -> None:
@@ -412,7 +412,7 @@ def check_inputs(self, flat_inputs: List[Any]) -> None:
412412
if has_any(flat_inputs, tv_tensors.BoundingBoxes, tv_tensors.Mask):
413413
raise TypeError(f"BoundingBoxes'es and Mask's are not supported by {type(self).__name__}()")
414414

415-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
415+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
416416
return self._call_kernel(F.ten_crop, inpt, self.size, vertical_flip=self.vertical_flip)
417417

418418

@@ -483,7 +483,7 @@ def __init__(
483483
self._fill = _setup_fill_arg(fill)
484484
self.padding_mode = padding_mode
485485

486-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
486+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
487487
fill = _get_fill(self._fill, type(inpt))
488488
return self._call_kernel(F.pad, inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type]
489489

@@ -551,7 +551,7 @@ def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
551551

552552
return dict(padding=padding)
553553

554-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
554+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
555555
fill = _get_fill(self._fill, type(inpt))
556556
return self._call_kernel(F.pad, inpt, **params, fill=fill)
557557

@@ -622,7 +622,7 @@ def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
622622
angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item()
623623
return dict(angle=angle)
624624

625-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
625+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
626626
fill = _get_fill(self._fill, type(inpt))
627627
return self._call_kernel(
628628
F.rotate,
@@ -743,7 +743,7 @@ def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
743743
shear = (shear_x, shear_y)
744744
return dict(angle=angle, translate=translate, scale=scale, shear=shear)
745745

746-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
746+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
747747
fill = _get_fill(self._fill, type(inpt))
748748
return self._call_kernel(
749749
F.affine,
@@ -897,7 +897,7 @@ def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
897897
padding=padding,
898898
)
899899

900-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
900+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
901901
if params["needs_pad"]:
902902
fill = _get_fill(self._fill, type(inpt))
903903
inpt = self._call_kernel(F.pad, inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode)
@@ -982,7 +982,7 @@ def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
982982
perspective_coeffs = _get_perspective_coeffs(startpoints, endpoints)
983983
return dict(coefficients=perspective_coeffs)
984984

985-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
985+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
986986
fill = _get_fill(self._fill, type(inpt))
987987
return self._call_kernel(
988988
F.perspective,
@@ -1074,7 +1074,7 @@ def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
10741074
displacement = torch.concat([dx, dy], 1).permute([0, 2, 3, 1]) # 1 x H x W x 2
10751075
return dict(displacement=displacement)
10761076

1077-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
1077+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
10781078
fill = _get_fill(self._fill, type(inpt))
10791079
return self._call_kernel(
10801080
F.elastic,
@@ -1194,7 +1194,7 @@ def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
11941194

11951195
return dict(top=top, left=left, height=new_h, width=new_w, is_within_crop_area=is_within_crop_area)
11961196

1197-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
1197+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
11981198

11991199
if len(params) < 1:
12001200
return inpt
@@ -1272,7 +1272,7 @@ def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
12721272

12731273
return dict(size=(new_height, new_width))
12741274

1275-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
1275+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
12761276
return self._call_kernel(
12771277
F.resize, inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias
12781278
)
@@ -1340,7 +1340,7 @@ def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
13401340

13411341
return dict(size=(new_height, new_width))
13421342

1343-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
1343+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
13441344
return self._call_kernel(
13451345
F.resize, inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias
13461346
)
@@ -1410,7 +1410,7 @@ def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
14101410
size = int(torch.randint(self.min_size, self.max_size, ()))
14111411
return dict(size=[size])
14121412

1413-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
1413+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
14141414
return self._call_kernel(
14151415
F.resize, inpt, params["size"], interpolation=self.interpolation, antialias=self.antialias
14161416
)

torchvision/transforms/v2/_meta.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def __init__(self, format: Union[str, tv_tensors.BoundingBoxFormat]) -> None:
1919
super().__init__()
2020
self.format = format
2121

22-
def _transform(self, inpt: tv_tensors.BoundingBoxes, params: Dict[str, Any]) -> tv_tensors.BoundingBoxes:
22+
def transform(self, inpt: tv_tensors.BoundingBoxes, params: Dict[str, Any]) -> tv_tensors.BoundingBoxes:
2323
return F.convert_bounding_box_format(inpt, new_format=self.format) # type: ignore[return-value, arg-type]
2424

2525

@@ -32,5 +32,5 @@ class ClampBoundingBoxes(Transform):
3232

3333
_transformed_types = (tv_tensors.BoundingBoxes,)
3434

35-
def _transform(self, inpt: tv_tensors.BoundingBoxes, params: Dict[str, Any]) -> tv_tensors.BoundingBoxes:
35+
def transform(self, inpt: tv_tensors.BoundingBoxes, params: Dict[str, Any]) -> tv_tensors.BoundingBoxes:
3636
return F.clamp_bounding_boxes(inpt) # type: ignore[return-value]

0 commit comments

Comments
 (0)