focal lossのいくつかの実装バージョン(Keras/Tensorflow)


仕事の中でfocal lossを使って出会った1つのバグから起源して、私は注意深く複数の頼りになるfocal lossの説明と実現バージョンを勉強します
テストを通じて、私はこのような奇妙な現象を発見しました.ほとんどのバージョンのfocal lossが同じ入力に対して計算したlossを実現するのは違います.
綿密な比較と思考を通じて、正しいfocal lossの実現方法を3つまとめ、コードを共有しました.
完全なコードは私のgithubコードライブラリAI-toolboxに整理され、コードスタンプはここにあります.
focal lossとは
focal lossはネットワークRetinaNetとともに提案された驚くべき損失関数paperダウンロードであり,主に正負のサンプル割合の深刻な傾きによるモデルの訓練が困難な問題を解決することを目的としている.
ここでfocal lossについてよく知っていると仮定して、簡単に式を振り返って、focal lossの定義は以下の通りです:式の中でγ {\gamma} γおよびα {\alpha} α2つの調整可能なスーパーパラメータです.
γ {\gamma} γの意味は、モデルがよりよく予測できるサンプルの損失の重みを弱め、モデルをより難しいhard caseの学習に集中させる役割を果たすことをよりよく理解しています.
そしてα t {\alpha}_t αtの定義、原文の中の表現は:
For notational convenience, we define αt analogously to how we defined pt
つまり、α t {\alpha}_t αtの定義はp t p_と同様であるt ptの定義.その役割は、カテゴリ間の重みをバランスさせることです.
ここで補足すると、ネット上で見つけられる様々な異なるバージョンのfocal lossが実現され、分岐は基本的にここに現れています.focal lossは、最初は目標検出に伴ってある領域が物体or背景(二分類問題)であると判断するため、focal lossを用いてより一般化された問題(例えば、多分類問題、多ラベル予測問題)を解決する場合、α t {\alpha}_t αtがどのように定義されるかは分岐し,どちらが絶対的に正統であるかは言い難い.異なる定義は損失関数の異なる機能を与え,異なる問題に対応できるからである.
私がまとめた3つの実装バージョンを見てみましょう.
focal loss for binary classification
二分類バージョン向けfocal loss実装
def binary_focal_loss(gamma=2, alpha=0.25):
    """
    Binary form of focal loss.
             focal loss
    
    focal_loss(p_t) = -alpha_t * (1 - p_t)**gamma * log(p_t)
        where p = sigmoid(x), p_t = p or 1 - p depending on if the label is 1 or 0, respectively.
    References:
        https://arxiv.org/pdf/1708.02002.pdf
    Usage:
     model.compile(loss=[binary_focal_loss(alpha=.25, gamma=2)], metrics=["accuracy"], optimizer=adam)
    """
    alpha = tf.constant(alpha, dtype=tf.float32)
    gamma = tf.constant(gamma, dtype=tf.float32)

    def binary_focal_loss_fixed(y_true, y_pred):
        """
        y_true shape need be (None,1)
        y_pred need be compute after sigmoid
        """
        y_true = tf.cast(y_true, tf.float32)
        alpha_t = y_true*alpha + (K.ones_like(y_true)-y_true)*(1-alpha)
    
        p_t = y_true*y_pred + (K.ones_like(y_true)-y_true)*(K.ones_like(y_true)-y_pred) + K.epsilon()
        focal_loss = - alpha_t * K.pow((K.ones_like(y_true)-p_t),gamma) * K.log(p_t)
        return K.mean(focal_loss)
    return binary_focal_loss_fixed

本損失関数を使用する前に、各サンプルをsigmoidを使用して0-1の数にマッピングし、二分類の確率を表すと仮定します.
この関数をkerasで損失関数として使用するには、モデルのコンパイル時にfocal lossとして損失関数を指定するだけです.
model.compile(loss=[binary_focal_loss(alpha=.25, gamma=2)], metrics=["accuracy"], optimizer=optimizer)

