torch.Tensor 後、日付情報の復元結果が微妙にずれる


サマリ

  • 日付情報の整数値をDL(torch.Tensor)に通したあと、復元すると微妙にずれる現象の概説
  • 微妙にズレた日付を復元するサンプルを提供

環境

  • jupyter notebook
    • たぶん、Colab でも動くと思う
  • pytorch: 1.7.0

現象

準備

import numpy
import pandas
import torch
index_date = pandas.date_range("2016-01-01", "2018-12-31")
index_date

DatetimeIndex(['2016-01-01', '2016-01-02', '2016-01-03', '2016-01-04',
               '2016-01-05', '2016-01-06', '2016-01-07', '2016-01-08',
               '2016-01-09', '2016-01-10',
               ...
               '2018-12-22', '2018-12-23', '2018-12-24', '2018-12-25',
               '2018-12-26', '2018-12-27', '2018-12-28', '2018-12-29',
               '2018-12-30', '2018-12-31'],
              dtype='datetime64[ns]', length=1096, freq='D')

df = pandas.DataFrame([])
df["time_index"] = index_date.astype(numpy.int64)
df.time_index[:5]

0    1451606400000000000
1    1451692800000000000
2    1451779200000000000
3    1451865600000000000
4    1451952000000000000
Name: time_index, dtype: int64

現象の確認

series/numpy のデータを to_datetime すると正しく復元できる

pandas.to_datetime(df.time_index)[:5]

0   2016-01-01
1   2016-01-02
2   2016-01-03
3   2016-01-04
4   2016-01-05
Name: time_index, dtype: datetime64[ns]

一方で、torch.Tensor で、テンソルに変換してから復元しようと、to_datetime をすると微妙にずれる(型変換時の精度誤差っぽい)

pandas.to_datetime(torch.Tensor(df.time_index))[:5]

DatetimeIndex(['2016-01-01 00:00:49.632313344',
               '2016-01-01 23:59:21.295093760',
               '2016-01-03 00:00:10.396827648',
               '2016-01-04 00:00:59.498561536',
               '2016-01-04 23:59:31.161341952'],
              dtype='datetime64[ns]', freq=None)

対策例

対策として、以下のような関数を作成する

import datetime


def to_date(ti: numpy.array):
    _ti = pandas.to_datetime(ti)
    ti_ser = pandas.Series(_ti, name="time_index")

    def _adjust_date(ts):
        dte = ts.to_pydatetime()
        dte += datetime.timedelta(hours=1)
        return datetime.datetime(year=dte.year, month=dte.month, day=dte.day, hour=0, minute=0, second=0)

    return ti_ser.apply(_adjust_date)

series に対しては、to_datetime と同じ結果になる

ti = to_date(df.time_index)
ti.head()

0   2016-01-01
1   2016-01-02
2   2016-01-03
3   2016-01-04
4   2016-01-05
Name: time_index, dtype: datetime64[ns]
(ti == pandas.to_datetime(df.time_index)).all()

True

torch.Tensor を通した後も、同じ結果になる

ti_restored = to_date(torch.Tensor(df.time_index))
ti_restored.head()

0   2016-01-01
1   2016-01-02
2   2016-01-03
3   2016-01-04
4   2016-01-05
Name: time_index, dtype: datetime64[ns]
(ti == ti_restored).all()

True

まとめ

  • 微妙にめんどくさいので、参考になれば幸いです