tf.estimator.Estimatorの使用
5137 ワード
tf.estimator.EstimatorはTFの比較的高度なインタフェースです.
最近bert予備訓練モデルを用いる際にtfを用いた.estimator.Estimator.このインタフェースを使用する場合、開発者が完了する作業が少なく、3つのステップがあります.
ステップ1、input_を設定fun,第2ステップmodel_を設定するfun、第3歩、訓練を開始します.
最初のステップのinput_funが完了した機能は、tfrecordファイルを読み取り、中身を解析してdatasetに戻るなど、データの入力準備です.あるいはオーディオや画像などのデータを読み取り、対応する結果を返し、現在のところdataset形式が望ましい.
ステップ2のモデル_funが完成する機能は、モデルを作成し(featureを入力predictを出力する)、lossを設定し、オプティマイザを設定し、結果をtfに返す.estimator.EstimatorSpec.(tf.estimator.Estimator Specとは何か、どのように設定するかは後述)
第3ステップの開始トレーニングは、パラメータの準備(例えば、学習率など、上記のステップ1-2で用いるパラメータ)であり、config(トレーニングモデルがモデルの保存経路を指定し、どのくらいの時間でモデルを保存するか、GPUを使用する場合)を設定し、状況に応じてestimatorの呼び出しを開始する.そうか?义齿predict.
ステップ1:input_fun
ステップ2:model_fun
ステップ3:main
==未完待機==
あとでhookなどの設定について更新します
参考文献:
https://zhuanlan.zhihu.com/p/129018863
https://zhuanlan.zhihu.com/p/106400162
https://www.jianshu.com/p/5495f87107e7
最近bert予備訓練モデルを用いる際にtfを用いた.estimator.Estimator.このインタフェースを使用する場合、開発者が完了する作業が少なく、3つのステップがあります.
ステップ1、input_を設定fun,第2ステップmodel_を設定するfun、第3歩、訓練を開始します.
最初のステップのinput_funが完了した機能は、tfrecordファイルを読み取り、中身を解析してdatasetに戻るなど、データの入力準備です.あるいはオーディオや画像などのデータを読み取り、対応する結果を返し、現在のところdataset形式が望ましい.
ステップ2のモデル_funが完成する機能は、モデルを作成し(featureを入力predictを出力する)、lossを設定し、オプティマイザを設定し、結果をtfに返す.estimator.EstimatorSpec.(tf.estimator.Estimator Specとは何か、どのように設定するかは後述)
第3ステップの開始トレーニングは、パラメータの準備(例えば、学習率など、上記のステップ1-2で用いるパラメータ)であり、config(トレーニングモデルがモデルの保存経路を指定し、どのくらいの時間でモデルを保存するか、GPUを使用する場合)を設定し、状況に応じてestimatorの呼び出しを開始する.そうか?义齿predict.
ステップ1:input_fun
def input_fn(filenames, batch_size=32, num_epochs=None, perform_shuffle=False):
"""
, TFRecord batch_size batch
Args:
filenames: TFRecord
batch_size: batch_size
num_epochs: TFRecord , None,
perform_shuffle:
Returns:
tensor , batch
"""
def _parse_fn(record):
features = {
"label": tf.FixedLenFeature([], tf.int64),
"image": tf.FixedLenFeature([], tf.string),
}
parsed = tf.parse_single_example(record, features)
# image
image = tf.decode_raw(parsed["image"], tf.uint8)
image = tf.reshape(image, [28, 28])
# label
label = tf.cast(parsed["label"], tf.int64)
return {"image": image}, label
# Extract lines from input files using the Dataset API, can pass one filename or filename list
dataset = tf.data.TFRecordDataset(filenames).map(_parse_fn, num_parallel_calls=10).prefetch(500000) # multi-thread pre-process then prefetch
# Randomizes input using a window of 256 elements (read into memory)
if perform_shuffle:
dataset = dataset.shuffle(buffer_size=256)
# epochs from blending together.
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size) # Batch size to use
iterator = dataset.make_one_shot_iterator()
batch_features, batch_labels = iterator.get_next()
return batch_features, batch_labels
ステップ2:model_fun
def model_fn(features, labels, mode, params):
"""
:param features:
:param labels:
:param mode: 、
tf.estimator.ModeKeys.TRAIN tf.estimator.ModeKeys.EVAL tf.estimator.ModeKeys.PREDICT
:param params:
:return:
"""
# step1:
logits = create_model(features)
predict = tf.nn.softmax(logits, axis=-1)
# step2: loss、optimization
loss = get_loss(logits, labels)
train_op = tf.train.GradientDescentOptimizer(params['lr']).minimize(loss)
# step3: mode, tf.estimator.EstimatorSpec
# For mode == ModeKeys.TRAIN: loss and train_op.
# For mode == ModeKeys.EVAL: loss.
# For mode == ModeKeys.PREDICT: predictions.
if mode == tf.estimator.ModeKeys.TRAIN:
# logging_hook / , , EarlyStopping,
# LearningRateScheduler, step / epoch / 。
output_spec = tf.estimator.EstimatorSpec(
mode=mode,
loss=loss,
train_op=train_op,
training_hooks=[logging_hook])
elif mode == tf.estimator.ModeKeys.EVAL:
output_spec = tf.estimator.EstimatorSpec(
mode=mode,
loss=loss,
eval_metric_ops=eval_metrics)
else:
output_spec = tf.estimator.EstimatorSpec(
mode=mode,
predictions={"probabilities": predict})
return output_spec
ステップ3:main
def main_():
# 1.
params = {'lr', 0.0001}
# 2. config, ,
session_config = tf.ConfigProto(log_device_placement=False,
inter_op_parallelism_threads=0,
intra_op_parallelism_threads=0,
allow_soft_placement=True)
run_config = tf.estimator.RunConfig(model_dir=model_output_dir,
save_checkpoints_steps=5000,
keep_checkpoint_max=3,
session_config=session_config)
# 3.
estimator = tf.estimator.Estimator(
model_fn=model_fn,
config=run_config,
params=params)
if do_train:
train_input_fn = input_fun(...)
estimator.train(input_fn=train_input_fn)
elif do_eval:
eval_input_fn = input_fun(...)
estimator.train(input_fn=eval_input_fn)
else:
predict_input_fn = input_fun(...)
estimator.train(input_fn=predict_input_fn)
==未完待機==
あとでhookなどの設定について更新します
参考文献:
https://zhuanlan.zhihu.com/p/129018863
https://zhuanlan.zhihu.com/p/106400162
https://www.jianshu.com/p/5495f87107e7