TSNアルゴリズムのPyTorchコード解読(テスト部)


このブログではTSNアルゴリズムのPyTorchコードのテスト部分を紹介しています.まずトレーニング部分のコード解読を見ることをお勧めします.TSNアルゴリズムのPyTorchコード解読(トレーニング部分)、test_moels.pyはテストモデルのエントリです.
前のモジュールのインポートとコマンドラインパラメータの構成は、トレーニングコードと似ています.詳しくは説明しません.
import argparse
import time

import numpy as np
import torch.nn.parallel
import torch.optim
from sklearn.metrics import confusion_matrix

from dataset import TSNDataSet
from models import TSN
from transforms import *
from ops import ConsensusModule

# options
parser = argparse.ArgumentParser(
    description="Standard video-level testing")
parser.add_argument('dataset', type=str, choices=['ucf101', 'hmdb51', 'kinetics'])
parser.add_argument('modality', type=str, choices=['RGB', 'Flow', 'RGBDiff'])
parser.add_argument('test_list', type=str)
parser.add_argument('weights', type=str)
parser.add_argument('--arch', type=str, default="resnet101")
parser.add_argument('--save_scores', type=str, default=None)
parser.add_argument('--test_segments', type=int, default=25)
parser.add_argument('--max_num', type=int, default=-1)
parser.add_argument('--test_crops', type=int, default=10)
parser.add_argument('--input_size', type=int, default=224)
parser.add_argument('--crop_fusion_type', type=str, default='avg',
                    choices=['avg', 'max', 'topk'])
parser.add_argument('--k', type=int, default=3)
parser.add_argument('--dropout', type=float, default=0.7)
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
parser.add_argument('--gpus', nargs='+', type=int, default=None)
parser.add_argument('--flow_prefix', type=str, default='')

args = parser.parse_args()

次に、データセットに基づいてカテゴリ数を決定します.そしてmodels.pyスクリプトのTSNクラスでネットワーク構造をインポートします.また、得られたインターネットの各階層情報を表示するには、net.state_dict()で表示します.checkpoint = torch.load(args.weights)は、プリトレーニングを導入するモデルであり、PyTorchでは、導入モデルはtorchを採用する.load()インタフェース実装、args入力.それは...pthファイル、すなわち予備訓練モデル.base_dict = {'.'.join(k.split('.')[1:]): v for k,v in list(checkpoint.state_dict().items())}は、予め訓練されたモデルを読み取る層と特定のパラメータがbase_に併存することである.dictという辞書の中にあります.net.load_state_dict(base_dict)はtorchを呼び出す.nn.Moduleクラスのload_state_dict法は,事前訓練モデルを用いてnetネットワークを初期化する過程に達する.注意が必要なのはload_state_dictメソッドには、strictという入力もあります.このパラメータがTrueの場合、ネットワーク構造のレイヤ情報は、事前トレーニングモデルのレイヤ情報と厳密に等しくなければならないことを示します.逆に、パラメータのデフォルトはTrueです.ではFalseはいつ使われますか?つまり、ネットワークの一部のレイヤパラメータをプリトレーニングネットワークで初期化したい場合、またはプリトレーニングネットワークのレイヤ情報と初期化されるネットワークのレイヤ情報が完全に一致しない場合、レイヤ情報と同じレイヤしか初期化されません.
if args.dataset == 'ucf101':
    num_class = 101
elif args.dataset == 'hmdb51':
    num_class = 51
elif args.dataset == 'kinetics':
    num_class = 400
else:
    raise ValueError('Unknown dataset '+args.dataset)

net = TSN(num_class, 1, args.modality,
          base_model=args.arch,
          consensus_type=args.crop_fusion_type,
          dropout=args.dropout)

checkpoint = torch.load(args.weights)

base_dict = {'.'.join(k.split('.')[1:]): v for k,v in list(checkpoint.state_dict().items())}
net.load_state_dict(base_dict)

次にargsについてtest_cropsの条件文は、データに対して異なるcrop操作を行うために使用されます.単純なcrop操作とサンプリングを繰り返すcrop操作です.もしargs.test_cropsが1に等しい場合は、まずresizeから指定サイズ(例えば400 resizeから256)まで、その後center crop操作を行い、最後にnetを得る.input_sizeのサイズ(例えば224)は、1枚の画像がcrop操作を終えた後に出力されるか、それとも1枚の画像であるかに注意してください.もしargs.test_cropsが10に等しい場合、プロジェクトのtransformsを呼び出す.pyスクリプトのGroupOverSampleクラスは、繰り返しサンプリングされたcrop操作を行い、最終的に1枚の画像で10枚のcropの結果が得られ、後でGroupOverSampleというクラスについて詳しく説明します.次のデータの読み取り部分は訓練時と似ていますが、注意しなければならないのは:1、num_segmentsのパラメータのデフォルトは25で、訓練時よりずっと多いです.2、test_mode=Trueなので、TSNDataSetクラスの__getitem__メソッドを呼び出すときとトレーニングするときでは少し違います.
if args.test_crops == 1:
    cropping = torchvision.transforms.Compose([
        GroupScale(net.scale_size),
        GroupCenterCrop(net.input_size),
    ])
