Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 33 additions & 4 deletions torchchat/utils/build_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,22 +148,51 @@ def get_precision():
### dtype name to torch.dtype mapping ###


def get_cuda_architecture(device=None):
device_str = get_device_str(device)
if "cuda" in device_str and torch.cuda.is_available():
# Get the compute capability as (major, minor) tuple
capability = torch.cuda.get_device_capability(device)
return capability[0], capability[1]
else:
return 0, 0


##########################################################################
### dtype name to torch.dtype mapping ###


def name_to_dtype(name, device):
device_str = get_device_str(device)
# if it's CUDA, the architecture level indicates whether we can use bf16
major, minor = get_cuda_architecture(device)

if (name == "fast") or (name == "fast16"):
# MacOS now supports bfloat16
import platform

if platform.processor() == "arm":
device = get_device_str(device)
# ARM CPU is faster with float16, MPS with bf16 if supported
if device == "cpu" or int(platform.mac_ver()[0].split(".")[0]) < 14:
if device_str == "cpu" or int(platform.mac_ver()[0].split(".")[0]) < 14:
return torch.float16
return torch.bfloat16

# if it's not CUDA, we know it's bfloat16
if "cuda" not in device_str:
return torch.bfloat16

if major >= 9:
return torch.bfloat16
else:
return torch.float16

try:
return name_to_dtype_dict[name]
dtype = name_to_dtype_dict[name]
except KeyError:
raise RuntimeError(f"unsupported dtype name {name} specified")

if ("cuda" in device_str) and (dtype == torch.bfloat16) and (major < 9):
raise RuntimeError(f"target device {device_str} does not support the bfloat16 data type")
return dtype


def allowable_dtype_names() -> List[str]:
Expand Down
Loading