Skip to content

Commit 0356728

Browse files
committed
update torch device allocation
1 parent 50588aa commit 0356728

File tree

10 files changed

+78
-53
lines changed

10 files changed

+78
-53
lines changed

tensorlayerx/backend/ops/mindspore_backend.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def get_tensor_shape(x):
7272

7373

7474
# initializers
75-
def zeros(shape, dtype=mstype.float32):
75+
def zeros(shape, dtype=mstype.float32, device = None):
7676
"""
7777
Creates a tensor with all elements set to zero.
7878
@@ -95,7 +95,7 @@ def zeros(shape, dtype=mstype.float32):
9595
return Tensor(arr, dtype=dtype)
9696

9797

98-
def ones(shape, dtype=mstype.float32):
98+
def ones(shape, dtype=mstype.float32, device = None):
9999
"""
100100
Creates a tensor with all elements set to ones.
101101
@@ -118,7 +118,7 @@ def ones(shape, dtype=mstype.float32):
118118
return Tensor(arr, dtype=dtype)
119119

120120

121-
def constant(value, dtype=mstype.float32, shape=None):
121+
def constant(value, dtype=mstype.float32, shape=None, device = None):
122122
"""
123123
Creates a constant tensor from a tensor-like object.
124124
@@ -413,7 +413,7 @@ def xavier_normal(shape, dtype, seed=None):
413413
return Tensor(arr, dtype=dtype)
414414

415415

416-
def Variable(initial_value, name, trainable=True):
416+
def Variable(initial_value, name, trainable=True, device = None):
417417
"""
418418
Creates a new variable with value initial_value.
419419
@@ -622,7 +622,7 @@ def concat(values, axis):
622622
return outputs
623623

624624

625-
def convert_to_tensor(value, dtype=None):
625+
def convert_to_tensor(value, dtype=None, device = None):
626626
"""
627627
Converts the given value to a Tensor.
628628

tensorlayerx/backend/ops/paddle_backend.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def get_tensor_shape(x):
4545

4646

4747
# initializers
48-
def zeros(shape, dtype="float32"):
48+
def zeros(shape, dtype="float32", device = None):
4949
"""
5050
Creates a tensor with all elements set to zero.
5151
@@ -64,7 +64,7 @@ def zeros(shape, dtype="float32"):
6464
return pd.zeros(shape=shape, dtype=dtype)
6565

6666

67-
def ones(shape, dtype="float32"):
67+
def ones(shape, dtype="float32", device = None):
6868
"""
6969
Creates a tensor with all elements set to ones.
7070
@@ -83,7 +83,7 @@ def ones(shape, dtype="float32"):
8383
return pd.ones(shape=shape, dtype=dtype)
8484

8585

86-
def constant(value, dtype="float32", shape=None):
86+
def constant(value, dtype="float32", shape=None, device = None):
8787
"""
8888
Creates a constant tensor from a tensor-like object.
8989
@@ -241,7 +241,7 @@ def xavier_uniform(shape, dtype, seed=None):
241241
raise NotImplementedError
242242

243243

244-
def Variable(initial_value, name, trainable=None):
244+
def Variable(initial_value, name, trainable=None, device = None):
245245
"""
246246
Creates a new variable with value initial_value.
247247
@@ -425,7 +425,7 @@ def concat(values, axis=0):
425425
return pd.concat(values, axis)
426426

427427

428-
def convert_to_tensor(value, dtype=None):
428+
def convert_to_tensor(value, dtype=None, device = None):
429429
"""
430430
Converts the given value to a Tensor.
431431

tensorlayerx/backend/ops/tensorflow_backend.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def get_tensor_shape(x):
8383

8484

