posenetで野球のストライクゾーンを算出する


以前、OpenPoseは使ってみたことがあるのですが、GPUが必要な関係で自分の環境ではgoogle colaboratoryの中でしか動かせなかった記憶があります。
最近どうなってるかな...と思ってのぞいてみたらPoseNetが出たということでちょっと触ってみました。

参考記事

こちらの記事に詳しい解説が載っています。

Pythonで実装する

環境構築の注意点

  • デモコードがtensorflow 2に非対応
  • tensorflowはpython3.7以降でインストール不可
conda install tensorflow-gpu scipy pyyaml python=3.6
pip install opencv-python==3.4.5.20

実装

PoseNet-Pythonのwebcamのデモにストライクゾーンの描画処理と描画したものを保存する処理、エラーハンドリングを加えています。

python webcam_demo.py --file (動画ファイル名)
で動画を解析できます。

GitHub

import tensorflow as tf
import cv2
import time
import argparse

import posenet

parser = argparse.ArgumentParser()
parser.add_argument('--model', type=int, default=101)
parser.add_argument('--cam_id', type=int, default=0)
parser.add_argument('--cam_width', type=int, default=1280)
parser.add_argument('--cam_height', type=int, default=720)
parser.add_argument('--scale_factor', type=float, default=0.7125)
parser.add_argument('--output_stride', type=float, default=8)
parser.add_argument('--file', type=str, default=None, help="Optionally use a video file instead of a live camera")
args = parser.parse_args()

outputFile="output.mp4"
video = cv2.VideoCapture(args.file)
outFourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
W = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
H = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
out = cv2.VideoWriter(outputFile, outFourcc, 30.0,
                          (W, H))  # 出力先のファイルを開く


def main():
    with tf.Session() as sess:
        model_cfg, model_outputs = posenet.load_model(args.model, sess)
        output_stride = model_cfg['output_stride']
        #print(model_cfg)

        if args.file is not None:
            cap = cv2.VideoCapture(args.file)
        else:
            cap = cv2.VideoCapture(args.cam_id)
        cap.set(3, args.cam_width)
        cap.set(4, args.cam_height)

        start = time.time()
        frame_count = 0
        while True:
            try:
                input_image, display_image, output_scale = posenet.read_cap(
                cap, scale_factor=args.scale_factor, output_stride=output_stride)

                heatmaps_result, offsets_result, displacement_fwd_result, displacement_bwd_result = sess.run(
                    model_outputs,
                    feed_dict={'image:0': input_image}
                )

                pose_scores, keypoint_scores, keypoint_coords = posenet.decode_multi.decode_multiple_poses(
                    heatmaps_result.squeeze(axis=0),
                    offsets_result.squeeze(axis=0),
                    displacement_fwd_result.squeeze(axis=0),
                    displacement_bwd_result.squeeze(axis=0),
                    output_stride=output_stride,
                    max_pose_detections=10,
                    min_pose_score=0.15)

                keypoint_coords *= output_scale
                

                # TODO this isn't particularly fast, use GL for drawing and display someday...
                overlay_image = posenet.draw_skel_and_kp(
                    display_image, pose_scores, keypoint_scores, keypoint_coords,
                    min_pose_score=0.15, min_part_score=0.1)
                overlay_image = cv2.rectangle(
                    overlay_image,
                    (
                        775,
                        int((keypoint_coords[0,5,0]/2 + keypoint_coords[0,11,0]/2))
                    ),
                    (
                        825,
                        int(keypoint_coords[0,14,0])
                    ),
                    (255,255,255),5
                )

                cv2.imshow('posenet', overlay_image)
                frame_count += 1
                out.write(overlay_image)
                if cv2.waitKey(1) & 0xFF == ord('q'):
                    break
            except IOError as e:
                #終了処理
                out.write(overlay_image)
                out.release()
                video.release()
                break

            
        print('Average FPS: ', frame_count / (time.time() - start))
        out.release()
        video.release()

if __name__ == "__main__":
    main()

フロントで実装を試してみる(React)

お断りしておくと失敗例です。

  • 再描画が多く処理が追い付かない
  • 検出した点が上手くマッピングできない

フロントで処理するのは無謀でした。

import { useCallback, useEffect, useRef, useState } from "react";
import { Player } from "video-react";
import * as tfjs from "@tensorflow/tfjs";
import "@tensorflow/tfjs-backend-cpu";
const posenet = require("@tensorflow-models/posenet");

type Person={
  score:number,
  keypoints:KeyPoint[]
}

type KeyPoint={
  score:number,
  part:string,
  position:{
    x:number,
    y:number
  }
}

const CalcPose = () => {
  const [fileURL, setFileURL] = useState(null);
  const playerRef = useRef(null);
  const canvasRef = useRef(null);
  const [poses, setPoses] = useState<Person[]>(null);
  const [taskId, setTaskId] = useState(null)

  useEffect(() => {
    const f = async () => {
      if (playerRef.current) {
        const net = await posenet.load();
        // estimate poses
        let poses = [];
        if(fileURL){
          const task_id = setInterval(async()=>{
            const tmpPoses = await net.estimatePoses(playerRef.current, {
              flipHorizontal: false,
              maxDetections: 5,
              scoreThreshold: 0.6,
              inputResolution: { width: 1280, height: 720 },
              //nmsRadius: 20,
            });
          console.log(tmpPoses)
          setPoses(tmpPoses)
          },1000)
          setTaskId(task_id)
        }
        //return poses;
      }
    };
    f();
  }, [playerRef, fileURL]);

  return (
    <div
      style={{
        position:'relative'
      }}
    >
      <video
        id="myvideo"
        ref={playerRef}
        autoPlay
        muted
        src={fileURL}
        width={"95%"}
        onEnded={()=>{
          clearInterval(taskId)
        }}
      ></video>
      <div>
        {poses? 
          poses.map((person)=>(
          person.keypoints.map((keypoint)=>(
            <div
              style={{
                position:'absolute',
                left:keypoint.position.x,
                top:keypoint.position.y,
                color:'white',
                fontSize:14
              }}
            ></div>
            
          ))
        ))
        :<></>}
      </div>
      <input
        onChange={(e) => {
          if (e.target.files.length > 0) {
            const url = URL.createObjectURL(e.target.files[0]);
            setFileURL(url);
          }
        }}
        type="file"
        multiple={false}
        accept=".mp4"
      />
    </div>
  );
};
export default CalcPose;

出力