elif args.test_crops == 10:
    cropping = torchvision.transforms.Compose([
        GroupOverSample(net.input_size, net.scale_size)
    ])
else:
    raise ValueError("Only 1 and 10 crops are supported while we got {}".format(args.test_crops))

data_loader = torch.utils.data.DataLoader(
        TSNDataSet("", args.test_list, num_segments=args.test_segments,
                   new_length=1 if args.modality == "RGB" else 5,
                   modality=args.modality,
                   image_tmpl="img_{:05d}.jpg" if args.modality in ['RGB', 'RGBDiff'] else args.flow_prefix+"{}_{:05d}.jpg",
                   test_mode=True,
                   transform=torchvision.transforms.Compose([
                       cropping,
                       Stack(roll=args.arch == 'BNInception'),
                       ToTorchFormatTensor(div=args.arch != 'BNInception'),
                       GroupNormalize(net.input_mean, net.input_std),
                   ])),
        batch_size=1, shuffle=False,
        num_workers=args.workers * 2, pin_memory=True)

transforms.pyスクリプトのGroupOverSampleクラス.まず__init__のGroupScaleクラスもtransforms.pyで定義するのは、入力したn枚の画像に対してtorchvisionを行うことです.transforms.Scale操作、すなわちresizeから指定サイズまで.GroupMultiScaleCrop.fill_fix_offsetが返すoffsetsは長さ5のリストで、各値はtupleで、最初の4つは4つの点座標で、最後の1つは中心点座標で、この5つの点を左上角座標とする場合、原図の4つの角と中心部分cropで指定寸法の図を出すことができることを目的としています.後述する例があります.crop = img.crop((o_w,o_h,o_w+crop_w,o_h+crop_h))はcrop_w*crop_hのサイズはcrop原画像から除去され、ここでは224*224を採用している.flip_crop = crop.copy().Transpose(Image.FLIP_LEFT_RIGHT)は、cropで得られた画像を左右に反転させるものです.最後に反転していないものと反転したリストをマージすると、1枚の入力画像で10枚の出力が得られます(crop 5枚、crop 5枚、反転5枚).例えばimage_w=340,image_h=256,crop_w=224,crop_h=224であればoffsetsは[(0,0),(116,0),(0,32),(116,32),(58,16)]であるため、最初のcropの結果は原図上左上角座標(0,0),右下角座標(224224)の図であり、これが原図の左上角部分図である.2番目のcropの結果、原図の左上隅座標は(116,0)、右下隅座標は(340224)の図であり、これが原図の右上隅部分図であり、その他は原図の左下隅部分図と右下隅部分図の順に類推され、最後は原図の真ん中中央cropから出た224*224図である.これが論文で述べたcorner cropであり,4つのcornerと1つのcenterである.
class GroupOverSample(object):
    def __init__(self, crop_size, scale_size=None):
        self.crop_size = crop_size if not isinstance(crop_size, int) else (crop_size, crop_size)

        if scale_size is not None:
            self.scale_worker = GroupScale(scale_size)
        else:
            self.scale_worker = None

    def __call__(self, img_group):

        if self.scale_worker is not None:
            img_group = self.scale_worker(img_group)

        image_w, image_h = img_group[0].size
        crop_w, crop_h = self.crop_size

        offsets = GroupMultiScaleCrop.fill_fix_offset(False, image_w, image_h, crop_w, crop_h)
        oversample_group = list()
        for o_w, o_h in offsets:
            normal_group = list()
            flip_group = list()
            for i, img in enumerate(img_group):
                crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h))
                normal_group.append(crop)
                flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT)

                if img.mode == 'L' and i % 2 == 0:
                    flip_group.append(ImageOps.invert(flip_crop))
                else:
                    flip_group.append(flip_crop)

            oversample_group.extend(normal_group)
            oversample_group.extend(flip_group)
        return oversample_group

次にGPUモードの設定、モデルを検証モードの設定、データの初期化などを行います.
if args.gpus is not None:
    devices = [args.gpus[i] for i in range(args.workers)]
else:
    devices = list(range(args.workers))

net = torch.nn.DataParallel(net.cuda(devices[0]), device_ids=devices)
net.eval()

data_gen = enumerate(data_loader)

total_num = len(data_loader.dataset)
output = []

