TensorFlow学習ノート-tf.estimator

2710 ワード

  • tfestimatorEstimator
  • 属性
  • 方法

  • tf.estimator.Estimator


    Estimator classトレーニングとTFモデルのテスト.Estimatorオブジェクトは、model_fnによって指定するモデルをカプセル化し、入力および他のスーパーパラメータを与え、opsに戻ってtraining、evaluation or preditionを実行する.すべての出力(checkpoints,event files,etc.を含む)はmodel_dirに書き込まれる.

    ツールバーの

  • config
    着信model_fnmodel_fnパラメータnamed"config"
  • model_dir
  • model_fn The model_fn with following signature: def model_fn(features, labels, mode, config)
  • params

  • 方法

  • __init__
  • __init__(
        model_fn,
        model_dir=None,
        config=None,
        params=None #  model_fn 
    )
  • evaluate

  • トレーニングモデルの評価
    evaluate(
        input_fn, #  , features labels
        steps=None,
        hooks=None, # List of SessionRunHook subclass instances
        checkpoint_path=None, # if none,  model_dir latest checkpoint
        name=None
    )
  • export_savemodelは、SavedModel
  • としてinference graphを導出する.
    export_savedmodel(
        export_dir_base, #  
        serving_input_receiver_fn, #  ServingInputReceiver 
        assets_extra=None,
        as_text=False,
        checkpoint_path=None
    )
  • get_variable_names
    get_variable_names()は、モデル内のすべての変数名のリストを返します.
  • get_variable_value(name)変数nameに基づいてvalue
  • を返す
  • latest_checkpoint()は、model_dirにおいて最近保存されたcheckpoint
  • を見つける.
  • predict所与のfeaturesに基づいて予測
  • を生成する.
    predict(
        input_fn,
        predict_keys=None,
        hooks=None,
        checkpoint_path=None
    )
  • train

  • 訓練データを与えた後にmodelを訓練する
    train(
        input_fn,
        hooks=None,
        steps=None,
        max_steps=None,
        saving_listeners=None
    )