PointNetの再現


一、コード運行


1.分類タスク


1.1トレーニング

python train.py --batch_size=8

グラフィックスカードのグラフィックスメモリが小さすぎて、batch sizeが8のを走るしかなく、250個のepochを訓練します.

1.2テスト

python  evaluate.py --visu

結果:
eval mean loss: 0.488872
eval accuracy: 0.878444
eval avg class acc: 0.851238
  airplane:     0.980
   bathtub:     0.860
       bed:     0.950
     bench:     0.700
 bookshelf:     0.900
    bottle:     0.950
      bowl:     0.950
       car:     1.000
     chair:     0.970
      cone:     0.900
       cup:     0.650
   curtain:     0.800
      desk:     0.791
      door:     0.850
   dresser:     0.651
flower_pot:     0.200
 glass_box:     0.950
    guitar:     1.000
  keyboard:     1.000
      lamp:     0.950
    laptop:     1.000
    mantel:     0.950
   monitor:     0.940
night_stand:    0.698
    person:     0.950
     piano:     0.860
     plant:     0.780
     radio:     0.800
range_hood:     0.910
      sink:     0.750
      sofa:     0.960
    stairs:     0.850
     stool:     0.800
     table:     0.810
      tent:     0.950
    toilet:     0.960
  tv_stand:     0.820
      vase:     0.810
  wardrobe:     0.650
      xbox:     0.800

論文の結果は86.2/89.2であり,論文に比べて若干劣り,batch sizeの問題である可能性がある.

2. part segmentation


コードはpart_にありますsegでは、sh download_data.shに入ってデータセットをダウンロードし、trian.pyトレーニング、test.pyテストを実行します.分割タスクのグラフィック占有率はもっと高く、私はbatchsize=4を走るしかありません....最終モデルの結果は論文よりも2点近く低かった(0.8197/0.837).
Accuracy: 0.923867
IoU: 0.819725
         02691156 Total Number: 341
         02691156 Accuracy: 0.9096132876935942
         02691156 IoU: 0.8228572722404234
         02773838 Total Number: 14
         02773838 Accuracy: 0.9517263003758022
         02773838 IoU: 0.7696178300040108
         02954340 Total Number: 11
         02954340 Accuracy: 0.8923338109796698
         02954340 IoU: 0.7994510477239435
         02958343 Total Number: 158
         02958343 Accuracy: 0.9047723118262955
         02958343 IoU: 0.734290400637856
         03001627 Total Number: 704
         03001627 Accuracy: 0.9394946531815962
         03001627 IoU: 0.8910455703735352
         03261776 Total Number: 14
         03261776 Accuracy: 0.9178562164306641
         03261776 IoU: 0.7298440933227539
         03467517 Total Number: 159
         03467517 Accuracy: 0.9630299694133255
         03467517 IoU: 0.9014740949906643
         03624134 Total Number: 80
         03624134 Accuracy: 0.8905808448791503
         03624134 IoU: 0.8044050216674805
         03636649 Total Number: 286
         03636649 Accuracy: 0.8275325214946186
         03636649 IoU: 0.7223616379957932
         03642806 Total Number: 83
         03642806 Accuracy: 0.9772439060440983
         03642806 IoU: 0.9505088530391096
         03790512 Total Number: 51
         03790512 Accuracy: 0.8500167996275658
         03790512 IoU: 0.633358824486826
         03797390 Total Number: 38
         03797390 Accuracy: 0.9923725128173828
         03797390 IoU: 0.921830227500514
         03948459 Total Number: 44
         03948459 Accuracy: 0.9490793401544745
         03948459 IoU: 0.7901863618330522
         04099429 Total Number: 12
         04099429 Accuracy: 0.7942575613657633
         04099429 IoU: 0.5516049861907959
         04225987 Total Number: 31
         04225987 Accuracy: 0.9462758341143208
         04225987 IoU: 0.7356803032659716
         04379243 Total Number: 848
         04379243 Accuracy: 0.9443713494066922
         04379243 IoU: 0.7989320575066332

二、コード解読


1.train.py


訓練の過程を定義した.

1.1 parser & FLAGS


必要な各種パラメータとデフォルト値を定義し、FLAGSにパッケージし、コマンドラインでインタラクティブにすることができます.

1.2 get_learning_rate()/get_bn_decay()

tf.train.exponential_decay()関数を使用してlrおよびbn_momentumは減衰します.
tf.train.exponential_decay(
                      BASE_LEARNING_RATE,  # Base learning rate.the value will be decayed.
                      batch * BATCH_SIZE,  # Current index into the dataset.
                      DECAY_STEP,          # Decay step.
                      DECAY_RATE,          # Decay rate.
                      staircase=True)      # True:decay every decay step.
					                       # False:decay every real setp.

