【MediaPipe】手話の数字を読み取ってみた(1~10まで)


mediapipeで日本手話の数字を読み取ってみた ※機械学習モデル使用

自作モデルをmediapipeに追加して日本手話の数字を読み取ってみました。ご参考になれば幸いです。
↓youtubeで動作をご紹介

環境と結果概要

  • python:3.8.5
  • mediapipe:0.8.2
  • mac:Catalina10.15.7、MacBook Pro
  • opencv-python:4.4.0
  • numpy:1.19.3
  • Pillow:8.0.1

参考にさせて頂いたページ・資料

機能概要

  • tfliteモデルを使用して日本手話を1−10で分類
  • 手型の分類結果を日本語で表示
  • 手の画像を切り出して保存(rで保存開始、sで保存ストップ)

結果

  • 日本語を表示するためにはpillowが必要。outputをpillowで出力すると画像windowが個別に立ち上がってきて、画面が埋め尽くされる。opencvのように同じwindowの表示を更新し続けるような方法を見つけられなかった。opencv→pillow→opencvという処理になっている。
  • 精度がいまいち。1−4の分類は5−10に比べて苦手な感じ。画像追加で改善するが、手法としてイマイチな感じ。
  • モデルで分類する前に、切り出した手の画像を回転させて、手の角度を固定化(mediapipeがtfliteモデルに画像を渡す前に行っている処理)して「手型+手型の向き(上下左右)」で判定する形に全体的に組み直したほうが精度が出そうな感じ(未検証)
  • 画像ではなくmediapipeのランドマークの数字から手型を分類したほうが精度が出るかもしれない(参考ページの方法、未検証)
  • 手話の指文字を認識させる際は同じ手型で向きが違うものがたくさん出てくるので、今回の結果は中途半端だが、拘らずに未検証の手法に手をつける予定、コードもやっつけではあるが備忘のため記載

コード

import cv2 as cv
import numpy as np
import mediapipe as mp
mp_drawing = mp.solutions.drawing_utils
mp_hands = mp.solutions.hands

from PIL import Image, ImageDraw, ImageFont
import datetime
import copy
import cvFpsCalc  #参考ページからFPS測定を拝借
import one_to_ten_number_recognition  #自作モデルに画像を渡すと分類結果が帰ってくる


base_dir = '適当なディレクトリ '
image_dir = base_dir + 'images/' #手の画像を保存する際のパス
tflite_model = base_dir + '適当なディレクトリ/jsl_one_to_ten.tflite'
interpriter = one_to_ten_number_recognition.load_model(tflite_model)
classes = ['1', '10', '2', '3', '4', '5', '6', '7', '8', '9']
im_rec_f = 0  #手の画像を保存するかどうかのフラグ
pred_str = ""


def calc_bounding_rect(image, landmarks):
  image_width, image_height = image.shape[1], image.shape[0]
  landmark_array = np.empty((0, 2), int)

  for _, landmark in enumerate(landmarks.landmark):
    landmark_x = min(int(landmark.x * image_width), image_width - 1)
    landmark_y = min(int(landmark.y * image_height), image_height - 1)
    landmark_point = [np.array((landmark_x, landmark_y))]
    landmark_array = np.append(landmark_array, landmark_point, axis=0)

  x, y, w, h = cv.boundingRect(landmark_array)
  return [x - 50, y - 50, x + w + 100, y + h + 100]


def draw_info(Image, fps, pred):
  info_list = ["フレームレート: " + str(fps) \
              ,"予測: " + pred \
              ,]

  d = ImageDraw.Draw(Image)
  d.font = ImageFont.truetype('/System/Library/Fonts/ヒラギノ丸ゴ ProN W4.ttc', 30)
  bw = 1
  pos = [10,20]
  color_f = (255,255,255)
  color_b = (0,0,0)

  for info in info_list:
    d.text((pos[0]-bw,pos[1]-bw),info,color_b)
    d.text((pos[0]-bw,pos[1]+bw),info,color_b)
    d.text((pos[0]+bw,pos[1]-bw),info,color_b)
    d.text((pos[0]+bw,pos[1]+bw),info,color_b)
    d.text((pos[0],pos[1]),info,color_f)
    pos[1] = pos[1] + 30
  return(Image)


