Pytoch損失関数nn.NLLLoss 2 d()用法説明


最近は顕著な星の検出にNLL損失関数を使用しました。
NLL関数には、自分でロゴとソフトポイントの確率値を計算してから入力することができます。
入力[batch_]size,chanel,h,w]

ターゲットsize,h,w)
入力されたターゲット行列は、各ピクセルがタイプでなければなりません。例を挙げます。最初のピクセルは0で、カテゴリが入力の第1チャンネルに属することを表しています。2番目の画素は0であり、カテゴリが入力の0番目のチャネルに属することを示しており、これに類推する。

x = Variable(torch.Tensor([[[1, 2, 1],
       [2, 2, 1],
       [0, 1, 1]],
       [[0, 1, 3],
       [2, 3, 1],
       [0, 0, 1]]]))

x = x.view([1, 2, 3, 3])
print("x  ", x)
ここでxを入力して、[batch_]に変更します。size、chanel、h、w」のフォーマットです。
ソフト=nn.Softmax(dim=1)
ロゴソフト=nn.LogSoftmax(dim=1)
そして、ソフトポイント関数を使って各カテゴリの確率を計算します。ここでdim=1は1次元です。
に計算します。つまりチャンネル次元です。log softmaxは、Softmaxを計算した後、log値を計算しています。

手計算で栗を挙げます。一番目の元素です。

y = Variable(torch.LongTensor([[1, 0, 1],
       [0, 0, 1],
       [1, 1, 1]]))

y = y.view([1, 3, 3])
label yを入力して、[batch_]に変更します。size,h,w]フォーマット

loss = nn.NLLLoss2d()
out = loss(x, y)
print(out)
関数を入力して、loss=0.7947を得ます。
を選択します
最初のlabel=1なら、loss=-1.133
第二のlabel=0なら、loss=-0.333

.
…
…
loss= -(-1.3133-0.3133-0.1269-0.6931-1.3133-0.6931-0.6931-1.3133-0.6931)/9 =0.7947222222222223
同じです
注意:この関数は各ピクセルに対して平均します。各batchも平均します。ここには9つの画素があります。size
補足知識:PyTorch:NLLLoss 2 d
余計なことを言わないで、コードを見てください。

import torch
import torch.nn as nn
from torch import autograd
import torch.nn.functional as F
 
inputs_tensor = torch.FloatTensor([
[[2, 4],
 [1, 2]],
[[5, 3],
 [3, 0]],
[[5, 3],
 [5, 2]],
[[4, 2],
 [3, 2]],
 ])
inputs_tensor = torch.unsqueeze(inputs_tensor,0)
# inputs_tensor = torch.unsqueeze(inputs_tensor,1)
print '--input size(nBatch x nClasses x height x width): ', inputs_tensor.shape
 
targets_tensor = torch.LongTensor([
 [0, 2],
 [2, 3]
])
 
targets_tensor = torch.unsqueeze(targets_tensor,0)
print '--target size(nBatch x height x width): ', targets_tensor.shape
 
inputs_variable = autograd.Variable(inputs_tensor, requires_grad=True)
inputs_variable = F.log_softmax(inputs_variable)
targets_variable = autograd.Variable(targets_tensor)
 
loss = nn.NLLLoss2d()
output = loss(inputs_variable, targets_variable)
print '--NLLLoss2d: {}'.format(output)
以上のPytouch損失関数nn.NLLLoss 2 d()の使い方説明は小編が皆さんに共有した内容の全部です。参考にしてもらいたいです。どうぞよろしくお願いします。