File tree Expand file tree Collapse file tree 2 files changed +13
-5
lines changed Expand file tree Collapse file tree 2 files changed +13
-5
lines changed Original file line number Diff line number Diff line change 1- import time
2- import os .path as osp
1+ import argparse
32import itertools
3+ import os .path as osp
4+ import time
45
5- import argparse
6- import wget
76import torch
7+ import wget
88from scipy .io import loadmat
9-
109from torch_scatter import scatter_add
10+
1111from torch_sparse .tensor import SparseTensor
1212
1313short_rows = [
@@ -62,6 +62,9 @@ def time_func(func, x):
6262 try :
6363 if torch .cuda .is_available ():
6464 torch .cuda .synchronize ()
65+ elif torch .backends .mps .is_available ():
66+ import torch .mps
67+ torch .mps .synchronize ()
6568 t = time .perf_counter ()
6669
6770 if not args .with_backward :
@@ -77,6 +80,9 @@ def time_func(func, x):
7780
7881 if torch .cuda .is_available ():
7982 torch .cuda .synchronize ()
83+ elif torch .backends .mps .is_available ():
84+ import torch .mps
85+ torch .mps .synchronize ()
8086 return time .perf_counter () - t
8187 except RuntimeError as e :
8288 if 'out of memory' not in str (e ):
Original file line number Diff line number Diff line change 1616devices = [torch .device ('cpu' )]
1717if torch .cuda .is_available ():
1818 devices += [torch .device ('cuda:0' )]
19+ if torch .backends .mps .is_available ():
20+ devices += [torch .device ('mps' )]
1921
2022
2123def tensor (x : Any , dtype : torch .dtype , device : torch .device ):
You can’t perform that action at this time.
0 commit comments