Skip to content

Commit d62abe2

Browse files
authored
Refactor Torch Mocking in OneDiff (#351)
1 parent 2f1fb3d commit d62abe2

File tree

1 file changed

+21
-2
lines changed

1 file changed

+21
-2
lines changed

src/onediff/infer_compiler/transform/manager.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import torch
12
import oneflow as flow
23
from typing import Dict, List, Union
4+
from contextlib import contextmanager
35
from pathlib import Path
46
from ..import_tools import (
57
get_classes_in_package,
@@ -9,6 +11,24 @@
911
__all__ = ["transform_mgr"]
1012

1113

14+
@contextmanager
15+
def onediff_mock_torch():
16+
# Fixes check the 'version' error.
17+
attr_name = "__version__"
18+
restore_funcs = [] # Backup
19+
if hasattr(flow, attr_name) and hasattr(torch, attr_name):
20+
orig_flow_attr = getattr(flow, attr_name)
21+
restore_funcs.append(lambda: setattr(flow, attr_name, orig_flow_attr))
22+
setattr(flow, attr_name, getattr(torch, attr_name))
23+
24+
# https://docs.oneflow.org/master/cookies/oneflow_torch.html
25+
with flow.mock_torch.enable(lazy=True):
26+
yield
27+
28+
for restore_func in restore_funcs:
29+
restore_func()
30+
31+
1232
class TransformManager:
1333
def __init__(self):
1434
self._torch_to_oflow_cls_map = {}
@@ -17,8 +37,7 @@ def __init__(self):
1737
def load_class_proxies_from_packages(self, package_names: List[Union[Path, str]]):
1838
print_green(f"Loading modules: {package_names}")
1939
of_mds = {}
20-
# https://docs.oneflow.org/master/cookies/oneflow_torch.html
21-
with flow.mock_torch.enable(lazy=True):
40+
with onediff_mock_torch():
2241
for package_name in package_names:
2342
of_mds.update(get_classes_in_package(package_name))
2443
print_green(f"Loaded Mock Torch {len(of_mds)} classes: {package_names}")

0 commit comments

Comments
 (0)