keras_to_tensorflow その2


概要

kerasのモデルをtensorflowに、変換して、deeplearn.jsで、使ってみた。
手順を記載する。

環境

windows 7 sp1 64bit
anaconda3
tensorflow 1.2

kerasのモデルをtensorflowに、変換したpbファイルから、manifesit.jsonとversのファイルをつくるコード。

import os
import os.path
import tensorflow as tf
from tensorflow.python.platform import gfile
import string
import json

FILENAME_CHARS = string.ascii_letters + string.digits + '_'
def _var_name_to_filename(var_name):
    chars = []
    for c in var_name:
        if c in FILENAME_CHARS:
            chars.append(c)
        elif c == '/':
            chars.append('_')
    return ''.join(chars)

chk_fpath = "./"
output_dir = "./deep"
tf.gfile.MakeDirs(output_dir)
manifest = {}
var_filenames_strs = []

with gfile.FastGFile("keras13.pb", 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    _ = tf.import_graph_def(graph_def, name = '')
    for n0 in graph_def.node:
        print (n0.op)
        if n0.op == 'Const':
            print (n0.name)
            tensor = n0.attr['value'].tensor
            size = len(tensor.tensor_content)
            print (size)
            size2 = len(tensor.tensor_shape.dim)
            #print(tensor.tensor_shape.dim[0].size)
            #print(tensor.tensor_shape.dim[1].size)
            if size > 0:
                name = n0.name
                var_filename = _var_name_to_filename(name)
                if size2 > 1:
                    manifest[name] = {
                        'filename': var_filename,
                        'shape': [tensor.tensor_shape.dim[0].size, tensor.tensor_shape.dim[1].size]
                    }
                else:
                    manifest[name] = {
                        'filename': var_filename,
                        'shape': [tensor.tensor_shape.dim[0].size]
                    }
                print ('Writing variable ' + name + '...')
                with open(os.path.join(output_dir, var_filename), 'wb') as f:
                    f.write(tensor.tensor_content)
                var_filenames_strs.append("\"" + var_filename + "\"")
    manifest_fpath = os.path.join(output_dir, 'manifest.json')
    print ('Writing manifest to ' + manifest_fpath)
    with open(manifest_fpath, 'w') as f:
        f.write(json.dumps(manifest, indent = 2, sort_keys = True))
    print ("ok")



jsdoにファイルアップロード

以下をアップする。

dense_1_bias
dense_1_kernel
dense_2_bias
dense_2_kernel
dense_3_kernel
manifest.json

モデルを作る。

kerasのモデル

inputs = Input(shape = (1, ))
m = Dense(30)(inputs)
m = Activation('sigmoid')(m)
m = Dense(10)(m)
m = Activation('sigmoid')(m)
m = Dense(1)(m)
model = Model(inputs, m)
sgd = SGD(lr = 0.1)
model.compile(loss = 'mean_squared_error', optimizer = sgd)
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         (None, 1)                 0         
_________________________________________________________________
dense_1 (Dense)              (None, 30)                60        
_________________________________________________________________
activation_1 (Activation)    (None, 30)                0         
_________________________________________________________________
dense_2 (Dense)              (None, 10)                310       
_________________________________________________________________
activation_2 (Activation)    (None, 10)                0         
_________________________________________________________________
dense_3 (Dense)              (None, 1)                 11        
=================================================================
Total params: 381
Trainable params: 381
Non-trainable params: 0
_________________________________________________________________

deeplearn.jsのモデル

    var input = g.placeholder('input', [1]);
    var hidden1W = g.constant(vars['dense_1/kernel']);
    var hidden1B = g.constant(vars['dense_1/bias']);
    var hidden1 = g.sigmoid(g.add(g.matmul(input, hidden1W), hidden1B));
    var hidden2W = g.constant(vars['dense_2/kernel']);
    var hidden2B = g.constant(vars['dense_2/bias']);
    var hidden2 = g.sigmoid(g.add(g.matmul(hidden1, hidden2W), hidden2B));
    var hidden3W = g.constant(vars['dense_3/kernel']);
    var hidden3B = g.constant(vars['dense_3/bias']);
    var logits = g.matmul(hidden2, hidden3W);
    return [input, logits];

写真

サンプルコード

deeplearnjs

function CheckpointLoader(urlPath) {
    this.urlPath = urlPath;
    if (this.urlPath.charAt(this.urlPath.length - 1) !== '/')
    {
        this.urlPath += '/';
    }
}
CheckpointLoader.prototype.loadManifest = function() {
    var _this = this;
    return new Promise(function(resolve, reject) {
        var xhr = new XMLHttpRequest();
        xhr.open('GET', _this.urlPath + MANIFEST_FILE);
        xhr.onload = function() {
            _this.checkpointManifest = JSON.parse(xhr.responseText);
            resolve();
        };
        xhr.onerror = function(error) {
            alert(MANIFEST_FILE + " not found at " + _this.urlPath + ". " + error);
        };
        xhr.send();
    });
};
CheckpointLoader.prototype.getCheckpointManifest = function() {
    var _this = this;
    if (this.checkpointManifest == null) 
    {
        return new Promise(function(resolve, reject) {
            _this.loadManifest().then(function() {
                resolve(_this.checkpointManifest);
            });
        });
    }
    return new Promise(function(resolve, reject) {
        resolve(_this.checkpointManifest);
    });
};
CheckpointLoader.prototype.getAllVariables = function() {
    var _this = this;
    if (this.variables != null) 
    {
        return new Promise(function(resolve, reject) {
            resolve(_this.variables);
        });
    }
    return new Promise(function(resolve, reject) {
        _this.getCheckpointManifest().then(function(checkpointDefinition) {
            var variableNames = Object.keys(_this.checkpointManifest);
            var variablePromises = [];
            for (var i = 0; i < variableNames.length; i++) 
            {
                variablePromises.push(_this.getVariable(variableNames[i]));
            }
            Promise.all(variablePromises).then(function(variables) {
                _this.variables = {};
                for (var i = 0; i < variables.length; i++)
                {
                    _this.variables[variableNames[i]] = variables[i];
                }
                resolve(_this.variables);
            });
        });
    });
};
CheckpointLoader.prototype.getVariable = function(varName) {
    var _this = this;
    if (!(varName in this.checkpointManifest))
    {
        alert('Cannot load non-existant variable ' + varName);
    }
    var variableRequestPromiseMethod = function(resolve, reject) {
        var xhr = new XMLHttpRequest();
        xhr.responseType = 'arraybuffer';
        var fname = _this.checkpointManifest[varName].filename;
        xhr.open('GET', _this.urlPath + fname);
        xhr.onload = function() {
            var values = new Float32Array(xhr.response);
            var ndarray = dl.NDArray.make(_this.checkpointManifest[varName].shape, {
                values: values 
            });
            resolve(ndarray);
        };
        xhr.onerror = function(error) {
            alert('Could not fetch variable ' + varName + ': ' + error);
        };
        xhr.send();
    };
    if (this.checkpointManifest == null)
    {
        return new Promise(function(resolve, reject) {
            _this.loadManifest().then(function() {
                new Promise(variableRequestPromiseMethod).then(resolve);
            });
        });
    }
    return new Promise(variableRequestPromiseMethod);
};


var MANIFEST_FILE = '/assets/g/m/b/V/gmbVz';
var dl = deeplearn;
var g = new dl.Graph();
var math = new dl.NDArrayMathCPU();
var vars2;
function buildModelGraphAPI(vars) {
    var input = g.placeholder('input', [1]);
    var hidden1W = g.constant(vars['dense_1/kernel']);
    var hidden1B = g.constant(vars['dense_1/bias']);
    var hidden1 = g.sigmoid(g.add(g.matmul(input, hidden1W), hidden1B));
    var hidden2W = g.constant(vars['dense_2/kernel']);
    var hidden2B = g.constant(vars['dense_2/bias']);
    var hidden2 = g.sigmoid(g.add(g.matmul(hidden1, hidden2W), hidden2B));
    var hidden3W = g.constant(vars['dense_3/kernel']);
    var hidden3B = g.constant(vars['dense_3/bias']);
    var logits = g.matmul(hidden2, hidden3W);
    return [input, logits];
}
var reader = new CheckpointLoader('http://jsrun.it');
reader.getAllVariables().then(function(vars) {
    vars2 = vars;
    var _a = buildModelGraphAPI(vars);
    var input = _a[0];
    var probs = _a[1];
    var sess = new dl.Session(input.node.graph, math);
    math.scope(function() {
        var t;
        var s = 200;
        var sin = new Float32Array(s);
        for (t = 0; t < s; t++)
        {
            var data = [t / 30];
            var inputData = dl.Array1D.new(data);
            var probsVal = sess.eval(probs, [{
                tensor: input,
                data: inputData
            }]);
            sin[t] = probsVal.getValues();
        }
        draw(sin, 0);
    });
});

var canvas = document.getElementById("canvas");
var ctx = canvas.getContext("2d");
function draw(data, n) {
    var hc = n * 100 + 150;
    ctx.strokeStyle = "#f00";
    ctx.lineWidth = 1;
    ctx.moveTo(0, hc);
    for (var i = 1; i < canvas.width; i++) 
    {
        ctx.lineTo(i, hc - data[i] * 30);
    }
    ctx.stroke();
}

成果物

以上。