TensorFlowユニットテストの簡単な例
2547 ワード
TensorFlowのtf.test.TestCaseクラスはunittestを継承した.TestCaseクラスはtensorflowコードのユニットテストに使用されます.
tf.test.TestCaseはassertAlleEqualを提供し、2つのnumpy arrayが完全に同じ値を持っていることを判断し、sessionメソッドは計算図のノードを実行し、その他のメソッドを実行します.詳細はリンクを参照してください.
次の2つの関数があります.
第1の関数は全接続層であり,第2の関数はテンソルを拡張塑形操作するために用いられる.
次にUtilsTestsクラスを作成し、tfを継承します.test.TestCaseクラス、定義test_dense_Layerメソッドは、最初の関数をテストし、test_を定義します.expand_reshape_tensor法は2番目の関数をテストする.
次に,3つの関数をテストし,そのうちの1つはスキップし,すなわち2つの関数をテストしたという結果を得た.テストに成功しました.
tf.test.TestCaseはassertAlleEqualを提供し、2つのnumpy arrayが完全に同じ値を持っていることを判断し、sessionメソッドは計算図のノードを実行し、その他のメソッドを実行します.詳細はリンクを参照してください.
次の2つの関数があります.
# Python3
import tensorflow as tf
def dense_layer(x, W, bias, activation=None):
y = x @ W + bias
if activation:
return activation(y)
else:
return y
def expand_reshape_tensor(x, high, width):
return tf.reshape(x, (high, width, 1, 1))
第1の関数は全接続層であり,第2の関数はテンソルを拡張塑形操作するために用いられる.
次にUtilsTestsクラスを作成し、tfを継承します.test.TestCaseクラス、定義test_dense_Layerメソッドは、最初の関数をテストし、test_を定義します.expand_reshape_tensor法は2番目の関数をテストする.
import tensorflow as tf
import utils
class UtilsTests(tf.test.TestCase):
def test_dense_layer(self):
x = tf.reshape(tf.range(27), (9, 3)) - 13
W = tf.reshape(tf.range(9), (3, 3)) - 4
bias = tf.range(3) - 1
y1 = utils.dense_layer(x, W, bias)
y2 = utils.dense_layer(x, W, bias, tf.nn.relu)
with self.session() as sess:
y1, y2 = sess.run((y1, y2))
self.assertAllEqual(
[[-13, -12, -11],
[-10, -9, -8],
[ -7, -6, -5],
[ -4, -3, -2],
[ -1, 0, 1],
[ 2, 3, 4],
[ 5, 6, 7],
[ 8, 9, 10],
[ 11, 12, 13]], x) # x
self.assertAllEqual(
[[-4, -3, -2],
[-1, 0, 1],
[ 2, 3, 4]], W) # W
self.assertAllEqual([-1, 0, 1], bias) # bias
self.assertAllEqual(
[[ 41, 6, -29],
[ 32, 6, -20],
[ 23, 6, -11],
[ 14, 6, -2],
[ 5, 6, 7],
[ -4, 6, 16],
[-13, 6, 25],
[-22, 6, 34],
[-31, 6, 43]], y1) #
self.assertAllEqual(
[[41, 6, 0],
[32, 6, 0],
[23, 6, 0],
[14, 6, 0],
[ 5, 6, 7],
[ 0, 6, 16],
[ 0, 6, 25],
[ 0, 6, 34],
[ 0, 6, 43]], y2) #
def test_expand_reshape_tensor(self):
x = tf.range(9)
y = utils.expand_reshape_tensor(x, 3, 3)
shape = tf.shape(y)
with self.session() as sess:
shape = sess.run(shape)
self.assertAllEqual(shape, (3, 3, 1, 1)) #
if __name__ == "__main__":
tf.test.main() #
次に,3つの関数をテストし,そのうちの1つはスキップし,すなわち2つの関数をテストしたという結果を得た.テストに成功しました.
----------------------------------------------------------------------
Ran 3 tests in 1.167s
OK (skipped=1)
[Finished in 4.6s]