Tensorflowのtf.argmax()関数
2389 ワード
転載は出典を明記してください。http://www.jianshu.com/p/469789141af7
公式API定義
tf.argmax(input、axis=None、name=None、dimension=None)
Returns the index with the largest value acros axes of a tens.Args:input:A Tensor.Must be one of the follwing types:float 32,float 64,int 64,int 32,uint 8,uint 16,int 16,int 8,complex 64,complex 128,qint 8,quint 8,qint 32,hint 32,half.4591678 axis:A Tensor.Must be one of the follwing types:int 32,int 64.int 32,0<=axisname:A name for the operation. Returns:A Tensor of type int 64 axisについて
定義中のaxisはnumpyのaxisと一致しています。コードによって説明します。axis=0の場合は、各列の最大値の位置インデックス axis=1の場合は、行の最大値の位置インデックス axis=2、3、4…つまり多次元テンソルの場合、同理推定 参照Tensorflow公式API tf.argmax説明 Numpy公式AIP numpy.argmax説明
公式API定義
tf.argmax(input、axis=None、name=None、dimension=None)
Returns the index with the largest value acros axes of a tens.Args:
定義中のaxisはnumpyのaxisと一致しています。コードによって説明します。
import numpy as np
import tensorflow as tf
sess = tf.session()
m = sess.run(tf.truncated_normal((5,10), stddev = 0.1) )
print type(m)
print m
-------------------------------------------------------------------------------
[[ 0.09957541 -0.0965599 0.06064715 -0.03011306 0.05533558 0.17263047
-0.02660419 0.08313394 -0.07225946 0.04916157]
[ 0.11304571 0.02099175 0.03591062 0.01287777 -0.11302195 0.04822164
-0.06853487 0.0800944 -0.1155676 -0.01168544]
[ 0.15760773 0.05613248 0.04839646 -0.0218203 0.02233066 0.00929849
-0.0942843 -0.05943 0.08726917 -0.059653 ]
[ 0.02553608 0.07298559 -0.06958302 0.02948747 0.00232073 0.11875584
-0.08325859 -0.06616175 0.15124641 0.09522969]
[-0.04616683 0.01816062 -0.10866459 -0.12478453 0.01195056 0.0580056
-0.08500613 0.00635608 -0.00108647 0.12054099]]
mは5行10列の行列で、タイプはnumpy.ndarayです。# tensorflow tf.argmax()
col_max = sess.run(tf.argmax(m, 0) ) # axis=0
print col_max
row_max = sess.run(tf.argmax(m, 1) ) # axis=1
print row_max
array([2, 3, 0, 3, 0, 0, 0, 0, 3, 4])
array([5, 0, 0, 8, 9])
-------------------------------------------------------------------------------
# numpy numpy.argmax
row_max = m.argmax(0)
print row_max
col_max = m.argmax(1)
print col_max
array([2, 3, 0, 3, 0, 0, 0, 0, 3, 4])
array([5, 0, 0, 8, 9])
tf.argmax()とnumpy.argmax()の使い方は同じです。