MatplotlibでのrcParams使用
1792 ワード
CSの課程を独学し終わって、自分で宿題を書いて、画像の可視化の問題に出会って、rcParamsの使用に関連して、少し勉強します.
ついでにCS作業でサンプルを可視化した関数を貼って
plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest' # interpolation style
plt.rcParams['image.cmap'] = 'gray'
plt.rcParams['savefig.dpi'] = 300 #
plt.rcParams['figure.dpi'] = 300 #
# :[6.0,4.0], 100, 600&400
# dpi=200, 1200*800
# dpi=300, 1800*1200
# figsize
ついでにCS作業でサンプルを可視化した関数を貼って
def VisualizeImage(X_train, y_train):
"""
:X_train:
:y_train:
"""
plt.rcParams['figure.figsize'] = (10.0, 8.0) #
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
num_classes = len(classes)
samples_per_class = 8
for y, cls in enumerate(classes):
#
idxs = np.flatnonzero(y_train == y)
# 8 ,replace False
idxs = np.random.choice(idxs, samples_per_class, replace=False)
# 8
for i, idx in enumerate(idxs):
plt_idx = i * num_classes + y + 1
#
plt.subplot(samples_per_class, num_classes, plt_idx)
plt.imshow(X_train[idx].astype('uint8'))
plt.axis('off')
#
if i == 0:
plt.title(cls)
plt.show()