8585
# initializers
86-
def zeros(shape, dtype='float32'):
86+
def zeros(shape, dtype='float32', device = None):
8787
"""
8888
Creates a tensor with all elements set to zero.
8989
@@ -93,6 +93,8 @@ def zeros(shape, dtype='float32'):
9393
a tuple of integers, or a 1-D Tensor of type int32.
9494
dtype : tensor or str
9595
The DType of an element in the resulting Tensor
96+
device : str or None
97+
create a tensor on 'cpu' or 'gpu', defautl is None.
9698
9799
Returns
98100
-------
@@ -109,7 +111,7 @@ def zeros(shape, dtype='float32'):
109111
return tf.zeros(shape=shape, dtype=dtype_str(dtype))
110112

111113

112-
def ones(shape, dtype='float32'):
114+
def ones(shape, dtype='float32', device = None):
113115
"""
114116
Creates a tensor with all elements set to ones.
115117
@@ -119,6 +121,8 @@ def ones(shape, dtype='float32'):
119121
a tuple of integers, or a 1-D Tensor of type int32.
120122
dtype : tensor or str
121123
The DType of an element in the resulting Tensor
124+
device : str or None
125+
create a tensor on 'cpu' or 'gpu', defautl is None.
122126
123127
Returns
124128
-------
@@ -135,7 +139,7 @@ def ones(shape, dtype='float32'):
135139
return tf.ones(shape=shape, dtype=dtype_str(dtype))
136140

137141

138-
def constant(value, dtype='float32', shape=None):
142+
def constant(value, dtype='float32', shape=None, device = None):
139143
"""
140144
Creates a constant tensor from a tensor-like object.
141145
@@ -147,6 +151,8 @@ def constant(value, dtype='float32', shape=None):
147151
The type of the elements of the resulting tensor.
148152
shape : tuple
149153
Optional dimensions of resulting tensor.
154+
device : str or None
155+
create a tensor on 'cpu' or 'gpu', defautl is None.
150156
151157
Returns
152158
-------
@@ -345,7 +351,7 @@ def xavier_uniform(shape, dtype='float32', seed=None):
345351
return tf.initializers.glorot_uniform(seed)(shape=shape, dtype=dtype_str(dtype))
346352

347353

348-
def Variable(initial_value, name, trainable=True):
354+
def Variable(initial_value, name, trainable=True, device = None):
349355
"""
350356
Creates a new variable with value initial_value.
351357
@@ -355,6 +361,8 @@ def Variable(initial_value, name, trainable=True):
355361
A Tensor, or Python object convertible to a Tensor
356362
name : str
357363
Optional name for the variable. Defaults to 'Variable' and gets uniquified automatically.
364+
device : str or None
365+
create a tensor on 'cpu' or 'gpu', defautl is None.
358366
Returns
359367
-------
360368
Variable
@@ -591,7 +599,7 @@ def concat(values, axis):
591599
return tf.concat(values, axis)
592600

593601

