【python】tensorflowフレームワーク下sess.run()読み出しデータが詰まっている---ソリューション

13465 ワード

最近tensorflowフレームワークの下でコードをデバッグしたときにsessに遭遇した.run()データが詰まっている場合、検索は多くの方法を試みてやっと解決策を見つけ、同じ問題に直面したパートナーができるだけ早く問題を解決することを望んでいる.
問題の説明:img,label=sess.run([image, labels])
プログラムは改行まで実行してデータの読み取りを行う時、間違いを報告しないで、しかしプログラムカードはこの文で、下へ実行することができません
解決策:with tf.session()as sess:次の文を部分的に追加
coord = tf.train.Coordinator()
thread = tf.train.start_queue_runners(sess, coord)

詳細:
イメージ-画像保存パスリスト(stringタイプ);Labels-ラベルリスト(int 32タイプ);
データ量が大きすぎるため、画像データに対しては通常バッチ処理、すなわちbatch_を一度に処理する方式を採用している.size(最小1)枚の画像をとり、CPUの負担を軽減し、プログラムの正常な動作を保証します.
このプロセスでは、データが経験した2回の変換(元のタイプstring/int 32をnumpyと記す):
numpy — tensor — numpy
上記の問題は、2回目の変換(tensorからnumpyに移行)時に発生する
(1)numpyからtensorコードへの移行
def get_batch(image, label, image_W, image_H, batch_size, capacity):
    '''
    Args:
        image: list type
        label: list type
        image_W: image width
        image_H: image height
        batch_size: batch size
        capacity: the maximum elements in queue
    Returns:
        image_batch: 4D tensor [batch_size, width, height, 3], dtype=tf.float32
        label_batch: 1D tensor [batch_size], dtype=tf.int32
    '''
    # image_W, image_H, :             
    #   batch_size:  batch       
    # capacity:        

    image = tf.cast(image, tf.string)
    label = tf.cast(label, tf.int32)

    input_queue = tf.train.slice_input_producer([image, label])

    label = input_queue[1]
    image_contents = tf.read_file(input_queue[0])  # read img from a queue

    # step2:     ,             ,    jpeg,    png 。
    image = tf.image.decode_png(image_contents, channels=3)
    # image = tf.image.resize_images(image,(image_W,image_H),0)

    # step3:     ,       、  、  、      ,          。
    image = tf.image.resize_image_with_crop_or_pad(image, image_W, image_H)
    image = tf.image.per_image_standardization(image)

    image_batch, label_batch = tf.train.batch([image, label],
                                              batch_size=batch_size,
                                              num_threads=64,
                                              capacity=capacity)

    #     label,   [batch_size]
    label_batch = tf.reshape(label_batch, [batch_size])
    image_batch = tf.cast(image_batch, tf.float32)

    return image_batch, label_batch

ネットワークトレーニング中にfeed_を与えるためdictのデータ型はすでに確定しているので(これがプレースホルダplaceholderがやったこと)、batch_images, batch_labels = sess.run([train_batch,train_label_batch])の目的はtensorを元のタイプのデータとして再読み込みし、得られた(batch_images,batch_labels)はfeed_に直接与えることができるdict.
(2)tensorからnumpyコードへ
	train, train_label = read_data.get_files(train_dir)

    train_batch, train_label_batch = read_data.get_batch(train,
                                                             train_label,
                                                             IMG_W,
                                                             IMG_H,
                                                             BATCH_SIZE,
                                                             CAPACITY)
    with tf.Session(config=config) as sess:
    # with tf.Session(config=tf.ConfigProto(log_device_placement=False, gpu_options=gpu_options)) as sess:
        # init = tf.initialize_all_variables()
        init = tf.global_variables_initializer()
        sess.run(init)
     	coord = tf.train.Coordinator()
		thread = tf.train.start_queue_runners(sess, coord)
  
        writer = tf.summary.FileWriter(save_dir, sess.graph)
        sample_z = np.random.uniform(-1, 1, size=(BATCH_SIZE, 100))

        print("start_sess")

        for step in range(MaxStep):

            batch_images, batch_labels = sess.run([train_batch, train_label_batch])
     
            batch_z = np.random.uniform(-1, 1, size=(BATCH_SIZE, 100))
          

            #    D    
            _, summary_str = sess.run([d_optim, d_sum],
                                      feed_dict={images: batch_images,
                                                 z: batch_z,
                                                 y: batch_labels})

[train_batch,train_label_batch]はtensorで表され、imageに格納されているのはすべてのピクチャパスであり、ここではキュー構造を用いて完了する.
  • 実行ブロックでstartというキューを先に起動して動作させなければ、
  • を正常に読み出すことができません.
    coord = tf.train.Coordinator()
    thread = tf.train.start_queue_runners(sess, coord)
    

    これもこの問題の核心である.