【tensorflow画像読み出し方式】ローカルファイル名読み出しおよびurl方式読み出し

4354 ワード

何気なく巨牛の人工知能のチュートリアルを見つけて、思わず共有してあげました.教程は基礎がゼロで、分かりやすくて、しかもとても面白くてユーモアがあって、小説を読むようです!すごいと思って、みんなに分かち合いました.ここをクリックするとチュートリアルにジャンプできます.人工知能チュートリアル
画像の存在形式は、一般的にローカルフォルダxxxである.jpg.xxx.png.あるいはurl方式で、https://timgsa.baidu.com/timg?この2つに似ています.
ではtensorflowはどのようにしてこの2つの画像を読み取るのでしょうか.以下で説明します.
ローカルピクチャの読み出しはtfでよい.read_file()とtf.image.decode_jpeg()の2つの関数.あるいはtf.gfile.FastGFile()とtf.image.decode_jpeg()の2つの関数.
# -*- encoding=utf-8 -*-
# author:dongli


import matplotlib.pyplot as plt
import tensorflow as tf

#-------------------  1:      ----------------
#         
image_raw = tf.gfile.FastGFile('F:/img_spam/test/3.jpg', 'rb').read()
#    tf      
img = tf.image.decode_jpeg(image_raw,channels=3)  # Tensor

# -------------------  2:      ----------------

#image_value = tf.read_file('F:/img_spam/test/3.jpg')
#img = tf.image.decode_jpeg(image_value, channels=3)




with tf.Session() as sess:
    img_ = img.eval()
    print(img_.shape)

plt.figure(1)
plt.imshow(img_)
plt.show()


url方式のピクチャ読み出しはrequestsのみを用いる.get(image_url).contentとtf.image.decode_jpeg(image_data,channels=3,name=‘jpeg_reader’)でよい.
#------------------ url    ----------------------
import requests
import tensorflow as tf

image_url="https://timgsa.baidu.com/xxxxx.jpg"
image_data=requests.get(image_url).content
print(image_data)
img = tf.image.decode_jpeg(image_data, channels=3,name='jpeg_reader')
print(img)

モデルのロードと、画像の前処理コードの読み出し.
# -*- encoding=utf-8 -*-
# author:dongli

import tensorflow as tf




def load_graph(model_file):

    """
    :param model_file:       output_graph.pb
    :return:    
    """
    graph = tf.Graph()
    graph_def = tf.GraphDef()
    with open(model_file, "rb") as f:
        graph_def.ParseFromString(f.read())
        with graph.as_default():
            tf.import_graph_def(graph_def)
    return graph

def load_labels(label_file):
    """
    :param label_file:         output_labels.txt
    :return:       
    """
    label = []
    proto_as_ascii_lines = tf.gfile.GFile(label_file).readlines()
    for l in proto_as_ascii_lines:
        label.append(l.rstrip())
    return label



def read_tensor_from_image_file(file_name, input_height=299, input_width=299,input_mean=0, input_std=255):

    """
    :param file_name:     
    :param input_height:      
    :param input_width:      
    :param input_mean:    
    :param input_std:    
    :return:          
    """
    file_reader = tf.read_file(file_name)
    if file_name.endswith(".png"):
        image_reader = tf.image.decode_png(file_reader, channels = 3, name='png_reader')
    elif file_name.endswith(".gif"):
        image_reader = tf.squeeze(tf.image.decode_gif(file_reader,name='gif_reader'))
    elif file_name.endswith(".bmp"):
        image_reader = tf.image.decode_bmp(file_reader, name='bmp_reader')
    else:
        image_reader = tf.image.decode_jpeg(file_reader, channels = 3,name='jpeg_reader')

    float_caster = tf.cast(image_reader, tf.float32)
    dims_expander = tf.expand_dims(float_caster, 0)
    resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width])
    normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std])
    sess = tf.Session()
    result = sess.run(normalized)
    return result







def read_tensor_from_image_data(image_data, input_height=299, input_width=299,input_mean=0, input_std=255):

    """
    :param file_name:     
    :param input_height:      
    :param input_width:      
    :param input_mean:    
    :param input_std:    
    :return:     byte      
    """
    image_reader = tf.image.decode_jpeg(image_data,channels = 3,name='jpeg_reader')
    float_caster = tf.cast(image_reader, tf.float32)
    dims_expander = tf.expand_dims(float_caster, 0)
    resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width])
    normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std])
    sess = tf.Session()
    result = sess.run(normalized)
    return result



tensorflow公式チュートリアル関数大全参照
https://www.w3cschool.cn/tensorflow_python/