Skip to content

Commit d1409e6

Browse files
authored
Allow injecting validators for component verification
Differential Revision: D70936713 Pull Request resolved: #1016
1 parent 5efd2b0 commit d1409e6

File tree

4 files changed

+101
-54
lines changed

4 files changed

+101
-54
lines changed

torchx/specs/file_linter.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -244,11 +244,18 @@ class TorchFunctionVisitor(ast.NodeVisitor):
244244
245245
"""
246246

247-
def __init__(self, component_function_name: str) -> None:
248-
self.validators = [
249-
TorchxFunctionArgsValidator(),
250-
TorchxReturnValidator(),
251-
]
247+
def __init__(
248+
self,
249+
component_function_name: str,
250+
validators: Optional[List[TorchxFunctionValidator]],
251+
) -> None:
252+
if validators is None:
253+
self.validators: List[TorchxFunctionValidator] = [
254+
TorchxFunctionArgsValidator(),
255+
TorchxReturnValidator(),
256+
]
257+
else:
258+
self.validators = validators
252259
self.linter_errors: List[LinterMessage] = []
253260
self.component_function_name = component_function_name
254261
self.visited_function = False
@@ -264,7 +271,11 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
264271
self.linter_errors += validator.validate(node)
265272

266273

267-
def validate(path: str, component_function: str) -> List[LinterMessage]:
274+
def validate(
275+
path: str,
276+
component_function: str,
277+
validators: Optional[List[TorchxFunctionValidator]],
278+
) -> List[LinterMessage]:
268279
"""
269280
Validates the function to make sure it complies the component standard.
270281
@@ -293,7 +304,7 @@ def validate(path: str, component_function: str) -> List[LinterMessage]:
293304
severity="error",
294305
)
295306
return [linter_message]
296-
visitor = TorchFunctionVisitor(component_function)
307+
visitor = TorchFunctionVisitor(component_function, validators)
297308
visitor.visit(module)
298309
linter_errors = visitor.linter_errors
299310
if not visitor.visited_function:

torchx/specs/finder.py

Lines changed: 53 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from typing import Callable, Dict, Generator, List, Optional, Union
2020

2121
from torchx.specs import AppDef
22-
from torchx.specs.file_linter import get_fn_docstring, validate
22+
from torchx.specs.file_linter import get_fn_docstring, TorchxFunctionValidator, validate
2323
from torchx.util import entrypoints
2424
from torchx.util.io import read_conf_file
2525
from torchx.util.types import none_throws
@@ -59,7 +59,9 @@ class _Component:
5959

