kerasでbackend その2


概要

kerasのbackendだけで、学習して分類してみた。

写真

環境

Raspberry Pi 3 model B v1.2 element14
2017-09-07-raspbian-stretch
tensorflow-1.3

サンプルコード

from tensorflow.contrib.keras.python.keras import backend as K
from tensorflow.contrib.keras.python.keras.optimizers import SGD
import numpy as np
import matplotlib.pyplot as plt

dx, dy = [], []
dx.append([0.0, 0.0])
dx.append([0.0, 1.0])
dx.append([1.0, 0.0])
dx.append([1.0, 1.0])
dy.append(0.0)
dy.append(1.0)
dy.append(1.0)
dy.append(0.0)
input_dim = 2
output_dim = 1
hidden_dim = 8
x = K.placeholder(shape = (None, input_dim), name = "x")
ytrue = K.placeholder(shape = (None, output_dim), name = "y")
W1 = K.random_uniform_variable((input_dim, hidden_dim), 0, 1, name = "W1")
W2 = K.random_uniform_variable((hidden_dim, output_dim), 0, 1, name = "W2")
b1 = K.random_uniform_variable((hidden_dim, ), 0, 1, name = "b1")
b2 = K.random_uniform_variable((output_dim, ), 0, 1, name = "b2")
params = [W1, b1, W2, b2]
hidden = K.tanh(K.dot(x, W1) + b1)
ypred = K.tanh(K.dot(hidden, W2) + b2)
loss = K.mean(K.square(ypred - ytrue), axis = -1)
opt = SGD()
updates = opt.get_updates(params, [], loss)
train = K.function(inputs = [x, ytrue], outputs = [loss], updates = updates)
for ep in range(1000):
    for i in range(4):
        c3 = train([[dx[i]], [[dy[i]]]])
    if ep % 100 == 0:
        print (ep, c3[0])
pred = K.function(inputs = [x], outputs = [ypred])

sample = 4
xps, yps = [], []
for i in range(sample):
    if dy[i] < 0.5:
        xps.append(dx[i][0])
        yps.append(dx[i][1])
xns, yns = [], []
for i in range(sample):
    if dy[i] >= 0.5:
        xns.append(dx[i][0])
        yns.append(dx[i][1])
ticks = 40
a = []
b = []
for ix in range(ticks):
    for iy in range(ticks):
        po = [ix * 2.0 / ticks - 0.5, iy * 2.0 / ticks - 0.5]
        a.append(po)
        preds = pred([[po]])
        b.append(preds)
p = np.array(a)
r = np.array(b)
thresh = 0.5
xp, yp = [], []
for i in range(ticks ** 2):
    if r[i]< thresh:
        xp.append(p[i][0])
        yp.append(p[i][1])
plt.scatter(xp, yp, c = "cyan", marker = ".")
xn, yn = [], []
for i in range(ticks ** 2):
    if r[i] >= thresh:
        xn.append(p[i][0])
        yn.append(p[i][1])
plt.scatter(xn, yn, c = "magenta", marker = ".")
plt.scatter(xps, yps, c = "blue", marker = "o", s = 100)
plt.scatter(xns, yns, c = "red", marker = "o", s = 200)
plt.savefig("backend11.png")
plt.show()



以上。