Tensorflow 2.* ネットワークトレーニング(一)compile(optimizer,loss,metrics,loss_weights)


ネットワークを構築した後,損失関数,誤差方向伝搬最適化アルゴリズムなど,モデル訓練の関数を構成する必要がある.Tensorflow 2.* のcompileコンパイル関数には、分類の問題など、この機能が統合されています.一般的な形式は次のとおりです.
model.compile(optimizer='rmsprop',
				loss='categorical_crossentropy',
				metrics=['accuracy'])

文書ディレクトリ
  • tf.keras.Model.compile()
  • optimizerオプティマイザ
  • loss損失関数
  • metricsモニタリング指標
  • lossとmetricsの関係
  • loss_weights
  • 重み辞書
  • 重みリスト
  • 参照
  • tf.keras.Model.compile()
    optimizerオプティマイザ
    訓練データと損失関数に基づいてネットワークを更新するメカニズムは,Adam,RMSprop,SGDなどがよく用いられる.
    loss損失関数
    ネットワークは、トレーニングデータのパフォーマンス、すなわち、ネットワークがどのように正しい方向に進むかを測定します.BinaryCrossentropy,CategoricalCrossentropy,KLDivergenceなど
    metricsモニタリング指標
    訓練とテストの過程で監視する必要がある指標.AUC、Accuracy、BinaryAccuracy、BinaryCrossentropy、CategoricalCrossentropy、KLDivergence、Precisionなどが一般的です
    lossとmetricsの関係
  • lossおよびmetricsは、訓練中のモデルの予測性能を評価するために使用される.
  • optimizerはloss値に基づいて逆誤差方向伝播を行い、更新ネットワーク重み値を計算する.
  • metricsはネットワークの訓練過程に参加せず、1つの監視指標として、モデルの予測を直感的に表示するのに便利で、選択範囲はlossより多い.
  • , loss, , , metrics

  • loss_weights
    複数出力のネットワークモデルの場合、各出力に対応するloss関数がトレーニング中に占める割合を定義する必要がある場合は、loss_を設定する必要があります.weights.例えば、次のような3つの出力のネットワークモデルでは、各出力が異なる損失関数に対応すると、ネットワークトレーニング中にネットワークが損失関数の数値が最大の出力に傾き、他の2つの出力の最適化効果が不足するという深刻な損傷関数間の不均衡を招く.したがって、各損失値が最終損失への寄与に異なるサイズの重要性(重み値)を割り当てる必要がある.
    age_prediction = layers.Dense(1, name='age')(x)
    income_prediction = layers.Dense(nb_classes,activation='softmax', name='income')(x)
    gender_prediction = layers.Dense(1, activation='sigmoid', name='gender')(x)
    model = Model(posts_input,[age_prediction, income_prediction, gender_prediction])
    model.compile(optimizer='rmsprop',
    				loss={'age': 'mse',
    					'income': 'categorical_crossentropy',
    					'gender': 'binary_crossentropy'})
    

    重み付け辞書
    出力レイヤの名前に従ってディクショナリ割当権値を作成
    model.compile(loss_weight={'age': 0.25,
    							'income'1.,
    							'gender':10.})
    

    ウェイト値リスト
    またはModel出力レイヤリスト順に重み値を割り当てる
    model.compile(loss_weight=[loss_weights=[0.25, 1., 10.])
    

    リファレンス
    公式ドキュメント