sklearn手書きデータセットDataloader分batchトレーニング
2524 ワード
from sklearn.datasets import load_digits
from torch.utils.data import DataLoader
import numpy as np
digits = load_digits()
img = digits['images'] #(1797,8,8)
img = img[:,np.newaxis,:,:] #(1797,1,8,8)
dataloader = DataLoader(img, batch_size=4, shuffle=True, num_workers=0, drop_last=True)
for i,img in enumerate(dataloader):
print(img.size()) ##(4,1,8,8)