diff --git a/torchchat/utils/build_utils.py b/torchchat/utils/build_utils.py index abda72d70..c9a9eae14 100644 --- a/torchchat/utils/build_utils.py +++ b/torchchat/utils/build_utils.py @@ -6,6 +6,7 @@ from __future__ import annotations +from enum import Enum import logging import os from pathlib import Path @@ -78,36 +79,33 @@ def set_backend(dso, pte): active_builder_args_pte = pte -def use_aoti_backend() -> bool: +class _Backend(Enum): + AOTI = 0, + EXECUTORCH = 1 + + +def _active_backend() -> _Backend: global active_builder_args_dso global active_builder_args_pte # eager == aoti, which is when backend has not been explicitly set if (not active_builder_args_dso) and not (active_builder_args_pte): - return True + return _Backend.AOTI if active_builder_args_pte and active_builder_args_dso: raise RuntimeError( "code generation needs to choose different implementations for DSO and PTE path. Please only use one export option, and call export twice if necessary!" ) - return bool(active_builder_args_dso) + return _Backend.AOTI if active_builder_args_dso else _Backend.EXECUTORCH -def use_et_backend() -> bool: - global active_builder_args_dso - global active_builder_args_pte - - # eager == aoti, which is when backend has not been explicitly set - if not (active_builder_args_pte or active_builder_args_dso): - return False +def use_aoti_backend() -> bool: + return _active_backend() == _Backend.AOTI - if active_builder_args_pte and active_builder_args_dso: - raise RuntimeError( - "code generation needs to choose different implementations for DSO and PTE path. Please only use one export option, and call export twice if necessary!" - ) - return bool(active_builder_args_pte) +def use_et_backend() -> bool: + return _active_backend() == _Backend.EXECUTORCH ########################################################################## @@ -142,9 +140,9 @@ def name_to_dtype(name, device): return torch.float16 return torch.bfloat16 - if name in name_to_dtype_dict: + try: return name_to_dtype_dict[name] - else: + except KeyError: raise RuntimeError(f"unsupported dtype name {name} specified") @@ -212,10 +210,7 @@ def canonical_path(path): def state_dict_device(d, device="cpu") -> Dict: - for key, weight in d.items(): - d[key] = weight.to(device=device) - - return d + return {key : weight.to(device=device) for (key, weight) in d.items()} ######################################################################### @@ -259,9 +254,9 @@ def get_device(device) -> str: return torch.device(device) -def is_cuda_or_cpu_device(device) -> bool: - return device == "" or str(device) == "cpu" or ("cuda" in str(device)) - - def is_cpu_device(device) -> bool: return device == "" or str(device) == "cpu" + + +def is_cuda_or_cpu_device(device) -> bool: + return is_cpu_device(device) or ("cuda" in str(device))