TensorFlow学習ノート-tf.estimator
2710 ワード
tf.estimator.Estimator
Estimator classトレーニングとTFモデルのテスト.
Estimator
オブジェクトは、model_fn
によって指定するモデルをカプセル化し、入力および他のスーパーパラメータを与え、opsに戻ってtraining、evaluation or preditionを実行する.すべての出力(checkpoints,event files,etc.を含む)はmodel_dir
に書き込まれる.ツールバーの
着信
model_fn
、model_fn
パラメータnamed"config"def model_fn(features, labels, mode, config)
方法
__init__
__init__(
model_fn,
model_dir=None,
config=None,
params=None # model_fn
)
evaluate
トレーニングモデルの評価
evaluate(
input_fn, # , features labels
steps=None,
hooks=None, # List of SessionRunHook subclass instances
checkpoint_path=None, # if none, model_dir latest checkpoint
name=None
)
export_savemodel
は、SavedModel export_savedmodel(
export_dir_base, #
serving_input_receiver_fn, # ServingInputReceiver
assets_extra=None,
as_text=False,
checkpoint_path=None
)
get_variable_names()は、モデル内のすべての変数名のリストを返します.
model_dir
において最近保存されたcheckpoint predict(
input_fn,
predict_keys=None,
hooks=None,
checkpoint_path=None
)
訓練データを与えた後にmodelを訓練する
train(
input_fn,
hooks=None,
steps=None,
max_steps=None,
saving_listeners=None
)