diff --git a/semisupervised/codes/train.py b/semisupervised/codes/train.py index 5bd39bc..79b05a0 100755 --- a/semisupervised/codes/train.py +++ b/semisupervised/codes/train.py @@ -186,7 +186,8 @@ def train(epoches): trainer.optimizer.zero_grad() k = 10 temp = torch.zeros([k, target_q.shape[0], target_q.shape[1]], dtype=target_q.dtype) - temp = temp.cuda() + if opt['cuda']: + temp = temp.cuda() for i in range(k): temp[i,:,:] = trainer.predict_noisy(inputs_q) target_predict = temp.mean(dim = 0)