Skip to content

Commit 36087d6

Browse files
committed
Expose check_inputs
1 parent b1f2064 commit 36087d6

File tree

4 files changed

+16
-18
lines changed

4 files changed

+16
-18
lines changed

torchvision/prototype/transforms/_geometry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def __init__(
3535

3636
self.padding_mode = padding_mode
3737

38-
def _check_inputs(self, flat_inputs: List[Any]) -> None:
38+
def check_inputs(self, flat_inputs: List[Any]) -> None:
3939
if not has_any(
4040
flat_inputs,
4141
PIL.Image.Image,

torchvision/transforms/v2/_geometry.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: An
366366
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
367367
return self._call_kernel(F.five_crop, inpt, self.size)
368368

369-
def _check_inputs(self, flat_inputs: List[Any]) -> None:
369+
def check_inputs(self, flat_inputs: List[Any]) -> None:
370370
if has_any(flat_inputs, tv_tensors.BoundingBoxes, tv_tensors.Mask):
371371
raise TypeError(f"BoundingBoxes'es and Mask's are not supported by {type(self).__name__}()")
372372

@@ -408,7 +408,7 @@ def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: An
408408
)
409409
return super()._call_kernel(functional, inpt, *args, **kwargs)
410410

411-
def _check_inputs(self, flat_inputs: List[Any]) -> None:
411+
def check_inputs(self, flat_inputs: List[Any]) -> None:
412412
if has_any(flat_inputs, tv_tensors.BoundingBoxes, tv_tensors.Mask):
413413
raise TypeError(f"BoundingBoxes'es and Mask's are not supported by {type(self).__name__}()")
414414

@@ -1132,7 +1132,7 @@ def __init__(
11321132
self.options = sampler_options
11331133
self.trials = trials
11341134

1135-
def _check_inputs(self, flat_inputs: List[Any]) -> None:
1135+
def check_inputs(self, flat_inputs: List[Any]) -> None:
11361136
if not (
11371137
has_all(flat_inputs, tv_tensors.BoundingBoxes)
11381138
and has_any(flat_inputs, PIL.Image.Image, tv_tensors.Image, is_pure_tensor)

torchvision/transforms/v2/_misc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def __init__(self, transformation_matrix: torch.Tensor, mean_vector: torch.Tenso
9999
self.transformation_matrix = transformation_matrix
100100
self.mean_vector = mean_vector
101101

102-
def _check_inputs(self, sample: Any) -> Any:
102+
def check_inputs(self, sample: Any) -> Any:
103103
if has_any(sample, PIL.Image.Image):
104104
raise TypeError(f"{type(self).__name__}() does not support PIL images.")
105105

@@ -157,7 +157,7 @@ def __init__(self, mean: Sequence[float], std: Sequence[float], inplace: bool =
157157
self.std = list(std)
158158
self.inplace = inplace
159159

160-
def _check_inputs(self, sample: Any) -> Any:
160+
def check_inputs(self, sample: Any) -> Any:
161161
if has_any(sample, PIL.Image.Image):
162162
raise TypeError(f"{type(self).__name__}() does not support PIL images.")
163163

torchvision/transforms/v2/_transform.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,15 @@ def __init__(self) -> None:
2424
super().__init__()
2525
_log_api_usage_once(self)
2626

27-
def _check_inputs(self, flat_inputs: List[Any]) -> None:
27+
def check_inputs(self, flat_inputs: List[Any]) -> None:
2828
pass
2929

30-
# This exists for BC. When v2 was introduced, this method was private. Now
31-
# it's publicly exposed as `make_params()`. It cannot be exposed as
32-
# `get_params()` because there is already a `get_params()` methods for v2
33-
# transforms: it's the v1's `get_params()` that we have to keep in order to
34-
# guarantee 100% BC with v1. (It's defined in __init_subclass__ below).
35-
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
36-
return self.make_params(flat_inputs)
37-
30+
# When v2 was introduced, this method was private and called
31+
# `_get_params()`. Now it's publicly exposed as `make_params()`. It cannot
32+
# be exposed as `get_params()` because there is already a `get_params()`
33+
# methods for v2 transforms: it's the v1's `get_params()` that we have to
34+
# keep in order to guarantee 100% BC with v1. (It's defined in
35+
# __init_subclass__ below).
3836
def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
3937
return dict()
4038

@@ -48,7 +46,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
4846
def forward(self, *inputs: Any) -> Any:
4947
flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0])
5048

51-
self._check_inputs(flat_inputs)
49+
self.check_inputs(flat_inputs)
5250

5351
needs_transform_list = self._needs_transform_list(flat_inputs)
5452
params = self.make_params(
@@ -161,12 +159,12 @@ def __init__(self, p: float = 0.5) -> None:
161159
def forward(self, *inputs: Any) -> Any:
162160
# We need to almost duplicate `Transform.forward()` here since we always want to check the inputs, but return
163161
# early afterwards in case the random check triggers. The same result could be achieved by calling
164-
# `super().forward()` after the random check, but that would call `self._check_inputs` twice.
162+
# `super().forward()` after the random check, but that would call `self.check_inputs` twice.
165163

166164
inputs = inputs if len(inputs) > 1 else inputs[0]
167165
flat_inputs, spec = tree_flatten(inputs)
168166

169-
self._check_inputs(flat_inputs)
167+
self.check_inputs(flat_inputs)
170168

171169
if torch.rand(1) >= self.p:
172170
return inputs

0 commit comments

Comments
 (0)