Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit bfab3f1

Browse files
authored
recognize and issue error if GPU does not support bf16
Address #1298 which causes models to fail on T4 (and other pre-9.0 arch level GPUs) by selecting an alternate dtype when possible, and issue an error otherwise
1 parent 9480258 commit bfab3f1

File tree

1 file changed

+33
-4
lines changed

1 file changed

+33
-4
lines changed

torchchat/utils/build_utils.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,22 +148,51 @@ def get_precision():
148148
### dtype name to torch.dtype mapping ###
149149

150150

151+
def get_cuda_architecture(device=None):
152+
device_str = get_device_str(device)
153+
if "cuda" is in device_str and torch.cuda.is_available():
154+
# Get the compute capability as (major, minor) tuple
155+
capability = torch.cuda.get_device_capability(device)
156+
return capability[0], capability[1]
157+
else:
158+
return 0, 0
159+
160+
161+
##########################################################################
162+
### dtype name to torch.dtype mapping ###
163+
164+
151165
def name_to_dtype(name, device):
166+
device_str = get_device_str(device)
167+
# if it's CUDA, the architecture level indicates whether we can use bf16
168+
major, minor = get_cuda_architecture(device)
169+
152170
if (name == "fast") or (name == "fast16"):
153171
# MacOS now supports bfloat16
154172
import platform
155173

156174
if platform.processor() == "arm":
157-
device = get_device_str(device)
158175
# ARM CPU is faster with float16, MPS with bf16 if supported
159-
if device == "cpu" or int(platform.mac_ver()[0].split(".")[0]) < 14:
176+
if device_str == "cpu" or int(platform.mac_ver()[0].split(".")[0]) < 14:
160177
return torch.float16
161-
return torch.bfloat16
178+
179+
# if it's not CUDA, we know it's bfloat16
180+
if "cuda" is not in device_str:
181+
return torch.bfloat16
182+
183+
if major >= 9:
184+
return torch.bfloat16
185+
else:
186+
return torch.float16
162187

163188
try:
164-
return name_to_dtype_dict[name]
189+
dtype = name_to_dtype_dict[name]
165190
except KeyError:
166191
raise RuntimeError(f"unsupported dtype name {name} specified")
192+
193+
if ("cuda" is in device_str) and (dtype == torch.bfloat16) and (major < 9):
194+
raise RuntimeError(f"target device {device_str} does not support the bfloat16 data type")
195+
return dtype
167196

168197

169198
def allowable_dtype_names() -> List[str]:

0 commit comments

Comments
 (0)