エッジ推論の信頼度が低いときに画像データを自動でS3にアップするLambdaの作成


はじめに

ラズパイカメラで撮った画像を画像分類して、結果の信頼度が微妙だったときにS3バケットに自動でアップロードして再学習用の画像を簡単に増やす。今回はそんな仕組みの雛形をGreenGrassとLambdaで作ります。

なおこの記事は以下の記事の続きです。

今回作成するLambdaの全体像

Lambdaの構成図

以下がLambdaの構成図。

camera.pyとinferene.pyにコードを足していく

camera.pytest.jpgとしてカメラの画像を保存する行camera.capture('/home/pi/test.jpg')
だけ足します。

camera.py
#
# Copyright 2010-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# This class is a camera that uses picamera to take a photo and DLC compiled
# Resnet-50 model to perform image classification, identifying the objects
# shown in the photo.
#

from io import BytesIO
import picamera
import time
import datetime
import boto3

class Camera(object):
    r"""
    Camera that captures an image for performing inference
    with DLC compiled model.
    """

    def capture_image(self):
        r"""
        Capture image with PiCamera.
        """
        camera = picamera.PiCamera()
        imageData = BytesIO()

        try:
            camera.resolution = (224, 224)
            print("Taking a photo from your camera...")
            camera.start_preview()
            time.sleep(2)
            camera.capture(imageData, format = "jpeg", resize = (224, 224))
            camera.stop_preview()

            imageData.seek(0)

            # とりあえず/home/piに'test.jpg'という形で画像を保存
            camera.capture('/home/pi/test.jpg')

            return imageData
        finally:
            camera.close()

        raise RuntimeError("There is a problem with your camera.")

inference.pyは複数のモジュールのインポートの行と条件付け、S3アップロードの行を足してあります。

inference.py
#
# Copyright 2010-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Greengrass lambda function to perform Image Classification with example model
# Resnet-50 that was compiled by DLC.
#
#
import logging
import os

from dlr import DLRModel
from PIL import Image
import numpy as np

import greengrasssdk
import camera
import utils

# 追加したコードに必要なモジュール
import time
import datetime
import boto3


# Create MQTT client
mqtt_client = greengrasssdk.client('iot-data')

# Initialize logger
customer_logger = logging.getLogger(__name__)

# LambdaをGGコンテナ内で動かすとき
# model_resource_path = os.environ.get('MODEL_PATH', '/trained_models')

# LambdaをGGコンテナ無しで動かすとき
model_resource_path = os.getenv("AWS_GG_RESOURCE_PREFIX") + "/trained_models"

dlr_model = DLRModel(model_resource_path, 'cpu')


# Read synset file
synset_path = os.path.join(model_resource_path, 'synset.txt')
with open(synset_path, 'r') as f:
    synset = eval(f.read())


def predict(image_data):
    r"""
    Predict image with DLR. The result will be published
    to MQTT topic '/resnet-50/predictions'.

    :param image: numpy array of the Image inference with.
    """
    flattened_data = image_data.astype(np.float32).flatten()

    prediction_scores = dlr_model.run({'data' : flattened_data})
    max_score_id = np.argmax(prediction_scores)
    max_score = np.max(prediction_scores)

    # Prepare result
    predicted_class = synset[max_score_id]
    result = 'Inference result: "{}" with probability {}.'.format(predicted_class, max_score)

    # Send result
    send_mqtt_message(
        'Prediction Result: {}'.format(result))

    # ここで推論結果に対して条件付けをして、条件に引っ掛かったらS3バケットにアップロードする。ここでは信頼度が80%以下の時アップロード
    if max_score < 0.8:

        s3 = boto3.resource('s3')
        # あげる対象となるS3バケット
        bucketName = 'minagawabucket'

        send_mqtt_message("The probability isn't high enough, sending data to {}".format(bucketName))

        # アップロードする時のタイムスタンプをファイルの末尾につけたい        
        date = datetime.datetime.now() # date to be like datetime.datetime(2021, 6, 11, 17, 7, 8, 805672)
        timestamp = date.strftime("%Y-%m-%d %H:%M:%S") # timestamp to be like '2021-06-11 17:07:08'
        # S3バケットのPicturesForRetraining配下に推論結果のクラス(例えば猫や車)のフォルダを作ってそこに格納する
        s3Key = 'PicturesForRetraining/' + predicted_class + '/' + timestamp + '.jpg'
        # camera.pyで保存したtest.jpgをアップロードする
        data = open('/home/pi/test.jpg', mode='rb')
        s3.Bucket(bucketName).put_object(Key = s3Key, Body = data)

def predict_from_cam():
    r"""
    Predict with the photo taken from your pi camera.
    """
    send_mqtt_message("Taking a photo...")
    my_camera = camera.Camera()
    image = Image.open(my_camera.capture_image())
    image_data = utils.transform_image(image)
    send_mqtt_message("Start predicting...")
    predict(image_data)


def send_mqtt_message(message):
    r"""
    Publish message to the MQTT topic:
    '/resnet-50/predictions'.

    :param message: message to publish
    """
    mqtt_client.publish(topic='/resnet-50/predictions',
                        payload=message)


# The lambda to be invoked in Greengrass
def handler(event, context):
    try:
        predict_from_cam()
    except Exception as e:
        customer_logger.exception(e)
        send_mqtt_message(
            'Exception occurred during prediction. Please check logs for troubleshooting: /greengrass/ggc/var/log.')

テスト

テストしてみます。#にサブスクライブしてtestになんでも良いのでメッセージをパブリッシュ。

7%の確率でマッチ棒。自動的に指定したバケットの再学習用の画像フォルダに送ります。

無事送れていますね。

これでSageMaker GroundTruth等でのアノテーション作業や再学習もやりやすくなりました。

あとがき

・一度SDカードにファイルを書く処理を挟んでいるのはあまりスマートではない気がします。カメラで撮影した画像データをそのまま推論を実行する関数に渡してそのままS3にアップロードするようにしたい。SDカードの寿命のためになるのとそれだけ無駄な処理が減るので。

・会社から自宅のラズパイにMQTTでLambdaを起動してカメラ撮影をしたのですが、部屋が暗くて画像が真っ黒でした。現在使っているMLモデルが真っ黒だとマッチ棒と推論するらしく、結果は全てマッチ棒というちょっと微妙な結果になりました。あ、そうだSwitchbotで電気つければいいじゃんと思ったのですが、物理スイッチの方がOFFになっていたらしく反応なし。。Switchbotあるあるですね。