TensorFlowユニットテストの簡単な例


TensorFlowのtf.test.TestCaseクラスはunittestを継承した.TestCaseクラスはtensorflowコードのユニットテストに使用されます.
tf.test.TestCaseはassertAlleEqualを提供し、2つのnumpy arrayが完全に同じ値を持っていることを判断し、sessionメソッドは計算図のノードを実行し、その他のメソッドを実行します.詳細はリンクを参照してください.
次の2つの関数があります.
# Python3
import tensorflow as tf


def dense_layer(x, W, bias, activation=None):
  y = x @ W + bias
  if activation:
    return activation(y)
  else:
    return y


def expand_reshape_tensor(x, high, width):
  return tf.reshape(x, (high, width, 1, 1))

第1の関数は全接続層であり,第2の関数はテンソルを拡張塑形操作するために用いられる.
次にUtilsTestsクラスを作成し、tfを継承します.test.TestCaseクラス、定義test_dense_Layerメソッドは、最初の関数をテストし、test_を定義します.expand_reshape_tensor法は2番目の関数をテストする.
import tensorflow as tf
import utils


class UtilsTests(tf.test.TestCase):

  def test_dense_layer(self):
    x = tf.reshape(tf.range(27), (9, 3)) - 13
    W = tf.reshape(tf.range(9), (3, 3)) - 4
    bias = tf.range(3) - 1
    y1 = utils.dense_layer(x, W, bias)
    y2 = utils.dense_layer(x, W, bias, tf.nn.relu)
    with self.session() as sess:
      y1, y2 = sess.run((y1, y2))

    self.assertAllEqual(
      [[-13, -12, -11],
       [-10,  -9,  -8],
       [ -7,  -6,  -5],
       [ -4,  -3,  -2],
       [ -1,   0,   1],
       [  2,   3,   4],
       [  5,   6,   7],
       [  8,   9,  10],
       [ 11,  12,  13]], x) #  x      

    self.assertAllEqual(
      [[-4, -3, -2],
       [-1,  0,  1],
       [ 2,  3,  4]], W) #  W      

    self.assertAllEqual([-1,  0,  1], bias) #   bias  

    self.assertAllEqual(
      [[ 41,   6, -29],
       [ 32,   6, -20],
       [ 23,   6, -11],
       [ 14,   6,  -2],
       [  5,   6,   7],
       [ -4,   6,  16],
       [-13,   6,  25],
       [-22,   6,  34],
       [-31,   6,  43]], y1) #           
    self.assertAllEqual(
      [[41,  6,  0],
       [32,  6,  0],
       [23,  6,  0],
       [14,  6,  0],
       [ 5,  6,  7],
       [ 0,  6, 16],
       [ 0,  6, 25],
       [ 0,  6, 34],
       [ 0,  6, 43]], y2) #           

  def test_expand_reshape_tensor(self):
    x = tf.range(9)
    y = utils.expand_reshape_tensor(x, 3, 3)
    shape = tf.shape(y)
    with self.session() as sess:
      shape = sess.run(shape)
    self.assertAllEqual(shape, (3, 3, 1, 1)) #         


if __name__ == "__main__":
  tf.test.main()  #       

次に,3つの関数をテストし,そのうちの1つはスキップし,すなわち2つの関数をテストしたという結果を得た.テストに成功しました.
----------------------------------------------------------------------
Ran 3 tests in 1.167s

OK (skipped=1)
[Finished in 4.6s]