Tensorflowのtf.argmax()関数

2389 ワード

転載は出典を明記してください。http://www.jianshu.com/p/469789141af7
公式API定義
tf.argmax(input、axis=None、name=None、dimension=None)
Returns the index with the largest value acros axes of a tens.Args:
  • input:A Tensor.Must be one of the follwing types:float 32,float 64,int 64,int 32,uint 8,uint 16,int 16,int 8,complex 64,complex 128,qint 8,quint 8,qint 32,hint 32,half.4591678
  • axis:A Tensor.Must be one of the follwing types:int 32,int 64.int 32,0<=axisname:A name for the operation.
  • Returns:
  • A Tensor of type int 64
  • axisについて
    定義中のaxisはnumpyのaxisと一致しています。コードによって説明します。
    import numpy as np
    import tensorflow as tf
    
    sess = tf.session()
    m = sess.run(tf.truncated_normal((5,10), stddev = 0.1) )
    print type(m)
    print m
    
    -------------------------------------------------------------------------------
    
    [[ 0.09957541 -0.0965599   0.06064715 -0.03011306  0.05533558  0.17263047
      -0.02660419  0.08313394 -0.07225946  0.04916157]
     [ 0.11304571  0.02099175  0.03591062  0.01287777 -0.11302195  0.04822164
      -0.06853487  0.0800944  -0.1155676  -0.01168544]
     [ 0.15760773  0.05613248  0.04839646 -0.0218203   0.02233066  0.00929849
      -0.0942843  -0.05943     0.08726917 -0.059653  ]
     [ 0.02553608  0.07298559 -0.06958302  0.02948747  0.00232073  0.11875584
      -0.08325859 -0.06616175  0.15124641  0.09522969]
     [-0.04616683  0.01816062 -0.10866459 -0.12478453  0.01195056  0.0580056
      -0.08500613  0.00635608 -0.00108647  0.12054099]]
    
    mは5行10列の行列で、タイプはnumpy.ndarayです。
    #  tensorflow  tf.argmax()
    col_max = sess.run(tf.argmax(m, 0) )  # axis=0               
    print col_max
    
    row_max = sess.run(tf.argmax(m, 1) )  # axis=1                
    print row_max
    
    array([2, 3, 0, 3, 0, 0, 0, 0, 3, 4])
    array([5, 0, 0, 8, 9])
    
    -------------------------------------------------------------------------------
    #  numpy  numpy.argmax
    row_max = m.argmax(0)
    print row_max
    
    col_max = m.argmax(1)
    print col_max
    
    array([2, 3, 0, 3, 0, 0, 0, 0, 3, 4])
    array([5, 0, 0, 8, 9])
    
    tf.argmax()とnumpy.argmax()の使い方は同じです。
  • axis=0の場合は、各列の最大値の位置インデックス
  • axis=1の場合は、行の最大値の位置インデックス
  • axis=2、3、4…つまり多次元テンソルの場合、同理推定
  • 参照
  • Tensorflow公式API tf.argmax説明
  • Numpy公式AIP numpy.argmax説明