sklearn.model_selection.GroupKFold
3829 ワード
グループKの交差検証:sklearn.model_selection.GroupKFold(n_splits=3)
パラメータの説明:
n_splits:折り返し、デフォルトは3、少なくとも2
注意点:同じグループのサンプルが同じ折り返しのテストセットとトレーニングセットに同時に現れることはありません.
①データセットは均等であり、各グループのサンプルも均等である
パラメータの説明:
n_splits:折り返し、デフォルトは3、少なくとも2
注意点:同じグループのサンプルが同じ折り返しのテストセットとトレーニングセットに同時に現れることはありません.
①データセットは均等であり、各グループのサンプルも均等である
In [11]: from sklearn.model_selection import GroupKFold
...: import numpy as np
...: X = np.arange(24).reshape(12,2)
...: y = np.array([1,1,2,3,1,2,3,2,2,3,3,1])
...: groups = np.array([1,2,3,4,5,6,1,2,3,4,5,6])
...: kf = GroupKFold(n_splits=6)
...: for train_index , test_index in kf.split(X,y,groups):
...: print('train_index:%s , test_index: %s ' %(train_index,test_index)
...: )
...: print('train_groups:%s , test_groups: %s ' %(groups[train_index],g
...: roups[test_index]))
...:
train_index:[ 0 1 2 3 4 6 7 8 9 10] , test_index: [ 5 11]
train_groups:[1 2 3 4 5 1 2 3 4 5] , test_groups: [6 6]
train_index:[ 0 1 2 3 5 6 7 8 9 11] , test_index: [ 4 10]
train_groups:[1 2 3 4 6 1 2 3 4 6] , test_groups: [5 5]
train_index:[ 0 1 2 4 5 6 7 8 10 11] , test_index: [3 9]
train_groups:[1 2 3 5 6 1 2 3 5 6] , test_groups: [4 4]
train_index:[ 0 1 3 4 5 6 7 9 10 11] , test_index: [2 8]
train_groups:[1 2 4 5 6 1 2 4 5 6] , test_groups: [3 3]
train_index:[ 0 2 3 4 5 6 8 9 10 11] , test_index: [1 7]
train_groups:[1 3 4 5 6 1 3 4 5 6] , test_groups: [2 2]
train_index:[ 1 2 3 4 5 7 8 9 10 11] , test_index: [0 6]
train_groups:[2 3 4 5 6 2 3 4 5 6] , test_groups: [1 1]
②折数均等、組数不平衡In [13]: #sklearn.model_selection.GroupKFold(n_splits=3)
...: from sklearn.model_selection import GroupKFold
...: import numpy as np
...: X = np.arange(24).reshape(12,2)
...: y = np.array([1,1,2,3,1,2,3,2,2,3,3,1])
...: groups = np.array([1,2,3,4,5,6,1,2,3,4,5,7])
...: kf = GroupKFold(n_splits=4)
...: for train_index , test_index in kf.split(X,y,groups):
...: print('train_index:%s , test_index: %s ' %(train_index,test_index)
...: )
...: print('train_groups:%s , test_groups: %s ' %(groups[train_index],g
...: roups[test_index]))
...:
train_index:[ 1 2 3 5 7 8 9 11] , test_index: [ 0 4 6 10]
train_groups:[2 3 4 6 2 3 4 7] , test_groups: [1 5 1 5]
train_index:[ 0 1 2 4 5 6 7 8 10] , test_index: [ 3 9 11]
train_groups:[1 2 3 5 6 1 2 3 5] , test_groups: [4 4 7]
train_index:[ 0 1 3 4 6 7 9 10 11] , test_index: [2 5 8]
train_groups:[1 2 4 5 1 2 4 5 7] , test_groups: [3 6 3]
train_index:[ 0 2 3 4 5 6 8 9 10 11] , test_index: [1 7]
train_groups:[1 3 4 5 6 1 3 4 5 7] , test_groups: [2 2]
③割引数が不均等で、グループ数が不均衡であるIn [14]: from sklearn.model_selection import GroupKFold
...: import numpy as np
...: X = np.arange(24).reshape(12,2)
...: y = np.array([1,1,2,3,1,2,3,2,2,3,3,1])
...: groups = np.array([1,2,3,4,5,6,1,2,3,4,5,3])
...: kf = GroupKFold(n_splits=5)
...: for train_index , test_index in kf.split(X,y,groups):
...: print('train_index:%s , test_index: %s ' %(train_index,test_index)
...: )
...: print('train_groups:%s , test_groups: %s ' %(groups[train_index],g
...: roups[test_index]))
...:
train_index:[ 0 1 3 4 5 6 7 9 10] , test_index: [ 2 8 11]
train_groups:[1 2 4 5 6 1 2 4 5] , test_groups: [3 3 3]
train_index:[ 0 1 2 3 6 7 8 9 11] , test_index: [ 4 5 10]
train_groups:[1 2 3 4 1 2 3 4 3] , test_groups: [5 6 5]
train_index:[ 0 1 2 4 5 6 7 8 10 11] , test_index: [3 9]
train_groups:[1 2 3 5 6 1 2 3 5 3] , test_groups: [4 4]
train_index:[ 0 2 3 4 5 6 8 9 10 11] , test_index: [1 7]
train_groups:[1 3 4 5 6 1 3 4 5 3] , test_groups: [2 2]
train_index:[ 1 2 3 4 5 7 8 9 10 11] , test_index: [0 6]
train_groups:[2 3 4 5 6 2 3 4 5 3] , test_groups: [1 1]