focal loss for multi categoryバージョン1
多分類問題または多ラベル問題に対するfocal loss実現1.
前述したように、ネット上の異なる実装バージョンではα t {\alpha}_t αtの定義には一定の相違がある
私たちがα t {\alpha}_t αtが異なるカテゴリ/ラベルの重みを制御する場合、実装コードは以下の通りである.
def multi_category_focal_loss1(alpha, gamma=2.0):
    """
    focal loss for multi category of multi label problem
                 focal loss
    alpha        /     ,             
              /        ,           loss
    Usage:
     model.compile(loss=[multi_category_focal_loss1(alpha=[1,2,3,2], gamma=2)], metrics=["accuracy"], optimizer=adam)
    """
    epsilon = 1.e-7
    alpha = tf.constant(alpha, dtype=tf.float32)
    #alpha = tf.constant([[1],[1],[1],[1],[1]], dtype=tf.float32)
    #alpha = tf.constant_initializer(alpha)
    gamma = float(gamma)
    def multi_category_focal_loss1_fixed(y_true, y_pred):
        y_true = tf.cast(y_true, tf.float32)
        y_pred = tf.clip_by_value(y_pred, epsilon, 1. - epsilon)
        y_t = tf.multiply(y_true, y_pred) + tf.multiply(1-y_true, 1-y_pred)
        ce = -tf.log(y_t)
        weight = tf.pow(tf.subtract(1., y_t), gamma)
        fl = tf.matmul(tf.multiply(weight, ce), alpha)
        loss = tf.reduce_mean(fl)
        return loss
    return multi_category_focal_loss1_fixed

注意してください.α {\alpha} α配列として指定します.配列のサイズは、各カテゴリに対応する重みを表すカテゴリの個数と一致する必要があります.
データセットの異なるカテゴリ/ラベルの間に傾きがある場合は、lossとしてこの関数を適用してみてください.
コア関数copyを出て簡単なテストをして検証しますα {\alpha} αバランスカテゴリ間のウェイトの有効性.
import os
from keras import backend as K
import tensorflow as tf
import numpy as np

os.environ["CUDA_VISIBLE_DEVICES"] = '0'

def multi_category_focal_loss1(y_true, y_pred):
    epsilon = 1.e-7
    gamma = 2.0
    #alpha = tf.constant([[2],[1],[1],[1],[1]], dtype=tf.float32)
    alpha = tf.constant([[1],[1],[1],[1],[1]], dtype=tf.float32)

    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.clip_by_value(y_pred, epsilon, 1. - epsilon)
    y_t = tf.multiply(y_true, y_pred) + tf.multiply(1-y_true, 1-y_pred)
    ce = -tf.log(y_t)
    weight = tf.pow(tf.subtract(1., y_t), gamma)
    fl = tf.matmul(tf.multiply(weight, ce), alpha)
    loss = tf.reduce_mean(fl)
    return loss
Y_true = np.array([[1, 1, 1, 1, 1], [0, 0, 0, 0, 0]])
Y_pred = np.array([[0.3, 0.99, 0.8, 0.97, 0.85], [0.9, 0.05, 0.1, 0.09, 0]], dtype=np.float32)
print(K.eval(multi_category_focal_loss1(Y_true, Y_pred)))

5つの出力のマルチlabel予測問題を処理していると仮定し、上記の例に従って、私たちのモデルが最初のlabelに対して他のラベルの予測よりも悪いと仮定します(これは、最初のlabelが現れる確率が小さく、損失を計算するときに発言権がないためかもしれません).
上のコードの演算結果は1.2347984です
我々はα {\alpha} α最初のlabelの重みを調整して、α {\alpha} α次のように変更します.
alpha = tf.constant([[2],[1],[1],[1],[1]], dtype=tf.float32)

