pytorchはモデルを保存し、プリトレーニングモデルの問題をロードします.

1862 ワード

pytorchコードを書くとき、事前トレーニングモデルをロードしたときに検証セットでテストしたpsnr結果と、訓練時に検証セットのpsnrとの差が特に大きいという問題が発生しました.
ソース:
pretrained_dict = torch.load('epochG_515.pth')
net.load_state_dict(pretrained_dict)
net = prepare(net)
valdata = Data(root=os.path.join(args.dir_data, args.data_val), args=args, train=False)
valset = DataLoader(valdata, batch_size=1, shuffle=False, num_workers=1)

val_psnr = 0
val_ssim = 0
with torch.no_grad():

    timer_test = util.timer()
    for batch, (lr, hr, filename) in enumerate(valset):
        lr, hr = prepare(lr), prepare(hr)
        sr = net(lr)
        print(sr.shape, hr.shape)
        val_psnr = val_psnr + cal_psnr(hr[0].data.cpu(), sr[0].data.cpu())
        val_ssim = val_ssim + cal_ssim(hr[0].data.cpu(), sr[0].data.cpu())
    print("Test psnr: {:.3f}".format(val_psnr / (len(valset))))
    print('Forward: {:.2f}s
'.format(timer_test.toc())) print(val_ssim / (len(valset)))

後の修正はテスト前にnetを加える.eval()
pretrained_dict = torch.load('epochG_515.pth')
net.load_state_dict(pretrained_dict)
net = prepare(net)
valdata = Data(root=os.path.join(args.dir_data, args.data_val), args=args, train=False)
valset = DataLoader(valdata, batch_size=1, shuffle=False, num_workers=1)

val_psnr = 0
val_ssim = 0
with torch.no_grad():
    net.eval()
    timer_test = util.timer()
    for batch, (lr, hr, filename) in enumerate(valset):
        lr, hr = prepare(lr), prepare(hr)

        sr = net(lr)
        print(sr.shape, hr.shape)
        val_psnr = val_psnr + cal_psnr(hr[0].data.cpu(), sr[0].data.cpu())
        val_ssim = val_ssim + cal_ssim(hr[0].data.cpu(), sr[0].data.cpu())
    print("Test psnr: {:.3f}".format(val_psnr / (len(valset))))
    print('Forward: {:.2f}s
'.format(timer_test.toc())) print(val_ssim / (len(valset)))