PyTorchのモデルをPruneしてProfileする - 推論の効率化の検証 -


なにこれ

  • PyTorchの枝刈り(Pruning)と分析(Profile)を紹介したい
  • Pruneしたモデルの効率化具合をProfileする

Prune

  • PRUNING TUTORIAL
    • 重みがsparseになって推論処理が軽くなることが期待できる
import torch
import torch.nn.utils.prune as prune
from torchvision import models
resnet18のconv1を使って確認する
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models.resnet18(pretrained=True).to(device)
module = model.conv1
実際のパラメータの変化の一部を見てみる
print(list(module.named_parameters())[0][1][0][0])
prune.l1_unstructured(module, name="weight", amount=0.3)
prune.remove(module, "weight")
print(list(module.named_parameters())[0][1][0][0])
"""
tensor([[-0.0104, -0.0061, -0.0018,  0.0748,  0.0566,  0.0171, -0.0127],
        [ 0.0111,  0.0095, -0.1099, -0.2805, -0.2712, -0.1291,  0.0037],
        [-0.0069,  0.0591,  0.2955,  0.5872,  0.5197,  0.2563,  0.0636],
        [ 0.0305, -0.0670, -0.2984, -0.4387, -0.2709, -0.0006,  0.0576],
        [-0.0275,  0.0160,  0.0726, -0.0541, -0.3328, -0.4206, -0.2578],
        [ 0.0306,  0.0410,  0.0628,  0.2390,  0.4138,  0.3936,  0.1661],
        [-0.0137, -0.0037, -0.0241, -0.0659, -0.1507, -0.0822, -0.0058]],
       grad_fn=<SelectBackward>)
tensor([[-0.0000, -0.0000, -0.0000,  0.0748,  0.0566,  0.0171, -0.0000],
        [ 0.0000,  0.0000, -0.1099, -0.2805, -0.2712, -0.1291,  0.0000],
        [-0.0000,  0.0591,  0.2955,  0.5872,  0.5197,  0.2563,  0.0636],
        [ 0.0305, -0.0670, -0.2984, -0.4387, -0.2709, -0.0000,  0.0576],
        [-0.0275,  0.0160,  0.0726, -0.0541, -0.3328, -0.4206, -0.2578],
        [ 0.0306,  0.0410,  0.0628,  0.2390,  0.4138,  0.3936,  0.1661],
        [-0.0000, -0.0000, -0.0241, -0.0659, -0.1507, -0.0822, -0.0000]],
       grad_fn=<SelectBackward>)
"""
torch.nn.Conv2dとtorch.nn.Linearをpruningする
model = models.resnet18(pretrained=True).to(device) # ↑でconv1だけpruneしてあるのでリロード

for name, module in model.named_modules():
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module, name='weight', amount=0.2)
        prune.remove(module, "weight")
    elif isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=0.4)
        prune.remove(module, "weight")
  • 注意:prune.removeしないと、forwardの際にpruneの結果を計算するhookがオーバーヘッドになってむしろ遅くなる場合も。(下記参照)

Profile

  • PYTORCH PROFILER(基本的な使用方法)
  • PROFILING YOUR PYTORCH MODULE(改善例)
    1. torch.floatが必要な処理に対してtorch.doubleからの変換をかませるとメモリ使用量が大きくなってしまう
    2. CUDAからCPUへのコピーやCUDA上でもできる処理をCPU上でわざわざ行うと処理時間が伸びる
  • PyTorch moduleがどれくらいのスピードで処理されるのかを確認できる
import torch.autograd.profiler as profiler
結果をexport_chrome_traceするとchrome
model = models.resnet18(pretrained=True).to(device) # ↑でpruneしたのでリロード
inputs = torch.randn(5, 3, 28, 28).to(device)

model(inputs) # warming up

use_cuda = torch.cuda.is_available()
with profiler.profile(record_shapes=True, profile_memory=True, use_cuda=use_cuda, with_stack=True) as prof:
    with profiler.record_function("model_inference"):
            model(inputs)
prof.export_chrome_trace("before.json")
  • tracing(chrome://tracing)でGUI付きで分析できます

Profilingの結果をtracingで表示した例

PruneしてProfileするなら

  • 複数のexportをまとめる場合、pidを変えれば並べて比較できる
import json
pruneしてexport
model = models.resnet18(pretrained=True).to(device)
for name, module in model.named_modules():
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module, name='weight', amount=0.2)
        prune.remove(module, "weight")
    elif isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=0.4)
        prune.remove(module, "weight")
use_cuda = torch.cuda.is_available()
with profiler.profile(record_shapes=True, profile_memory=True, use_cuda=use_cuda, with_stack=True) as prof:
    with profiler.record_function("model_inference"):
            model(inputs)
prof.export_chrome_trace("after.json")
  • prune.removeを忘れると遅くなることを確認したい
removeしなかった場合をexport
model = models.resnet18(pretrained=True).to(device)
for name, module in model.named_modules():
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module, name='weight', amount=0.2)
        # prune.remove(module, "weight")
    elif isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=0.4)
        # prune.remove(module, "weight")
use_cuda = torch.cuda.is_available()
with profiler.profile(record_shapes=True, profile_memory=True, use_cuda=use_cuda, with_stack=True) as prof:
    with profiler.record_function("model_inference"):
            model(inputs)
prof.export_chrome_trace("with_hooks.json")
exportしたjsonを一つにまとめる
with open("before.json", "r") as f:
    before = json.load(f)
with open("with_hooks.json", "r") as f:
    with_hooks = json.load(f)
with open("after.json", "r") as f:
    after = json.load(f)

concat = []
for prof in before:
    if prof["pid"] == "CPU functions":
        prof["pid"] = "CPU (before)"
    elif prof["pid"] == "CUDA functions":
        prof["pid"] = "CUDA (before)"
    concat.append(prof)
for prof in with_hooks:
    if prof["pid"] == "CPU functions":
        prof["pid"] = "CPU (with_hooks)"
    elif prof["pid"] == "CUDA functions":
        prof["pid"] = "CUDA (with_hooks)"
    concat.append(prof)
for prof in after:
    if prof["pid"] == "CPU functions":
        prof["pid"] = "CPU (after)"
    elif prof["pid"] == "CUDA functions":
        prof["pid"] = "CUDA (after)"
    concat.append(prof)

with open("trace_all.json", "w") as f:
    json.dump(concat, f)
  • tracingで表示してみると
    trace_all

  • pruning前に49.601 msかかっていた処理は、

    • pruning後に49.485 msにできる。
    • ただし、prune.removeを忘れると55.591 msかかる
  • ちなみに、同じpruningを行った場合、resnet152であれば10%近い処理スピードの改善が見られた

    • 前:399.683 ms
    • 後:367.759 ms

まとめ

  • スカスカにすると精度が落ちることは忘れてはならない。早くなったことを喜んでいる場合じゃないかもしれない。
  • resnet18やresnet152のような小さなモデルであればこの程度だが、より巨大なモデルでのpruningの効果は大きい。精度を許容範囲内に維持しつつ疎にできるパラメータをどのように発見するかがキモでしょう。