データのループ読み出しが開始され、ループが実行されるたびにvideoのデータが1つ読み込まれることを示す.ループで主にeval_を呼び出すビデオ関数でテストします.予測結果と実際のラベルの結果はoutputリストに保存されます.
proc_start_time = time.time()
max_num = args.max_num if args.max_num > 0 else len(data_loader.dataset)

for i, (data, label) in data_gen:
    if i >= max_num:
        break
    rst = eval_video((i, data, label))
    output.append(rst[1:])
    cnt_time = time.time() - proc_start_time
    print('video {} done, total {}/{}, average {} sec/video'.format(i, i+1,
                                                                    total_num,
                                                                    float(cnt_time) / (i+1)))

eval_ビデオ関数はテストの主体であり,テストデータとモデルを準備した後,この関数により予測する.ビデオを入力dataはtuple:(i,data,label)です.data.view(-1, length, data.size(2), data.size(3))は、本来入力(1,3*args.test_crops*args.test_segments,224224)を(args.test_crops*args.test_segments,3224224)に変換するものであり、batch sizeがargsである.test_crops*args.test_segments.そしてtorch.autograd.Variableインタフェースは、Variableタイプのデータとしてカプセル化され、モデルの入力として使用されます.Net(input_var)で得られた結果はVariableであり、Tensorコンテンツを読み込むにはdata変数を読み取る必要があり、cpu()はcpuに格納され、numpy()はTensorがnumpy arrayに移行し、copy()はコピーを表す.rst.reshape((num_crop,args.test_segments,num_class))は、入力次元(2次元)を指定次元(3次元)に変更することを示し、mean(axis=0)はnum_crop次元は平均値、すなわち、あるフレーム画像の10枚のcropまたはclip画像を予測し、最後にこの10枚の予測結果の平均値をそのフレーム画像の結果とする.最後にreshape操作をもう1つ実行します.最後に返されるのは,videoのindex,予測結果,videoの真実ラベルの3つの値である.
def eval_video(video_data):
    i, data, label = video_data
    num_crop = args.test_crops

    if args.modality == 'RGB':
        length = 3
    elif args.modality == 'Flow':
        length = 10
    elif args.modality == 'RGBDiff':
        length = 18
    else:
        raise ValueError("Unknown modality "+args.modality)

    input_var = torch.autograd.Variable(data.view(-1, length, data.size(2), data.size(3)),
                                        volatile=True)
    rst = net(input_var).data.cpu().numpy().copy()
    return i, rst.reshape((num_crop, args.test_segments, num_class)).mean(axis=0).reshape(
        (args.test_segments, 1, num_class)
    ), label[0]

次にvideo-levelの予測結果を計算します.ここではnp.mean(x[0],axis=0)からargsがわかる.test_segmentsフレーム画像の結果も平均的な方法でvideo-levelの予測結果を計算し、np.argmaxは確率が最も大きいカテゴリをこのビデオの予測カテゴリとする.video_Labelsは実際のカテゴリです.cf = confusion_matrix(video_labels, video_pred).astype(float)は、混同マトリクス生成結果(numpy array)を呼び出した例であり、y_true=[2,0,2,2,0,1],y_pred=[0,0,2,2,0,2]ではconfusion_matrix(y_true,y_pred)の結果はarray([2,0,0],[0,0,1],[1,0,2]])であり,各行は真のカテゴリを表し,各列は予測カテゴリを表す.だからcls_cnt=cf.sum(axis=1)は、実際のカテゴリごとに何個のvideoがあるかを示し、cls_hit = np.diag(cf)は、cfの対角線データを取り出し、各カテゴリのvideoで各予測がどれだけ合っているかを示すのでcls_acc = cls_hit/cls_cntは各カテゴリのvideo予測精度である.np.mean(cls_acc)は各カテゴリの平均精度である.最後のif args.save_scores is not None:文は予測結果をファイルに保存するために使用されます.
video_pred = [np.argmax(np.mean(x[0], axis=0)) for x in output]

video_labels = [x[1] for x in output]

cf = confusion_matrix(video_labels, video_pred).astype(float)

cls_cnt = cf.sum(axis=1)
cls_hit = np.diag(cf)

cls_acc = cls_hit / cls_cnt

print(cls_acc)

print('Accuracy {:.02f}%'.format(np.mean(cls_acc) * 100))

if args.save_scores is not None:

    # reorder before saving
    name_list = [x.strip().split()[0] for x in open(args.test_list)]

    order_dict = {e:i for i, e in enumerate(sorted(name_list))}

    reorder_output = [None] * len(output)
    reorder_label = [None] * len(output)

    for i in range(len(output)):
        idx = order_dict[name_list[i]]
        reorder_output[idx] = output[i]
        reorder_label[idx] = video_labels[i]

    np.savez(args.save_scores, scores=reorder_output, labels=reorder_label)