PytorchのDatasetを層状にfoldingする


既存のDatasetをN-foldするときに、先頭から順に1,2,...N,1,2,...,N,1,2,...,N,1,2,...とデータセットを分割する。時系列データを分割するときに、特定の月や季節にデータが偏らないようにするときに使用する。

from torch.utils.data import Dataset


class LayeredFoldWrapper(Dataset):
    def __init__(self, dataset, n_splits=5, fold=0, valid=False):
        self.dataset = dataset
        self.n_splits = n_splits
        self.fold = fold
        self.valid = valid
        self.valid_index = list(self._valid_index(len(dataset), n_splits, fold))
        self.train_index = list(set(range(len(dataset))) - set(self.valid_index))

    def __len__(self):
        return len(self._get_index_list(self.valid))

    def __getitem__(self, i):
        return self.dataset.__getitem__(self._get_index_list(self.valid)[i])

    def _valid_index(self, N, n_splits, fold):
        """
        N: 全データの数
        n_splits: foldのスプリットの数
        fold: 各foldを指定する値 0<=fold<=n_splits-1
        """
        assert(0<=fold<=n_splits-1)
        return range(n_splits - fold - 1, N+1, n_splits)

    def _get_index_list(self, valid):
        if valid:
            return self.valid_index
        else:
            return self.train_index