TensorFlow:モデルの組み立て、トレーニング、テスト

7577 ワード

ネットワークを訓練する際,一般的な流れは,順方向計算によりネットワークの出力値を求め,損失関数によりネットワーク誤差を計算した後,自動導出ツールにより勾配を計算し更新し,ネットワークの性能を間隔的に試験することである.このような一般的なトレーニングロジックは,Kerasが提供するモデルアセンブリとトレーニング高レベルインタフェースによって直接実現でき,簡潔で明確である.

文書ディレクトリ

  • 一、モデル組立
  • 二、モデルマッチング
  • 三、モデルテスト

  • 一、モデル組立


    Kerasには,keras.Modelkeras.layers.Layerの2つの比較的特殊なクラスがある.ここで、Layerクラスはネットワーク層の親クラスであり、重み値の追加、重み値リストの管理など、ネットワーク層の一般的な機能を定義します.Modelクラスはネットワークの親クラスで、Layerクラスの機能に加えて、モデルの保存、ロード、トレーニング、テストなどの便利な機能が追加されています.SequentialもModelのサブクラスであるため、Modelクラスのすべての機能を持つ.
    #  5  
    network = Sequential([layers.Dense(256, activation='relu'),
    layers.Dense(128, activation='relu'),
    layers.Dense(64, activation='relu'),
    layers.Dense(32, activation='relu'),
    layers.Dense(10)])
    network.build(input_shape=(None, 28*28))
    network.summary()
    

    ネットワークを作成した後、通常のプロセスは、ループ反復データセットを複数回繰り返し、バッチごとにトレーニングデータを生成し、順方向に計算し、損失関数を介して誤差値を計算し、自動的に勾配を計算し、ネットワークパラメータを更新することです.この部分の論理は非常に汎用的であるため、kerasにおいてcompile()およびfit()関数が提供され、上述の論理を容易に実現する.まずcompile関数でネットワークで使用するオプティマイザの対象、損失関数、評価指標などを指定します.
    #  , 
    from tensorflow.keras import optimizers,losses
    #  Adam  , 0.01; , Softmax
    network.compile(optimizer=optimizers.Adam(lr=0.01),
    loss=losses.CategoricalCrossentropy(from_logits=True),
    metrics=['accuracy'] #  
    )
    

    私たちがcompile()関数で指定したオプティマイザ、損失関数などのパラメータも、私たちが自分で訓練する際に使用するパラメータであり、特別な点はありませんが、kerasはこの部分の常用論理を実現し、開発効率を高めています.

    二、モデルマッチング


    モデルアセンブリが完了すると、fit()関数を使用して、トレーニング対象のデータと検証用のデータセットを送信できます.
    #  train_db, val_db, 5  epochs, 2  epoch  
    #  history  
    history = network.fit(train_db, epochs=5, validation_data=val_db,
    validation_freq=2)
    

    ここでtrain_dbはtf.data.Datasetオブジェクトであり、Numpy Arrayタイプのデータを転送することもできる.epochsは訓練反復のepochs数を指定する.validation_dataは、検証(テスト)のためのデータセットと検証の頻度validation_を指定するfreq. 上記のコードを実行するとネットワークの訓練と検証の機能を実現し、fit関数は訓練過程のデータ記録historyを返し、history.historyは辞書の対象で、訓練中のloss、測定指標などの記録項目が含まれています.
    history.history #  
    
    compile&fit方式で実現されたコードは非常に簡潔で効率的であり,開発時間を大幅に削減していることがわかる.しかし,インタフェースが非常に高いため,柔軟性も低下し,使用するか否かはユーザが判断する必要がある.

    三、モデルテスト


    Modelベースクラスは、ネットワークの組み立てと訓練、検証を容易に行うだけでなく、予測とテストを非常に便利に行うことができます.検証とテストの違いについては、検証とテストをモデル評価の方法として理解するために、章をフィッティングして詳しく説明します.Model.predict(x)の方法でモデルの予測を完了します.
    #  batch  
    x,y = next(iter(db_test))
    print('predict x:', x.shape)
    out = network.predict(x) #  
    print(out) //  out  
    

    単純なテストモデルのパフォーマンスのみであれば、dbデータセット上のすべてのサンプルをModel.evaluate(db)でループテストし、パフォーマンス指標を印刷できます.
    network.evaluate(db_test) #