PyTorchでクロスバリデーション


はじめに

Pytorch で Dataset を使用するときのクロスバリデーションのやり方を説明します。

Subsetを使用した分割

torch.utils.data.dataset.Subsetを使用するとインデックスを指定してDatasetを分割することが出来ます。これとscikit-learnのsklearn.model_selectionを組み合わせます。

train_test_split

sklearn.model_selection.train_test_splitを使用してインデックスをtrain_indexvalid_indexに分割し、Subsetを使用してDatasetを分割します。

from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataset import Subset
from sklearn.model_selection import train_test_split


dataset = get_dataset()

train_index, valid_index = train_test_split(range(len(dataset)), test_size=0.3)

batch_size = 16
train_dataset = Subset(dataset, train_index)
train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True)
valid_dataset   = Subset(dataset, valid_index)
valid_dataloader = DataLoader(valid_dataset, batch_size, shuffle=False)

# ここに学習コード

KFoldクロスバリデーション

sklearn.model_selection.KFoldを使用してインデックスをtrain_indexvalid_indexに分割し、Subsetを使用してDatasetを分割します。

from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataset import Subset
from sklearn.model_selection import KFold


dataset = get_dataset()

batch_size = 16
kf = KFold(n_splits=3)

cv = 0
for _fold, (train_index, test_index) in enumerate(kf.split(X)):
    train_dataset = Subset(dataset, train_index)
    train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True)
    valid_dataset   = Subset(dataset, valid_index)
    valid_dataloader = DataLoader(valid_dataset, batch_size, shuffle=False)

    for i in range(num_epochs):
        # ここに学習コード

    cv += valid_loss / kf.n_splits

クラス分類のDatasetであればdataset[:][1]とすればyの値を取得することができるはずなので、StratifiedKFoldもできるはずです。