ローカルMNISTデータセット読み取り(最も明確で最も実用的な)コード

12106 ワード

多くの書籍やブログでコード事例を紹介している場合、使用するMNISTデータセットはコードの中で直接ダウンロードして使用されるため、一人一人の機器を考慮せずに直接実行できるが、データセットがダウンロードできない可能性があり、実行処理速度が遅いという弊害がある.
そこで、このブログでは、ローカルでダウンロードしたMNISTデータセットを解凍して使用するコードを提供します.必要に応じて、データを1次元配列に展開するかどうか、データ正規化、one-hot符号化のパラメータを与えることができ、トレーニングを容易にすることができます.
One-Hot符号化は分類変数をバイナリベクトルとして表す.これは、まず分類値を整数値にマッピングする必要があります.その後、各整数値はバイナリベクトルとして表され、整数のインデックスを除いてゼロ値であり、1としてマークされる.
import numpy as np
import os
import gzip
import pickle

#          ,data_folder   gz      ,      4   
# 'train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz',
# 't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz'

"""  MNIST   
Parameters
----------
normalize :            0.0~1.0
one_hot_label : 
    one_hot_label True    ,    one-hot    
    one-hot    [0,0,1,0,0,0,0,0,0,0]     
flatten :             

Returns
-------
(    ,     ), (    ,     )
"""

train_num = 60000
test_num = 10000
img_dim = (1, 28, 28)
img_size = 784


def _change_one_hot_label(X):
    T = np.zeros((X.size, 10))
    for idx, row in enumerate(T):
        row[X[idx]] = 1

    return T


def load_data(data_folder, normalize=True, flatten=True, one_hot_label=False):

    files = [
      'train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz',
      't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz'
    ]

    paths = []
    for fname in files:
        paths.append(os.path.join(data_folder,fname))

    with gzip.open(paths[0], 'rb') as lbpath:
        y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)

    with gzip.open(paths[1], 'rb') as imgpath:
        x_train = np.frombuffer(
        imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)

    with gzip.open(paths[2], 'rb') as lbpath:
        y_test = np.frombuffer(lbpath.read(), np.uint8, offset=8)

    with gzip.open(paths[3], 'rb') as imgpath:
        x_test = np.frombuffer(
        imgpath.read(), np.uint8, offset=16).reshape(len(y_test), 28, 28)

    if normalize:
        x_train = x_train.astype(np.float32) / 255.0  #      
        x_test = x_test.astype(np.float32) / 255.0

    if one_hot_label:
        y_train = _change_one_hot_label(y_train)
        y_test = _change_one_hot_label(y_test)

    if flatten:
        x_train = x_train.reshape(60000, 784)
        x_test = x_test.reshape(10000, 784)

    return (x_train, y_train), (x_test, y_test)


Reference
斎藤康毅『深度学習入門——pythonに基づく理論と実現』[M].2016
https://blog.csdn.net/AugustMe/article/details/90604473