機械学習用にプロ野球の球団マスコット画像を集めるスクリプトを作ってみた


機械学習用(画像判定のデータ収集用)のデータ集め用に、google画像検索から画像ファイルを集めるスクリプトを書いてみたので、ここに残しておこうかと思います。

ちなみにseleniumを使用するための環境構築は以下のリンクなんかを参考にしています。
https://tanuhack.com/python/selenium/

from selenium import webdriver
from selenium.webdriver.common.keys import Keys
import os
import json
import urllib
import sys
import time
import io
from PIL import Image

text = [ "ドアラ","スラィリー","トラッキー","ジャビット", "つば九郎","スターマン" ]
download_path = ["doara/","sllighly/","toracky/","jabitto/","tubakuro/","starman/"]
headers = {}
headers['User-Agent'] = "Mozilla/5.0 (Windows NT 6.1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/41.0.2228.0 Safari/537.36"
extensions = {"jpg", "jpeg", "png", "gif"}
img_count = 0
downloaded_img_count = 0
img_skip = 0
dPath_index = 0

#画像を取得
for index in text:
    # ディレクトリを作成
    try:
        os.mkdir(download_path[dPath_index])
    except OSError:
        print("file exist")

    # Firefoxを起動
    browser = webdriver.Firefox()
    # Webブラウザを表示
    url = "https://www.google.co.in/search?q={}&source=lnms&tbm=isch".format(index)
    browser.get(url)
    #  画像を全て表示させる
    for __ in range(10):
        # スクロール
        browser.execute_script("window.scrollBy(0, 1000000)")
        time.sleep(0.2)
        try:
            # 「結果をもっと表示」があればクリック
            browser.find_element_by_xpath("//input[@value='結果をもっと表示']").click()
            time.sleep(2.5)
        except Exception as e:
            print("not found:"+ str(e))

    imges = browser.find_elements_by_xpath('//div[contains(@class,"rg_meta")]')

    print("Total images:"+ str(len(imges)) + "\n")
    for img in imges:
            # Get image
            img_count += 1
            img_url = json.loads(img.get_attribute('innerHTML'))["ou"]
            img_type = json.loads(img.get_attribute('innerHTML'))["ity"]

            print("Downloading image "+ str(img_count) + ": "+ img_url)
            try:
                if img_type not in extensions:
                    img_type = "jpg"
                # Download image and save it
                raw_img = io.BytesIO(urllib.urlopen(img_url).read())

                img = Image.open(raw_img)
                img.save( download_path[dPath_index]+'img{}.jpg'.format(img_count))
                time.sleep(0.2)
                downloaded_img_count += 1
            except Exception as e:
                print("Download failed:"+ str(e))
            finally:
                print("")

    # Webブラウザを一旦閉じる
    img_count = 0
    browser.close()
    dPath_index+= 1

参考文献

  • 以下のgitHubと記事を参考に作らせていただきました。