Pytorchプリトレーニングモデルのダウンロードとロード(VGGを例に)カスタムパス
簡単に述べる
一般的に、Pytorchがtorchvisionでvggなどのモデルを呼び出すと、コンピュータがcache(Pytorchハードコーディングのアドレス)に(環境変数に
例えば、私のは
これは、私たちが望んでいるダウンロードモデルのアドレスではなく、このようなダウンロード方法が遅いなどの可能性があります.
また、このアドレスは簡単に直接呼び出すことができず、非常に不便です.
この点、私が今pytorchバージョンを使っているのかgithubの最新バージョンを使っているのか、似たような改善はしていません.
しかし、このデザインは(私のような強迫症にとって)、需要があるかもしれません.
解決策
まず、ダウンロードの問題を処理します.
ソースコードを読んで、
呼び出しインスタンス
しゅつりょく
他のモデルアドレスは、githubの中の対応するモデルのコードを開くことができ、開くと見えます.
ロードを見てみましょう
使用例:
出力:
一般的に、Pytorchがtorchvisionでvggなどのモデルを呼び出すと、コンピュータがcache(Pytorchハードコーディングのアドレス)に(環境変数に
TORCH_HOME
とTORCH_MODEL_ZOO
が追加されている場合、この2つの位置の連合経路の下で、例えばTORCH_MODEL_ZOO\model
である場合)、そうでなければTORCH_HOME\models
または~/.torch/models
である例えば、私のは
C:\Users\lijy2/.torch\models\vgg11-bbd30ac9.pth
です.これは、私たちが望んでいるダウンロードモデルのアドレスではなく、このようなダウンロード方法が遅いなどの可能性があります.
また、このアドレスは簡単に直接呼び出すことができず、非常に不便です.
この点、私が今pytorchバージョンを使っているのかgithubの最新バージョンを使っているのか、似たような改善はしていません.
しかし、このデザインは(私のような強迫症にとって)、需要があるかもしれません.
解決策
まず、ダウンロードの問題を処理します.
ソースコードを読んで、
import torch.utils.model_zoo as model_zoo
の関数を使ってデータをロードします.ソースコードに含まれるこの部分を整理しましたfrom urllib.parse import urlparse
import torch.utils.model_zoo as model_zoo
import re
import os
def download_model(url, dst_path):
parts = urlparse(url)
filename = os.path.basename(parts.path)
HASH_REGEX = re.compile(r'-([a-f0-9]*)\.')
hash_prefix = HASH_REGEX.search(filename).group(1)
model_zoo._download_url_to_file(url, os.path.join(dst_path, filename), hash_prefix, True)
return filename
呼び出しインスタンス
model_urls = {
'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
}
import os
path = 'D:/Software/DataSet/models/vgg'
if not (os.path.exists(path)):
os.makedirs(path)
for url in model_urls.values():
download_model(url, path)
しゅつりょく
100%|███████████████████████████████████████████████████████████████| 531456000/531456000 [01:14<00:00, 7114218.15it/s]
100%|███████████████████████████████████████████████████████████████| 532194478/532194478 [02:46<00:00, 3193007.52it/s]
100%|███████████████████████████████████████████████████████████████| 553433881/553433881 [01:13<00:00, 7536750.60it/s]
100%|██████████████████████████████████████████████████████████████| 574673361/574673361 [00:54<00:00, 10587712.79it/s]
100%|███████████████████████████████████████████████████████████████| 531503671/531503671 [01:10<00:00, 7548305.64it/s]
100%|███████████████████████████████████████████████████████████████| 532246301/532246301 [01:35<00:00, 5598996.73it/s]
100%|██████████████████████████████████████████████████████████████| 553507836/553507836 [00:50<00:00, 10900603.60it/s]
100%|███████████████████████████████████████████████████████████████| 574769405/574769405 [01:11<00:00, 8023263.07it/s]
他のモデルアドレスは、githubの中の対応するモデルのコードを開くことができ、開くと見えます.
https://github.com/pytorch/vision/tree/master/torchvision/models
ロードを見てみましょう
import glob
import os
def load_model(model_name, model_dir):
model = eval('models.%s(init_weights=False)' % model_name)
path_format = os.path.join(model_dir, '%s-[a-z0-9]*.pth' % model_name)
model_path = glob.glob(path_format)[0]
model.load_state_dict(torch.load(model_path))
return model
使用例:
model_dir = 'D:/Software/DataSet/models/vgg/'
model = load_model('vgg11', model_dir)
出力:
VGG(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace)
(2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): ReLU(inplace)
(5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(6): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(7): ReLU(inplace)
(8): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(9): ReLU(inplace)
(10): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(11): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(12): ReLU(inplace)
(13): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(14): ReLU(inplace)
(15): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(16): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(17): ReLU(inplace)
(18): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(19): ReLU(inplace)
(20): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(classifier): Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU(inplace)
(2): Dropout(p=0.5)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU(inplace)
(5): Dropout(p=0.5)
(6): Linear(in_features=4096, out_features=1000, bias=True)
)
)