TensorFlow:モデルの組み立て、トレーニング、テスト
7577 ワード
ネットワークを訓練する際,一般的な流れは,順方向計算によりネットワークの出力値を求め,損失関数によりネットワーク誤差を計算した後,自動導出ツールにより勾配を計算し更新し,ネットワークの性能を間隔的に試験することである.このような一般的なトレーニングロジックは,Kerasが提供するモデルアセンブリとトレーニング高レベルインタフェースによって直接実現でき,簡潔で明確である.
一、モデル組立 二、モデルマッチング 三、モデルテスト
Kerasには,
ネットワークを作成した後、通常のプロセスは、ループ反復データセットを複数回繰り返し、バッチごとにトレーニングデータを生成し、順方向に計算し、損失関数を介して誤差値を計算し、自動的に勾配を計算し、ネットワークパラメータを更新することです.この部分の論理は非常に汎用的であるため、kerasにおいて
私たちがcompile()関数で指定したオプティマイザ、損失関数などのパラメータも、私たちが自分で訓練する際に使用するパラメータであり、特別な点はありませんが、kerasはこの部分の常用論理を実現し、開発効率を高めています.
モデルアセンブリが完了すると、
ここでtrain_dbは
Modelベースクラスは、ネットワークの組み立てと訓練、検証を容易に行うだけでなく、予測とテストを非常に便利に行うことができます.検証とテストの違いについては、検証とテストをモデル評価の方法として理解するために、章をフィッティングして詳しく説明します.
単純なテストモデルのパフォーマンスのみであれば、dbデータセット上のすべてのサンプルを
文書ディレクトリ
一、モデル組立
Kerasには,
keras.Model
とkeras.layers.Layer
の2つの比較的特殊なクラスがある.ここで、Layerクラスはネットワーク層の親クラスであり、重み値の追加、重み値リストの管理など、ネットワーク層の一般的な機能を定義します.Modelクラスはネットワークの親クラスで、Layerクラスの機能に加えて、モデルの保存、ロード、トレーニング、テストなどの便利な機能が追加されています.SequentialもModelのサブクラスであるため、Modelクラスのすべての機能を持つ.# 5
network = Sequential([layers.Dense(256, activation='relu'),
layers.Dense(128, activation='relu'),
layers.Dense(64, activation='relu'),
layers.Dense(32, activation='relu'),
layers.Dense(10)])
network.build(input_shape=(None, 28*28))
network.summary()
ネットワークを作成した後、通常のプロセスは、ループ反復データセットを複数回繰り返し、バッチごとにトレーニングデータを生成し、順方向に計算し、損失関数を介して誤差値を計算し、自動的に勾配を計算し、ネットワークパラメータを更新することです.この部分の論理は非常に汎用的であるため、kerasにおいて
compile()
およびfit()
関数が提供され、上述の論理を容易に実現する.まずcompile関数でネットワークで使用するオプティマイザの対象、損失関数、評価指標などを指定します.# ,
from tensorflow.keras import optimizers,losses
# Adam , 0.01; , Softmax
network.compile(optimizer=optimizers.Adam(lr=0.01),
loss=losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'] #
)
私たちがcompile()関数で指定したオプティマイザ、損失関数などのパラメータも、私たちが自分で訓練する際に使用するパラメータであり、特別な点はありませんが、kerasはこの部分の常用論理を実現し、開発効率を高めています.
二、モデルマッチング
モデルアセンブリが完了すると、
fit()
関数を使用して、トレーニング対象のデータと検証用のデータセットを送信できます.# train_db, val_db, 5 epochs, 2 epoch
# history
history = network.fit(train_db, epochs=5, validation_data=val_db,
validation_freq=2)
ここでtrain_dbは
tf.data.Dataset
オブジェクトであり、Numpy Arrayタイプのデータを転送することもできる.epochsは訓練反復のepochs数を指定する.validation_dataは、検証(テスト)のためのデータセットと検証の頻度validation_を指定するfreq. 上記のコードを実行するとネットワークの訓練と検証の機能を実現し、fit関数は訓練過程のデータ記録historyを返し、history.historyは辞書の対象で、訓練中のloss、測定指標などの記録項目が含まれています.history.history #
compile&fit
方式で実現されたコードは非常に簡潔で効率的であり,開発時間を大幅に削減していることがわかる.しかし,インタフェースが非常に高いため,柔軟性も低下し,使用するか否かはユーザが判断する必要がある.三、モデルテスト
Modelベースクラスは、ネットワークの組み立てと訓練、検証を容易に行うだけでなく、予測とテストを非常に便利に行うことができます.検証とテストの違いについては、検証とテストをモデル評価の方法として理解するために、章をフィッティングして詳しく説明します.
Model.predict(x)
の方法でモデルの予測を完了します.# batch
x,y = next(iter(db_test))
print('predict x:', x.shape)
out = network.predict(x) #
print(out) // out
単純なテストモデルのパフォーマンスのみであれば、dbデータセット上のすべてのサンプルを
Model.evaluate(db)
でループテストし、パフォーマンス指標を印刷できます.network.evaluate(db_test) #