Pytouchにおけるtoch.nn.Softmaxのdimパラメータの使い方について説明します。


Pytouchにおけるtoch.nn.Softmaxのdimパラメータの使用意味
多次元tenssorに関しては、ソフトmaxのパラメータdimに対しては、常に迷っています。以下の例で説明します。

import torch.nn as nn
m = nn.Softmax(dim=0)
n = nn.Softmax(dim=1)
k = nn.Softmax(dim=2)
input = torch.randn(2, 2, 3)
print(input)
print(m(input))
print(n(input))
print(k(input))
出力:
input
tenssor([[[0.555、-0.6964、1.0446]
[0.634,1.6969,0.7158]
[[1.0092,0.421,-0.8928]
[0.0344,0.9723,0.4328]]
dim=0
tenssor([[0.860,0.956,0.8741]
[0.452,0.7180,0.703]
[[0.140,0.7044,0.0.0.259]
[0.548,0.820,0.4297]]
dim=0の場合は、0次元でsum=1となります。
[0][0]、[0]+[1][0]=0.860+0.6940=1
[0][0][1]+[0][1]=0.956+0.7044=1
……
dim=1
tenssor([[0.4782,0.0736,0.815]
[0.218,0.9264,0.4185]
[[0.7261,0.321,0.2009]
[0.739,0.6979,0.7901]]
dim=1の場合は、第1次元でsum=1となります。
[0][0][0]+[0][1][0]=0.4782+0.218=1
[0][0][1]+[0][1]=0.0736+0.9264=1
……
dim=2
tenssor([[0.381,0.1848,0.572]
[0.766,0.6915,0.919]
[[0.697,0.2888,0.0925]
[0.983,0.5665,0.953]]
dim=2の場合は、2次元でsum=1となります。
[0][0]+[0]、[0]+[0]、[0]、[0]=0.381+0.488+0.572=1.0001(四捨五入問題)
[0][1][0]+[0][1]+[0]、[1]、[2]=0.766+0.6915+0.919=1
……
223のテンソルを図で表します。
在这里插入图片描述
多分類問題toch.nn.Softmaxの使用
なぜこの問題について話しますか?私は仕事中に意味分割予測出力特徴図の個数が16で、いわゆる16分類問題に遭遇したからです。
各チャンネルの画素の値の大きさは、画素がチャンネルに属する種類の大きさを表しているので、一枚の図に異なる色で表示するためには、toch.nn.Softmaxの使用を学ばなければならない。
まず簡単な例を見て、(3、4、4)と出力すれば、4 x 4の特徴図が3枚あります。

import torch
img = torch.rand((3,4,4))
print(img)
出力:
tenssor([[0.0413,0.8728,0.8926,0.0693]
[0.4072,0.0302,0.9248,0.676]
[0.4699,0.9197,0.333,0.4809]
[0.377,0.7673,0.612,0.533]
[[0.4940,0.7996,0.513,0.8016]
[0.157,0.8323,0.9944,0.127]
[0.35,0.4343,0.8123,0.184]
[0.8246,0.6931,0.329,0.730]
[[0.0661,0.1905,0.4490,0.7484]
[0.4013,0.1488,0.145,0.8838]
[0.0083,0.0.559,0.0141,0.8998]
[0.8673,0.308,0.8808,0.0532]]
全部で三つの特徴図を見ることができます。それぞれの特徴図に対応する値が大きいほど、この特徴図に属する対応クラスの確率が大きいと説明しています。

import torch.nn as nn
sogtmax = nn.Softmax(dim=0)
img = sogtmax(img)
print(img)
出力:
tenssor([[0.780,0.4107,0.4251,0.979]
[0.368,0.297,0.3901,0.3747]
[0.4035,0.4396,0.293,0.967]
[0.2,402,0.4008,0.323,0.4285]
[0.4371,0.817,0.322,0.4117]
[0.726,0.52122,0.4182,0.20206]
[0.323,0.2706,0.4832,0.522]
[0.3778,0.348,0.449,0.3208]
[[0.849,0.2076,0.728,0.3904]
[0.367,0.581,0.917,0.4317]
[0.543,0.898,0.175,0.4511]
[0.344,0.4278,0.686]]
上記のコードは、特徴マップごとに対応する位置のピクセル値をSoftmax関数処理し、図中に赤い位置を加算し、=1を加算し、同じ理屈で青の位置を加算します。
Softmax関数は、元の特徴マップの各ピクセルの値を対応次元(ここでdim=0、つまり第一次元)で計算し、それを0~1の間に処理し、サイズを一定にしていることを見ました。

print(torch.max(img,0))
出力:
toch.return_types.max(
values=tenssor([0.4371,0.4107,0.4251,0.4117]
[0.3618,0.52122,0.4182,0.4317]
[0.4035,0.4396,0.4832,0.4511]
[0.880,0.4008,0.4278,0.4285])
indices=tenssor([1,0,0,1]
[0,1,1,2]
[0,0,1,2]
[2,0,2,0])
ここで3 x 4 x 4は1 x 4 x 4になり、対応する位置の値は画素が各チャネルの最大値に対応し、indicesは対応する分類であることがわかる。
上の流れがよく分かりましたので、簡単に処理します。
具体的なケースを見て、ここでアウトプットのサイズは16 x 416 x 416.

output = torch.tensor(output)
sm = nn.Softmax(dim=0)
output = sm(output)
mask = torch.max(output,0).indices.numpy()
 
#       RGB   ,      
rgb_img = np.zeros((output.shape[1], output.shape[2], 3))
for i in range(len(mask)):
    for j in range(len(mask[0])):
        if mask[i][j] == 0:
            rgb_img[i][j][0] = 255
            rgb_img[i][j][1] = 255
            rgb_img[i][j][2] = 255
        if mask[i][j] == 1:
            rgb_img[i][j][0] = 255
            rgb_img[i][j][1] = 180
            rgb_img[i][j][2] = 0
        if mask[i][j] == 2:
            rgb_img[i][j][0] = 255
            rgb_img[i][j][1] = 180
            rgb_img[i][j][2] = 180
        if mask[i][j] == 3:
            rgb_img[i][j][0] = 255
            rgb_img[i][j][1] = 180
            rgb_img[i][j][2] = 255
        if mask[i][j] == 4:
            rgb_img[i][j][0] = 255
            rgb_img[i][j][1] = 255
            rgb_img[i][j][2] = 180
        if mask[i][j] == 5:
            rgb_img[i][j][0] = 255
            rgb_img[i][j][1] = 255
            rgb_img[i][j][2] = 0
        if mask[i][j] == 6:
            rgb_img[i][j][0] = 255
            rgb_img[i][j][1] = 0
            rgb_img[i][j][2] = 180
        if mask[i][j] == 7:
            rgb_img[i][j][0] = 255
            rgb_img[i][j][1] = 0
            rgb_img[i][j][2] = 255
        if mask[i][j] == 8:
            rgb_img[i][j][0] = 255
            rgb_img[i][j][1] = 0
            rgb_img[i][j][2] = 0
        if mask[i][j] == 9:
            rgb_img[i][j][0] = 180
            rgb_img[i][j][1] = 0
            rgb_img[i][j][2] = 0
        if mask[i][j] == 10:
            rgb_img[i][j][0] = 180
            rgb_img[i][j][1] = 255
            rgb_img[i][j][2] = 255
        if mask[i][j] == 11:
            rgb_img[i][j][0] = 180
            rgb_img[i][j][1] = 0
            rgb_img[i][j][2] = 180
        if mask[i][j] == 12:
            rgb_img[i][j][0] = 180
            rgb_img[i][j][1] = 0
            rgb_img[i][j][2] = 255
        if mask[i][j] == 13:
            rgb_img[i][j][0] = 180
            rgb_img[i][j][1] = 255
            rgb_img[i][j][2] = 180
        if mask[i][j] == 14:
            rgb_img[i][j][0] = 0
            rgb_img[i][j][1] = 180
            rgb_img[i][j][2] = 255
        if mask[i][j] == 15:
            rgb_img[i][j][0] = 0
            rgb_img[i][j][1] = 0
            rgb_img[i][j][2] = 0
 
cv2.imwrite('output.jpg', rgb_img)
最後に保存された図は次の通りです。

以上は個人の経験ですので、参考にしていただければと思います。