再実行すると、損失は2.462384に増大し、損失関数が最初のカテゴリの重みを拡大することに成功したことを示し、モデルは最初のlabelの正確な予測をより重視します.
focal loss for multi categoryバージョン2
多分類問題または多ラベル問題に対するfocal loss実現2.
私たちがα t {\alpha}_t αtは真値yを制御するtrueが1 or 0の場合の重み
すなわちy=1の場合の重みはα {\alpha} α, y=0の場合の重みは1−であるα 1-{\alpha} 1−α
実装コードは次のとおりです.
def multi_category_focal_loss2(gamma=2., alpha=.25):
    """
    focal loss for multi category of multi label problem
                 focal loss
    alpha    y_true 1/0    
        1    alpha, 0    1-alpha
            ,       ,           loss
           (            1),   alpha  
           (            0,          ,          )
           alpha  ,         1。
    Usage:
     model.compile(loss=[multi_category_focal_loss2(alpha=0.25, gamma=2)], metrics=["accuracy"], optimizer=adam)
    """
    epsilon = 1.e-7
    gamma = float(gamma)
    alpha = tf.constant(alpha, dtype=tf.float32)

    def multi_category_focal_loss2_fixed(y_true, y_pred):
        y_true = tf.cast(y_true, tf.float32)
        y_pred = tf.clip_by_value(y_pred, epsilon, 1. - epsilon)
    
        alpha_t = y_true*alpha + (tf.ones_like(y_true)-y_true)*(1-alpha)
        y_t = tf.multiply(y_true, y_pred) + tf.multiply(1-y_true, 1-y_pred)
        ce = -tf.log(y_t)
        weight = tf.pow(tf.subtract(1., y_t), gamma)
        fl = tf.multiply(tf.multiply(weight, ce), alpha_t)
        loss = tf.reduce_mean(fl)
        return loss
    return multi_category_focal_loss2_fixed

注意してください.α {\alpha} α配列として指定します.配列のサイズは、各カテゴリに対応する重みを表すカテゴリの個数と一致する必要があります.
モデルがフィットしていない場合、学習に困難がある場合は、lossとしてこの関数を適用してみてください.
モデルが過激すぎる場合(常に1を予測する傾向にある)、alphaを小さくしようとする
モデルが「怠け者」すぎる場合(常に0を予測する傾向があるか、一定の定数が有効な特徴を学んでいないことを示す)、alphaを大きくしてモデルが1を予測することを奨励しようとします.
同様に,コア関数copyを出て簡単なテストを行い,検証した.α {\alpha} αバランス0-1の重みの有効性.
import os
from keras import backend as K
import tensorflow as tf
import numpy as np

os.environ["CUDA_VISIBLE_DEVICES"] = '0'

def multi_category_focal_loss2_fixed(y_true, y_pred):
    epsilon = 1.e-7
    gamma=2.
    alpha = tf.constant(0.5, dtype=tf.float32)

    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.clip_by_value(y_pred, epsilon, 1. - epsilon)

    alpha_t = y_true*alpha + (tf.ones_like(y_true)-y_true)*(1-alpha)
    y_t = tf.multiply(y_true, y_pred) + tf.multiply(1-y_true, 1-y_pred)
    ce = -tf.log(y_t)
    weight = tf.pow(tf.subtract(1., y_t), gamma)
    fl = tf.multiply(tf.multiply(weight, ce), alpha_t)
    loss = tf.reduce_mean(fl)
    return loss
Y_true = np.array([[1, 1, 1, 1, 1], [0, 1, 1, 1, 1]])
Y_pred = np.array([[0.9, 0.99, 0.8, 0.97, 0.85], [0.9, 0.95, 0.91, 0.99, 1]], dtype=np.float32)
print(K.eval(multi_category_focal_loss2_fixed(Y_true, Y_pred)))

5つの出力のマルチlabel予測問題を処理していると仮定します
上記の例によれば、今回私たちが直面した問題は、すべてのラベルが高い確率で1が現れると仮定すると、私たちのモデルは投機的に巧みな方法を発見し、各結果を1と予測すると、小さなlossが得られ、モデルは深刻な不適合を得ることができる.
上記のコードの演算結果は0.093982555であり,我々が予想したように損失は大きくなく,モデルの成功収束に影響を及ぼすことは明らかである.
我々はα {\alpha} αモデル出力1の重みを抑え、α {\alpha} α次のように変更します.
alpha = tf.constant(0.25, dtype=tf.float32)

再動作し,損失は0.14024596に増大し,損失関数がこの投機挙動の損失を増幅することに成功したことを示した.
参考文献
focal loss paper KerasカスタムLoss関数Kerasカスタム複雑なloss関数github:focal-loss-kerasインプリメンテーション1 github:focal-loss-kerasインプリメンテーション2 kaggle kernel:FocalLoss for Keras Focal Loss理解アプリケーション:Multi-class classification with focal loss for imbalanced datasets