Pytorch_geometric の前処理にて、複数の自作transformクラスを渡す方法


概要

グラフ関連のDeep Learningの研究をしているためPytorch_geometricにはお世話になっています。が、最近geometricのデータセットにアクセスするたびに複数の種類の前処理(transform)を行いたいと思って調べましたが意外と時間がかかったためメモ

実装方法

論文の引用関係データセットの'Cora'をインストールし、複数の前処理(transform)を行う場合、

from torch_geometric.datasets import Planetoid
import torchvision.transforms as transforms

transforms = transforms.Compose([TransA(), TransB(), ...])
dataset = Planetoid(root='./data', name='Cora', 
                    transform=transforms)

上記のように渡したい複数の自作transform(TransA, TransB, ...)をリストとしてComposeに渡してあげます。そのオブジェクトをgeometric.datasetsの引数transformに渡せば、その順に前処理してくれます。