tfrecordデータの読み出し

4120 ワード

tfrecordデータセットがあれば、自然にデータセットから元の訓練データを解いて訓練を行い、tensorflowはtfrecordデータセットの読み取りを処理する方法を提供し、読み取り関数とマルチスレッドでデータを処理する方法を含む.
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np

def read_and_decode(filename):
    #"C:\\Python34\\tensorflow\\tfrecord_train_plane_1.tfrecords"
    filename_queue = tf.train.string_input_producer([filename], shuffle = False) #                     ,                     

    reader = tf.TFRecordReader()#    reader   TFRecord      
    _, serialized_example = reader.read(filename_queue)   #          ,        
    #batch = tf.train.batch(tensors=[serialized_example],batch_size=3)
    #         
    features = tf.parse_single_example(serialized_example,
                                       features={
                                           'label': tf.FixedLenFeature([3], tf.int64),
                                           'img_raw' : tf.FixedLenFeature([], tf.string),
                                       })  #    image label feature  
    image = tf.decode_raw(features['img_raw'], tf.uint8)#  decode_raw                
    image = tf.reshape(image, [128, 128, 1])#           128*128   ,1       
    image = (image-tf.reduce_min(image))/(tf.reduce_max(image)-tf.reduce_min(image))#        
    label = tf.cast(features['label'], tf.float32)#          
    return image, label
if __name__=="__main__":
    img, label = read_and_decode("C:\\tensorflowprogram\\tensorflow\\tfrecord_train_plane\\tfrecord_train_plane_128_330.tfrecords")
    count = 0
    #  tf.train.shuffle_batch            batch     ,[img,label]          ,batch_size            ,capacity          ,min_after_dequeue             
    img_batch, label_batch = tf.train.shuffle_batch([img, label],batch_size=20, capacity=1060,min_after_dequeue = 30)                                           
    with tf.Session() as sess: #      
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
        coord=tf.train.Coordinator()#  tf.train.Coordinator        ,     
        threads= tf.train.start_queue_runners(coord=coord)
         #    ,             
        for i in range (15):
            k,l=sess.run([img_batch,label_batch])
            print(type(k))
            print(k)
            print(l)
           #print(type(l))
            print(i)

        coord.request_stop()
        coord.join(threads)

テストの結果,このプログラムは合格し,ニューラルネットワークモデルを入力して訓練するためにデータを間欠的に読み取ることができる.