1.3 train()


loss、modelなどのノードを設定し、スーパーパラメータ、place holderとともにopsに配置します.本当にrunではなく、epochごとの訓練を制御します.

1.4 train_one_epoch()


訓練データをlen(TRAIN_FILES)部に分けた.各データでは、batch_sizeに従ってbatchを1つ取って訓練を行う.1つのbatchのデータは、rotateおよびjitterを含むデータ強化を行い、provider.pyによって実現される.ops里のphに関するパラメータに基づいてfeed_dictを構築し、そこから構造のトレーニング、予測、lossなどのノードを取り出し、sess.run()のパラメータとしてトレーニングを行う.

1.5 eval_one_epoch()


trainの部分と同じように、テストファイルを何分に分けて、各batchを入れて、全体の正確率を計算して、訓練中のepochのテストとします.

2. models


modelsフォルダの下には,訓練に用いる3種類のネットワーク構造,cls,seg,T−NETが格納されている.使用されるボリュームなどのネットワーク構造は、utils/tf_uitl.pyで実現される.ネットワーク入力はB*N*3です.Bはbatch_size,Nはサンプル中の点の個数であり,3は点群次元(3次元座標)である.ここでは,1つのサンプルをN*3のマトリクスと見なし,2次元ピクチャに類似し,その後その上でボリューム化などの操作を行った.説明に値するのは、ネットワークが抽出した特徴は3という次元ではなく、この次元はネットワークの最初から1にボリューム化され、特徴次元はexpandを入力した新しい次元である.論文で述べたMLPは,ここでも畳み込みで実現されている.

2.1 pointnet_cls.py


このファイルはネットワークの分類構造を実現した.出力はB*40であり、各サンプルの各カテゴリに対する確率である.ネットワーク構造はget_model()で定義され、lossはget_lossで定義される.
ネットワーク構造
ネットワークプロセスに従って、ネットワーク全体を以下の段階に分けます.ネットワーク入力:[B,N,3,1].
  • フィーチャー抽出このセクションでは、2 d-convが使用されます.まず,[1,3]のボリュームコアを用いて電気雲の幅を1にボリューム化し,64個のボリュームコアを用いて出力次元[B,N,1,64]を得た.さらに[1,1]のボリュームコアを接続し、特徴を再抽出します.
  • STNという部分はサブネットワークとみなされ,model/transform_net.pyを単独で用いて実現される.実装の過程で、特徴をいくつかの[1,1]のボリュームで次元化し、[B,N,1,1024]を得た.次いで、max poolingをN次元で使用し、reshapeは[B,1024]を得た.特徴的次元ダウンをFC経由で行い、[B,256]を得た.さらに[256,64*64]のT−NETを生成し、乗算すると[B,4096]、reshapeは[B,64,64]となる.T-NETと元の入力tensorをテンソル乗算し、expand 2次元、STNの出力[B,N,1,64]を得る.
  • フィーチャー統合はまた、一連の[1,1]ボリュームであり、フィーチャーを[B,N,1,1024]に次元化する.
  • 対称動作はN次元でmax poolingを用い,reshapeを1の次元に落として[B,1024]を得た.
  • 予測結果は、いくつかのFCおよびdrop outを接続し、40に次元を下げ、結果[B,40]を得た.

  • loss設計
    lossは3つの部分に分かれています.
  • classify_loss:分類されたクロスエントロピー損失.
  • mat_diff_loss:これはT-NETに自分で学んだことを直交させるために使われています.
  • reg_Weight:正規項.

  • 3.utils


    このフォルダにはいくつかのツール関数があります.

    3.1 data_prep_util.py


    データの読み取りを実現する、単一のファイルパスが存在する.txtファイルでは,ファイル名,join親パスを読み出し,データを読み出し,データセットを作成する.

    3.2 pc_util.py


    点群とvolumeの変換,点群の可視化,点群の読み書きを含む3次元データ処理の関数を提供した.

    3.3 tf.util.py


    自分でカプセル化したいくつかの層は,種々のボリューム,pooling,fcを実現した.ウェイト初期化の関数も実現した.しかし、ネットは2 Dボリュームしか使われていないようで、max pooing?

    4. provider.py


    データ読み取り、データ処理、データ強化の関数を実現しました.次の内容が含まれます.
  • shuffle.
  • rotate:回転マトリクスを構築し、点群に乗算します.
  • jitter:ジッタを構築し、clipで制限して元のデータに追加します.