MNISTのデータセットの簡易可視化


はじめに

機械学習のサンプルはMNISTの手書き文字認識をやることが多いんだけど、「機械学習する人はMNISTのデータがどういうものかくらい知ってるでしょ?」って感じで、データ構造の説明が無いことが多い。

ってなわけで、機械学習のド素人がChainerのサンプルで使われるMNISTの手書きデータセットがどういうものかよく見てみる。Chainerはインストール済みであること前提。

データ構造

まずはデータをロードしてみる。chainer.datasets.get_mnist()でデータセットが読み込まれる。そのtestをあれこれしてみる。

>>> import chainer
>>> train, test = chainer.datasets.get_mnist()
>>> test
<chainer.datasets.tuple_dataset.TupleDataset object at 0x2aaab303c910>
>>> len(test)
10000
>>> len(test[0])
2
>>> len(test[0][0])
784
>>> test[0][0]
(snip)
        0.        ,  0.        ,  0.        ,  0.        ], dtype=float32)
>>> test[0][1]
7

上記から以下のようなことがわかる。

  • testは10000個のサンプルデータからなるTupleDatasetクラス
  • test[i]はi番目のサンプルデータ
  • test[i][0]はi番目のサンプルの784個のfloat32のデータ
  • test[i][1]は整数で、i番目の手書きサンプルの「正解」を表す

784 = 28**2 なので、28×28のグリッドのグレースケールのデータなんだろう。

可視化

というわけで、これを二値化して可視化してみる。とりあえず0番のデータ決め打ちで。

import chainer
n = 0
train, test = chainer.datasets.get_mnist()
s = 28
print test[n][1]
for i in range(1,28):
  for j in range(1,28):
    v =  test[n][0][j + i*s]
    if v > 0.1:
      print "*",
    else:
      print " ",
  print "\n"

最初に正解を、次に手書きデータを出力する。結果はこんな感じ。

$ python show_mnist.py 
7












          * * * * * *                                 

          * * * * * * * * * * * * * * * *             

          * * * * * * * * * * * * * * * *             

                      *   * * * *   * * *             

                                  * * *               

                                  * * *               

                                * * * *               

                              * * * *                 

                              * * *                   

                              * * *                   

                            * * *                     

                          * * * *                     

                          * * *                       

                        * * * *                       

                      * * * *                         

                    * * * *                           

                    * * * *                           

                  * * * * *                           

                  * * * * *                           

                  * * *                               





0番のデータは数字の「7」ですね。世の中の機械学習のサンプルはこれを入力にして、どの数字の可能性が高いかを判定している模様。