pytorch-Ministチュートリアル

4552 ワード

pytorch-Ministチュートリアル
https://pytorch-cn.readthedocs.io/zh/latest/package_references/torch/#blas-and-lapack-operations
https://github.com/pytorch/examples/blob/master/mnist/main.py
リソースの2つのリンク:1つ目は中国語のドキュメントで、2つ目はgithubの学習チュートリアルで、とてもいいです.
一、まずライブラリをロードする
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F 
import torch.optim as optim
from torchvision import datasets,transforms

二、訓練パラメータの定義
parser = argparse.ArgumentParser()
    	parser.add_argument('--batch_size',type=int,default=64)
    	parser.add_argument('--test-batch_size',type=int,default=1000)
    	parser.add_argument('--epochs',type=int,default=10)
    	parser.add_argument('--lr',type=float,default=0.01)
    	parser.add_argument('--momentum',type=float,default=0.5)
    	parser.add_argument('--no_cuda',action='store_true',default=False)
    	parser.add_argument('--seed',type=int,default=1)
    	parser.add_argument('--log_interval',type=int,default=10)

    	args = parser.parse_args()

ここで定義しますbatch_size,test_batch_size,epochs,lr,momentum,no_cuda,seed,log_interval
args = parser.parse_args()

use_cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)

device = torch.device('cuda' if use_cuda else 'cpu')

kwargs = {'num_workers':1,'pin_memory':True} if use_cuda else{}

三、データのロード
train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=False, transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=args.test_batch_size, shuffle=True, **kwargs)

四、ネットワーク定義
def __init__(self):

		super(Net,self).__init__()
		self.conv1 = nn.Conv2d(1,10,kernel_size=5)
		self.conv2 = nn.Conv2d(10,20,kernel_size=5)
		self.conv2_drop = nn.Dropout2d()
		self.fc1 = nn.Linear(320,50)
		self.fc2 = nn.Linear(50,10)


	def forward(self,x):

		x = F.relu(F.max_pool2d(self.conv1(x),2))
		x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x),2)))
		x = x.view(-1,320)
		x = F.relu(self.fc1(x))
		x = F.dropout(x,training=self.training)
		x = self.fc2(x)

		return F.log_softmax(x,dim=1)

ここでは、次のように定義します.
conv 1(1,10,5),入力チャネル数1,出力チャネル数10,ボリュームコアsize 5
max_pool 2 d(2)、pooling核のsizeは2
F.relu(F.max_pool2d(self.conv1(x),2))
次に、
conv2(10,20,5)
conv 2_も定義されていますdrop = Dropout2d()
最後に2つのフル接続レイヤで、出力次元は10です.
4.1トレーニング詳細定義
def train(args,model,device,train_loader,optimizer,epoch):

		model.train()
		for batch_idx,(data,target) in enumerate(train_loader):

			data,target = data.to(device),target.to(device)
			optimizer.zero_grad()
			output = model(data)
			loss = F.nll_loss(output,target)
			loss.backward()
			optimizer.step()

			if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

大体データをロードした後、データを出力すべき形式に変換してzero_grad()を行い、出力を行い、lossを定義し、backwardを定義し、勾配降下を定義します.
五テスト
 def test(args,model,device,train_loader,optimizer,epoch):

     	model.eval()
     	test_loss = 0
     	correct = 0
     	with torch.no_grad():

     		for data,target in test_loader:

     			data,target = data.to(device),target.to(device)
     			output = model(data)
     			test_loss += F.nll_loss(output,target).item()
     			pred = output.max(1,keepdim=True)[1]
     			correct += pred.eq(target.view_as(pred)).sum().item()

     	test_loss /=len(test_loader.dataset)
     	print('
Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)
'.format( test_loss, correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset)))