diff --git a/torchchat/utils/build_utils.py b/torchchat/utils/build_utils.py index 1b649ffbc..0b52b0835 100644 --- a/torchchat/utils/build_utils.py +++ b/torchchat/utils/build_utils.py @@ -148,22 +148,51 @@ def get_precision(): ### dtype name to torch.dtype mapping ### +def get_cuda_architecture(device=None): + device_str = get_device_str(device) + if "cuda" in device_str and torch.cuda.is_available(): + # Get the compute capability as (major, minor) tuple + capability = torch.cuda.get_device_capability(device) + return capability[0], capability[1] + else: + return 0, 0 + + +########################################################################## +### dtype name to torch.dtype mapping ### + + def name_to_dtype(name, device): + device_str = get_device_str(device) + # if it's CUDA, the architecture level indicates whether we can use bf16 + major, minor = get_cuda_architecture(device) + if (name == "fast") or (name == "fast16"): # MacOS now supports bfloat16 import platform if platform.processor() == "arm": - device = get_device_str(device) # ARM CPU is faster with float16, MPS with bf16 if supported - if device == "cpu" or int(platform.mac_ver()[0].split(".")[0]) < 14: + if device_str == "cpu" or int(platform.mac_ver()[0].split(".")[0]) < 14: return torch.float16 - return torch.bfloat16 + + # if it's not CUDA, we know it's bfloat16 + if "cuda" not in device_str: + return torch.bfloat16 + + if major >= 9: + return torch.bfloat16 + else: + return torch.float16 try: - return name_to_dtype_dict[name] + dtype = name_to_dtype_dict[name] except KeyError: raise RuntimeError(f"unsupported dtype name {name} specified") + + if ("cuda" in device_str) and (dtype == torch.bfloat16) and (major < 9): + raise RuntimeError(f"target device {device_str} does not support the bfloat16 data type") + return dtype def allowable_dtype_names() -> List[str]: