pytorchはCNNモデル(spatial XuNet)のロードテストモデルを実現する



from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import argparse
import torch.utils.data as data
from torch.utils.data import DataLoader, Dataset
from xu_spatial_net import xu_net, myDatasets, test

import numpy as np
import cv2

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '6, 7'

parser = argparse.ArgumentParser(description='Pytorch MNIST Example')
parser.add_argument('--batch_size', type=int, default=36, metavar='N',
                    help='input batch size for training (default: 36)')
parser.add_argument('--test_batch_size', type=int, default=1000, metavar='N',
                    help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=1000, metavar='N',
                    help='epochs for training (default: 1000)')
parser.add_argument('--learning_rate', type=float, default=0.01, metavar='LR',
                    help='learning rate for training (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
                    help='SGD momentum for training (default: 0.5)')
parser.add_argument('--no_cuda', action='store_true', default=False,
                    help='disables CUDA training (default: False)')
parser.add_argument('--seed', type=int, default=1234, metavar='S',
                    help='random seed (default: 1234)')
parser.add_argument('--log_interval', type=int, default=10, metavar='N',
                    help='show log message interval (default: 10)')
parser.add_argument('--save_model', action='store_true', default=True,
                    help='saving model ')
parser.add_argument('--train_cover_path', type=str, default='/program/data/used/cover/train', metavar='S',
                    help='training image path')
parser.add_argument('--test_cover_path', type=str, default=/program/data/used/cover/test', metavar='S',
                    help='testing image path')
parser.add_argument('--train_stego_path', type=str, default=/Boss_256_40/stc_reset8_cost_10000_4w/train', metavar='S',
                    help='training image path')
parser.add_argument('--test_stego_path', type=str, default=/Boss_256_40/stc_reset8_cost_10000_4w/test', metavar='S',
                    help='testing image path')
args = parser.parse_args()

kwargs = {'num_workers': 1, 'pin_memory': True}
transform_img = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

test_data = myDatasets(args.test_cover_path, args.test_stego_path, transform=transform_img)
test_loader = data.DataLoader(test_data, batch_size=args.test_batch_size, shuffle=True, **kwargs)

device = torch.device("cuda")
model = xu_net()
model.load_state_dict(torch.load('xuNet_cnn.pt'))
model.to(device)

device = torch.device('cuda')
optimizer = optim.SGD(model.parameters(), args.learning_rate, args.momentum)

for epoch in range(args.epochs):
    test(test_loader, optimizer, model, device, epoch)