カスタムestimator
3874 ワード
Tensorflowは1.3バージョンから公式サポートの高層パッケージ
Tensorflowはカスタムestimatorをサポートし、まずモデル関数model_を定義する必要があります.fn,関数にはfeatures,labels,mode,paramsの4つの入力がある.featuresはモデルの入力であり,labelsは予測の真実値modeの取値として3種類ある:
ただし、
最後に、
パラメータ
コードカスタムestimator
深入学习のtensorflow工程化プロジェクト実戦.
tf.estimator
を発売した.Estimators APIは、トレーニングモデル、テストモデル、および予測を生成する方法のセットを提供します.カスタムモデル関数
Tensorflowはカスタムestimatorをサポートし、まずモデル関数model_を定義する必要があります.fn,関数にはfeatures,labels,mode,paramsの4つの入力がある.featuresはモデルの入力であり,labelsは予測の真実値modeの取値として3種類ある:
tf.estimator.ModeKeys.TRAIN
,tf.estimator.ModeKeys.EVAL
とtf.estimator.ModeKeys.PREDICT
,それぞれ訓練,検証,テストに対応する.modeの値により,現在どの段階に属しているかを判断できる.paramsは、learning rateなどのモデル関連のスーパーパラメータを含む辞書です.カスタム関数model_fn戻り値はtf.estimator.EstimatorSpec
オブジェクトでなければなりません. def __new__(cls,
mode,
predictions=None,
loss=None,
train_op=None,
eval_metric_ops=None,
export_outputs=None,
training_chief_hooks=None,
training_hooks=None,
scaffold=None,
evaluation_hooks=None,
prediction_hooks=None):
ただし、
mode
はモデルの使用パターンを表し、model_に対応するfnのパラメータmode;predictions
は、入力された特徴features
から返される予測値を表す.loss
は損失を表す.train_op
は、モデルの損失を最小化するopを表す.eval_metric_ops
は、モデルがevalの場合、追加の出力を必要とする指標を示す.export_outputs
は、モデルをエクスポートするパスを示します.フック関数もいくつかあります.モデルが異なる場合、EstimatorSpecに必要なパラメータも異なります.modeがTRAIN
である場合、EstimatorSpecをインスタンス化するには、パラメータloss
およびtrain_op
を設定する必要があり、modeがEVAL
である場合、パラメータloss
を設定する必要があり、modeがPREDICT
である場合、パラメータpredictions
を設定する必要があります.def my_model(features, labels, mode, params):
W = tf.Variable(tf.random_normal([1]), name="weight")
b = tf.Variable(tf.zeros([1]), name="bias")
predictions = tf.multiply(W, tf.cast(features, dtype=tf.float32)) + b
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode, predictions=predictions)
loss = tf.losses.mean_squared_error(labels=labels, predictions=predictions)
mean_loss = tf.metrics.mean(loss)
metrics = {'mean_loss':mean_loss}
if mode == tf.estimator.ModeKeys.EVAL:
# eval_metric_ops` , eval 。
return tf.estimator.EstimatorSpec(mode=mode, loss=loss, eval_metric_ops=metrics)
assert mode == tf.estimator.ModeKeys.TRAIN
optimizer = tf.train.AdagradDAOptimizer(learning_rate=params["learning_rate"], global_step=tf.train.get_or_create_global_step())
train_op = optimizer.minimize(loss=loss, global_step=tf.train.get_or_create_global_step())
return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
インスタンスestimator
最後に、
tf.estimator.Estimator
をインスタンス化することによって、カスタムestimatorを得ることができる.def __init__(self,
model_fn: Any,
model_dir: Any = None,
config: Any = None,
params: Any = None,
warm_start_from: Any = None) -> Any
パラメータ
model_fn
はカスタムモデル関数であり、model_dir
はモデルのパラメータやモデル図などの内容を保存するために使用されます.warm_start_from
チェックポイントパスを指定し、checkpointをインポートしてトレーニングを開始します.warm_start_fromはtf.estimator.WarmStartSettings
でインスタンス化できます.def __new__(cls,
ckpt_to_initialize_from: Any,
vars_to_warm_start: str = '.*',
var_name_to_vocab_info: Any = None,
var_name_to_prev_var_name: Any = None) -> _T
ckpt_to_initialize_from
checkpointをロードするパスを指定し、vars_to_warm_start
ホットスタートが必要なパラメータを指定します.コード#コード#
コードカスタムestimator