PytorchはCIFAR-10分類を実現

39400 ワード

Pytorch CIFAR-10分類を実現する5つのステップ

  • 1. 準備データ:CIFAR-10をダウンロードし、規格化
  • torch.utils.data.DataLoader()クラス
  • 部分画像
  • を表示する.
  • 2. CNN
  • の定義
  • `super()`コンストラクタ
  • 3. 損失関数の定義
  • 4. training setでCNN
  • を訓練する
  • `enumerate()関数
  • `iter()`と`enumerate()`
  • 5. Test setでCNN
  • をテスト
  • `iter()関数
  • `next()メソッド
  • `next()関数
  • 問題:`iter(trainloader)`と`enumerate(trainloader)`
  • 参考資料
  • tensorvisionパッケージには、CIFAR-10を含む一般的なビジュアルデータセットが付属しています.Tutorialでは、ネットワークのトレーニングを5つのステップに分けます.
  • 準備データ:CIFAR-10をダウンロードして正規化
  • CNN
  • の定義
  • 損失関数
  • を定義する
  • training setでCNN
  • を訓練
  • test setでCNN
  • をテスト

    1.データの準備:CIFAR-10をダウンロードして正規化する

    import torch
    import torchvision
    import torchvision.transforms as transforms
    
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                            download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                              shuffle=True, num_workers=2)
    
    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                           download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                             shuffle=False, num_workers=2)
    
    classes = ('plane', 'car', 'bird', 'cat',
               'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    

    データ処理にはtorch.ultis.dataクラスとtorchvision.transformsクラスが用いられる.ここで、transforms.Compose()は、各種データ変換を組み合わせて用いられ、各種変換からなるリストである.torchvision.datasets.CIFAR10()は自分でダウンロードしたものを使って、解凍して./data内に置いて、download=Falseを設定することができます.

    torch.utils.data.DataLoader()クラス

    class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)
    
    torch.utils.data.DataLoaderクラスは、データセットとサンプラを組み合わせ、データセットに単一プロセスまたはマルチプロセス反復器を提供するデータローダです.

    画像の一部を表示


    次のコードは、画像の一部を表示するために使用されます.
    import matplotlib.pyplot as plt
    import numpy as np
    
    # functions to show an image
    
    def imshow(img):
        img = img / 2 + 0.5     # unnormalize
        npimg = img.numpy()
        plt.imshow(np.transpose(npimg, (1, 2, 0)))
    
    
    # get some random training images
    dataiter = iter(trainloader)
    images, labels = dataiter.next()
    
    # show images
    imshow(torchvision.utils.make_grid(images))
    # print labels
    print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
    

    2.CNNの定義

    import torch.nn as nn
    import torch.nn.functional as F
    
    
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(3, 6, 5)
            self.pool = nn.MaxPool2d(2, 2)
            self.conv2 = nn.Conv2d(6, 16, 5)
            self.fc1 = nn.Linear(16 * 5 * 5, 120)
            self.fc2 = nn.Linear(120, 84)
            self.fc3 = nn.Linear(84, 10)
    
        def forward(self, x):
            x = self.pool(F.relu(self.conv1(x)))
            x = self.pool(F.relu(self.conv2(x)))
            x = x.view(-1, 16 * 5 * 5)
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x
    
    
    net = Net()
    

    super()コンストラクタ


    書き込み待ち

    3.損失関数の定義


    クロスエントロピーcriterionとSGD最適化を用いた.
    import torch.optim as optim
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    

    4.training setでCNNを訓練する

    for epoch in range(2):  # loop over the dataset multiple times
    
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            # get the inputs
            inputs, labels = data
    
            # zero the parameter gradients
            optimizer.zero_grad()
    
            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
    
            # print statistics
            running_loss += loss.item()
            if i % 2000 == 1999:    # print every 2000 mini-batches
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, i + 1, running_loss / 2000))
                running_loss = 0.0
    
    print('Finished Training')
    

    Enumerate()関数

    enumerate()関数は、リスト、メタグループ、文字列などの遍歴可能なデータオブジェクトをインデックスシーケンスに結合するとともに、forループで一般的に使用されるデータとデータの下付きラベルをリストするために使用されます.次のようにenumerateの使い方を示します.
    enumerate(sequence, [start=0])
    
  • 例:
  • >>>seq = ['one', 'two', 'three']
    >>> for i, element in enumerate(seq):
    ...     print i, element
    ... 
    0 one
    1 two
    2 three
    

    iter()とenumerate()

    
    lst = [3,60,9]
    strlst = ['a','box','c']
    
    lst_enum = enumerate(lst)
    lst_iter = iter(lst)
    strlst_enum = enumerate(strlst)    
    strlst_iter = iter(strlst)
    
    for a in lst_enum:
        print(a)
    >>>	(0, 3)
    	(1, 60)
    	(2, 9)
    for b in lst_iter:
        print(b)
    >>>	3
    	60
    	9
    for c in strlst_enum:
        print(c)
    >>>	(0, 'a')
    	(1, 'box')
    	(2, 'c')
    for d in strlst_iter:
        print(d)
    >>>  a
    	box
    	c 
    for e,element in lst_enum:
        print(e,element)
    >>>	0 a
    	1 box
    	2 c
    

    5.test setでCNNをテストする

    dataiter = iter(testloader)
    images, labels = dataiter.next()
    
    # print images
    #  4 
    imshow(torchvision.utils.make_grid(images))
    print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))
    
    outputs = net(images)
    
    _, predicted = torch.max(outputs, 1)
    
    print('Predicted: ', ' '.join('%5s' % classes[predicted[j]]
                                  for j in range(4)))
    #  
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    print('Accuracy of the network on the 10000 test images: %d %%' % (
        100 * correct / total))
    

    iter()関数

    iter()関数は反復器を生成するために使用される.listtupleなどは反復可能なオブジェクトであり、iter()関数によってこれらの反復可能なオブジェクトの反復器を取得することができる.次に、取得した反復器に対してnext()関数を絶えず適用して次のデータを取得することができる.iter()関数は、実際には反復可能なオブジェクトを適用した__iter__方式である.iter()の使い方は以下の通りです.
    iter(object[, sentinel])
    
  • パラメータ:反復をサポートする集合オブジェクト.
  • は、値を返します.反復オブジェクトです.
  • >>>lst = [1, 2, 3]
    >>> for i in iter(lst):
    ...     print(i)
    ... 
    1
    2
    3
    

    next()メソッド

    next()メソッドは、ファイルが反復器を使用する場合に使用され、ループではnext()メソッドが各ループで呼び出され、このメソッドはファイルの次の行を返し、エンドポイント(EOF)に達するとStopIterationがトリガーされます.使用方法:
    fileObject.next(); 
    
  • パラメータ:なし.
  • は、値を返します.ファイルの次の行です.

  • next()関数


    next()は反復器の次の項目を返します.使用法は次のとおりです.
    next(iterator[, default])
    
  • パラメータ:
  • iterator–反復可能オブジェクト
  • default–オプションで、次の要素がない場合にデフォルト値を返すように設定します.設定しない場合、次の要素がない場合、StopIteration例外がトリガーされます.

  • 戻り値:オブジェクトヘルプ情報を返します.
  • 例:
  • #  Iterator :
    it = iter([1, 2, 3, 4, 5])
    #  :
    while True:
        try:
            #  :
            x = next(it)
            print(x)
        except StopIteration:
            #  StopIteration 
            break1
    2
    3
    4
    5
    

    質問:iter(trainloader)とenumerate(trainloader)


    画像を可視化する際に使用するのは、次のとおりです.
    dataiter = iter(trainloader)
    images, labels = dataiter.next()
    

    これでimages、labelsのsizeはすべて4で、trainloaderの1つのbatch_ですsizeの大きさ.でもネットを訓練するときに使うのは
    for i, data in enumerate(trainloader, 0):
    

    このように巡るiは0-12499です

    参考資料


    [1] https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar 10-tutorial-py[2]PyTorchドキュメント中国語版:https://pytorch-cn.readthedocs.io/zh/latest/[3] http://www.runoob.com/python/