Faster RCNN PR曲線を描くtest_Netパラメータ変更

5144 ワード

参考ブログ:https://blog.csdn.net/hongxingabc/article/details/80064574 https://blog.csdn.net/dlh_sycamore/article/details/865347122人のブロガーの共有に感謝します.これは主にtestについてnet.pyコードの一部のパラメータの変更方法は、皆さんの役に立つことを望んでいます.まず,訓練したモデルがテストに成功することを確保しなければならない.ここには全部で2つのtestを変更する必要があります.Netのところ、まだ問題があればさらに議論してもいいです.コード:https://github.com/dBeker/Faster-RCNN-TensorFlow-Python3.5/issues/60
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# import _init_paths
from lib.utils.test import test_net
from lib.config import config as cfg
# from lib.config import cfg, cfg_from_file, cfg_from_list
from lib.datasets.factory import get_imdb
import argparse
import pprint
import time, os, sys

import tensorflow as tf
from lib.nets.vgg16 import vgg16
# from nets.resnet_v1 import resnetv1
# from nets.mobilenet_v1 import mobilenetv1

demonet = 'vgg16'
dataset = 'pascal_voc'
NETS = {'vgg16': ('vgg16.ckpt',)}
DATASETS = {'pascal_voc': ('voc_2007_trainval',)}
#tfmodel = os.path.join('output', demonet, DATASETS[dataset][0], 'default', NETS[demonet][0])
# !
tfmodel = r'D:\xxxx.ckpt'# 

def parse_args():
  """
  Parse input arguments
  """
  parser = argparse.ArgumentParser(description='Test a Fast R-CNN network')
  parser.add_argument('--cfg', dest='cfg_file',
            help='optional config file', default=None, type=str)
  parser.add_argument('--model', dest='model',
            help='model to test',
            default=None, type=str)
  parser.add_argument('--imdb', dest='imdb_name',
            help='dataset to test',
            default='voc_2007_test', type=str)
  parser.add_argument('--comp', dest='comp_mode', help='competition mode',
            action='store_true')
  parser.add_argument('--num_dets', dest='max_per_image',
            help='max number of detections per image',
            default=100, type=int)
  parser.add_argument('--tag', dest='tag',
                        help='tag of the model',
                        default='', type=str)
  # parser.add_argument('--net', dest='net',
  #                     help='vgg16, res50, res101, res152, mobile',
  #                     default='res50', type=str)
  parser.add_argument('--set', dest='set_cfgs',
                        help='set config keys', default=None,
                        nargs=argparse.REMAINDER)
  parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16 res101]',
                      choices=NETS.keys(), default='vgg16')
  parser.add_argument('--dataset', dest='dataset', help='Trained dataset [pascal_voc pascal_voc_0712]',
                      choices=DATASETS.keys(), default='pascal_voc_0712')

  # if len(sys.argv) == 1:
  #   parser.print_help()
  #   sys.exit(1)

  args = parser.parse_args()                 #  
  return args

if __name__ == '__main__':
  args = parse_args()

  print('Called with args:')
  print(args)

  # if args.cfg_file is not None:
  #   cfg_from_file(args.cfg_file)
  # if args.set_cfgs is not None:
  #   cfg_from_list(args.set_cfgs)

  print('Using config:')
  pprint.pprint(cfg)

  # if has model, get the name from it
  # if does not, then just use the initialization weights
  if tfmodel:
    filename = os.path.splitext(os.path.basename(tfmodel))[0]
  else:
    filename = os.path.splitext(os.path.basename(args.weight))[0]

  tag = args.tag
  tag = tag if tag else 'default'
  filename = tag + '/' + filename

  imdb = get_imdb(args.imdb_name)
  imdb.competition_mode(args.comp_mode)

  tfconfig = tf.ConfigProto(allow_soft_placement=True)
  tfconfig.gpu_options.allow_growth=True

  # init session
  sess = tf.Session(config=tfconfig)
  # load network



  if demonet == 'vgg16':
      net = vgg16()
  # elif args.net == 'res50':
  #   net = resnetv1(num_layers=50)
  # elif args.net == 'res101':
  #   net = resnetv1(num_layers=101)
  # elif args.net == 'res152':
  #   net = resnetv1(num_layers=152)
  # elif args.net == 'mobile':
  #   net = mobilenetv1()
  else:
    raise NotImplementedError

  # load model
  # num_classes +1
  net.create_architecture(sess, mode="TEST",num_classes=4, tag='default',
                          anchor_scales=[8, 16, 32],
                          anchor_ratios=[0.5, 1, 2])

  if tfmodel:
    print(('Loading model check point from {:s}').format(tfmodel))
    saver = tf.train.Saver()
    saver.restore(sess, tfmodel)
    print('Loaded.')
  else:
    print(('Loading initial weights from {:s}').format(args.weight))
    sess.run(tf.global_variables_initializer())
    print('Loaded.')

  test_net(sess, net, imdb, filename, max_per_image=args.max_per_image)

  sess.close()