TensorFlowでのmnistデータセットの操作と可視化

3396 ワード

from tensorflow.examples.tutorials.mnist import input_data

まず、データセットをインターネットでダウンロードする必要があります.
mnsit = input_data.read_data_sets(train_dir='./MNIST_DATA', one_hot=True)
    #   MNIST_DATA, ,  mnist  

トレーニングセットとテストセットの区分:
X_train, y_train = mnist.train.images, mnist.train.labels
        #   X_train   numpy    ,(55000, 784)
X_test, y_test = mnist.test.images, mnist.test.labels
        # (10000, 784)
X_valid, y_valid = mnist.valid.images, mnist.valid.labels
        # (5000, 784)

もちろん反復形式で一定batch_sizeデータの読み出し:
mnist.train.next_batch(100)
  • mnist.train.next_batch()⇒画像データ、画像データに対応するカテゴリ情報の2つの値を返します.
    >> X_batch, y_batch = mnist.train.next_batch(100)
    >> X_batch.shape
    (100, 784)
    >> y_batch.shape
    (100, 10)                 # one hot  

  • 1.可視化

    # images:9*(28*28)   numpy.ndarray
    # y_  
    def plot_mnist_3_3(images, y_, y=None):
        assert images.shape[0] == len(y_)
        fig, axes = plt.subplots(3, 3)
        for i, ax in enumerate(axes.flat):
            ax.imshow(images[i].reshape(image_shp), cmap='binary')
            if y is None:
                xlabel = 'True: {}'.format(y_[i])
            else:
                xlabel = 'True: {0}, Pred: {1}'.format(y_[i], y[i])
            ax.set_xlabel(xlabel)
            ax.set_xticks([])
            ax.set_yticks([])
        plt.show()