1+ import torch
12import oneflow as flow
23from typing import Dict , List , Union
4+ from contextlib import contextmanager
35from pathlib import Path
46from ..import_tools import (
57 get_classes_in_package ,
911__all__ = ["transform_mgr" ]
1012
1113
14+ @contextmanager
15+ def onediff_mock_torch ():
16+ # Fixes check the 'version' error.
17+ attr_name = "__version__"
18+ restore_funcs = [] # Backup
19+ if hasattr (flow , attr_name ) and hasattr (torch , attr_name ):
20+ orig_flow_attr = getattr (flow , attr_name )
21+ restore_funcs .append (lambda : setattr (flow , attr_name , orig_flow_attr ))
22+ setattr (flow , attr_name , getattr (torch , attr_name ))
23+
24+ # https://docs.oneflow.org/master/cookies/oneflow_torch.html
25+ with flow .mock_torch .enable (lazy = True ):
26+ yield
27+
28+ for restore_func in restore_funcs :
29+ restore_func ()
30+
31+
1232class TransformManager :
1333 def __init__ (self ):
1434 self ._torch_to_oflow_cls_map = {}
@@ -17,8 +37,7 @@ def __init__(self):
1737 def load_class_proxies_from_packages (self , package_names : List [Union [Path , str ]]):
1838 print_green (f"Loading modules: { package_names } " )
1939 of_mds = {}
20- # https://docs.oneflow.org/master/cookies/oneflow_torch.html
21- with flow .mock_torch .enable (lazy = True ):
40+ with onediff_mock_torch ():
2241 for package_name in package_names :
2342 of_mds .update (get_classes_in_package (package_name ))
2443 print_green (f"Loaded Mock Torch { len (of_mds )} classes: { package_names } " )
0 commit comments