【scikit-learnメモ】 同じ入力で別の値を予測する複数モデルを同時に作る
7766 ワード
はじめに
入力値は同一で異なる値を予測するモデルを同時並行的に作りたい要件がありました。
scikit-learnのライブラリでこの要件を満たすモデルを作れることがわかったので、その手順のメモです。
問題設定
無理矢理感があるのですが、boston datasetを使い、target(=PRICE)の値だけでなく、CRIM(犯罪発生率)の値も予測したいと仮にします。
予測モデルはSVRを使います。
案の定、CRIMの予測値はメタメタになってしまいましたが、その点は大目に見て下さい。
サンプルコード
以下に、コメント付きのサンプルコードを記載します。
動作確認はWatson Studio上のJupyter Notebookで行いました。
# データ準備
# 必要ライブラリのロード
import numpy as np
from sklearn.datasets import load_boston
import pandas as pd
from sklearn import preprocessing
# boston datasetの読み込み
boston = load_boston()
# 入力データをデータフレームに
df = pd.DataFrame(boston.data, columns=boston.feature_names)
# 正解データをpriceに
price = boston.target
# 入力データのうち、CRIMも正解データにするため、別変数に読み込む
crim = df['CRIM'].values
# 入力データから"CRIM"の列を落とす
df2 = df.drop('CRIM', axis=1)
# 入力データの正規化 (=X)
sc=preprocessing.StandardScaler()
sc.fit(df2)
X=sc.transform(df2)
# priceとcrimをつないで正解データとする (=y)
y = np.array([price, crim]).T
# モデル学習用データのサイズの確認
print(X.shape)
print(y.shape)
# 学習
# 必要ライブラリのロード
from sklearn.multioutput import MultiOutputRegressor
from sklearn.svm import SVR
# 多値予測モデルの生成
mor = MultiOutputRegressor(SVR(kernel='rbf', C=1e3, gamma=0.1))
# 学習
mor.fit(X, y)
# 予測の実施
p = mor.predict(X)
# 予測結果のサイズの確認
print(p.shape)
(おまけ) モデルをWatson Machine Learning上に登録
今回は、このモデルをWatson Studio上で動かしたい要件があったので、モデルの登録も行いました。手順は基本的にWatson Studioでscikit-learn機械学習モデルをWebサービス化すると同じなのですが、せっかくなので一緒に記載しておきます。
# 認証情報の設定
wml_credentials={
"apikey": "xxxx",
"instance_id": "xxxx",
"password": "xxxx",
"url": "https://us-south.ml.cloud.ibm.com",
"username": "xxxx"
}
# Watson ML clientオブジェクトの生成
from watson_machine_learning_client import WatsonMachineLearningAPIClient
client = WatsonMachineLearningAPIClient(wml_credentials)
# Watson ML にモデルの登録
published_model = client.repository.store_model(model=mor, meta_props={'name':'Multi Classifier Sample'}, \
training_data=X, training_target=y)
Author And Source
この問題について(【scikit-learnメモ】 同じ入力で別の値を予測する複数モデルを同時に作る), 我々は、より多くの情報をここで見つけました https://qiita.com/makaishi2/items/c9863123052dd79ed96b著者帰属:元の著者の情報は、元のURLに含まれています。著作権は原作者に属する。
Content is automatically searched and collected through network algorithms . If there is a violation . Please contact us . We will adjust (correct author information ,or delete content ) as soon as possible .