Cifar-10はデータを導入し、切り取り部分はknn,svm,pcaを分析する

23106 ワード

# -*- coding: utf-8 -*-
"""
Created on Wed May 27 22:56:16 2020

@author: guangjie2333

"""


"""

   

"""
import numpy as np
import pickle
import matplotlib.pyplot as plt
import PIL.Image as image
from sklearn.neighbors import KNeighborsClassifier
from sklearn import metrics
from sklearn.decomposition import PCA
from sklearn import svm
"""

    

"""

#      
def unpickle(file):  
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
        fo.close()
    return dict



""""
   
"""   

if __name__ == '__main__':  
    
    #------   -------
    
    # (x_train,y_train),(x_test,y_test) = cifar10.load_data()
    
    #        
    batches = unpickle('batches.meta')
    # print(batches.keys())
    name = batches.get(b'label_names') #      b?
    # print(name[0])

    data_batch1 = unpickle('data_batch_1')
    data_batch2 = unpickle('data_batch_2')
    data_batch3 = unpickle('data_batch_3')
    data_batch4 = unpickle('data_batch_4')
    data_batch5 = unpickle('data_batch_5')
    test_batch  = unpickle('test_batch')
    
    # print(data_batch1.keys())
    
    
    #    
    #axis=1     ,axis=0     
    x_train = np.concatenate([data_batch1[b'data'],
                              data_batch2[b'data'],
                              data_batch3[b'data'],
                              data_batch4[b'data'],
                              data_batch5[b'data']],axis=0)
    
    y_train = np.concatenate([data_batch1.get(b'labels'),
                              data_batch2.get(b'labels'),
                              data_batch3.get(b'labels'),
                              data_batch4.get(b'labels'),
                              data_batch5.get(b'labels')],axis=0)
    
    
    #    
    x_test = test_batch.get(b'data')
    
    y_test = test_batch.get(b'labels')

    


 
    # #      x_train
    
    classifle = 10   #   10      
    
    picture_num = 5  #     5  
    
    classplot = 0    #        
    
    classplot_y = 0  #        5 ,      classplot_y 
    
    # #       
    # # img0 = x_train[0]
    # # img_reshape = img0.reshape(3,32,32)
    # # r = image.fromarray(img_reshape[0]).convert('L')
    # # g = image.fromarray(img_reshape[1]).convert('L')
    # # b = image.fromarray(img_reshape[2]).convert('L')
    # # img_m = image.merge('RGB',(r,g,b))
    # # plt.imshow(img_m)
    # # plt.show()
 
    #   #-------   --------
    plt.figure(figsize=(5, 10))
    
    for classplot in range(classifle):
        j = 0
        # 5w      (        )  
        for i in y_train:
            #          
            if i ==  classplot :
                if classplot_y :
                    sub = plt.subplot(picture_num, classifle, classplot_y * classifle + classplot+1)
                else:
                    # 0    
                    sub = plt.subplot(picture_num, classifle, classplot_y * classifle + classplot+1,
                                                title=name[classplot])
                                  
                   
                sub.axis("off")
                img0 = x_train[j]
                img_reshape = img0.reshape(3,32,32)
                r = image.fromarray(img_reshape[0]).convert('L')
                g = image.fromarray(img_reshape[1]).convert('L')
                b = image.fromarray(img_reshape[2]).convert('L')
                img_m = image.merge('RGB',(r,g,b))
                sub.imshow(img_m)
                classplot_y = classplot_y + 1;
                   
     
            j = j + 1
            
            if classplot_y >= 5 :  
                classplot_y = 0;
                break;
    
    
    
    #  ,  import          
    # x_train = x_train.reshape(50000, 3, 32, 32)
    
    # for classplot in range(classifle):
    #     pos = np.argwhere(y_train == classplot)[0:picture_num,0]
    #     for i in range(picture_num):
    #         plt.subplot(picture_num, classifle, i * classifle + classplot+1)
    #         plt.imshow(x_train[pos[i]])
    #         if i == 0:
    #             plt.title(name[classplot])
    #         plt.axis('off')   
                
                
    #-------   --------
    x_train = x_train[0:500]
    y_train = y_train[0:500]
    x_test = x_test[0:500]
    y_test = y_test[0:500]
    
    
    
     #-------   --------
    neigh = KNeighborsClassifier(n_neighbors=1)
    neigh.fit(x_train, y_train)
    y_test_predict = neigh.predict(x_test)
    
    #    
    print("Accuracy1",metrics.accuracy_score(y_test,y_test_predict));
    
    
    #-------   --------
    
    #        n_neighbors=1
    
    
    #-------   --------  
    
    pca = PCA(n_components=2)
    pca.fit(x_train) 
    x_train_reduction = pca.transform(x_train)
    x_test_reduction  = pca.transform(x_test)
    
    
    #-------   -------- 
    knn_pca = KNeighborsClassifier(n_neighbors=1)
    knn_pca.fit(x_train_reduction,y_train)
    y_test_predict = knn_pca.predict(x_test_reduction)
        
    #    
    print("Accuracy2",metrics.accuracy_score(y_test,y_test_predict));
        
        
    #-------   -------- 
    clf = svm.SVC(kernel = 'rbf', C = 1000, gamma=0.5)
    clf.fit(x_train_reduction,y_train)
    y_test_predict = clf.predict(x_test_reduction)
    
    #    
    print("Accuracy3",metrics.accuracy_score(y_test,y_test_predict));



    #PCA      ,        ,           
    #   ,knn SVM        ,  SVM      。  knn       ,       ,SVM  。