予測値の95%信頼区間をどうやって求めるか
はじめに
最近、EI戦略を用いるために標準偏差(σ)を求めることが多いため、どうやったら機械学習の予測値に対して標準偏差を求めることができるかを備忘録としてまとめます。
※使ったモデルにの理論的な解説は載せていません。あくまで、どうやって求めることが出来たのかの方法を記述していきます。
まだまだ、勉強が追い付いていない部分が多いため、間違っているところがあればご教示いただけるとありがたいです。
予測する関数
f(x) = x * np.sin(x) + noise
noise = np.random.randn(len(x))
学習に使う値
x = np.linspace(0,10,21)
y = f(x)
テストで使う値
x_test = np.linspace(0, 20, 10)
GPR(Gaussian Process Regression)を使う方法
GPRにおいては、値の予測を行うときに引数として、return_std=Trueを指定すれば標準偏差を得ることが出来ます。
kernel = ConstantKernel() * RBF() + WhiteKernel()
model = GaussianProcessRegressor(kernel=kernel, alpha=0)
model.fit(x, y2.ravel())
x_test = np.linspace(0, 20, 10).reshape(-1, 1)
y_pred, sigma = model.predict(x_test, return_std=True)
すると、下のような図を得ることが出来ます。(当然なのですが、)学習の範囲外の予測値に関しては信頼区間が非常に広くなっていることが確認できます。
RFR(RandomForestRegressor), ETR(ExtraTreesRegressor)を使う方法
まず、以下のように標準偏差を返すことが出来る関数を定義します。
def _return_std(X, trees, predictions, min_variance):
# This derives std(y | x) as described in 4.3.2 of arXiv:1211.0906
std = np.zeros(len(X))
for tree in trees:
var_tree = tree.tree_.impurity[tree.apply(X)]
# This rounding off is done in accordance with the
# adjustment done in section 4.3.3
# of http://arxiv.org/pdf/1211.0906v2.pdf to account
# for cases such as leaves with 1 sample in which there
# is zero variance.
var_tree[var_tree < min_variance] = min_variance
mean_tree = tree.predict(X)
std += var_tree + mean_tree ** 2
std /= len(trees)
std -= predictions ** 2.0
std[std < 0.0] = 0.0
std = std ** 0.5
return std
これを使うと、標準偏差を得ることが出来るため以下のような図を得ることが出来ます。
RFR
ETR
GBRの時と比べて、未学習領域の広さがだいぶ狭くなっていることが確認できます。
(過学習しにくい事と関係があるのかな...?)
GBR(GradientBoostingRegressor)のQuantileRegressionを用いる方法
GBRでは
loss = 'quantile'
を指定すると、任意の分位数を得ることが出来ます。
今回は、以下のようにモデルを宣言することで95%信頼区間を得ます。
model_low = GradientBoostingRegressor(loss='quantile', alpha=0.025, **params)
model_med = GradientBoostingRegressor(loss='quantile', alpha=0.5, **params)
model_high = GradientBoostingRegressor(loss='quantile', alpha=0.975, **params)
それぞれのモデルで学習を行い、予測値を得ると以下のような図を得ることが出来ます。
GBR
まとめ
この記事では機械学習(深層学習を除く)での、標準偏差の求め方をまとめました。
皆さんの、機械学習ライフの一助となれば幸いです。
参考文献
scikitlearn 1.7. Gaussian Processes
scikitlearn ExtraTreesRegressor
scikitlearn RandomForestRegressor
scikitlearn GradientBoostingRegressor
Author And Source
この問題について(予測値の95%信頼区間をどうやって求めるか), 我々は、より多くの情報をここで見つけました https://qiita.com/Shichi3/items/c8023edb113910631f0d著者帰属:元の著者の情報は、元のURLに含まれています。著作権は原作者に属する。
Content is automatically searched and collected through network algorithms . If there is a violation . Please contact us . We will adjust (correct author information ,or delete content ) as soon as possible .