sklearn手書きデータセットDataloader分batchトレーニング


from sklearn.datasets import load_digits
from torch.utils.data import DataLoader
import numpy as np

digits = load_digits()
img = digits['images']  #(1797,8,8)
img = img[:,np.newaxis,:,:] #(1797,1,8,8)
dataloader = DataLoader(img, batch_size=4, shuffle=True, num_workers=0, drop_last=True)       
for i,img in enumerate(dataloader):
    print(img.size())  ##(4,1,8,8)