[ハンマー]クロス検証とグリッド検索
検証セット
モデルを評価してスーパーパラメータ調整を行うと、トレーニングセットから取り出したデータセットはテストセットを使用しなくなります.
クロス検証
トレーニングセットを複数のフォルダに分割し、1つのフォルダが検証セットの役割を果たし、残りのフォルダでモデルトレーニングを行います.これは、すべてのロッドについて検証スコアを取得し、平均する方法である.
グリッドサーチ(Grid Search)
スーパーパラメータ探索の自動化ツール.参照するパラメータをリストした後、クロス検証を行い、最適な検証ポイントのパラメータの組合せを選択します.最後に,これらのパラメータの組合せを用いて最終モデルを訓練した.
ランダムサーチ
連続パラメータの値を参照するときに便利です.検索する値を直接リストするのではなく、サンプリング可能な確率分布オブジェクトを渡します.指定したサンプリング回数に基づいてクロス検証を行うため、検索量を調整できます.
cross_validate()
クロス検証を実行する関数.
GridSearchCV
クロス検証によるスーパーパラメータ検索.ベストモデルを見つけたら、トレーニングセット全体を使って最終モデルを訓練します.
RandomizedSearchCV
クロス検証によりランダムスーパーパラメータ検索を行います.ベストモデルを見つけたら、トレーニングセット全体を使って最終モデルを訓練します.
検証セット
import pandas as pd
wine = pd.read_csv('https://bit.ly/wine-date')
data = wine[['alcohol', 'sugar', 'pH']].to_numpy()
target = wine['class'].to_numpy()
from sklearn.model_selection import train_test_split
train_input, test_input, train_target, test_target = train_test_split(data, target, test_size=0.2, random_state=42)
sub_input,val_input,sub_target,val_target=train_test_split(train_input,train_target,test_size=0.2,random_state=42)
import pandas as pd
wine = pd.read_csv('https://bit.ly/wine-date')
data = wine[['alcohol', 'sugar', 'pH']].to_numpy()
target = wine['class'].to_numpy()
from sklearn.model_selection import train_test_split
train_input, test_input, train_target, test_target = train_test_split(data, target, test_size=0.2, random_state=42)
sub_input,val_input,sub_target,val_target=train_test_split(train_input,train_target,test_size=0.2,random_state=42)
from sklearn.tree import DecisionTreeClassifier
dt = DecisionTreeClassifier(random_state=42)
dt.fit(sub_input, sub_target)
print(dt.score(sub_input, sub_target))
print(dt.score(val_input, val_target))
クロス検証 from sklearn.model_selection import cross_validate
scores=cross_validate(dt,train_input,train_target)
print(scores)
from sklearn.model_selection import cross_validate
scores=cross_validate(dt,train_input,train_target)
print(scores)
import numpy as np
print(np.mean(scores['test_score']))
from sklearn.model_selection import StratifiedKFold
scores=cross_validate(dt,train_input,train_target,cv=StratifiedKFold())
print(np.mean(scores['test_score']))
splitter=StratifiedKFold(n_splits=10,shuffle=True,random_state=42)
scores=cross_validate(dt,train_input,train_target,cv=splitter)
print(np.mean(scores['test_score']))
グリッドサーチ(1)
from sklearn.model_selection import GridSearchCV
params={'min_impurity_decrease':[0.0001,0.0002,0.0003,0.0004,0.0005]}
from sklearn.model_selection import GridSearchCV
params={'min_impurity_decrease':[0.0001,0.0002,0.0003,0.0004,0.0005]}
gs=GridSearchCV(DecisionTreeClassifier(random_state=42),params,n_jobs=-1)
gs.fit(train_input,train_target)
dt=gs.best_estimator_
print(dt.score(train_input,train_target))
print(gs.best_params_)
print(gs.cv_results_['mean_test_score'])
best_index=np.argmax(gs.cv_results_['mean_test_score'])
print(gs.cv_results_['params'][best_index])
グリッド検索(2)
params = {'min_impurity_decrease': np.arange(0.0001, 0.001, 0.0001),
'max_depth': range(5, 20, 1),
'min_samples_split': range(2, 100, 10)
}
params = {'min_impurity_decrease': np.arange(0.0001, 0.001, 0.0001),
'max_depth': range(5, 20, 1),
'min_samples_split': range(2, 100, 10)
}
gs=GridSearchCV(DecisionTreeClassifier(random_state=42),params,n_jobs=-1)
gs.fit(train_input,train_target)
print(gs.best_params_)
print(np.max(gs.cv_results_['mean_test_score']))
ランダムサーチ from scipy.stats import uniform,randint
params = {'min_impurity_decrease': uniform(0.0001, 0.001),
'max_depth': randint(20, 50),
'min_samples_split': randint(2, 25),
'min_samples_leaf': randint(1, 25),
}
from scipy.stats import uniform,randint
params = {'min_impurity_decrease': uniform(0.0001, 0.001),
'max_depth': randint(20, 50),
'min_samples_split': randint(2, 25),
'min_samples_leaf': randint(1, 25),
}
from sklearn.model_selection import RandomizedSearchCV
gs = RandomizedSearchCV(DecisionTreeClassifier(random_state=42), params, n_iter=100, n_jobs=-1, random_state=42)
gs.fit(train_input, train_target)
print(gs.best_params_)
print(np.max(gs.cv_results_['mean_test_score']))
dt=gs.best_estimator_
print(dt.score(test_input,test_target))
참고문헌: 혼공머신
Reference
この問題について([ハンマー]クロス検証とグリッド検索), 我々は、より多くの情報をここで見つけました https://velog.io/@oooops/교차-검증과-그리드-서치テキストは自由に共有またはコピーできます。ただし、このドキュメントのURLは参考URLとして残しておいてください。
Collection and Share based on the CC Protocol