pytorchチュートリアルbatch-normalizationエラーを解決RuntimeError:Expected object of type Variable[torch.FloatTensor]

2692 ワード

廖星雲pytorchチュートリアルbatch-normalization編を学習中に以下のエラーが発生しました.
File "C:/Users/demons/Desktop/trainingtorch/batch_normalization.py", line 25, in batch_norm_1d
    moving_mean[:] = moving_momentum * moving_mean + (1. - moving_momentum) * x_mean

RuntimeError: Expected object of type Variable[torch.FloatTensor] but found type Variable[torch.cuda.FloatTensor] for argument #1 'other'

解決策は2つあります.
(1)utilsを注釈する.pyファイルのcuda()関連部分を調整し、対応するフォーマットを調整します.
#   train        ,         
def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
if torch.cuda.is_available():
        net = net.cuda()
for im, label in train_data:
            if torch.cuda.is_available():
                im = Variable(im.cuda())  # (bs, 3, h, w)
                label = Variable(label.cuda())  # (bs, h, w)
            else:
if torch.cuda.is_available():
                    im = Variable(im.cuda(), volatile=True)
                    label = Variable(label.cuda(), volatile=True)
                else:

(2)Class multi_Network関数は次のように変更されます.
x = batch_norm_1d(x.cpu(), self.gamma.cpu(), self.beta.cpu(), is_train, self.moving_mean.cpu(), self.moving_var.cpu()).cuda()

実行に成功しました:
runfile('C:/Users/demons/Desktop/trainingtorch/batch_normalization.py', wdir='C:/Users/demons/Desktop/trainingtorch')
Reloaded modules: utils
Epoch 0. Train Loss: 0.302185, Train Acc: 0.912797, Valid Loss: 0.186934, Valid Acc: 0.947191, Time 00:00:05
Epoch 1. Train Loss: 0.169958, Train Acc: 0.951359, Valid Loss: 0.133628, Valid Acc: 0.962520, Time 00:00:05
Epoch 2. Train Loss: 0.129881, Train Acc: 0.962803, Valid Loss: 0.117917, Valid Acc: 0.965487, Time 00:00:05
Epoch 3. Train Loss: 0.106306, Train Acc: 0.969150, Valid Loss: 0.106132, Valid Acc: 0.968354, Time 00:00:05
Epoch 4. Train Loss: 0.090785, Train Acc: 0.973764, Valid Loss: 0.101401, Valid Acc: 0.971025, Time 00:00:05
Epoch 5. Train Loss: 0.081850, Train Acc: 0.975746, Valid Loss: 0.093533, Valid Acc: 0.971618, Time 00:00:05
Epoch 6. Train Loss: 0.072291, Train Acc: 0.978995, Valid Loss: 0.092226, Valid Acc: 0.972112, Time 00:00:05
Epoch 7. Train Loss: 0.065007, Train Acc: 0.980844, Valid Loss: 0.090979, Valid Acc: 0.972310, Time 00:00:05
Epoch 8. Train Loss: 0.059790, Train Acc: 0.981726, Valid Loss: 0.090877, Valid Acc: 0.973299, Time 00:00:05
Epoch 9. Train Loss: 0.054136, Train Acc: 0.984325, Valid Loss: 0.089308, Valid Acc: 0.974288, Time 00:00:05
   :
57.95331573486328
Variable containing:
-1.8784
 4.0507
 0.2430
 0.1976
-0.3430
-2.2162
 0.8868
-1.9118
-1.3165
 0.9459
[torch.FloatTensor of size 10]