594-
def convert_to_tensor(value, dtype=None):
602+
def convert_to_tensor(value, dtype=None, device = None):
595603
"""
596604
Converts the given value to a Tensor.
597605
@@ -601,6 +609,8 @@ def convert_to_tensor(value, dtype=None):
601609
An object whose type has a registered Tensor conversion function.
602610
dtype : optional
603611
Optional element type for the returned tensor. If missing, the type is inferred from the type of value.
612+
device : str or None
613+
create a tensor on 'cpu' or 'gpu', defautl is None.
604614
605615
Returns
606616
-------

tensorlayerx/backend/ops/torch_backend.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def get_tensor_shape(x):
5252

5353

5454
# initializers
55-
def zeros(shape, dtype=None):
55+
def zeros(shape, dtype=None, device = None):
5656
"""
5757
Creates a tensor with all elements set to zero.
5858
@@ -68,11 +68,14 @@ def zeros(shape, dtype=None):
6868
A Tensor with all elements set to zero.
6969
7070
"""
71+
if device == 'cpu':
72+
device = torch.device('cpu')
73+
elif device == 'gpu':
74+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
75+
return torch.zeros(size=shape, dtype=dtype, device = device)
7176

72-
return torch.zeros(size=shape, dtype=dtype)
7377

74-
75-
def ones(shape, dtype=None):
78+
def ones(shape, dtype=None, device = None):
7679
"""
7780
Creates a tensor with all elements set to ones.
7881
@@ -88,11 +91,14 @@ def ones(shape, dtype=None):
8891
A Tensor with all elements set to zero.
8992
9093
"""
91-
92-
return torch.ones(size=shape, dtype=dtype)
94+
if device == 'cpu':
95+
device = torch.device('cpu')
96+
elif device == 'gpu':
97+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
98+
return torch.ones(size=shape, dtype=dtype, device = device)
9399

94100

95-
def constant(value, dtype=None, shape=None):
101+
def constant(value, dtype=None, shape=None, device =None):
96102
"""
97103
Creates a constant tensor from a tensor-like object.
98104
@@ -110,8 +116,11 @@ def constant(value, dtype=None, shape=None):
110116
A Constant Tensor.
111117
112118
"""
113-
114-
w = torch.empty(size=shape, dtype=dtype)
119+
if device == 'cpu':
120+
device = torch.device('cpu')
121+
elif device == 'gpu':
122+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
123+
w = torch.empty(size=shape, dtype=dtype, device = device)
115124
return torch.nn.init.constant_(w, value)
116125

117126

@@ -426,7 +435,7 @@ def concat(values, axis=0):
426435
return torch.cat(values, axis)
427436

428437

429-
def convert_to_tensor(value, dtype=None):
438+
def convert_to_tensor(value, dtype=None, device = None):
430439
"""
431440
Converts the given value to a Tensor.
432441
@@ -443,14 +452,18 @@ def convert_to_tensor(value, dtype=None):
443452
"""
444453
if isinstance(dtype, str):
445454
dtype = _dtypeDict[dtype]
446-
return torch.tensor(value, dtype=dtype)
455+
if device == 'cpu':
456+
device = torch.device('cpu')
457+
elif device == 'gpu':
458+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
459+
return torch.tensor(value, dtype=dtype, device = device)
447460

448461

449462
def convert_to_numpy(value):
450463
try:
451464
return value.numpy()
452465
except:
453-
return value.detach().numpy()
466+
return value.cpu().detach().numpy()
454467

455468

456469
def sqrt(x):
@@ -1505,7 +1518,7 @@ def tanh(x):
15051518
A Tensor. Has the same type as x.
15061519
"""
15071520

1508-
return F.tanh(x)
1521+
return torch.tanh(x)
15091522

15101523

15111524
def any(x, axis=None, keepdims=False):

