NASNETのインポートとテスト

5051 ワード

NASNETのインポートとテスト


引用する


NASNETはこれまでで最も性能の良い画像分類ネットワークであり,tensorflow/modelsとyeephycho/nasnet-tensorflowは対応する訓練とテストコードを公表した.しかし、提供されるサンプルコードは1つのコマンドラインにすぎず、パッケージの程度が高すぎて、理解と自分でテストするのに不便です.できる限り少ないpythonコードを用いた自己テストを実現するために,本人はソースコードを調べ,GitHub上で関連問題を検索し,ついにこの問題を解決した.

インプリメンテーション

  • まず必要なライブラリ
  • をインポートする.
    import os
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #  import tensorflow warning
    
    import numpy as np
    import matplotlib.pyplot as plt
    %matplotlib inline
    import tensorflow as tf
    import scipy.io
    import skimage.transform

    2.次にclone models/research/slim/サブディレクトリで、前の文書で簡単に説明した方法で、pythonシステムパスにファイルのパスを追加します.
    import sys
    nets_path = r'...\slim'
    sys.path.insert(0,nets_path)
    
    from nets.nasnet import nasnet
    
    slim = tf.contrib.slim

    3.slim/datasetsに既存のコードを利用してimageenet labelを生成する
    from datasets import imagenet
    labels = imagenet.create_readable_names_for_imagenet_labels()

    ここのlabelは1001クラスであり、通常は1000クラスであるべきであることに気づいた.これはinceptionなど多くのモデルに対して,その設計者がわざわざ「0:background」,すなわちすべての既存カテゴリに対応する数値を1つ加えたためである.これはlabelsの具体的なコンテンツ検証を表示することによって
     >> print(len(labels))
    1001
     >> print(labels)
    {0: 'background', 1: 'tench, Tinca tinca', 2: 'goldfish, Carassius auratus', 3: 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias', 4: 'tiger shark, Galeocerdo cuvieri', 5: 'hammerhead, hammerhead shark', 6: 'electric ray, crampfish, numbfish, torpedo',...

    imagenetの1000種類に対応するカテゴリはそれぞれ
    tench, Tinca tinca
    goldfish, Carassius auratus
    great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias
    tiger shark, Galeocerdo cuvieri
    hammerhead, hammerhead shark
    electric ray, crampfish, numbfish, torpedo

    1001クラスは確かに0に背景クラスが多くなっただけであることがわかる.nasnetプリトレーニングモデルをインポートします.これは、mobileとlargeの2つの形式でnasnet checkpointファイルを事前にダウンロードする必要があります.次のmobileのインポート(largeモデルをインポートするには、すべてのmobileをlargeに変更する必要があります)
    ckpt_path = r'...
    asnet-a_mobile_04_10_2017\model.ckpt'
    tf.reset_default_graph() x = tf.placeholder(tf.float32,shape = [None,224,224,3],name = 'im') # mobile input shape (224,224,3),large (331,331,3), slim/nets/nasnet/nasnet.py mean = tf.constant([[[[ 123.68/255, 116.779/255, 103.939/255]]]],name = 'im_mean',dtype = tf.float32) # [0,1] x1 = tf.subtract(x,mean) slim.get_or_create_global_step() # with slim.arg_scope(nasnet.nasnet_mobile_arg_scope()): net,endpoints = nasnet.build_nasnet_mobile(images = x1,num_classes = 1000 + 1) pass saver = tf.train.Saver() sess = tf.InteractiveSession() saver.restore(sess,ckpt_path) prob = tf.nn.softmax(net,axis = 1) y = tf.argmax(prob,axis = 1)

    5.テスト
    y0 = sess.run(y,{x:np.expand_dims(im,0)})[0]
    print(y0,labels[y0])

    まとめ


    実際にはslim/eval_image_classifier.pyに必要な情報がすべて与えられており、コードを辛抱強く読むだけで解決するのは難しくありません.