Caffe:BN層回転batchnorm+scale修正prototxt
10269 ワード
#coding=UTF-8
import sys
sys.path.insert(0,'/home/cdli/ECO2/caffe_3d/python')
import copy
from caffe.proto import caffe_pb2
from google.protobuf import text_format
import google
def create_layer(base_name_,type_,bottom_,top_):
layer = caffe_pb2.LayerParameter()
layer.name = base_name_+'/'+type_
layer.type = type_
if type_ == 'batchnorm':
layer.bottom.append(bottom)
layer.top.append(top)
temp = caffe_pb2.ParamSpec()
temp.lr_mult = 0
temp.decay_mult = 0
layer.param.append(temp)
layer.param.append(temp)
layer.param.append(temp)
layer.batch_norm_param.use_global_stats = False
layer.batch_norm_param.eps = 0.00001
elif type_ == 'scale':
layer.bottom.append(top)
layer.top.append(top)
temp = caffe_pb2.ParamSpec()
temp.lr_mult = 0.2
temp.decay_mult = 0.2
layer.param.append(temp)
layer.param.append(temp)
layer.scale_param.filler.value = 1
layer.scale_param.bias_filler.value = 0
return layer
net1 = caffe_pb2.NetParameter()
net2 = copy.copy(net1)
deploy = 'deploy-pool3.prototxt'
text_format.Merge(open(deploy).read(), net1) #
layers = net1.layer
for i, l in enumerate(layers):
if str(l.type)!='BN':
continue
name = str(l.name)
print name
bottom = str(l.bottom[0])
top = str(l.top[0])
batchnorm = create_layer(name,'batchnorm',bottom,top)
scale = create_layer(name,'scale',bottom,top)
layers.pop(i)
layers.insert(i, batchnorm)
layers.insert(i, scale)
with open(deploy.split('.')[0]+'-batchnorm.prototxt','w') as f:
f.write(str(net1))