pytorchはどのようにGCN関連のネットワークを構築します

8962 ワード

前言:私の研究分野は交通の面で、科学研究をするにはGCNに関するネットワークを構築する必要があります.例えばGCN-GAN【1】は、GCNの重みに基づいてネットワークを完了する【2】と、これらのネットワークに基づく新しいGCNネットワークフレームワーク.しかし、pytorchを使ってGCNネットワークを構築しているインターネット上の資料を検索すると、github上の無解釈コードとここ数年発表された論文だけで、詳細な説明がある資料は少なく、GCNの実戦に素早く入門するには大きな敷居があります.これに鑑みて、数日の実戦探索を経て、私は自分の関門をpytorchを用いてGCN型ネットワークを構築した経験を共有するここで、論文とコードがどのように対応するかを説明し、よく見られるGCNとpytorchについて説明する.geometricの資料をまとめて説明し、フレームワークコードを出して詳細な注釈を行い、GCNの実戦入門を使用するのに少し助けを提供し、自分で後で調べるのに便利です.
一.既存資料の説明
1.githubサイトhttps://github.com/rusty1s/pytorch_geometricここ数年発表された論文のソースコードpytorchの実現です.
ディレクトリexamplesは、酵素構造に基づいて図の分類、auto-encoderなど、一般的な論文で言及されているデータセットの実験です.
後で中のenzymesでtopk_pool.py(topk poolプール化方式を用いて酵素の構造に基づいて分類し,論文Graph U-Netsで述べたプール化方式である)コードを例に詳細に説明する
examplesの下のすべてのpyコードの対応実験は私も完全に対応していませんが、githubの下の論文で述べた実験で、使うときは自分で探してもいいです.
  • pytorchの下にgeometricパッケージがあり、GCNネットワークを構築するために使用されています.中国語のドキュメントアドレスは:
  • です.
    https://pytorch-geometric.readthedocs.io/en/latest/index.html
    上記のgithubの論文のソースコードは直接このサイトにジャンプしますが、このドキュメントの欠点はいくつかのパラメータの説明が分かりにくいことです.多くのペンを持っていて、出力が何なのか、どのように使うのか分かりません.このドキュメントのソースコードと対応する論文を結びつけて見る必要があります.
    このサイトのいくつかの方法の下には、対応する論文とソースコードがリストされています.
    二.コード詳細
    次はgithubの中examplesの下のenzymesでtopk_pool.pyコードを例に,論文とソースコードを組み合わせてパラメータ設定,次元変化などの説明を行う.
    (1)まず用いたパケットとデータロード部であり、ロードされたデータはTUDataset内の酵素の構造データであり、そのうちbatch_size=80はロットごとに80枚の図があることを示し、これは分類の問題であるため、対応するy値は各図のカテゴリ値である
    import os.path as osp
    import torch import torch.nn.functional as F from torch_geometric.datasets import TUDataset from torch_geometric.data import DataLoader from torch_geometric.nn import GraphConv, TopKPooling from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
    path = osp.join(osp.dirname(osp.realpath(‘file’)), ‘…’, ‘data’, ‘ENZYMES’) dataset = TUDataset(path, name=‘ENZYMES’) dataset = dataset.shuffle() #print(“dataset.num_classes=”,dataset.num_classes) n = len(dataset)//10 test_dataset=dataset[:n]#この表示はtrain_をテストするために10分の1を取り出します.dataset=dataset[n:]#の10分の9はtest_を訓練するために使用されます.loader=DataLoader(test_dataset,batch_size=80)#は、ロットtrain_を表します.loader=DataLoader(train_dataset,batch_size=80)(2)DataLoader関数は自動的にデータを10部に分割し、反復器を返します.次のコードはtrain_loaderの内容を印刷してみて、説明します.
    for data in train_loader:
    print("data=",data)
    print("data.batch=",data.batch)
    print("data.batch.shape=",data.batch.shape)
    print("data.x.shape=",data.x.shape)
    print("data.num_features=",data.num_features)
    print("
    ")

    このサイクルの最初の2回のサイクルの結果は次のとおりです.
    data= Batch(batch=[2468], edge_index=[2, 9540], x=[2468, 3], y=[80]) data.batch= tensor([ 0, 0, 0, …, 79, 79, 79]) data.batch.shape= torch.Size([2468]) data.x.shape= torch.Size([2468, 3]) data.num_features= 3
    data= Batch(batch=[2588], edge_index=[2, 9854], x=[2588, 3], y=[80]) data.batch= tensor([ 0, 0, 0, …, 79, 79, 79]) data.batch.shape= torch.Size([2588]) data.x.shape= torch.Size([2588, 3]) data.num_features=3 dataデータには少なくとも4つの部分が含まれていることがわかります.それぞれbatchで、次元は1*2468次元です.このbatchの80枚の図には2468のノードがあります.batchという変数の値は、このノードが80枚の図のどの図に対応しているかを示しています.batchの値の印刷からも0-79の80の値が見えます.
    edge_Indexには隣接マトリクス情報が格納されていますので、具体的にどのように構築するかはリンク【3】を参照してください.
    xは入力されたフィーチャーマトリクスで、次元は2468*3で、合計2468ノードを表し、各ノードに3つのフィーチャーがあります.
    yの次元は1*80であり、このbatchの80枚の図の各図に対応する分類カテゴリを示す
    (3)最後にネットワーク全体の構築と訓練のコードである.
    class Net(torch.nn.Module): def init(self): super(Net, self).init()
        self.conv1 = GraphConv(dataset.num_features, 128)#num_features        , 3
        self.pool1 = TopKPooling(128, ratio=0.8)
        self.conv2 = GraphConv(128, 128)
        self.pool2 = TopKPooling(128, ratio=0.8)
        self.conv3 = GraphConv(128, 128)
        self.pool3 = TopKPooling(128, ratio=0.8)
    
        self.lin1 = torch.nn.Linear(256, 128)
        self.lin2 = torch.nn.Linear(128, 64)
        self.lin3 = torch.nn.Linear(64, dataset.num_classes)
    
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        #print("data.batch=",data.batch)
        print("raw_x.shape=",x.shape)
        x = F.relu(self.conv1(x, edge_index))
        print("conv1_x.shape=",x.shape)
        x, edge_index, _, batch, _, _ = self.pool1(x, edge_index, None, batch)
        print("x.shape=",x.shape)
        
        x_gmp=gmp(x, batch)
        x_gap=gap(x, batch)
        print("x_gmp.shape=",x_gmp.shape)
        print("x_gap.shape=",x_gap.shape)
        x1 = torch.cat([x_gmp, x_gap], dim=1)#       ,      
        print("x1.shape=",x1.shape)
    
        x = F.relu(self.conv2(x, edge_index))
        x, edge_index, _, batch, _, _ = self.pool2(x, edge_index, None, batch)
        x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)
        print("x2.shape=",x2.shape)
    
        x = F.relu(self.conv3(x, edge_index))
        x, edge_index, _, batch, _, _ = self.pool3(x, edge_index, None, batch)
        x3 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)
        print("x3.shape=",x3.shape)
    
        x = x1 + x2 + x3
        print("x+.shape=",x.shape)
    
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = F.relu(self.lin2(x))
        x = F.log_softmax(self.lin3(x), dim=-1)
        
        print("final_x.shape",x.shape)
        return x
    

    device = torch.device(‘cuda’ if torch.cuda.is_available() else ‘cpu’) model = Net().to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)
    def train(epoch): model.train()
    loss_all = 0
    for data in train_loader:
        print("raw_data.y.shape",data.y.shape)
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data)
        print("data.y.shape=",data.y.shape)
        loss = F.nll_loss(output, data.y)
        loss.backward()
        loss_all += data.num_graphs * loss.item()
        optimizer.step()
        print("
    ") return loss_all / len(train_dataset)

    def test(loader): model.eval()
    correct = 0
    for data in loader:
        data = data.to(device)
        pred = model(data).max(dim=1)[1]
        correct += pred.eq(data.y).sum().item()
        print("
    ") return correct / len(loader.dataset)

    for epoch in range(1, 2): loss = train(epoch) train_acc = test(train_loader) test_acc=test(test_loader)print('Epoch:{:03 d},Loss:{:.5f},Train Acc:{:.5f},Test Acc:{:.5f}’.format(epoch,loss,train_acc,test_acc))print("")部分の印刷結果は次のとおりです.
    raw_data.y.shape torch.Size([80]) raw_x.shape= torch.Size([2468, 3]) conv1_x.shape= torch.Size([2468, 128]) x.shape= torch.Size([2004, 128]) x_gmp.shape= torch.Size([80, 128]) x_gap.shape= torch.Size([80, 128]) x1.shape= torch.Size([80, 256]) x2.shape= torch.Size([80, 256]) x3.shape= torch.Size([80, 256]) x+.shape= torch.Size([80, 256]) final_x.shape torch.Size([80, 6]) data.y.shape= torch.Size([80])印刷結果に基づいてネットワーク構造を分析する.
    まずclass Netに基づいて変数全体のフローチャートを描きます.
    図中の数字はすべて次元を表します
    GCNネットワーク,Topk poolプール化方式,global mean/max poolが主に用いられていることがわかる
    GCNネットワークは論文【4】で述べた手法を用いており,現在のl層から得られるl+1層の主な公式は以下の通りである.
    前のDとAは度行列と隣接行列がセル行列Iと処理された行列であり,H(l)の前半全体を隣接行列Aと見なすことができ,これで学習が必要なのはWというパラメータ行列のみである.H(l)は入力層では特徴行列Xと見なすことができ,このように上の式はH(l+1)=AXWとなり,入力ノード数が10であると仮定し,各ノードに3つの特徴があり,Aの次元は1010,Xは103となる.GCNのパラメータが(3128)であると仮定すると、Wの次元は3128であり、最終的には1層のGCNを通過した後に特徴行列Hの次元は10128となる.したがって、GCNの2つの重要なパラメータin_channelとout_channelは、実際にはWの次元を表している、あるいはin_channelは特徴数であり、out_channelは出力される各ノードの特徴数である.
    従って、上図に入力された特徴行列X(24683)は、1層GCN(3128)を経て出力される次元は2468*128であり、その後のGCNを経た出力次元もこのように計算される.
    続いてTopk poolとなり、これは【5】で実現したプール化方式である.
    思想は主に次の図です.
    投影方式を採用し、学習可能な投影ベクトルpを利用し、入力された特徴行列Xlの次元はnum_である.nodenum_Feature,pの次元はnum_Feature 1は,まず青色のxと黄色のpを乗算し,||p||で割ってnum_の次元を得る.Node*1のyは、yにおいてpの投影方向において値が最も大きいnodeに対応するk個のインデックスを前にして、対応する特徴を得、その後、得られた投影値に点乗して、処理後の紫色の特徴行列を得る.次に,インデックスに基づいて新しい隣接行列を得る.
    Topk pool(128,0.8)のパラメータ128は、入力された各ノードの特徴数を意味し、0.8は80%保持されたノードを表す
    したがって、2468128はTopk pool(128,0.8)を通過した後、本batchの80%の2004ノードを保持し、各ノードの特徴数は128個であり、出力の次元は2004128であり、24680.8がなぜ2004に厳密に等しくないのかと聞かれる可能性がある.これは、このpoolが本batchの80枚の図で行われ、各図のノード数に0.8を乗じて結果の上限整数をとるためである.したがって、24680.8=1974よりもいくつかのノードがあります.
    最後にgmp/gap:global max pool/global mean poolであり、この2つの操作はbatch階層で行われ、出力の次元は80*128であり、80はこのロットに80枚の図があり、128は各ノードの特徴数を表す
    global mean poolを例にとると、ある入力特徴がXnの上にある操作は次のようになります.
    すなわち、各フィーチャー次元で、入力した次元がnum_である場合、すべてのノードのこのフィーチャー次元の最大値を取得します.node*num_feature
    では、この図の操作後は1*num_feature
    ではなぜTopk poolのほかにglobal poolを行うのかというと、やはりCNNのプール化の役割と同じように、疎接続を増やし、正規化を加えてモデルのフィットを防ぐと思います.このglobal poolの操作がなければ機能も実現できますが、モデルがフィットしすぎて性能が悪い可能性があります.つまり、ここのglobal poolは必須ではありません.他の機能はまだ見ていませんが、もっと理解しているお客様のご指導を歓迎します.
    コードワードは容易ではありません.もしあなたを助けることができたら、いいねをつけたり、賞をあげたりしてください.多少は制限しません.ありがとうございます.
    三.参考資料:
    【1】GCN-GAN: A Non-linear Temporal Link Prediction Model for Weighted Dynamic Networks
    【2】Stochastic Weight Completion for Road Networks using Graph Convolutional Networks
    【3】pytorch.geometricにおけるデータフォーマットの構築:(後述)
    【4】SEMI-SUPERVISED CLASSIFICATION WITH GRAPH CONVOLUTIONAL NETWORKS
    【5】Graph U-Nets