[PyTorch]マルチインスタンス学習(MIL)のためのマルチプロセスによるGPU並列化


0.この記事の対象者

  1. PyTorchを使って画像のマルチインスタンス学習(MIL)を実装する方
  2. GPU並列化によって実行時間を短縮したい方
  3. 容量の大きな画像(例:Whole Slide Image)を扱っており, patchの読み込みがボトルネックになっている方
  4. 特に自前のデータセット (torchvision.datasetsにないデータ)を使用する方

1 概要

画像全体を一度にCNNに入力できない程に画像が大きく, マルチインスタンス学習(MIL)を用いる場合を想定

  • PyTorchのマルチプロセス化機能によってMILを複数のGPUで並列で処理
  • Bag単位での学習用に自作のDatasetクラスを作成
    • それぞれの大きい画像から小領域画像(patch)を切り出して複数のbagを作成
    • Bag単位でデータセットを作成
  • マルチプロセス対応した学習用ソースコードを変更
    • 主にMILモデル, Datasetのマルチプロセス化
  • Bag単位での学習の実装
    • PyTorchではbatch単位での学習を行うため, bagとbatchのギャップを吸収する実装

サンプルコード
https://github.com/switch23/multi-gpu-MIL

2 問題点

PyTorchではマルチプロセス化する際に, Datasetクラスを継承したクラスを用いる必要があるので, MILでの学習に特有な以下の設定を上手く組み込んだデータセットクラスを自作しなければならない

  • MILではbag単位でlabelが割り当てられる
    • 通常, batch学習では画像1枚に対してlabelが1つ割り当てられるが, MILでは同じ画像から切り出された複数の小領域画像patchを一塊のbagにしてlabelが1つ割り当てられる
  • 1枚の大きな画像から複数のbagを作成する
    • 例えば, 1枚の画像から100patchから成るbagを50個作成
  • 画像の読み込みをbag単位で行う
    • 例えば, 1bag読み込む際には100枚の画像を一度に読み込む
    • したがって, 画像の読み込みは大きなボトルネックとなるため画像を読み込む処理を記述するタイミングに注意が必要

3 準備

容量の大きい画像データセットを扱うに当たって本稿では事前に以下のような準備を実施

  • 重要: 切り出すpatchの座標を書き出したCSVファイルを用意
    • SVSなどの圧縮済みの大きい画像ファイルからOpenSlideを用いて直接patch読み込みを行うことを想定
    • あらかじめSVSなどのファイルを展開してpatchを切り出してそのまま保存すると容量が大変大きくなるため
    • それぞれの大きな画像について以下のようなpatchの座標情報が書かれたCSVを用意
sample.csv
25760,31136
13664,12320
37184,26656
13440,15456
37184,24416
21056,28672
24192,21728
10752,11648
32256,22176
34272,19264
29120,20832
...
  • 個人的な都合 : Cross validation用のデータセット分割プログラムを用意
    • あくまでディレクトリ構造の都合上, 個人的にはこうしました
    • 各々のデータセットのディレクトリ構造などの都合に合わせてよしなに変更するなり無視するなりして頂いて大丈夫です
    • 以下のようにあらかじめ3クラス(A,B,C)の画像を等分割(ここでは5分割)しておく
sample_dataset.py
A_1 = [
    'SLIDE_000',
        ...    ,
    'SLIDE_009',
]
       ...
A_5 = [
    'SLIDE_040',
        ...    ,
    'SLIDE_049',
]

B_1 = [
    'SLIDE_100',
        ...    ,
    'SLIDE_109',
]
        ...
C_5 = [
    'SLIDE_240',
        ...    ,
    'SLIDE_249',
]

# slideを訓練用とテスト(valid)用に分割
def slide_split(train, test):
    # ex) train = '123', test_or_valid = '4'

    data_map = {}
    data_map['data1'] = [A_1, B_1, C_1]
    data_map['data2'] = [A_2, B_2, C_2]
    data_map['data3'] = [A_3, B_3, C_3]
    data_map['data4'] = [A_4, B_4, C_4]
    data_map['data5'] = [A_5, B_5, C_5]

    train_list = [i for i in train]
    train_A = []
    train_B = []
    train_C = []
    for num in train_list:
        train_A = train_A + data_map[f'data{num}'][0]
        train_B = train_B + data_map[f'data{num}'][1]
        train_C = train_C + data_map[f'data{num}'][2]

    test_list = [i for i in test]
    test_A = []
    test_B = []
    test_C = []
    for num in test_list:
        test_A = test_A + data_map[f'data{num}'][0]
        test_B = test_B + data_map[f'data{num}'][1]
        test_C = test_C + data_map[f'data{num}'][2]

    return train_A, train_B, train_C, test_A, test_B, test_C

4 実装

