Djangoと組み合わせてモデルを利用してアップロード画像の予測を行う

6268 ワード

1前処理
(1)アップロードした画像を100*100サイズに前処理する
def prepicture(picname):
    img = Image.open('./media/pic/' + picname)
    new_img = img.resize((100, 100), Image.BILINEAR)
    new_img.save(os.path.join('./media/pic/', os.path.basename(picname)))

(2)画像を配列に変換する
def read_image2(filename):
    img = Image.open('./media/pic/'+filename).convert('RGB')
    return np.array(img)

2モデルによる予測
def testcat(picname):
    #         100 x 100
    prepicture(picname)
    x_test = []

    x_test.append(read_image2(picname))

    x_test = np.array(x_test)

    x_test = x_test.astype('float32')
    x_test /= 255

    keras.backend.clear_session() #  session      
    model = Sequential()
    model.add(Conv2D(32, (3, 3), activation='relu', input_shape=(100, 100, 3)))
    model.add(Conv2D(32, (3, 3), activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.25))

    model.add(Conv2D(64, (3, 3), activation='relu'))
    model.add(Conv2D(64, (3, 3), activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.25))

    model.add(Flatten())
    model.add(Dense(256, activation='relu'))
    model.add(Dropout(0.5))
    model.add(Dense(4, activation='softmax'))

    sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
    model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])


    model.load_weights('./cat/cat_weights.h5')
    classes = model.predict_classes(x_test)[0]
    # target = ['   ', '   ', '   ', '     ']
    # print(target[classes])
    return classes

3 Djangoと結合
viewsでモデルを呼び出して画像分類を行う
def catinfo(request):
    if request.method == "POST":
        f1 = request.FILES['pic1']
        #     
        fname = '%s/pic/%s' % (settings.MEDIA_ROOT, f1.name)
        with open(fname, 'wb') as pic:
            for c in f1.chunks():
                pic.write(c)
        #     
        fname1 = './static/img/%s' % f1.name
        with open(fname1, 'wb') as pic:
            for c in f1.chunks():
                pic.write(c)

        num = testcat(f1.name)
        #      id 1        
        #          id=0     id=4
        #           
        # if(num == 0):
        #   num = 4 
        #   id      
        name = models.Catinfo.objects.get(id = num)
        return render(request, 'info.html', {'nameinfo': name.nameinfo, 'feature': name.feature, 'livemethod': name.livemethod, 'feednn': name.feednn, 'feedmethod': name.feedmethod, 'picname': f1.name})
    else:
        return HttpResponse("    !")