Skip to content

Commit 8fb61d2

Browse files
committed
fix device
1 parent 1a3d9a2 commit 8fb61d2

File tree

4 files changed

+17
-17
lines changed

4 files changed

+17
-17
lines changed

tensorlayerx/backend/ops/torch_nn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2182,7 +2182,7 @@ def __init__(self, data_format):
21822182
self.data_format = data_format
21832183

21842184
def __call__(self, input, weight):
2185-
weight = weight.to(input.device)
2185+
# weight = weight.to(input.device)
21862186
return torch.prelu(input, weight)
21872187

21882188

tensorlayerx/dataflow/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,22 +78,22 @@ def default_collate_torch(batch):
7878
data = batch[0]
7979
data_type = type(data)
8080
import torch
81-
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
81+
# device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
8282
if isinstance(data, torch.Tensor):
8383
batch = torch.stack(batch, 0)
84-
batch = batch.to(device)
84+
# batch = batch.to(device)
8585
return batch
8686
elif isinstance(data, np.ndarray):
8787
batch = np.stack(batch, axis=0)
8888
batch = torch.as_tensor(batch)
89-
batch = batch.to(device)
89+
# batch = batch.to(device)
9090
return batch
9191
elif isinstance(data, numbers.Number):
9292
batch = torch.as_tensor(batch)
93-
batch = batch.to(device)
93+
# batch = batch.to(device)
9494
return batch
9595
elif isinstance(data, (str, bytes)):
96-
batch = batch.to(device)
96+
# batch = batch.to(device)
9797
return batch
9898
elif isinstance(data, collections.abc.Mapping):
9999
return {key: default_collate_torch([d[key] for d in batch]) for key in data}

tensorlayerx/model/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -429,8 +429,8 @@ def th_train(
429429
self, n_epoch, train_dataset, network, loss_fn, train_weights, optimizer, metrics, print_train_batch,
430430
print_freq, test_dataset
431431
):
432-
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
433-
network = network.to(device)
432+
# device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
433+
# network = network.to(device)
434434
for epoch in range(n_epoch):
435435
start_time = time.time()
436436

tensorlayerx/model/utils.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -177,20 +177,20 @@ def __call__(self, data, label):
177177
class TrainOneStepWithTH(object):
178178

179179
def __init__(self, net_with_loss, optimizer, train_weights):
180-
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
180+
# self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
181181
self.net_with_loss = net_with_loss
182-
self.net_with_loss = self.net_with_loss.to(self.device)
182+
# self.net_with_loss = self.net_with_loss.to(self.device)
183183
self.optimizer = optimizer
184184
self.train_weights = train_weights
185185

186186
def __call__(self, data, label):
187-
if isinstance(data, dict):
188-
for k, v in data.items():
189-
if isinstance(v, torch.Tensor):
190-
data[k] = v.to(self.device)
191-
elif isinstance(data, torch.Tensor):
192-
data = data.to(self.device)
193-
label = label.to(self.device)
187+
# if isinstance(data, dict):
188+
# for k, v in data.items():
189+
# if isinstance(v, torch.Tensor):
190+
# data[k] = v.to(self.device)
191+
# elif isinstance(data, torch.Tensor):
192+
# data = data.to(self.device)
193+
# label = label.to(self.device)
194194
loss = self.net_with_loss(data, label)
195195
grads = self.optimizer.gradient(loss, self.train_weights)
196196
self.optimizer.apply_gradients(zip(grads, self.train_weights))

0 commit comments

Comments
 (0)