ここでは, 2でも述べたMIL特有の問題を解決しながらマルチプロセスによるGPU並列化を実現する実装方法を説明します

4.1 MIL用自作データセットクラス

本節ではbag単位でデータセットを作成するプログラムを掲載します

特に注意が必要なのは__init__()関数内部では, 直接画像を扱わない点です.
この時点では並列化の恩恵は受けられないので, 普通に画像の読み込みに時間がかかるばかりか, メモリが足りなくなると思います(多分)
したがって, __init__()関数の段階ではpatchの座標リストのみを扱うことで, bagリスト作成のコストを小さくしてやりましょう

逆に画像を読み出すタイミングは__getitem__()関数の内部です
__getitem__()はDataLoaderが実際にデータセットを読み出す際に呼び出されるので, ここでbagのサイズ分だけ画像を読み出すことになります
後述するdeta_samplerによって各々のGPUでDataLoaderが機能するようになり, ここで画像読み出しという最大のボトルネックが解消されます

OriginalDataset.py
import torch
import random
import numpy as np
import os
import openslide

# map vips formats to np dtypes
format_to_dtype = {
    'uchar': np.uint8,
    'char': np.int8,
    'ushort': np.uint16,
    'short': np.int16,
    'uint': np.uint32,
    'int': np.int32,
    'float': np.float32,
    'double': np.float64,
    'complex': np.complex64,
    'dpcomplex': np.complex128,
}

DATA_PATH = f'data_directry'

class OriginalDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, transform = None, bag_num=50, bag_size=100, train=True):
        self.transform = transform

        # Bagリスト作成
        # ここでは座標情報のみを扱い, 画像は読み込まない
        self.bag_list = []
        for slide_data in dataset:
            slideID = slide_data[0]
            label = slide_data[1]
            # CSVファイルからpatchの座標リスト取得
            pos = np.loadtxt(f'{DATA_PATH}/csv/{slideID}.csv', delimiter=',', dtype='int')
            if not train:
                np.random.seed(seed=int(slideID.replace('SLIDE_','')))
            np.random.shuffle(pos)
            if pos.shape[0] > bag_num*bag_size:
                pos = pos[0:(bag_num*bag_size),:]
                for i in range(bag_num):
                    patches = pos[i*bag_size:(i+1)*bag_size,:].tolist()
                    self.bag_list.append([patches, slideID, label])
            else:
                for i in range(pos.shape[0]//bag_size):
                    patches = pos[i*bag_size:(i+1)*bag_size,:].tolist()
                    self.bag_list.append([patches, slideID, label])

        random.shuffle(self.bag_list)
        self.data_num = len(self.bag_list)

    def __len__(self):
        return self.data_num

    def __getitem__(self, idx):
        pos_list = self.bag_list[idx][0]
        patch_len = len(pos_list)
        b_size = 224
        svs_list = os.listdir(f'{DATA_PATH}/svs')
        svs_fn = [s for s in svs_list if self.bag_list[idx][1] in s]
        svs = openslide.OpenSlide(f'{DATA_PATH}/svs/{svs_fn[0]}')
        bag = torch.empty(patch_len, 3, 224, 224, dtype=torch.float)

        # 画像読み込み
        i = 0
        for pos in pos_list:
            if self.transform:
                img = svs.read_region((pos[0],pos[1]),0,(b_size,b_size)).convert('RGB')
                img = self.transform(img)
                bag[i] = img
            i += 1

        label = self.bag_list[idx][2]
        label = torch.LongTensor([label])

        # バッグ, ラベルを返す
        return bag, label

4.2 マルチプロセス対応した学習用ソースコード

本節では, MILモデルやDatasetなどのマルチプロセス化するソースコードを掲載します

まず, モデルのマルチプロセス化について説明します
今回使用するMILモデルをmodelとすると, 以下のようにしてマルチプロセス化します

from torch.nn.parallel import DistributedDataParallel as DDP

model = model.to(rank)
process_group = torch.distributed.new_group([i for i in range(world_size)])
# modelのBatchNormをSyncBatchNormに変更してくれる
model = nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group)
# modelをmulti GPU対応させる
ddp_model = DDP(model, device_ids=[rank])

ここで, rankというのがそれぞれのGPUのことを指し, world_sizeプロセス数を指します
重要なのはmulti GPU対応した新たなモデルddp_modelを生成する点にあります
以降, モデルはこのddp_modelを用いるので注意が必要です

次に, Datasetクラスのマルチプロセス化です
特に重要となるのはdeta_sampler = torch.utils.data.distributed.DistributedSampler(original_dataset, rank=rank)
によって, 各GPUにデータセットを割り当てるdeta_samplerを作成している点です
このdeta_samplerDataLoaderの引数samplerに渡してやれば各GPUに対して割り当てられたデータが読み込まれるようになります
なお, DataLoaderの引数shufflesamplerと併用不可なので必ずFalseとしてください
また, 引数num_workersを2以上に設定すると読み込み時に使用できるCPU数が増え, 高速化が期待できます(3以上に設定すると多少高速になるが2の時と大差ない)

