カスタムestimator

3874 ワード

Tensorflowは1.3バージョンから公式サポートの高層パッケージtf.estimatorを発売した.Estimators APIは、トレーニングモデル、テストモデル、および予測を生成する方法のセットを提供します.

カスタムモデル関数


Tensorflowはカスタムestimatorをサポートし、まずモデル関数model_を定義する必要があります.fn,関数にはfeatures,labels,mode,paramsの4つの入力がある.featuresはモデルの入力であり,labelsは予測の真実値modeの取値として3種類ある:tf.estimator.ModeKeys.TRAIN,tf.estimator.ModeKeys.EVALtf.estimator.ModeKeys.PREDICT,それぞれ訓練,検証,テストに対応する.modeの値により,現在どの段階に属しているかを判断できる.paramsは、learning rateなどのモデル関連のスーパーパラメータを含む辞書です.カスタム関数model_fn戻り値はtf.estimator.EstimatorSpecオブジェクトでなければなりません.
  def __new__(cls,
              mode,
              predictions=None,
              loss=None,
              train_op=None,
              eval_metric_ops=None,
              export_outputs=None,
              training_chief_hooks=None,
              training_hooks=None,
              scaffold=None,
              evaluation_hooks=None,
              prediction_hooks=None):

ただし、modeはモデルの使用パターンを表し、model_に対応するfnのパラメータmode;predictionsは、入力された特徴featuresから返される予測値を表す.lossは損失を表す.train_opは、モデルの損失を最小化するopを表す.eval_metric_opsは、モデルがevalの場合、追加の出力を必要とする指標を示す.export_outputsは、モデルをエクスポートするパスを示します.フック関数もいくつかあります.モデルが異なる場合、EstimatorSpecに必要なパラメータも異なります.modeがTRAINである場合、EstimatorSpecをインスタンス化するには、パラメータlossおよびtrain_opを設定する必要があり、modeがEVALである場合、パラメータlossを設定する必要があり、modeがPREDICTである場合、パラメータpredictionsを設定する必要があります.
def my_model(features, labels, mode, params):
    W = tf.Variable(tf.random_normal([1]), name="weight")
    b = tf.Variable(tf.zeros([1]), name="bias")
    predictions = tf.multiply(W, tf.cast(features, dtype=tf.float32)) + b
    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode, predictions=predictions)
    loss = tf.losses.mean_squared_error(labels=labels, predictions=predictions)
    mean_loss = tf.metrics.mean(loss)
    metrics = {'mean_loss':mean_loss}
    if mode == tf.estimator.ModeKeys.EVAL:
        # eval_metric_ops` , eval 。
        return tf.estimator.EstimatorSpec(mode=mode, loss=loss, eval_metric_ops=metrics) 
    assert mode == tf.estimator.ModeKeys.TRAIN
    optimizer = tf.train.AdagradDAOptimizer(learning_rate=params["learning_rate"], global_step=tf.train.get_or_create_global_step())
    train_op = optimizer.minimize(loss=loss, global_step=tf.train.get_or_create_global_step())
    return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)

インスタンスestimator


最後に、tf.estimator.Estimatorをインスタンス化することによって、カスタムestimatorを得ることができる.
def __init__(self,
             model_fn: Any,
             model_dir: Any = None,
             config: Any = None,
             params: Any = None,
             warm_start_from: Any = None) -> Any

パラメータmodel_fnはカスタムモデル関数であり、model_dirはモデルのパラメータやモデル図などの内容を保存するために使用されます.warm_start_fromチェックポイントパスを指定し、checkpointをインポートしてトレーニングを開始します.warm_start_fromはtf.estimator.WarmStartSettingsでインスタンス化できます.
def __new__(cls,
            ckpt_to_initialize_from: Any,
            vars_to_warm_start: str = '.*',
            var_name_to_vocab_info: Any = None,
            var_name_to_prev_var_name: Any = None) -> _T
ckpt_to_initialize_fromcheckpointをロードするパスを指定し、vars_to_warm_startホットスタートが必要なパラメータを指定します.

コード#コード#


コードカスタムestimator

リファレンス

  • 深入学习のtensorflow工程化プロジェクト実戦.