Pytorchプリトレーニングモデルのダウンロードとロード(VGGを例に)カスタムパス

34136 ワード

簡単に述べる
一般的に、Pytorchがtorchvisionでvggなどのモデルを呼び出すと、コンピュータがcache(Pytorchハードコーディングのアドレス)に(環境変数にTORCH_HOMETORCH_MODEL_ZOOが追加されている場合、この2つの位置の連合経路の下で、例えばTORCH_MODEL_ZOO\modelである場合)、そうでなければTORCH_HOME\modelsまたは~/.torch/modelsである
例えば、私のはC:\Users\lijy2/.torch\models\vgg11-bbd30ac9.pthです.
これは、私たちが望んでいるダウンロードモデルのアドレスではなく、このようなダウンロード方法が遅いなどの可能性があります.
また、このアドレスは簡単に直接呼び出すことができず、非常に不便です.
この点、私が今pytorchバージョンを使っているのかgithubの最新バージョンを使っているのか、似たような改善はしていません.
しかし、このデザインは(私のような強迫症にとって)、需要があるかもしれません.
解決策
まず、ダウンロードの問題を処理します.
ソースコードを読んで、import torch.utils.model_zoo as model_zooの関数を使ってデータをロードします.ソースコードに含まれるこの部分を整理しました
from urllib.parse import urlparse
import torch.utils.model_zoo as model_zoo
import re
import os
def download_model(url, dst_path):
    parts = urlparse(url)
    filename = os.path.basename(parts.path)
    
    HASH_REGEX = re.compile(r'-([a-f0-9]*)\.')
    hash_prefix = HASH_REGEX.search(filename).group(1)
    
    model_zoo._download_url_to_file(url, os.path.join(dst_path, filename), hash_prefix, True)
    return filename

呼び出しインスタンス
model_urls = {
    'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
    'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
    'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
    'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
    'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
    'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
    'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
    'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
}

import os

path = 'D:/Software/DataSet/models/vgg'
if not (os.path.exists(path)):
    os.makedirs(path)
for url in model_urls.values():
    download_model(url, path)

しゅつりょく
100%|███████████████████████████████████████████████████████████████| 531456000/531456000 [01:14<00:00, 7114218.15it/s]
100%|███████████████████████████████████████████████████████████████| 532194478/532194478 [02:46<00:00, 3193007.52it/s]
100%|███████████████████████████████████████████████████████████████| 553433881/553433881 [01:13<00:00, 7536750.60it/s]
100%|██████████████████████████████████████████████████████████████| 574673361/574673361 [00:54<00:00, 10587712.79it/s]
100%|███████████████████████████████████████████████████████████████| 531503671/531503671 [01:10<00:00, 7548305.64it/s]
100%|███████████████████████████████████████████████████████████████| 532246301/532246301 [01:35<00:00, 5598996.73it/s]
100%|██████████████████████████████████████████████████████████████| 553507836/553507836 [00:50<00:00, 10900603.60it/s]
100%|███████████████████████████████████████████████████████████████| 574769405/574769405 [01:11<00:00, 8023263.07it/s]

他のモデルアドレスは、githubの中の対応するモデルのコードを開くことができ、開くと見えます.
https://github.com/pytorch/vision/tree/master/torchvision/models

ロードを見てみましょう
import glob
import os
def load_model(model_name, model_dir):
    model  = eval('models.%s(init_weights=False)' % model_name)
    path_format = os.path.join(model_dir, '%s-[a-z0-9]*.pth' % model_name)
    
    model_path = glob.glob(path_format)[0]
    
    model.load_state_dict(torch.load(model_path))
    return model

使用例:
model_dir = 'D:/Software/DataSet/models/vgg/'
model = load_model('vgg11', model_dir)

出力:
VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU(inplace)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace)
    (8): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace)
    (10): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (11): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): ReLU(inplace)
    (13): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (14): ReLU(inplace)
    (15): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (16): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (17): ReLU(inplace)
    (18): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (19): ReLU(inplace)
    (20): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace)
    (2): Dropout(p=0.5)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace)
    (5): Dropout(p=0.5)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)