# Train bag作成(epochごとにbag再構築)
original_dataset = OriginalDataset(
    transform=transform,
    dataset=train_dataset,
    train=True
)

# Datasetをmulti GPU対応させる
deta_sampler = torch.utils.data.distributed.DistributedSampler(original_dataset, rank=rank)

# batch_sizeで設定した個数だけbagを各GPUに分配
train_loader = torch.utils.data.DataLoader(
    original_dataset,
    batch_size=1,
    shuffle=False,
    pin_memory=False,
    num_workers=2,
    sampler=deta_sampler
)

以上の処理をまとめたプログラムを掲載します

MIL_train.py
    # model読み込み
    from MIL import feature_extractor, class_predictor, MIL

    # 特徴抽出器の作成 (今回はResNet50を使用)
    encoder = models.resnet50(pretrained=True)
    encoder.fc = nn.Identity()
    for p in encoder.parameters():
        p.required_grad = True

    # MILモデルの構築
    feature_ex = feature_extractor(encoder)
    class_pred = class_predictor()
    model = MIL(feature_ex, class_pred)
    model = model.to(rank)
    process_group = torch.distributed.new_group([i for i in range(world_size)])
    # modelのBatchNormをSyncBatchNormに変更してくれる
    model = nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group)
    # modelをmulti GPU対応させる
    ddp_model = DDP(model, device_ids=[rank])

    # クロスエントロピー損失関数使用
    loss_fn = nn.CrossEntropyLoss()
    # SGDmomentum法使用
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0001)

    # 前処理
    transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize((224, 224)),
        torchvision.transforms.ToTensor()
    ])

    # Train
    for epoch in range(EPOCHS):
        # Train bag作成(epochごとにbag再構築)
        original_dataset = OriginalDataset(
            transform=transform,
            dataset=train_dataset,
            train=True
        )

        # Datasetをmulti GPU対応させる
        deta_sampler = torch.utils.data.distributed.DistributedSampler(original_dataset, rank=rank)

        # batch_sizeで設定した個数だけbagを各GPUに分配
        train_loader = torch.utils.data.DataLoader(
            original_dataset,
            batch_size=1,
            shuffle=False,
            pin_memory=False,
            num_workers=2,
            sampler=deta_sampler
        )

        # 学習
        class_loss, acc = train(ddp_model, rank, loss_fn, optimizer, train_loader)

4.3 Bag単位での学習の実装

4.1で作成した自作データセットクラスを用いて学習を行うわけですが, DataLoaderによる読み込み後にbagとbatchの概念ギャップを吸収するために実装でひと工夫必要となります
本節ではその部分の実装を掲載します

以下のプログラムにおいて, train_loaderからは取り出されるinput_tensor[bag, patch, channel, width, hight]のような5次元構造となっています
通常のbatch学習であれば, input_tensor[batch, channel, width, hight]のような4次元構造のはずです.
なので, ここの次元数のズレにbagとbatchの概念ギャップが顕在化します.
解決は至って簡単で, input_tensorから一つずつbagを取り出してやれば良いのです

MIL_train.py
def train(model, rank, loss_fn, optimizer, train_loader):
    model.train()
    train_class_loss = 0.0
    correct_num = 0

    for (input_tensor, class_label) in train_loader:
        # Bagとbatchのギャップを吸収して学習
        for bag_num in range(input_tensor.shape[0]):
            bag = input_tensor[bag_num].to(rank, non_blocking=True)
            bag_class_label = class_label[bag_num].to(rank, non_blocking=True)
            class_prob, class_hat = model(bag)
            class_loss = loss_fn(class_prob, bag_class_label)
            train_class_loss += class_loss.item()
            optimizer.zero_grad()
            class_loss.backward()
            optimizer.step() 
            correct_num += eval_ans(class_hat, bag_class_label)

    train_class_loss = train_class_loss / len(train_loader)
    train_acc = correct_num / len(train_loader)

    return train_class_loss, train_acc

4.4 その他

MILをPyTorchでマルチプロセス化する際に重要な部分は大体上述したが, 必要なことはまだあるので, 具体的にはここを参照してください

5 まとめ

  • PyTorchのマルチプロセス化を使ってMILをGPU並列化
  • 画像読み込みがボトルネックなため, 大体GPUの枚数倍高速化可能(8GPUまで確認済)
  • 出力は各GPUで吐かれるので, あとでそれらをシングルプロセスで集計する工夫は必要