Python+OpenCVインタフェースアプリケーションcaffe予測分類
2735 ワード
#import caffe
#import lmdb
import numpy as np
import cv2
#from caffe.proto import caffe_pb2
import os
import sys
import time
import id_deploy
import caffe
#caffe.set_mode_gpu()
#
def dirlist(path, allfile):
filelist = os.listdir(path)
for filename in filelist:
filepath = os.path.join(path, filename)
if os.path.isdir(filepath):
dirlist(filepath, allfile)
else:
allfile.append(filepath)
return allfile
# sys.setrecursionlimit(1000000)
#
def is_bgr_img(img):
bools = True
try:
a, b, c = img.shape
except AttributeError:
bools = False
return bools
dirs = ['0_BM', '1_WD', '2_DO', '3_WJ', '4_WT', '5_YM', '6_CD','7_SG','8_TW','9_WZ']
imgnames = dirlist('F:zfh20181012', [])
path ='f:/out_zfh/'
caffe_model='AI_Stomach.caffemodel' # caffemodel
temp = imgnames[0]
#print(temp.split('\\')[-2].split('_')[0])
print(temp)
deploy_temp_file = 'deepid_capsule_protxt.prototxt'
handle_file = open(deploy_temp_file,'w')
handle_file.write(id_deploy._deepid_capsule_protxt)
handle_file.close()
#load model
caffenet = cv2.dnn.readNetFromCaffe(deploy_temp_file, caffe_model)
#delete deploy file
os.remove(deploy_temp_file)
#transformer = caffe.io.Transformer({'data': (1,3,128,128)}) # shape (1,3,28,28)
#transformer.set_transpose('data', (2, 0, 1))
for imgname in imgnames:
image = cv2.imread(imgname)
print(imgname)
temp = imgname
try:
image.shape
except AttributeError:
print(imgname)
os.remove(imgname)
continue
# BLOB , : , , , ,RGB BGR,
blob = cv2.dnn.blobFromImage(image, 1.0, (128, 128), (0,0,0),False,False)
caffenet.setInput(blob)
detections = caffenet.forward('softmax')
prob = detections[0]
order=prob.argsort()[-1]
prob_class = prob[order]
print('the predict class is:',order)
if prob_class > 0.5:
imgname = temp.split('\\')[-1]
imgpath = path + dirs[order]
if not os.path.exists(imgpath):
os.mkdir(imgpath)
cv2.imwrite(imgpath+'/'+imgname, image)
else:
imgname = temp.split('\\')[-1]
imgpath = path + 'unkown'
if not os.path.exists(imgpath):
os.mkdir(imgpath)
cv2.imwrite(imgpath+'/'+imgname, image)
cv2.imshow('cv2', image)
k = cv2.waitKey(1)
if k == 27:
break
if k == 32:
cv2.waitKey()