Skip to content

Commit f71e8ca

Browse files
committed
Update gpu testing function
1 parent 342ef3e commit f71e8ca

File tree

2 files changed

+6
-11
lines changed

2 files changed

+6
-11
lines changed

RawRefinery/application/ModelHandler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from colour_demosaicing import demosaicing_CFA_Bayer_Malvar2004
1414
from RawRefinery.application.dng_utils import convert_color_matrix, to_dng
1515
from RawRefinery.application.postprocessing import match_colors_linear
16-
from RawRefinery.application.utils import can_use_cuda
16+
from RawRefinery.application.utils import can_use_gpu
1717

1818
MODEL_REGISTRY = {
1919
"Tree Net Denoise": {
@@ -184,7 +184,7 @@ def __init__(self):
184184

185185
# Manage devices
186186
devices = {
187-
"cuda": can_use_cuda(),
187+
"cuda": can_use_gpu(),
188188
"mps": torch.backends.mps.is_available(),
189189
"cpu": lambda : True
190190
}

RawRefinery/application/utils.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,10 @@
11
import torch
22

3-
def can_use_cuda():
3+
def can_use_gpu():
44
if not torch.cuda.is_available():
5-
print("CUDA not available.")
65
return False
7-
arch = torch.cuda.get_arch_list()
8-
major, minor = torch.cuda.get_device_capability()
9-
10-
if f"sm_{major}{minor}" in arch:
11-
print(f"Found CUDA arch {"sm_{major}{minor}"}. Using Cuda")
6+
try:
7+
x = torch.zeros(1, device="cuda")
128
return True
13-
else:
14-
print(f"Found CUDA arch {"sm_{major}{minor}"}. Must be in {arch}")
9+
except Exception:
1510
return False

0 commit comments

Comments
 (0)