|
2 | 2 | __author__ = "TensorCircuit-NG Authors" |
3 | 3 | __creator__ = "refraction-ray" |
4 | 4 |
|
| 5 | +import importlib |
| 6 | +from typing import Any, List |
5 | 7 | from .utils import gpu_memory_share |
6 | 8 |
|
7 | 9 | gpu_memory_share() |
|
62 | 64 |
|
63 | 65 | FGSCircuit = FGSSimulator |
64 | 66 |
|
65 | | -try: |
66 | | - from . import keras |
67 | | - from .keras import KerasLayer, KerasHardwareLayer |
68 | | -except ModuleNotFoundError: |
69 | | - pass # in case tf is not installed |
| 67 | +# lazy imports for heavy frameworks |
| 68 | +# name: (module_relative_path, is_module) |
| 69 | +_lazy_imports = { |
| 70 | + "keras": (".keras", True), |
| 71 | + "KerasLayer": (".keras", False), |
| 72 | + "KerasHardwareLayer": (".keras", False), |
| 73 | + "torchnn": (".torchnn", True), |
| 74 | + "TorchLayer": (".torchnn", False), |
| 75 | + "TorchHardwareLayer": (".torchnn", False), |
| 76 | +} |
| 77 | + |
| 78 | + |
| 79 | +def __getattr__(name: str) -> Any: |
| 80 | + if name in _lazy_imports: |
| 81 | + path, is_module = _lazy_imports[name] |
| 82 | + module = importlib.import_module(path, __package__) |
| 83 | + if is_module: |
| 84 | + attr = module |
| 85 | + else: |
| 86 | + attr = getattr(module, name) |
| 87 | + globals()[name] = attr |
| 88 | + return attr |
| 89 | + raise AttributeError("module %s has no attribute %s" % (__name__, name)) |
| 90 | + |
| 91 | + |
| 92 | +def __dir__() -> List[str]: |
| 93 | + return sorted(set(globals().keys()).union(_lazy_imports.keys())) |
70 | 94 |
|
71 | | -try: |
72 | | - from . import torchnn |
73 | | - from .torchnn import TorchLayer, TorchHardwareLayer |
74 | | -except ModuleNotFoundError: |
75 | | - pass # in case torch is not installed |
76 | 95 |
|
77 | 96 | try: |
78 | 97 | import qiskit |
|
0 commit comments