Weighted Least Squares in scikit-learn


Reference

Preparation

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
plt.rcParams['font.size']=15

def plt_legend_out(frameon=True):
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0, frameon=frameon)
from sklearn.linear_model import LinearRegression

OLS

ここでは、以下のデータを想定します。x=10の外れているデータ点を「外れ値」とみなすか、否かで対処が異なってきます。Ordinal Least Squaresモデルで無視するケース(WSL1)と組み込むケース(WSL2)を考えます。OLSですと、以下の通りx=10の点に引っ張られた回帰直線になります。

n = 11
x = np.arange(0,n,1)

np.random.seed(0)
y = x + np.random.randn(n)
y[len(y)-1] = y[len(y)-1]*2

plt.scatter(x,y,color='k')
plt.xlabel('x')
plt.ylabel('y')

reg = LinearRegression().fit(x.reshape(-1,1),y)
y_pr1 = reg.predict(x.reshape(-1,1))

plt.plot(x,y_pr1,label='pred')
plt_legend_out()
plt.show()

WLS1

x=10の点を「外れ値」とみなして、処理を進めます。まずは、残差を取ります。

diff = y - y_pr1
df_plot = pd.DataFrame({'x':x,'error':diff})
sns.barplot(data=df_plot,x='x',y='error',color='k')
plt.ylabel('$y_i-\hat{y_i}$')
plt.show()

(略)

sw = (np.max(np.abs(diff)) - np.abs(diff))**2
df_plot = pd.DataFrame({'x':x,'error':sw})
sns.barplot(data=df_plot,x='x',y='error',color='k')
plt.ylabel('$(argmax|y_i-\hat{y_i}|-|y_i-\hat{y_i}|)^2$')
plt.show()

上記のsample weightをOrdinal Least Squaresに組み込みます。x=10の点を無視した回帰直線となりました。

reg.fit(x.reshape(-1,1),y, sample_weight=sw)
y_pr2 = reg.predict(x.reshape(-1,1))
plt.plot(x,y_pr1,label='Ordinal Least Squares')
plt.plot(x,y_pr2,label='Weighted Least Squares')
plt_legend_out()
plt.scatter(x,y,color='k')
plt.xlabel('x')
plt.ylabel('y')
plt.show()

plt.figure(figsize=(10,4))
plt.subplots_adjust(wspace=0.4)

min = np.min(np.concatenate([y,y_pr1,y_pr2]))-1
max = np.max(np.concatenate([y,y_pr1,y_pr2]))+1

plt.subplot(1,2,1)
plt.scatter(y,y_pr1,color='k')
plt.plot([min,max],[min,max],color='gray',lw=0.5)
plt.xlabel('exp')
plt.ylabel('pred')
plt.grid()
plt.xlim(min,max)
plt.ylim(min,max)
plt.title('Ordinal Least Squares')

plt.subplot(1,2,2)
plt.scatter(y,y_pr2,color='k')
plt.plot([min,max],[min,max],color='gray',lw=0.5)
plt.xlabel('exp')
plt.ylabel('pred')
plt.grid()
plt.xlim(min,max)
plt.ylim(min,max)
plt.title('Weighted Least Squares')
plt.show()

WLS2

x=10の点を過剰にフィットさせたいケースを考えます。

df_plot = pd.DataFrame({'x':x,'y':y})
sns.barplot(data=df_plot,x='x',y='y',color='k')
plt.show()

reg.fit(x.reshape(-1,1),y, sample_weight=1/y)
y_pr2 = reg.predict(x.reshape(-1,1))
plt.plot(x,y_pr1,label='Ordinal Least Squares')
plt.plot(x,y_pr2,label='Weighted Least Squares')
plt_legend_out()
plt.scatter(x,y,color='k')
plt.xlabel('x')
plt.ylabel('y')
plt.show()

plt.figure(figsize=(10,4))
plt.subplots_adjust(wspace=0.4)

min = np.min(np.concatenate([y,y_pr1,y_pr2]))-1
max = np.max(np.concatenate([y,y_pr1,y_pr2]))+1

plt.subplot(1,2,1)
plt.scatter(y,y_pr1,color='k')
plt.plot([min,max],[min,max],color='gray',lw=0.5)
plt.xlabel('exp')
plt.ylabel('pred')
plt.grid()
plt.xlim(min,max)
plt.ylim(min,max)
plt.title('Ordinal Least Squares')

plt.subplot(1,2,2)
plt.scatter(y,y_pr2,color='k')
plt.plot([min,max],[min,max],color='gray',lw=0.5)
plt.xlabel('exp')
plt.ylabel('pred')
plt.grid()
plt.xlim(min,max)
plt.ylim(min,max)
plt.title('Weighted Least Squares')
plt.show()

さらにWeightを強く掛けます。

diff = y - y_pr1
df_plot = pd.DataFrame({'x':x,'error':np.abs(diff)*y})
sns.barplot(data=df_plot,x='x',y='error',color='k')
plt.ylabel('$y_i\cdot|y_i-\hat{y_i}|$')
plt.show()

reg.fit(x.reshape(-1,1),y, sample_weight=np.abs(diff)*y)
y_pr2 = reg.predict(x.reshape(-1,1))
plt.plot(x,y_pr1,label='Ordinal Least Squares')
plt.plot(x,y_pr2,label='Weighted Least Squares')
plt_legend_out()
plt.scatter(x,y,color='k')
plt.xlabel('x')
plt.ylabel('y')
plt.show()

plt.figure(figsize=(10,4))
plt.subplots_adjust(wspace=0.4)

min = np.min(np.concatenate([y,y_pr1,y_pr2]))-1
max = np.max(np.concatenate([y,y_pr1,y_pr2]))+1

plt.subplot(1,2,1)
plt.scatter(y,y_pr1,color='k')
plt.plot([min,max],[min,max],color='gray',lw=0.5)
plt.xlabel('exp')
plt.ylabel('pred')
plt.grid()
plt.xlim(min,max)
plt.ylim(min,max)
plt.title('Ordinal Least Squares')

plt.subplot(1,2,2)
plt.scatter(y,y_pr2,color='k')
plt.plot([min,max],[min,max],color='gray',lw=0.5)
plt.xlabel('exp')
plt.ylabel('pred')
plt.grid()
plt.xlim(min,max)
plt.ylim(min,max)
plt.title('Weighted Least Squares')
plt.show()

その他

>>> reg.predict(np.array(30).reshape(-1,1))
array([71.52016456])