【PyTorch】DataLoaderのshuffleとは


なんとなく利用していたDataLoader。さらになんとなく利用していた引数shuffle。本記事では引数shuffleにより、サンプル抽出がどのように変わるのかをコードともに残しておく。
下記の質問に回答できればスルーでOK。

dataloader = torch.utils.data.DataLoader(dataset, batch_size, shuffle = True)
Trainer.fit(model, dataloader)

Shuffle Trueの場合

  • dataloader定義時のみサンプルはシャッフルされる?
  • Trainer.fit実行すると、Epoch毎にサンプルはシャッフルされる?

結論

DataLoaderのshuffleは、データセットからサンプルを抽出する際の挙動を決める引数である。DataLoader定義時ではなく、DataLoaderが呼び出されるたびにサンプルはシャッフルされる。Trainer.fit実行すると、Epoch毎にDataLoaderが呼び出され、サンプルはシャッフルされる。

Shuffle Falseの場合

  • データセットの上から順番に、サンプルを抽出

Shuffle Trueの場合

  • データセットからランダムに、サンプルを抽出

詳細

ShuffleをTrueにすることで、すべてのバッチのサンプル抽出はランダムに行われる。
Trainer.fitで学習を進める際には、下記のようにDataLoaderを実装する。

dataloader = torch.utils.data.DataLoader(dataset, batch_size, shuffle = True)
Trainer.fit(model, dataloader)

私が勘違いしていたこと

(誤)

dataloader作成時に、サンプルはランダム抽出される。
Trainer.fitにはサンプル抽出後のdataloaderが代入されている。
そのため、Shuffle Trueでも、1epoch目と2epoch目のサンプルの組み合わせは同じ。

(正)

dataloaderは、呼び出されるたびにサンプルをランダムに抽出する。
Trainer.fit内部では、epochが変わるたびにdataloaderを呼び出し、サンプルをランダムに抽出している。
そのため、Shuffle Trueであれば、1epoch目と2epochでもサンプルの組み合わせは異なる。

実装

  1. 事前準備

  2. DataLoader検証
    2.1 Shuffle Falseの場合
    2.2 Shuffle Trueの場合

    • dataloader定義時のみサンプルはシャッフルされる? -> 呼び出し時に実行される
    • Trainer.fit実行すると、Epoch毎にサンプルはシャッフルされる? -> される

    2.3 おまけ drop_last

 事前準備

# ライブラリの読込
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl

# サンプルデータの作成

#- 入力値: 3変数、11sample
x = torch.tensor([
                              [0, 0, 0],
                              [1, 1, 1],
                              [2, 2, 2], 
                              [3, 3, 3], 
                              [4, 4, 4], 
                              [5, 5, 5],
                              [6, 6, 6],
                              [7, 7, 7], 
                              [8, 8, 8], 
                              [9, 9, 9], 
                              [10, 10, 10]])

#- 目標値: 要素数11, 1次元ベクトル
t = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])

# datasetの作成
dataset = torch.utils.data.TensorDataset(x, t)

#バッチサイズ定義
batch_size  = 5

DataLoader検証

Shuffle Falseの場合

dataloader = torch.utils.data.DataLoader(dataset, batch_size, shuffle = False, drop_last = True)
for tmp in iter(dataloader):
    print(tmp)

実行結果
0-4, 5-9と上から順番にサンプルが抽出されている。

[tensor([[0, 0, 0],
        [1, 1, 1],
        [2, 2, 2],
        [3, 3, 3],
        [4, 4, 4]]), tensor([0, 1, 2, 3, 4])]
[tensor([[5, 5, 5],
        [6, 6, 6],
        [7, 7, 7],
        [8, 8, 8],
        [9, 9, 9]]), tensor([5, 6, 7, 8, 9])]

Shuffle Trueの場合

dataloader = torch.utils.data.DataLoader(dataset, batch_size, shuffle = True, drop_last = True)
for tmp in iter(dataloader):
    print(tmp)

実行結果
ランダムにサンプルが抽出されている。

[tensor([[3, 3, 3],
        [5, 5, 5],
        [4, 4, 4],
        [2, 2, 2],
        [7, 7, 7]]), tensor([3, 5, 4, 2, 7])]