6060
class ComponentsFinder(abc.ABC):
6161
@abc.abstractmethod
62-
def find(self) -> List[_Component]:
62+
def find(
63+
self, validators: Optional[List[TorchxFunctionValidator]]
64+
) -> List[_Component]:
6365
"""
6466
Retrieves a set of components. A component is defined as a python
6567
function that conforms to ``torchx.specs.file_linter`` linter.
@@ -203,10 +205,12 @@ def _iter_modules_recursive(
203205
else:
204206
yield self._try_import(module_info.name)
205207

206-
def find(self) -> List[_Component]:
208+
def find(
209+
self, validators: Optional[List[TorchxFunctionValidator]]
210+
) -> List[_Component]:
207211
components = []
208212
for m in self._iter_modules_recursive(self.base_module):
209-
components += self._get_components_from_module(m)
213+
components += self._get_components_from_module(m, validators)
210214
return components
211215

212216
def _try_import(self, module: Union[str, ModuleType]) -> ModuleType:
@@ -221,7 +225,9 @@ def _try_import(self, module: Union[str, ModuleType]) -> ModuleType:
221225
else:
222226
return module
223227

224-
def _get_components_from_module(self, module: ModuleType) -> List[_Component]:
228+
def _get_components_from_module(
229+
self, module: ModuleType, validators: Optional[List[TorchxFunctionValidator]]
230+
) -> List[_Component]:
225231
functions = getmembers(module, isfunction)
226232
component_defs = []
227233

@@ -230,7 +236,7 @@ def _get_components_from_module(self, module: ModuleType) -> List[_Component]:
230236
module_path = os.path.abspath(module_path)
231237
rel_module_name = module_relname(module, relative_to=self.base_module)
232238
for function_name, function in functions:
233-
linter_errors = validate(module_path, function_name)
239+
linter_errors = validate(module_path, function_name, validators)
234240
component_desc, _ = get_fn_docstring(function)
235241

236242
# remove empty string to deal with group=""
@@ -255,13 +261,20 @@ def __init__(self, filepath: str, function_name: str) -> None:
255261
self._filepath = filepath
256262
self._function_name = function_name
257263

258-
def _get_validation_errors(self, path: str, function_name: str) -> List[str]:
259-
linter_errors = validate(path, function_name)
264+
def _get_validation_errors(
265+
self,
266+
path: str,
267+
function_name: str,
268+
validators: Optional[List[TorchxFunctionValidator]],
269+
) -> List[str]:
270+
linter_errors = validate(path, function_name, validators)
260271
return [linter_error.description for linter_error in linter_errors]
261272

262-
def find(self) -> List[_Component]:
273+
def find(
274+
self, validators: Optional[List[TorchxFunctionValidator]]
275+
) -> List[_Component]:
263276
validation_errors = self._get_validation_errors(
264-
self._filepath, self._function_name
277+
self._filepath, self._function_name, validators
265278
)
266279

267280
file_source = read_conf_file(self._filepath)
@@ -284,7 +297,9 @@ def find(self) -> List[_Component]:
284297
]
285298

286299

287-
def _load_custom_components() -> List[_Component]:
300+
def _load_custom_components(
301+
validators: Optional[List[TorchxFunctionValidator]],
302+
) -> List[_Component]:
288303
component_modules = {
289304
name: load_fn()
290305
for name, load_fn in
@@ -303,11 +318,13 @@ def _load_custom_components() -> List[_Component]:
303318
# _0 = torchx.components.dist
304319
# _1 = torchx.components.utils
305320
group = "" if group.startswith("_") else group
306-
components += ModuleComponentsFinder(module, group).find()
321+
components += ModuleComponentsFinder(module, group).find(validators)
307322
return components
308323

309324

310-
def _load_components() -> Dict[str, _Component]:
325+
def _load_components(
326+
validators: Optional[List[TorchxFunctionValidator]],
327+
) -> Dict[str, _Component]:
311328
"""
312329
Loads either the custom component defs from the entrypoint ``[torchx.components]``
313330
or the default builtins from ``torchx.components`` module.
@@ -318,37 +335,43 @@ def _load_components() -> Dict[str, _Component]:
318335
319336
"""
320337

321-
components = _load_custom_components()
338+
components = _load_custom_components(validators)
322339
if not components:
323-
components = ModuleComponentsFinder("torchx.components", "").find()
340+
components = ModuleComponentsFinder("torchx.components", "").find(validators)
324341
return {c.name: c for c in components}
325342

326343

327344
_components: Optional[Dict[str, _Component]] = None
328345

329346

330-
def _find_components() -> Dict[str, _Component]:
347+
def _find_components(
348+
validators: Optional[List[TorchxFunctionValidator]],
349+
) -> Dict[str, _Component]:
331350
global _components
332351
if not _components:
333-
_components = _load_components()
352+
_components = _load_components(validators)
334353
return none_throws(_components)
335354

336355

337356
def _is_custom_component(component_name: str) -> bool:
338357
return ":" in component_name
339358

340359

341-
def _find_custom_components(name: str) -> Dict[str, _Component]:
360+
def _find_custom_components(
361+
name: str, validators: Optional[List[TorchxFunctionValidator]]
362+
) -> Dict[str, _Component]:
342363
if ":" not in name:
343364
raise ValueError(
344365
f"Invalid custom component: {name}, valid template : `FILEPATH`:`FUNCTION_NAME`"
345366
)
346367
filepath, component_name = name.split(":")
347-
components = CustomComponentsFinder(filepath, component_name).find()
368+
components = CustomComponentsFinder(filepath, component_name).find(validators)
348369
return {component.name: component for component in components}
349370

350371

351-
def get_components() -> Dict[str, _Component]:
372+
def get_components(
373+
validators: Optional[List[TorchxFunctionValidator]] = None,
374+
) -> Dict[str, _Component]:
352375
"""
353376
Returns all custom components registered via ``[torchx.components]`` entrypoints
354377
OR builtin components that ship with TorchX (but not both).
@@ -395,23 +418,25 @@ def get_components() -> Dict[str, _Component]:
395418
"""
396419

397420
valid_components: Dict[str, _Component] = {}
398-
for component_name, component in _find_components().items():
421+
for component_name, component in _find_components(validators).items():
399422
if len(component.validation_errors) == 0:
400423
valid_components[component_name] = component
401424
return valid_components
402425

403426

404-
def get_component(name: str) -> _Component:
427+
def get_component(
428+
name: str, validators: Optional[List[TorchxFunctionValidator]] = None
429+
) -> _Component:
405430
"""
406431
Retrieves components by the provided name.
407432
408433
Returns:
409434
Component or None if no component with ``name`` exists
410435
"""
411436
if _is_custom_component(name):
412-
components = _find_custom_components(name)
437+
components = _find_custom_components(name, validators)
413438
else:
414-
components = _find_components()
439+
components = _find_components(validators)
415440
if name not in components:
416441
raise ComponentNotFoundException(
417442
f"Component `{name}` not found. Please make sure it is one of the "
@@ -428,7 +453,9 @@ def get_component(name: str) -> _Component:
428453
return component
429454

430455

431-
def get_builtin_source(name: str) -> str:
456+
def get_builtin_source(
457+
name: str, validators: Optional[List[TorchxFunctionValidator]] = None
458+
) -> str:
432459
"""
433460
Returns a string of the the builtin component's function source code
434461
with all the import statements. Intended to be used to make a copy
@@ -446,7 +473,7 @@ def get_builtin_source(name: str) -> str:
446473
are optimized and formatting adheres to your organization's standards.
447474
"""
448475

449-
component = get_component(name)
476+
component = get_component(name, validators)
450477
fn = component.fn
451478
fn_name = component.name.split(".")[-1]
452479

torchx/specs/test/file_linter_test.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -121,22 +121,21 @@ def test_syntax_error(self) -> None:
121121
content = "!!foo====bar"
122122
with patch("torchx.specs.file_linter.read_conf_file") as read_conf_file_mock:
123123
read_conf_file_mock.return_value = content
124-
errors = validate(self._path, "unknown_function")
124+
errors = validate(self._path, "unknown_function", None)
125125
self.assertEqual(1, len(errors))
126126
self.assertEqual("invalid syntax", errors[0].description)
127127

128128
def test_validate_varargs_kwargs_fn(self) -> None:
129129
linter_errors = validate(
130-
self._path,
131-
"_test_invalid_fn_with_varags_and_kwargs",
130+
self._path, "_test_invalid_fn_with_varags_and_kwargs", None
132131
)
133132
self.assertEqual(1, len(linter_errors))
134133
self.assertTrue(
135134
"Arg args missing type annotation", linter_errors[0].description
136135
)
137136

138137
def test_validate_no_return(self) -> None:
139-
linter_errors = validate(self._path, "_test_fn_no_return")
138+
linter_errors = validate(self._path, "_test_fn_no_return", None)
140139
self.assertEqual(1, len(linter_errors))
141140
expected_desc = (
142141
"Function: _test_fn_no_return missing return annotation or "
@@ -145,20 +144,32 @@ def test_validate_no_return(self) -> None:
145144
self.assertEqual(expected_desc, linter_errors[0].description)
146145

147146
def test_validate_incorrect_return(self) -> None:
148-
linter_errors = validate(self._path, "_test_fn_return_int")
147+
linter_errors = validate(self._path, "_test_fn_return_int", None)
149148
self.assertEqual(1, len(linter_errors))
150149
expected_desc = (
151150
"Function: _test_fn_return_int has incorrect return annotation, "
152151
"supported annotation: AppDef"
153152
)
154153
self.assertEqual(expected_desc, linter_errors[0].description)
155154

155+
def test_no_validators_has_no_validation(self) -> None:
156+
linter_errors = validate(self._path, "_test_fn_return_int", [])
157+
self.assertEqual(0, len(linter_errors))
158+
159+
linter_errors = validate(self._path, "_test_fn_no_return", [])
160+
self.assertEqual(0, len(linter_errors))
161+
162+
linter_errors = validate(
163+
self._path, "_test_invalid_fn_with_varags_and_kwargs", []
164+
)
165+
self.assertEqual(0, len(linter_errors))
166+
156167
def test_validate_empty_fn(self) -> None:
157-
linter_errors = validate(self._path, "_test_empty_fn")
168+
linter_errors = validate(self._path, "_test_empty_fn", None)
158169
self.assertEqual(0, len(linter_errors))
159170

160171
def test_validate_args_no_type_defs(self) -> None:
161-
linter_errors = validate(self._path, "_test_args_no_type_defs")
172+
linter_errors = validate(self._path, "_test_args_no_type_defs", None)
162173
print(linter_errors)
163174
self.assertEqual(2, len(linter_errors))
164175
self.assertEqual(
@@ -169,10 +180,7 @@ def test_validate_args_no_type_defs(self) -> None:
169180
)
170181

171182
def test_validate_args_no_type_defs_complex(self) -> None:
172-
linter_errors = validate(
173-
self._path,
174-
"_test_args_dict_list_complex_types",
175-
)
183+
linter_errors = validate(self._path, "_test_args_dict_list_complex_types", None)
176184
self.assertEqual(5, len(linter_errors))
177185
self.assertEqual(
178186
"Arg arg0 missing type annotation", linter_errors[0].description
@@ -210,7 +218,7 @@ def test_validate_docstring_no_docs(self) -> None:
210218
self.assertEqual(" ", param_desc["arg0"])
211219

212220
def test_validate_unknown_function(self) -> None:
213-
linter_errors = validate(self._path, "unknown_function")
221+
linter_errors = validate(self._path, "unknown_function", None)
214222
self.assertEqual(1, len(linter_errors))
215223
self.assertEqual(
216224
"Function unknown_function not found", linter_errors[0].description

0 commit comments

Comments
 (0)