caffeのpythonインタフェース学習(8):caffemodelにおけるパラメータおよび特徴の抽出

4387 ワード

式y=f(wx+b)を用いると
全体の演算過程を表すと、wとbは私たちが訓練しなければならないものであり、wは重み値と呼ばれ、cnnではボリュームコア(filter)と呼ばれ、bはバイアス項である.fはアクティブ化関数でsigmoid,reluなどがある.xは入力されたデータです.
データトレーニングが完了すると,保存されたcaffemodelの中には,実際には各層のwとb値がある.
コードを実行します.
deploy=root + 'mnist/deploy.prototxt'    #deploy  
caffe_model=root + 'mnist/lenet_iter_9380.caffemodel'   #     caffemodel
net = caffe.Net(net_file,caffe_model,caffe.TEST)   #  model network

すべてのパラメータとデータを1つのnet変数にロードしましたが、netは複雑なobjectで、直接表示して見るのはだめです.次のようになります.
net.params:各層のパラメータ値(wとb)を保存する
net.blobs:各層のデータ値を保存する
使用可能なコマンド:
[(k,v[0].data) for k,v in net.params.items()]

各レイヤのパラメータ値を表示し、kはレイヤの名前、v[0]を表す.dataは各層のW値である、v[1]である.dataは各層のb値である.注:すべてのレイヤにパラメータがあるわけではありません.ボリュームレイヤと全接続レイヤのみです.
具体的な値を表示せずにshapeだけを見たい場合は、コマンドを使用できます.
[(k,v[0].data.shape) for k,v in net.params.items()]

最初のボリューム層の名前が「Convolution 1」であることを知っていれば、この層のパラメータを抽出することができます.
w1=net.params['Convolution1'][0].data
b1=net.params['Convolution1'][1].data

これらのコードを入力して、実際に見てみると、networkを理解するのに役立ちます.
同様に、パラメータを表示するだけでなく、データを表示することもできますが、netには最初はデータがなく、実行する必要があります.
net.forward()

後でデータがあります.コードを使用できます.
[(k,v.data.shape) for k,v in net.blobs.items()]

または
[(k,v.data) for k,v in net.blobs.items()]

に表示されます.上の表示パラメータとの違いに注意してください.一つはnetです.params、一つはnet.blobs.
実際にデータが入力されたばかりの頃、私たちは画像データと呼ばれ、ボリュームが蓄積された後、私たちは特徴と呼ばれました.
最初のフル接続レイヤのフィーチャーを抽出する場合は、次のコマンドを使用できます.
fea=net.blobs['InnerProduct1'].data

ある層の名前さえ分かれば,その層の特徴を抽出することができる.
spyderでは、上のすべてのコードを実行して、モデルの各層を深く理解することをお勧めします.
最後に、コードをまとめます.
import caffe
import numpy as np
root='/home/xxx/'   #   
deploy=root + 'mnist/deploy.prototxt'    #deploy  
caffe_model=root + 'mnist/lenet_iter_9380.caffemodel'   #     caffemodel
net = caffe.Net(deploy,caffe_model,caffe.TEST)   #  model network
[(k,v[0].data.shape) for k,v in net.params.items()]  #        
w1=net.params['Convolution1'][0].data  #    w
b1=net.params['Convolution1'][1].data  #    b
net.forward()   #    

[(k,v.data.shape) for k,v in net.blobs.items()]  #        
fea=net.blobs['InnerProduct1'].data   #      (  )