pytorch変換onnx,再変換caffeテストcaffe,pytorchモデル結果が一致するかどうか

2217 ワード


def pytorch_out(input):
    model = model_res() #model.eval
    # input = input.cuda()
    # model.cuda()
    torch.no_grad()
    t1 = time.time()
    output = model(input)
    print(" torch  ",time.time()-t1)
    # print output[0].flatten()[70:80]
    return output

def caffe_output(input,input_layer_name="data",out_layer_name="fc1"):
    import sys
    # sys.path.insert(0, "/nfs-data/xingwg/deep_learning/NVCaffe/python")   #python3.6
    sys.path.insert(0, "/home/shiyy/nas/NVCaffe/python")  #python2.7
    import caffe
    caffe.set_device(0)

    deploy = "./transform_model/succed_res50.prototxt"
    weight = "./transform_model/succed_res50.caffemodel"
    caffe_model = caffe.Net(deploy, weight, caffe.TEST)
    # reshape network inputs
    blobs = {}
    blobs["0"] = input.data.numpy() #add dict

    t1=time.time()
    caffe_model.blobs[input_layer_name].reshape(*blobs["0"].shape)  #'input.1'  prototxt  name
    print("caffe  :",time.time()-t1)

    # do forward
    forward_kwargs = {input_layer_name: blobs['0'].astype(np.float32, copy=False)}  #input data name
    output_blobs = caffe_model.forward_all(**forward_kwargs)

    return output_blobs[out_layer_name]

def pytorch_caffe_test():
    #  
    torch.manual_seed(66)
    dummy_input = torch.randn(1, 3, 112, 112, device='cpu')

    print("================>")
    caffe_out = caffe_output(dummy_input)
    print(caffe_out)
    print(caffe_out.shape)

    print("================>")
    torch_out_res = pytorch_out(dummy_input).detach().numpy()
    print(torch_out_res)
    print(torch_out_res.shape)

    print("===================================>")
    print(" , np")
    torch_out_res = torch_out_res.flatten()
    caffe_out = caffe_out.flatten()

    pytor = np.array(torch_out_res,dtype="float32") #need to float32
    caff=np.array(caffe_out,dtype="float32")  ##need to float32
    np.testing.assert_almost_equal(pytor,caff, decimal=5)
    print("  ^^,caffe   pytorch , Exported model has been executed decimal=5 and the result looks good!")