19
19
from typing import Callable , Dict , Generator , List , Optional , Union
20
20
21
21
from torchx .specs import AppDef
22
- from torchx .specs .file_linter import get_fn_docstring , validate
22
+ from torchx .specs .file_linter import get_fn_docstring , TorchxFunctionValidator , validate
23
23
from torchx .util import entrypoints
24
24
from torchx .util .io import read_conf_file
25
25
from torchx .util .types import none_throws
@@ -59,7 +59,9 @@ class _Component:
59
59
60
60
class ComponentsFinder (abc .ABC ):
61
61
@abc .abstractmethod
62
- def find (self ) -> List [_Component ]:
62
+ def find (
63
+ self , validators : Optional [List [TorchxFunctionValidator ]]
64
+ ) -> List [_Component ]:
63
65
"""
64
66
Retrieves a set of components. A component is defined as a python
65
67
function that conforms to ``torchx.specs.file_linter`` linter.
@@ -203,10 +205,12 @@ def _iter_modules_recursive(
203
205
else :
204
206
yield self ._try_import (module_info .name )
205
207
206
- def find (self ) -> List [_Component ]:
208
+ def find (
209
+ self , validators : Optional [List [TorchxFunctionValidator ]]
210
+ ) -> List [_Component ]:
207
211
components = []
208
212
for m in self ._iter_modules_recursive (self .base_module ):
209
- components += self ._get_components_from_module (m )
213
+ components += self ._get_components_from_module (m , validators )
210
214
return components
211
215
212
216
def _try_import (self , module : Union [str , ModuleType ]) -> ModuleType :
@@ -221,7 +225,9 @@ def _try_import(self, module: Union[str, ModuleType]) -> ModuleType:
221
225
else :
222
226
return module
223
227
224
- def _get_components_from_module (self , module : ModuleType ) -> List [_Component ]:
228
+ def _get_components_from_module (
229
+ self , module : ModuleType , validators : Optional [List [TorchxFunctionValidator ]]
230
+ ) -> List [_Component ]:
225
231
functions = getmembers (module , isfunction )
226
232
component_defs = []
227
233
@@ -230,7 +236,7 @@ def _get_components_from_module(self, module: ModuleType) -> List[_Component]:
230
236
module_path = os .path .abspath (module_path )
231
237
rel_module_name = module_relname (module , relative_to = self .base_module )
232
238
for function_name , function in functions :
233
- linter_errors = validate (module_path , function_name )
239
+ linter_errors = validate (module_path , function_name , validators )
234
240
component_desc , _ = get_fn_docstring (function )
235
241
236
242
# remove empty string to deal with group=""
@@ -255,13 +261,20 @@ def __init__(self, filepath: str, function_name: str) -> None:
255
261
self ._filepath = filepath
256
262
self ._function_name = function_name
257
263
258
- def _get_validation_errors (self , path : str , function_name : str ) -> List [str ]:
259
- linter_errors = validate (path , function_name )
264
+ def _get_validation_errors (
265
+ self ,
266
+ path : str ,
267
+ function_name : str ,
268
+ validators : Optional [List [TorchxFunctionValidator ]],
269
+ ) -> List [str ]:
270
+ linter_errors = validate (path , function_name , validators )
260
271
return [linter_error .description for linter_error in linter_errors ]
261
272
262
- def find (self ) -> List [_Component ]:
273
+ def find (
274
+ self , validators : Optional [List [TorchxFunctionValidator ]]
275
+ ) -> List [_Component ]:
263
276
validation_errors = self ._get_validation_errors (
264
- self ._filepath , self ._function_name
277
+ self ._filepath , self ._function_name , validators
265
278
)
266
279
267
280
file_source = read_conf_file (self ._filepath )
@@ -284,7 +297,9 @@ def find(self) -> List[_Component]:
284
297
]
285
298
286
299
287
- def _load_custom_components () -> List [_Component ]:
300
+ def _load_custom_components (
301
+ validators : Optional [List [TorchxFunctionValidator ]],
302
+ ) -> List [_Component ]:
288
303
component_modules = {
289
304
name : load_fn ()
290
305
for name , load_fn in
@@ -303,11 +318,13 @@ def _load_custom_components() -> List[_Component]:
303
318
# _0 = torchx.components.dist
304
319
# _1 = torchx.components.utils
305
320
group = "" if group .startswith ("_" ) else group
306
- components += ModuleComponentsFinder (module , group ).find ()
321
+ components += ModuleComponentsFinder (module , group ).find (validators )
307
322
return components
308
323
309
324
310
- def _load_components () -> Dict [str , _Component ]:
325
+ def _load_components (
326
+ validators : Optional [List [TorchxFunctionValidator ]],
327
+ ) -> Dict [str , _Component ]:
311
328
"""
312
329
Loads either the custom component defs from the entrypoint ``[torchx.components]``
313
330
or the default builtins from ``torchx.components`` module.
@@ -318,37 +335,43 @@ def _load_components() -> Dict[str, _Component]:
318
335
319
336
"""
320
337
321
- components = _load_custom_components ()
338
+ components = _load_custom_components (validators )
322
339
if not components :
323
- components = ModuleComponentsFinder ("torchx.components" , "" ).find ()
340
+ components = ModuleComponentsFinder ("torchx.components" , "" ).find (validators )
324
341
return {c .name : c for c in components }
325
342
326
343
327
344
_components : Optional [Dict [str , _Component ]] = None
328
345
329
346
330
- def _find_components () -> Dict [str , _Component ]:
347
+ def _find_components (
348
+ validators : Optional [List [TorchxFunctionValidator ]],
349
+ ) -> Dict [str , _Component ]:
331
350
global _components
332
351
if not _components :
333
- _components = _load_components ()
352
+ _components = _load_components (validators )
334
353
return none_throws (_components )
335
354
336
355
337
356
def _is_custom_component (component_name : str ) -> bool :
338
357
return ":" in component_name
339
358
340
359
341
- def _find_custom_components (name : str ) -> Dict [str , _Component ]:
360
+ def _find_custom_components (
361
+ name : str , validators : Optional [List [TorchxFunctionValidator ]]
362
+ ) -> Dict [str , _Component ]:
342
363
if ":" not in name :
343
364
raise ValueError (
344
365
f"Invalid custom component: { name } , valid template : `FILEPATH`:`FUNCTION_NAME`"
345
366
)
346
367
filepath , component_name = name .split (":" )
347
- components = CustomComponentsFinder (filepath , component_name ).find ()
368
+ components = CustomComponentsFinder (filepath , component_name ).find (validators )
348
369
return {component .name : component for component in components }
349
370
350
371
351
- def get_components () -> Dict [str , _Component ]:
372
+ def get_components (
373
+ validators : Optional [List [TorchxFunctionValidator ]] = None ,
374
+ ) -> Dict [str , _Component ]:
352
375
"""
353
376
Returns all custom components registered via ``[torchx.components]`` entrypoints
354
377
OR builtin components that ship with TorchX (but not both).
@@ -395,23 +418,25 @@ def get_components() -> Dict[str, _Component]:
395
418
"""
396
419
397
420
valid_components : Dict [str , _Component ] = {}
398
- for component_name , component in _find_components ().items ():
421
+ for component_name , component in _find_components (validators ).items ():
399
422
if len (component .validation_errors ) == 0 :
400
423
valid_components [component_name ] = component
401
424
return valid_components
402
425
403
426
404
- def get_component (name : str ) -> _Component :
427
+ def get_component (
428
+ name : str , validators : Optional [List [TorchxFunctionValidator ]] = None
429
+ ) -> _Component :
405
430
"""
406
431
Retrieves components by the provided name.
407
432
408
433
Returns:
409
434
Component or None if no component with ``name`` exists
410
435
"""
411
436
if _is_custom_component (name ):
412
- components = _find_custom_components (name )
437
+ components = _find_custom_components (name , validators )
413
438
else :
414
- components = _find_components ()
439
+ components = _find_components (validators )
415
440
if name not in components :
416
441
raise ComponentNotFoundException (
417
442
f"Component `{ name } ` not found. Please make sure it is one of the "
@@ -428,7 +453,9 @@ def get_component(name: str) -> _Component:
428
453
return component
429
454
430
455
431
- def get_builtin_source (name : str ) -> str :
456
+ def get_builtin_source (
457
+ name : str , validators : Optional [List [TorchxFunctionValidator ]] = None
458
+ ) -> str :
432
459
"""
433
460
Returns a string of the the builtin component's function source code
434
461
with all the import statements. Intended to be used to make a copy
@@ -446,7 +473,7 @@ def get_builtin_source(name: str) -> str:
446
473
are optimized and formatting adheres to your organization's standards.
447
474
"""
448
475
449
- component = get_component (name )
476
+ component = get_component (name , validators )
450
477
fn = component .fn
451
478
fn_name = component .name .split ("." )[- 1 ]
452
479
0 commit comments