[tensor([[6, 6, 6],
        [8, 8, 8],
        [0, 0, 0],
        [9, 9, 9],
        [1, 1, 1]]), tensor([6, 8, 0, 9, 1])]

学習の際の挙動、Epoch毎にミニバッチの内訳は異なる?

->異なる

for i in range(5):
#表示用のprint文
    print('##############')
    print('###  epoch{}  #'.format(i))
    print('##############\n')

# Trainer.fit内部の挙動
    dataloader_tmp = dataloader
    for tmp in iter(dataloader):
        print(tmp)

実行結果
Epoch毎にサンプルがランダムに抽出されている。
dataloaderがTrainer.fit内部で何度もよびだされているため。

##############
###  epoch0  #
##############

[tensor([[ 8,  8,  8],
        [ 3,  3,  3],
        [ 2,  2,  2],
        [10, 10, 10],
        [ 6,  6,  6]]), tensor([ 8,  3,  2, 10,  6])]
[tensor([[7, 7, 7],
        [9, 9, 9],
        [0, 0, 0],
        [1, 1, 1],
        [5, 5, 5]]), tensor([7, 9, 0, 1, 5])]


##############
###  epoch1  #
##############

[tensor([[1, 1, 1],
        [4, 4, 4],
        [7, 7, 7],
        [6, 6, 6],
        [0, 0, 0]]), tensor([1, 4, 7, 6, 0])]
[tensor([[10, 10, 10],
        [ 3,  3,  3],
        [ 9,  9,  9],
        [ 8,  8,  8],
        [ 5,  5,  5]]), tensor([10,  3,  9,  8,  5])]



##############
###  epoch2  #
##############

[tensor([[7, 7, 7],
        [4, 4, 4],
        [3, 3, 3],
        [9, 9, 9],
        [2, 2, 2]]), tensor([7, 4, 3, 9, 2])]
[tensor([[ 1,  1,  1],
        [ 8,  8,  8],
        [10, 10, 10],
        [ 5,  5,  5],
        [ 6,  6,  6]]), tensor([ 1,  8, 10,  5,  6])]


##############
###  epoch3  ##
##############

[tensor([[5, 5, 5],
        [2, 2, 2],
        [1, 1, 1],
        [7, 7, 7],
        [9, 9, 9]]), tensor([5, 2, 1, 7, 9])]
[tensor([[ 0,  0,  0],
        [ 4,  4,  4],
        [10, 10, 10],
        [ 8,  8,  8],
        [ 6,  6,  6]]), tensor([ 0,  4, 10,  8,  6])]

##############
###  epoch4  #
##############

[tensor([[ 3,  3,  3],
        [10, 10, 10],
        [ 0,  0,  0],
        [ 2,  2,  2],
        [ 7,  7,  7]]), tensor([ 3, 10,  0,  2,  7])]
[tensor([[8, 8, 8],
        [6, 6, 6],
        [4, 4, 4],
        [1, 1, 1],
        [9, 9, 9]]), tensor([8, 6, 4, 1, 9])]


おまけ(drop_last)

drop_last = False

# drop last False
dataloader = torch.utils.data.DataLoader(dataset, batch_size, shuffle = False, drop_last = False)
for tmp in iter(dataloader):
    print(tmp)

実行結果

[tensor([[0, 0, 0],
        [1, 1, 1],
        [2, 2, 2],
        [3, 3, 3],
        [4, 4, 4]]), tensor([0, 1, 2, 3, 4])]
[tensor([[5, 5, 5],
        [6, 6, 6],
        [7, 7, 7],
        [8, 8, 8],
        [9, 9, 9]]), tensor([5, 6, 7, 8, 9])]
[tensor([[10, 10, 10]]), tensor([10])]

drop_last = True

# drop last True
dataloader = torch.utils.data.DataLoader(dataset, batch_size, shuffle = False, drop_last = True)
for tmp in iter(dataloader):
    print(tmp)

実行結果

[tensor([[0, 0, 0],
        [1, 1, 1],
        [2, 2, 2],
        [3, 3, 3],
        [4, 4, 4]]), tensor([0, 1, 2, 3, 4])]
[tensor([[5, 5, 5],
        [6, 6, 6],
        [7, 7, 7],
        [8, 8, 8],
        [9, 9, 9]]), tensor([5, 6, 7, 8, 9])]

参照