交差検証 CombinatorialPurgedCV Python [備忘録]


背景

Kaggle [Jane Street Market Prediction] にてfinanceでよく利用される交差検証を試した.

ノイズの影響が大きく、LBも信用できないので交差検証が非常に大切だった.

概要

  • 赤:train
  • 青:validation

図は全体の1/4を検証データとし、trainとvalidationの間に10dayのギャップを設定している.

実装

簡単のためN_COMB=2で固定.

import numpy as np
import pandas as pd
import itertools
class CombinatorialPurgedKFold:
    N_COMB = 2

    def __init__(self, data: pd.DataFrame, date_col_nm: str, n_block: int, n_gap_day: int = 5):
        self.n_block = n_block
        self.n_gap_day = n_gap_day

        self.uni_date = np.unique(data[date_col_nm].values)
        self.date_blocks = np.array_split(self.uni_date, self.N_COMB * n_block)
        self.valid_comb = list(itertools.combinations(list(range(self.N_COMB * n_block)), self.N_COMB))

        if np.min([len(db) for db in self.date_blocks]) < self.n_gap_day * 2:
            raise ValueError()

        self.__i = -1
        self.__data = data
        self.__date_col_nm = date_col_nm

    def __iter__(self):
        self.__i = -1
        return self

    def __next__(self):
        self.__i += 1
        if self.__i < 0 or len(self.valid_comb) <= self.__i:
            raise StopIteration()
        return self.splits()

    def splits(self):
        val_idx1, val_idx2 = self.valid_comb[self.__i]

        val_dates = list()
        s1 = 0 if val_idx1 == 0 else self.n_gap_day
        e2 = len(self.date_blocks[val_idx2]) if val_idx2 == (len(self.date_blocks) - 1) else -self.n_gap_day
        s2, e1 = (0, len(self.date_blocks[val_idx1])) if (val_idx1 + 1) == val_idx2 else (self.n_gap_day, -self.n_gap_day)

        val_dates += self.date_blocks[val_idx1][s1:e1].tolist()
        val_dates += self.date_blocks[val_idx2][s2:e2].tolist()     
        tra_dates = [d for d in self.uni_date if d not in (self.date_blocks[val_idx1].tolist() + self.date_blocks[val_idx2].tolist())]

        tra_df, val_df = self.__data.query(f'{self.__date_col_nm} in {tra_dates}'), self.__data.query(f'{self.__date_col_nm} in {val_dates}')
        return tra_df, val_df

使用例

cpkf = CombinatorialPurgedKFold(data=train, date_col_nm='date', n_block=4, n_gap_day=10)

for tr_df, va_df in cpkf:
    print(tr_df, va_df)

参考