torchvisionの事前トレーニング済みモデルを使う 画像分類


torchvisionのトレーニング済みモデルをダウンロードして使う方法です

トレーニングするのは面倒

トレーニング済みの画像モデルが簡単に使える

いろんな種類のモデルが使えます。わりと新しいモデルもあります。

方法

事前トレーニング済みモデルをダウンロードしてインスタンス化。

import torchvision.models as models

regnet_y_400mf = models.regnet_y_400mf(pretrained=True)

これだけでモデルができます。

入力画像をモデルの入力に合わせて前処理します。

224*224にリサイズ、ImageNetデータセットに合わせて正規化します。

import torch
from torchvision import transforms
from PIL import Image

transform = transforms.Compose([
 transforms.Resize(224),
 transforms.ToTensor(),
 transforms.Normalize(
 mean=[0.485, 0.456, 0.406],
 std=[0.229, 0.224, 0.225]
 )])

img = Image.open("tabby.jpg")
img_t = transform(img)
batch_t = torch.unsqueeze(img_t, 0)

画像をモデルのフォワードプロセスにかけます。

out = regnet_y_400mf(batch_t)

出力は(1,1000)の数値です。ImageNet1000クラスに対応する信頼度です。
ソフトマックスで%値にまとめます。
クラスラベルをダウンロードし、トップを表示します。

_, index = torch.max(out, 1)
percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100

import urllib
label_url = 'https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt'
class_labels = urllib.request.urlopen(label_url).read().splitlines()
class_labels = class_labels[1:] # remove the first class which is background

print(class_labels[index[0]], percentage[index[0]].item())

'Egyptian cat' 66.66095733642578

使えるモデルは以下で確認できます。

🐣


フリーランスエンジニアです。
お仕事のご相談こちらまで
[email protected]

Core MLやARKitを使ったアプリを作っています。
機械学習/AR関連の情報を発信しています。

Twitter
Medium