|
11 | 11 |
|
12 | 12 | from enum import Enum |
13 | 13 | from pathlib import Path |
14 | | -from typing import Any, Callable, Dict, List, Tuple |
| 14 | +from typing import Any, Callable, Dict, List, Optional, Tuple |
15 | 15 |
|
16 | 16 | import torch |
17 | 17 |
|
@@ -77,31 +77,39 @@ def unpack_packed_weights( |
77 | 77 | def set_backend(dso, pte, aoti_package): |
78 | 78 | global active_builder_args_dso |
79 | 79 | global active_builder_args_pte |
| 80 | + global active_builder_args_aoti_package |
80 | 81 | active_builder_args_dso = dso |
81 | 82 | active_builder_args_aoti_package = aoti_package |
82 | 83 | active_builder_args_pte = pte |
83 | 84 |
|
84 | 85 |
|
85 | 86 | class _Backend(Enum): |
86 | | - AOTI = (0,) |
| 87 | + AOTI = 0 |
87 | 88 | EXECUTORCH = 1 |
88 | 89 |
|
89 | 90 |
|
90 | | -def _active_backend() -> _Backend: |
| 91 | +def _active_backend() -> Optional[_Backend]: |
91 | 92 | global active_builder_args_dso |
92 | 93 | global active_builder_args_aoti_package |
93 | 94 | global active_builder_args_pte |
94 | 95 |
|
95 | | - # eager == aoti, which is when backend has not been explicitly set |
96 | | - if (not active_builder_args_pte) and (not active_builder_args_aoti_package): |
97 | | - return True |
| 96 | + args = ( |
| 97 | + active_builder_args_dso, |
| 98 | + active_builder_args_pte, |
| 99 | + active_builder_args_aoti_package, |
| 100 | + ) |
| 101 | + |
| 102 | + # Return None, as default |
| 103 | + if not any(args): |
| 104 | + return None |
98 | 105 |
|
99 | | - if active_builder_args_pte and active_builder_args_aoti_package: |
| 106 | + # Catch more than one arg |
| 107 | + if sum(map(bool, args)) > 1: |
100 | 108 | raise RuntimeError( |
101 | | - "code generation needs to choose different implementations for AOTI and PTE path. Please only use one export option, and call export twice if necessary!" |
| 109 | + "Code generation needs to choose different implementations. Please only use one export option, and call export twice if necessary!" |
102 | 110 | ) |
103 | 111 |
|
104 | | - return _Backend.AOTI if active_builder_args_pte else _Backend.EXECUTORCH |
| 112 | + return _Backend.EXECUTORCH if active_builder_args_pte else _Backend.AOTI |
105 | 113 |
|
106 | 114 |
|
107 | 115 | def use_aoti_backend() -> bool: |
|
0 commit comments