def get_hand_image(image, brect):
  hand_image = image[brect[1]:brect[3],brect[0]:brect[2]]
  try:
    hand_image = cv.cvtColor(hand_image,cv.COLOR_BGR2RGB)
    hand_image = Image.fromarray(hand_image)
    hand_image = hand_image.convert('RGB')
    hand_image = hand_image.resize((224,224))
  except:
    pass
  return hand_image


def seve_hand_image(hand_image):  #保存する画像量を減らしている
  if im_rec_f == 1 and int(datetime.datetime.now().strftime('%Y%m%d%H%M%S%f')) % 6 == 0:
    image_path = image_dir  + datetime.datetime.now().strftime('%Y%m%d%H%M%S%f') + '.png'
    hand_image.save(image_path)


def get_pred(hand_image):
  pred = one_to_ten_number_recognition.image_classify(interpriter,classes,hand_image)
  return pred


#### main ####
# For webcam input:
cvFpsCalc = cvFpsCalc.CvFpsCalc(buffer_len=10)
hands = mp_hands.Hands(min_detection_confidence=0.5, min_tracking_confidence=0.5)
cap = cv.VideoCapture(0)

while cap.isOpened():
  try:  #手の画像がうまく取れなかった時に墜落する、乱暴に回避した
    fps = cvFpsCalc.get()
    success, image = cap.read()
    if not success:
      print("Ignoring empty camera frame.")
      # If loading a video, use 'break' instead of 'continue'.
      continue

    #input_image = image
    input_image = cv.flip(image, 1)
    output_image = copy.deepcopy(input_image)

    input_image = cv.cvtColor(input_image,cv.COLOR_BGR2RGB)
    image.flags.writeable = False

    #手型の判定頻度を減らしている
    if int(datetime.datetime.now().strftime('%Y%m%d%H%M%S%f')) % 3 == 0:
      results = hands.process(input_image)

      # Draw the hand annotations on the image.
      if results.multi_hand_landmarks is not None:
        for hand_landmarks, handedness in zip(results.multi_hand_landmarks,results.multi_handedness):
          brect = calc_bounding_rect(output_image, hand_landmarks)
          hand_image = get_hand_image(output_image, brect)
          pred_str = get_pred(hand_image)
          seve_hand_image(hand_image)
    output_image = Image.fromarray(output_image)
    output_image = draw_info(output_image, fps, pred_str)
    output_image = np.array(output_image)
    cv.imshow('Hand Gesture Recognition', output_image)
  except:
    pass

  key = cv.waitKey(5) & 0xff
  if key == 27: break
  elif key == ord('r'): im_rec_f = 1
  elif key == ord('c'): im_rec_f = 0

hands.close()
cap.release()

以下、tfliteの処理は別ファイルにしました

import os
import sys
import numpy as np
import tensorflow as tf
import torch
import torchvision


def load_model(tflite_model):
  interpreter = tf.lite.Interpreter(model_path = tflite_model)
  interpreter.allocate_tensors()
  return(interpreter)

def image_classify(interpreter,classes,image):
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()
    trans = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                                          torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),])

    input_tensor = trans(image)
    input_tensor = input_tensor[np.newaxis,:,:,:]
    interpreter.set_tensor(input_details[0]['index'], input_tensor)

    # 実行
    interpreter.invoke()
    output_convs = interpreter.get_tensor(output_details[0]['index'])
    output_convs = [f'{output_convs:.2f}' for output_convs in output_convs[0]]
    if float(output_convs[np.argmax(output_convs)]) >= 7.0:
      return('数字:' + classes[np.argmax(output_convs)] + ', 確信度: ' \
           + output_convs[np.argmax(output_convs)])
    else:
      return('確信度低:' + ', 推定 : ' + classes[np.argmax(output_convs)] + ' ,確信度: ' + output_convs[np.argmax(output_convs)])