@@ -24,17 +24,15 @@ def __init__(self) -> None:
2424 super ().__init__ ()
2525 _log_api_usage_once (self )
2626
27- def _check_inputs (self , flat_inputs : List [Any ]) -> None :
27+ def check_inputs (self , flat_inputs : List [Any ]) -> None :
2828 pass
2929
30- # This exists for BC. When v2 was introduced, this method was private. Now
31- # it's publicly exposed as `make_params()`. It cannot be exposed as
32- # `get_params()` because there is already a `get_params()` methods for v2
33- # transforms: it's the v1's `get_params()` that we have to keep in order to
34- # guarantee 100% BC with v1. (It's defined in __init_subclass__ below).
35- def _get_params (self , flat_inputs : List [Any ]) -> Dict [str , Any ]:
36- return self .make_params (flat_inputs )
37-
30+ # When v2 was introduced, this method was private and called
31+ # `_get_params()`. Now it's publicly exposed as `make_params()`. It cannot
32+ # be exposed as `get_params()` because there is already a `get_params()`
33+ # methods for v2 transforms: it's the v1's `get_params()` that we have to
34+ # keep in order to guarantee 100% BC with v1. (It's defined in
35+ # __init_subclass__ below).
3836 def make_params (self , flat_inputs : List [Any ]) -> Dict [str , Any ]:
3937 return dict ()
4038
@@ -48,7 +46,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
4846 def forward (self , * inputs : Any ) -> Any :
4947 flat_inputs , spec = tree_flatten (inputs if len (inputs ) > 1 else inputs [0 ])
5048
51- self ._check_inputs (flat_inputs )
49+ self .check_inputs (flat_inputs )
5250
5351 needs_transform_list = self ._needs_transform_list (flat_inputs )
5452 params = self .make_params (
@@ -161,12 +159,12 @@ def __init__(self, p: float = 0.5) -> None:
161159 def forward (self , * inputs : Any ) -> Any :
162160 # We need to almost duplicate `Transform.forward()` here since we always want to check the inputs, but return
163161 # early afterwards in case the random check triggers. The same result could be achieved by calling
164- # `super().forward()` after the random check, but that would call `self._check_inputs ` twice.
162+ # `super().forward()` after the random check, but that would call `self.check_inputs ` twice.
165163
166164 inputs = inputs if len (inputs ) > 1 else inputs [0 ]
167165 flat_inputs , spec = tree_flatten (inputs )
168166
169- self ._check_inputs (flat_inputs )
167+ self .check_inputs (flat_inputs )
170168
171169 if torch .rand (1 ) >= self .p :
172170 return inputs
0 commit comments