tf.estimator.Estimatorの使用


tf.estimator.EstimatorはTFの比較的高度なインタフェースです.
最近bert予備訓練モデルを用いる際にtfを用いた.estimator.Estimator.このインタフェースを使用する場合、開発者が完了する作業が少なく、3つのステップがあります.
ステップ1、input_を設定fun,第2ステップmodel_を設定するfun、第3歩、訓練を開始します.
最初のステップのinput_funが完了した機能は、tfrecordファイルを読み取り、中身を解析してdatasetに戻るなど、データの入力準備です.あるいはオーディオや画像などのデータを読み取り、対応する結果を返し、現在のところdataset形式が望ましい.
ステップ2のモデル_funが完成する機能は、モデルを作成し(featureを入力predictを出力する)、lossを設定し、オプティマイザを設定し、結果をtfに返す.estimator.EstimatorSpec.(tf.estimator.Estimator Specとは何か、どのように設定するかは後述)
第3ステップの開始トレーニングは、パラメータの準備(例えば、学習率など、上記のステップ1-2で用いるパラメータ)であり、config(トレーニングモデルがモデルの保存経路を指定し、どのくらいの時間でモデルを保存するか、GPUを使用する場合)を設定し、状況に応じてestimatorの呼び出しを開始する.そうか?义齿predict.
 
ステップ1:input_fun
def input_fn(filenames, batch_size=32, num_epochs=None, perform_shuffle=False):
    """
        , TFRecord          batch_size batch
    Args:
        filenames: TFRecord  
        batch_size: batch_size  
        num_epochs:  TFRecord        ,   None,           
        perform_shuffle:     

    Returns:
        tensor   ,  batch   
    """
    def _parse_fn(record):
        features = {
            "label": tf.FixedLenFeature([], tf.int64),
            "image": tf.FixedLenFeature([], tf.string),
        }
        parsed = tf.parse_single_example(record, features)
        # image
        image = tf.decode_raw(parsed["image"], tf.uint8)
        image = tf.reshape(image, [28, 28])
        # label
        label = tf.cast(parsed["label"], tf.int64)
        return {"image": image}, label

    # Extract lines from input files using the Dataset API, can pass one filename or filename list
    dataset = tf.data.TFRecordDataset(filenames).map(_parse_fn, num_parallel_calls=10).prefetch(500000)    # multi-thread pre-process then prefetch

    # Randomizes input using a window of 256 elements (read into memory)
    if perform_shuffle:
        dataset = dataset.shuffle(buffer_size=256)

    # epochs from blending together.
    dataset = dataset.repeat(num_epochs)
    dataset = dataset.batch(batch_size)  # Batch size to use

    iterator = dataset.make_one_shot_iterator()
    batch_features, batch_labels = iterator.get_next()
    return batch_features, batch_labels

 
ステップ2:model_fun
def model_fn(features, labels, mode, params):
    """
    :param features:
    :param labels:
    :param mode:     、         
    tf.estimator.ModeKeys.TRAIN    tf.estimator.ModeKeys.EVAL  tf.estimator.ModeKeys.PREDICT
    :param params:             
    :return:
    """
    # step1:     
    logits = create_model(features)
    predict = tf.nn.softmax(logits, axis=-1)

    # step2:   loss、optimization 
    loss = get_loss(logits, labels)
    train_op = tf.train.GradientDescentOptimizer(params['lr']).minimize(loss)

    # step3:   mode,        tf.estimator.EstimatorSpec
    # For mode == ModeKeys.TRAIN:        loss and train_op.
    # For mode == ModeKeys.EVAL:          loss.
    # For mode == ModeKeys.PREDICT:        predictions.
    if mode == tf.estimator.ModeKeys.TRAIN:
        # logging_hook     /     ,         ,            EarlyStopping,
        #        LearningRateScheduler,       step  /      epoch  /           。
        output_spec = tf.estimator.EstimatorSpec(
            mode=mode,
            loss=loss,
            train_op=train_op,
            training_hooks=[logging_hook])
    elif mode == tf.estimator.ModeKeys.EVAL:
        output_spec = tf.estimator.EstimatorSpec(
            mode=mode,
            loss=loss,
            eval_metric_ops=eval_metrics)
    else:
        output_spec = tf.estimator.EstimatorSpec(
            mode=mode,
            predictions={"probabilities": predict})

    return output_spec

 
ステップ3:main
def main_():
    # 1.      
    params = {'lr', 0.0001}

    # 2.   config,           ,       
    session_config = tf.ConfigProto(log_device_placement=False,
                                    inter_op_parallelism_threads=0,
                                    intra_op_parallelism_threads=0,
                                    allow_soft_placement=True)
    run_config = tf.estimator.RunConfig(model_dir=model_output_dir,
                                        save_checkpoints_steps=5000,
                                        keep_checkpoint_max=3,
                                        session_config=session_config)

    # 3.     
    estimator = tf.estimator.Estimator(
        model_fn=model_fn,
        config=run_config,
        params=params)

    if do_train:
        train_input_fn = input_fun(...)
        estimator.train(input_fn=train_input_fn)
    
    elif do_eval:
        eval_input_fn = input_fun(...)
        estimator.train(input_fn=eval_input_fn)
        
    else:
        predict_input_fn = input_fun(...)
        estimator.train(input_fn=predict_input_fn)


 
 
==未完待機==
あとでhookなどの設定について更新します
参考文献:
https://zhuanlan.zhihu.com/p/129018863
https://zhuanlan.zhihu.com/p/106400162
https://www.jianshu.com/p/5495f87107e7