Covid-19をDartsを使って予測してみた


時系列データの予測はFacebookが開発してくれたProphetがわかりやすくて長く愛用してきましたが、最近Prophetをsklearn-likeに使うことができて、他の時系列解析手法もラッピングしてくれているDartsというライブラリがあるのを知ったので、Covid-19のデータを用いてその使い勝手を確かめてみたので解説します。

Dartsとは


https://github.com/unit8co/darts
Dartsはスイスの企業が2020年6月に後悔したライブラリ。
ProphetやLSTMなどのDeeplearning、ARIMAなどの統計モデルも全てsklearnベースのAPIで扱えるライブラリでかなり便利そう。。。

インストール方法

インストールはpipでインストールできます。

pip install 'u8darts[all]'

[all]を付けずにpip install u8dartsでもインストールできますが、そうするとLSTMを実行する際のpytorchなどがインストールされないようでエラーになりました。とりあえず使用感を確かめたいくらいであればいらないかもしれません。

実行環境

OS: macOS ver11.1
CPU: core i5
メモリ: 16GB
python: 3.8.7
Darts: 0.5.0

データの準備

今回はCovid-19のデータを利用したと思います。
厚生労働省のサイトにダウンロードリンクが貼られてある日付別の陽性者数のデータをダウンロード。
https://www.mhlw.go.jp/content/pcr_positive_daily.csv
陽性者予測するにはPCR検査数なども使わないといけないようなーとは思いつつ、今回はDartsの検証なの単純に過去の陽性者数から未来の陽性者数を予測するモデルの構築をおこないます。

使ってみた

ライブラリのインストール

import warnings
warnings.simplefilter('ignore') #warningがたくさん出るので気になる人はやったほうがいい
import pandas as pd
import darts
from darts import TimeSeries #Dartsのデータ型変換モジュール
import matplotlib.pyplot as plt

データの読み込み

df = pd.read_csv('https://www.mhlw.go.jp/content/pcr_positive_daily.csv') #記事作成時点で1月14日までのデータがダウンロード可能

データの中身はこんな感じ。おしい!ちょうど1年分に1日足りず!

データ型の変換。

DartsはpandasのDataFrameからの変換をTimeSeriesモジュールで行ってくれる。今回は202012月01日以降を予測する形で行うようにする。このあたりがsklearnのAPIベースでかなり助かる。直感的にわかりやすい。

ts = TimeSeries.from_dataframe(df, time_col='日付', value_cols='PCR 検査陽性者数(単日)')
train, val = ts.split_after(pd.Timestamp('20201201'))

学習モデルの作成

かなり豊富に学習モデルを用意してくれている。
Deeplearning系の学習モデルはデータの変換にもうひと手間必要なのでそれ以外のモデルをfor文で実行。
あと何度も書いてますが、やっぱりsklearnベースは楽。fitとpredictという何度やったかわからない書き方でそのまま実行。

# modelのインポート
from darts.models import ExponentialSmoothing, NaiveSeasonal, NaiveDrift, Prophet, ARIMA
from darts.models import AutoARIMA, StandardRegressionModel, Theta, FFT

models = [ExponentialSmoothing(), 
          NaiveSeasonal(), 
          NaiveDrift(), 
          Prophet(daily_seasonality=True, yearly_seasonality=True), 
          Prophet(daily_seasonality=True, yearly_seasonality=True, weekly_seasonality=True),# 曜日によって検査数にばらつきがあるので週の周期性を見るバージョンもせっかくなので準備 
          ARIMA(), 
          AutoARIMA(), 
          StandardRegressionModel(), 
          Theta(), 
          FFT()]

for model in models:
    print(model.__str__())
    try: #実行したときにエラーになってしまうモデルもあったので回避用です
        model.fit(train) #sklearnのやり方
        prediction = model.predict(len(val))
        # 可視化による確認
        plt.figure(figsize=(12, 5))
        ts.split_after(pd.Timestamp('20201101')) [1].plot(label='actual', lw=1) #最初から表示すると大事な予測結果との乖離部分が見えにくかったので20201101からのプロット
        prediction.plot(label='forecast', lw=1)
        plt.legend()
        plt.xlabel('Day')
        plt.show()
    except Exception as e:
        print('error¥t :{}'.format(e))

実行結果

Exponetial smoothing

Naive seasonal model

Naive drift model

Prophet

Prophet(Weekly True)

ARIMA

Auto-ARIMA

Theta
errorでした!公式ドキュメントみないとだめですね。。。

FFT

結果はARIMAモデルとExponetial smoothingがうまく学習できてそうです。
ProphetのWeeklyはデフォルトでAutoになっているので結果変わらずでした。
実績のラインをみても4月以降は爆発的に上がってるんですね。これは緊急事態宣言出して収束させたのも納得です。

Deeplearningモデルを試してみた

LSTMを試してみます。
パラメーターは参考にさせていただいた記事に書いてあったパラメーターを使わせていただきました。記事はページの一番下に載せてあります。

データの加工

from darts.models import TCNModel, RNNModel
from darts.dataprocessing.transformers import Scaler
from darts.metrics import mape, r2_score
from darts.utils.missing_values import fill_missing_values

# データの準備。Scalerは0,1で正規化するsklearnラッパーらしいです。
scaler = Scaler()
train_tr = scaler.fit_transform(train)
val_tr = scaler.transform(val)
ts_tr = scaler.transform(ts)

LSTM

そこそこ時間かかります。

model = RNNModel(
    model='LSTM',
    output_length=1, # 出力(=予測)のタイムステップ数
    hidden_size=25, # RNNにおける隠れ状態の数
    n_rnn_layers=3, # RNNの隠れ層の数
    input_length=12, # Number of previous time stamps taken into account.(?分からなかった…)
    dropout=0.4,
    batch_size=16,
    n_epochs=400,
    optimizer_kwargs={'lr': 1e-3},
    log_tensorboard=True,
    random_state=42
)
model.fit(train_tr, val_training_series=val_tr, verbose=True)

実行結果確認

prediction = model.predict(len(val))
fig = plt.figure(figsize=(12, 5))
ts_tr_after10 = ts_tr.drop_before(pd.Timestamp('20201001'))
ts_tr_after10.plot(label='actual')
prediction.plot(label='forecast', color='red')
plt.legend()


精度は微妙ですね・・・

まとめ

Dartsはかなり便利に感じました。今後はないとは思いますが、元のProphetを使った時の精度と比べて精度が劣化していないかなどを調べて問題なさそうならDartsを使っていこうと思います。sklearn-likeなのでhyperparameterの探索とかガンガンやってみたいと思います。
あと、当たり前ですがデフォルトパラメーターでデータの工夫も全くせずに実行しただけなので精度は出ていないです。これはライブラリのせいではなく私がサボったせいです。あと、バックテストの仕組みとかも充実しているのでその辺りも試せたらいいなと思っています。

ご挨拶

今回緊急事態宣言も出て、家にこもっていないと行けなくなったので初めて記事を書いてみました。せっかくやり始めたので今後も少しずつ書けて行けたらなと思います。コードの修正部分など気軽にコメントいただけると嬉しいです。参考にさせていただきます。

参考記事

https://blog.ikedaosushi.com/entry/2020/08/25/003557
https://qiita.com/hironey/items/d1d8a80c8329d5d46c16