Caffe 2-(四)squeezenet分類に基づくモデルテスト


Caffe 2モデルのロードとテスト


Model Zoo
ここではsqueezenetモデルを例に、画像中のobjectを分類する.
訓練されたモデルをダウンロードします.
python -m caffe2.python.models.download -i squeezenet

モデルのロード:
  • protobufファイルの読み込み:
    with open("init_net.pb") as f:
       init_net = f.read()
    with open("predict_net.pb") as f:
       predict_net = f.read()  
  • workspaceを採用する.Predictor関数protobufsからblobsをロード:
    p = workspace.Predictor(init_net, predict_net)
  • netを実行する結果:
    results = p.run([img])
    resultsは多次元配列の形式であり、確率値を記憶する.各行はobjectがあるクラスに属する確率を識別する.

  • 完全なコード

    # -------------------------------
    # Configuration
    # -------------------------------
    
    CAFFE2_ROOT = "~/caffe2"
    CAFFE_MODELS = "~/caffe2/caffe2/python/models" #   model  
    
    from caffe2.proto import caffe2_pb2
    import numpy as np
    import skimage.io
    import skimage.transform
    from matplotlib import pyplot
    import os
    from caffe2.python import core, workspace
    import urllib2
    print("Required modules imported.")
    
    IMAGE_LOCATION =  "https://cdn.pixabay.com/photo/2015/02/10/21/28/flower-631765_1280.jpg"
    MODEL = 'squeezenet', 'init_net.pb', 'predict_net.pb', 'ilsvrc_2012_mean.npy', 227
    
    # codes - these help decypher the output and source from a list from AlexNet's object codes to provide an result like "tabby cat" or "lemon" depending on what's in the picture you submit to the neural network.
    # The list of output codes for the AlexNet models (also squeezenet)
    codes =  "https://gist.githubusercontent.com/aaronmarkham/cd3a6b6ac071eca6f7b4a6e40e6038aa/raw/9edb4038a37da6b5a44c3b5bc52e448ff09bfe5b/alexnet_codes"
    print "Config set!"
    
    
    # -------------------------------
    # Pre-processing image
    # -------------------------------
    def crop_center(img,cropx,cropy):
        y,x,c = img.shape
        startx = x//2-(cropx//2)
        starty = y//2-(cropy//2)    
        return img[starty:starty+cropy,startx:startx+cropx]
    
    def rescale(img, input_height, input_width):
        print("Original image shape:" + str(img.shape) + " and remember it should be in H, W, C!")
        print("Model's input shape is %dx%d") % (input_height, input_width)
        aspect = img.shape[1]/float(img.shape[0])
        print("Orginal aspect ratio: " + str(aspect))
        if(aspect>1):
            # landscape orientation - wide image
            res = int(aspect * input_height)
            imgScaled = skimage.transform.resize(img, (input_width, res))
        if(aspect<1):
            # portrait orientation - tall image
            res = int(input_width/aspect)
            imgScaled = skimage.transform.resize(img, (res, input_height))
        if(aspect == 1):
            imgScaled = skimage.transform.resize(img, (input_width, input_height))
        pyplot.figure()
        pyplot.imshow(imgScaled)
        pyplot.axis('on')
        pyplot.title('Rescaled image')
        print("New image shape:" + str(imgScaled.shape) + " in HWC")
        return imgScaled
    print "Functions set."
    
    # set paths and variables from model choice and prep image
    CAFFE2_ROOT = os.path.expanduser(CAFFE2_ROOT)
    CAFFE_MODELS = os.path.expanduser(CAFFE_MODELS)
    
    # mean can be 128 or custom based on the model
    # gives better results to remove the colors found in all of the training images
    MEAN_FILE = os.path.join(CAFFE_MODELS, MODEL[0], MODEL[3])
    if not os.path.exists(MEAN_FILE):
        mean = 128
    else:
        mean = np.load(MEAN_FILE).mean(1).mean(1)
        mean = mean[:, np.newaxis, np.newaxis]
    print "mean was set to: ", mean
    
    INPUT_IMAGE_SIZE = MODEL[4]
    
    # make sure all of the files are around...
    if not os.path.exists(CAFFE2_ROOT):
        print("Houston, you may have a problem.")
    INIT_NET = os.path.join(CAFFE_MODELS, MODEL[0], MODEL[1])
    print 'INIT_NET = ', INIT_NET
    PREDICT_NET = os.path.join(CAFFE_MODELS, MODEL[0], MODEL[2])
    print 'PREDICT_NET = ', PREDICT_NET
    if not os.path.exists(INIT_NET):
        print(INIT_NET + " not found!")
    else:
        print "Found ", INIT_NET, "...Now looking for", PREDICT_NET
        if not os.path.exists(PREDICT_NET):
            print "Caffe model file, " + PREDICT_NET + " was not found!"
        else:
            print "All needed files found! Loading the model in the next block."
    
    #  
    img = skimage.img_as_float(skimage.io.imread(IMAGE_LOCATION)).astype(np.float32)
    img = rescale(img, INPUT_IMAGE_SIZE, INPUT_IMAGE_SIZE)
    img = crop_center(img, INPUT_IMAGE_SIZE, INPUT_IMAGE_SIZE)
    print "After crop: " , img.shape
    pyplot.figure()
    pyplot.imshow(img)
    pyplot.axis('on')
    pyplot.title('Cropped')
    
    # switch to CHW
    img = img.swapaxes(1, 2).swapaxes(0, 1)
    pyplot.figure()
    for i in range(3):
        pyplot.subplot(1, 3, i+1)
        pyplot.imshow(img[i])
        pyplot.axis('off')
        pyplot.title('RGB channel %d' % (i+1))
    
    # switch to BGR
    img = img[(2, 1, 0), :, :]
    
    # remove mean for better results
    #  
    img = img * 255 - mean
    
    # add batch size
    img = img[np.newaxis, :, :, :].astype(np.float32)
    print "NCHW: ", img.shape
    
    # -------------------------------
    #  
    # -------------------------------
    with open(INIT_NET) as f:
        init_net = f.read()
    with open(PREDICT_NET) as f:
        predict_net = f.read()
    
    p = workspace.Predictor(init_net, predict_net)
    
    # -------------------------------
    #  
    # -------------------------------
    # run the net and return prediction
    results = p.run([img]) # 
    
    # turn it into something we can play with and examine which is in a multi-dimensional array
    results = np.asarray(results)
    print "results shape: ", results.shape
    # results shape:  (1, 1, 1000, 1, 1)
    
    # -------------------------------
    #  
    # -------------------------------
    results = np.delete(results, 1)
    index = 0
    highest = 0
    arr = np.empty((0,2), dtype=object)
    arr[:,0] = int(10)
    arr[:,1:] = float(10)
    for i, r in enumerate(results):
        # imagenet index begins with 1!
        i=i+1
        arr = np.append(arr, np.array([[i,r]]), axis=0)
        if (r > highest):
            highest = r
            index = i
    
    print index, " :: ", highest
    
    # lookup the code and return the result
    # top 3 results
    # sorted(arr, key=lambda x: x[1], reverse=True)[:3]
    
    # now we can grab the code list
    response = urllib2.urlopen(codes)
    
    # and lookup our result from the list
    for line in response:
        code, result = line.partition(":")[::2]
        if (code.strip() == str(index)):
            print result.strip()[1:-2]
    
    # output
    # 985  ::  0.979059
    # daisy

    Reference


    [1] - Loading Pre-Trained Models