tensorlayerx/dataflow/dataloader.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,11 @@ def __init__(
6161
batch_sampler=None,
6262
num_workers=0,
6363
collate_fn=None,
64-
# pin_memory = False,
6564
time_out=0,
6665
worker_init_fn=None,
67-
#multiprocessing_context=None,
6866
prefetch_factor=2,
6967
persistent_workers=False,
7068
):
71-
# assert isinstance(dataset, Dataset), "dataset should be subclass of tensorlayerx.dataflow.Dataset"
7269
self.dataset = dataset
7370
assert num_workers >= 0, "num_workers should be a non_negative integer"
7471
if num_workers == 0 and prefetch_factor != 2:
@@ -77,10 +74,8 @@ def __init__(
7774
raise ValueError('persistent_workers option needs num_workers > 0')
7875
self.num_workers = num_workers
7976
self.prefetch_factor = prefetch_factor
80-
# self.pin_memory = pin_memory
8177
self.time_out = time_out
8278
self.worker_init_fn = worker_init_fn
83-
#self.multiprocessing_context = multiprocessing_context
8479
if isinstance(dataset, IterableDataset):
8580
self._dataset_kind = _DatasetKind.Iter
8681
if shuffle is not False:

tensorlayerx/metrics/torch_metric.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def update(self, y_pred, y_true):
5151
y_true = torch.argmax(y_true, dim=-1, keepdim=True)
5252
correct = y_pred == y_true
5353
correct = correct.to(torch.float32)
54-
correct = correct.numpy()
54+
correct = correct.cpu().numpy()
5555
num_samples = np.prod(np.array(correct.shape[:-1]))
5656
num_corrects = correct[..., :self.topk].sum()
5757
self.total = num_corrects
@@ -78,12 +78,12 @@ def __init__(
7878

7979
def update(self, y_pred, y_true):
8080
if isinstance(y_true, torch.Tensor):
81-
y_true = y_true.numpy()
81+
y_true = y_true.cpu().numpy()
8282
elif not isinstance(y_pred, np.ndarray):
8383
raise TypeError("The y_true must be a numpy array or Tensor.")
8484

8585
if isinstance(y_pred, torch.Tensor):
86-
y_pred = y_pred.numpy()
86+
y_pred = y_pred.cpu().numpy()
8787
elif not isinstance(y_pred, np.ndarray):
8888
raise TypeError("The y_pred must be a numpy array or Tensor.")
8989

@@ -131,12 +131,12 @@ def __init__(self):
131131

132132
def update(self, y_pred, y_true):
133133
if isinstance(y_true, torch.Tensor):
134-
y_true = y_true.numpy()
134+
y_true = y_true.cpu().numpy()
135135
elif not isinstance(y_pred, np.ndarray):
136136
raise TypeError("The y_true must be a numpy array or Tensor.")
137137

138138
if isinstance(y_pred, torch.Tensor):
139-
y_pred = y_pred.numpy()
139+
y_pred = y_pred.cpu().numpy()
140140
elif not isinstance(y_pred, np.ndarray):
141141
raise TypeError("The y_pred must be a numpy array or Tensor.")
142142

@@ -169,12 +169,12 @@ def __init__(self):
169169

170170
def update(self, y_pred, y_true):
171171
if isinstance(y_true, torch.Tensor):
172-
y_true = y_true.numpy()
172+
y_true = y_true.cpu().numpy()
173173
elif not isinstance(y_pred, np.ndarray):
174174
raise TypeError("The y_true must be a numpy array or Tensor.")
175175

176176
if isinstance(y_pred, torch.Tensor):
177-
y_pred = y_pred.numpy()
177+
y_pred = y_pred.cpu().numpy()
178178
elif not isinstance(y_pred, np.ndarray):
179179
raise TypeError("The y_pred must be a numpy array or Tensor.")
180180

@@ -209,7 +209,7 @@ def acc(predicts, labels, topk=1):
209209
y_true = torch.argmax(labels, dim=-1, keepdim=True)
210210
correct = y_pred == y_true
211211
correct = correct.to(torch.float32)
212-
correct = correct.numpy()
212+
correct = correct.cpu().numpy()
213213
num_samples = np.prod(np.array(correct.shape[:-1]))
214214
num_corrects = correct[..., :topk].sum()
215215
total = num_corrects

tensorlayerx/model/core.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,10 @@ class Model:
5959
>>> class Net(Module):
6060
>>> def __init__(self):
6161
>>> super(Net, self).__init__()
62-
>>> self.conv = tlx.nn.Conv2d(n_filter=32, filter_size=(3, 3), strides=(2, 2), in_channels=5, name='conv2d')
62+
>>> self.conv = tlx.nn.Conv2d(out_channels=32, kernel_size=(3, 3), stride=(2, 2), in_channels=5, name='conv2d')
6363
>>> self.bn = tlx.nn.BatchNorm2d(num_features=32, act=tlx.ReLU)
6464
>>> self.flatten = tlx.nn.Flatten()
65-
>>> self.fc = tlx.nn.Dense(n_units=12, in_channels=32*224*224) # padding=0
65+
>>> self.fc = tlx.nn.Linear(out_features=12, in_features=32*224*224) # padding=0
6666
>>>
6767
>>> def construct(self, x):
6868
>>> x = self.conv(x)
@@ -434,8 +434,11 @@ def th_train(
434434

435435
train_loss, train_acc, n_iter = 0, 0, 0
436436
for X_batch, y_batch in train_dataset:
437+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
437438
network.set_train()
438-
439+
X_batch = X_batch.to(device)
440+
y_batch = y_batch.to(device)
441+
network.to(device)
439442
output = network(X_batch)
440443
loss = loss_fn(output, y_batch)
441444

0 commit comments

Comments
 (0)