DNNにおけるClassificationの種類まとめ


概要

"AI"とかいうやつの課員用教育ネタの続き
深層学習の出口レイヤを変えると簡単に問題を切り替えられることを自習させようかと…
備忘録も兼ねて残すこととした

  • 実施期間: 2020年12月
  • 環境:Keras

ここで説明する各モデルの出口の全結合レイヤについて記述する
それより前のレイヤ構成はケースバイケース

Number of Nodes : ノード数 [1]
Activation : 活性関数 [2]
Loss : ロスの算出方法 [3]
Accuracy metrics: Fit中の精度評価指標 [4]
Output: 出力の型

Kerasだと、それぞれ下記の引数がそれらとなる
Optimizerはadamである必要はない
Dense, compileはそれぞれオフィシャルサイト参照のこと

model.add(Dense([1], activation='[2]'))
model.compile(loss='[3]', optimizer='adam‘, metrics=['[4]'])

Linear Regression

連続値である目的変数yを予測するモデルで線形回帰問題と呼ばれる

Linear regression

基本形がこれ
モデルにXを入力すると予測値のy_hat(ここでは0.539)が出力される
KerasではDenseの引数にActivationを指定しなければデフォルトでLinearとなる

Multi-output regression

入力Xに対して複数の連続値目的変数yがあるモデルがこれ
出口の全結合レイヤのNode数はそのyの数(ここでは3つ)でTrainingさせておく

Classification

Logistic regressionの一種がClassification
離散値である目的変数yを予測するモデルで分類問題と呼ばれる
モデル出口とyの型以外はRegressionと変わらず、ビビるに及ばない
なお、Logistic regressionとClassificationの違いはココで議論されているので興味があれば参照
ここでは便宜上、出口レイヤがTrueやFalseを出力しているが、実際には連続値が出力される
それをユーザがif文などでTrue/Falseに読み替えなければならない

Binary classification

モデルにXを入力すると予測値のy_hat(ここでは1(True))が出力される
yが連続値から離散値になり、出口の全結合レイヤの引数がやや変わっただけ
あとは上述のRegressionと同じ
不良品か否かを当てるようなモデル

Multilabel binary classification

Multi-output regressionとほぼ同じ
実際は各Nodeから出力されるのは連続値で、これを0(False)か1(True)に丸める
ただ小生はHeuristicに、例えば0.3以上ならTrueとか閾値は柔軟にしてよいと思っている
Training時のy(例えば[1, 1, 0])を各Nodeで予測するので、それぞれの出力は 0 <y_hat< 1 となる

Multiclass classification

これだけ少し毛色が違う
Activationがsoftmaxとなっており、全Nodeからの出力中Trueは1つである
実際は各Nodeから出力されるのは確率値となり、従い全Nodeの出力合計は1.0になる
最大のものを1(True)、それ以外を0(False)とする
犬、猫、人の画像からそのどれかを当てるようなモデルに使う

まとめ

モデル # of nodes Activation Loss Accuracy metrics Output
Regression 1 Linear MSE, etc. accuracy 連続値
Multi-output regression 複数 Linear MSE, etc. accuracy 連続値
Binary classification 1 sigmoid binary_crossentropy binary_accuracy 離散値(True/False)
Multilabel binary classification 複数 sigmoid binary_crossentropy binary_accuracy 離散値(複数True/False)
Multiclass classification 複数 softmax categorical_crossentropy categorical_accuracy 離散値(単数True/False)

“Class”と”Label”の違いについて

下図の場合、Classは3つありLabelは[犬, 猫, 人]

以上