PyTorchのSoftMax交差エントロピー損失と勾配法


PyTorchではソフトMaxの交差エントロピー損失と入力勾配の計算を簡単に検証できます。
ソフトマックスについてクロスギentropyガイドの過程は、HEREを参照することができます。
例:

# -*- coding: utf-8 -*-
import torch
import torch.autograd as autograd
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn
import numpy as np

#  data   ,       
data = Variable(torch.FloatTensor([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]), requires_grad=True)

#       one-hot  
label = Variable(torch.zeros((3, 3)))
label[0, 2] = 1
label[1, 1] = 1
label[2, 0] = 1
print(label)

# for batch loss = mean( -sum(Pj*logSj) )
# for one : loss = -sum(Pj*logSj)
loss = torch.mean(-torch.sum(label * torch.log(F.softmax(data, dim=1)), dim=1))

loss.backward()
print(loss, data.grad)
出力:

tensor([[ 0., 0., 1.],
    [ 0., 1., 0.],
    [ 1., 0., 0.]])
# loss:     input's grad:     
tensor(1.4076) tensor([[ 0.0300, 0.0816, -0.1116],
    [ 0.0300, -0.2518, 0.2217],
    [-0.3033, 0.0816, 0.2217]])
注意:
単一入力のためのlossとgrad

data = Variable(torch.FloatTensor([[1.0, 2.0, 3.0]]), requires_grad=True)


label = Variable(torch.zeros((1, 3)))
#         label 1
label[0, 0] = 1
# label[0, 1] = 1
# label[0, 2] = 1
print(label)

# for batch loss = mean( -sum(Pj*logSj) )
# for one : loss = -sum(Pj*logSj)
loss = torch.mean(-torch.sum(label * torch.log(F.softmax(data, dim=1)), dim=1))

loss.backward()
print(loss, data.grad)
その出力:

#    :
lable: tensor([[ 1., 0., 0.]])
loss: tensor(2.4076) 
grad: tensor([[-0.9100, 0.2447, 0.6652]])

#    :
lable: tensor([[ 0., 1., 0.]])
loss: tensor(1.4076) 
grad: tensor([[ 0.0900, -0.7553, 0.6652]])

#    :
lable: tensor([[ 0., 0., 1.]])
loss: tensor(0.4076) 
grad: tensor([[ 0.0900, 0.2447, -0.3348]])

"""
  :
       tensor([[ 1., 2., 3.]]) softmax       
tensor([[ 0.0900, 0.2447, 0.6652]])
              s[0, 0]-1, s[0, 1]-1, s[0, 2]-1      label         
"""
pytochが提供するソフトxとロゴ_ソフトマックス関係

#      softmax  
In[2]: import torch
 ...: import torch.autograd as autograd
 ...: from torch.autograd import Variable
 ...: import torch.nn.functional as F
 ...: import torch.nn as nn
 ...: import numpy as np
In[3]: data = Variable(torch.FloatTensor([[1.0, 2.0, 3.0]]), requires_grad=True)
In[4]: data
Out[4]: tensor([[ 1., 2., 3.]])
In[5]: e = torch.exp(data)
In[6]: e
Out[6]: tensor([[ 2.7183,  7.3891, 20.0855]])
In[7]: s = torch.sum(e, dim=1)
In[8]: s
Out[8]: tensor([ 30.1929])
In[9]: softmax = e/s
In[10]: softmax
Out[10]: tensor([[ 0.0900, 0.2447, 0.6652]])
In[11]: #     pytorch     softmax 
In[12]: org_softmax = F.softmax(data, dim=1)
In[13]: org_softmax
Out[13]: tensor([[ 0.0900, 0.2447, 0.6652]])
In[14]: org_softmax == softmax #       
Out[14]: tensor([[ 1, 1, 1]], dtype=torch.uint8)

#  log_softmax  
# log_softmax = log(softmax)
In[15]: _log_softmax = torch.log(org_softmax) 
In[16]: _log_softmax
Out[16]: tensor([[-2.4076, -1.4076, -0.4076]])
In[17]: log_softmax = F.log_softmax(data, dim=1)
In[18]: log_softmax
Out[18]: tensor([[-2.4076, -1.4076, -0.4076]])
公式に提供されたソフトマックスクロスエントロピーソリューションの結果

# -*- coding: utf-8 -*-
import torch
import torch.autograd as autograd
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn
import numpy as np

data = Variable(torch.FloatTensor([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]), requires_grad=True)
log_softmax = F.log_softmax(data, dim=1)

label = Variable(torch.zeros((3, 3)))
label[0, 2] = 1
label[1, 1] = 1
label[2, 0] = 1
print("lable: ", label)

#           
loss_fn = torch.nn.NLLLoss() # reduce=True loss.sum/batch & grad/batch
# NLLLoss   log_softmax, target  one-hot   label
loss = loss_fn(log_softmax, torch.argmax(label, dim=1))
loss.backward()
print("loss: ", loss, "
grad: ", data.grad) """ # loss_fn = torch.nn.CrossEntropyLoss() # the target label is NOT an one-hotted #CrossEntropyLoss #input: softmax nn.output, target one-hot label loss = loss_fn(data, torch.argmax(label, dim=1)) loss.backward() print("loss: ", loss, "
grad: ", data.grad) """ """
出力

lable: tensor([[ 0., 0., 1.],
    [ 0., 1., 0.],
    [ 1., 0., 0.]])
loss: tensor(1.4076) 
grad: tensor([[ 0.0300, 0.0816, -0.1116],
    [ 0.0300, -0.2518, 0.2217],
    [-0.3033, 0.0816, 0.2217]])
例の出力と比較すると、両者は同じであることがわかった。
以上のPyTorchのSoftMax交差エントロピー損失と勾配の使い方は小編が皆さんに共有している内容です。参考にしていただければと思います。よろしくお願いします。