自作のTransformの作成方法


参考ページ:Pytorch – torchvision で使える Transform まとめ

画像データの前処理に利用するtransformsでは、Lambda関数を渡すことでユーザ定義のTransformが作れる。

from torchvision import transforms
import  cv2
import matplotlib.pyplot as plt

img = cv2.imread("sample.jpeg")

plt.imshow(img)

def gray(img):
    """
    RGBに変換してグレースケール化
    """

    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)

    return img

transform = transforms.Lambda(gray)

img_transformed = transform(img)
plt.imshow(img_transformed)

この処理をComposeでつなげれば、pytorchのtransformのpipelineに組み込むことができる。