MatplotlibでのrcParams使用

1792 ワード

CSの課程を独学し終わって、自分で宿題を書いて、画像の可視化の問題に出会って、rcParamsの使用に関連して、少し勉強します.
plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest' #    interpolation style
plt.rcParams['image.cmap'] = 'gray'  
plt.rcParams['savefig.dpi'] = 300 #    
plt.rcParams['figure.dpi'] = 300 #   
#      :[6.0,4.0],    100,      600&400
#   dpi=200,      1200*800
#   dpi=300,      1800*1200
#   figsize                

ついでにCS作業でサンプルを可視化した関数を貼って
def VisualizeImage(X_train, y_train):
    """   

    :X_train:    
    :y_train:     
    
    """
    plt.rcParams['figure.figsize'] = (10.0, 8.0)  #       
    plt.rcParams['image.interpolation'] = 'nearest'
    plt.rcParams['image.cmap'] = 'gray' 
    classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
    num_classes = len(classes)
    samples_per_class = 8
    for y, cls in enumerate(classes):
        #               
        idxs = np.flatnonzero(y_train == y)
        #              8   ,replace  False             
        idxs = np.random.choice(idxs, samples_per_class, replace=False)
        #       8       
        for i, idx in enumerate(idxs):
            plt_idx = i * num_classes + y + 1
            #      
            plt.subplot(samples_per_class, num_classes, plt_idx)
            plt.imshow(X_train[idx].astype('uint8'))
            plt.axis('off')
            #     
            if i == 0:
                plt.title(cls)
    plt.show()