Tensorflow共通関数

7429 ワード

文書ディレクトリ
  • Tensorflow共通関数
  • 1.基礎
  • .データ処理
  • 3.ネットワーク構築
  • 4.Keras
  • を使用
  • 5.自家製データ(前処理)、パッケージ
  • Tensorflow共通関数
    1.基礎
    tf.int
    tf.float32
    tf.float64
    
    tf.constant(    ,dtype=    ) #     
    tf.convert_to_tensor(   ,dtype=?) # numpy tensor
    tf.zeros(  )
    tf.ones(  )
    tf.fill(  ,   )  #       
    
    tf.random.normal(  ,mean=  ,stddev=   )  #         
    tf.random.truncated_normal(  ,mean=  ,stddev=   )  #           
    
    tf.cast(   ,dtype)  #     
    tf.reduce_min(   )
    tf.reduce_max(   )
    te.reduce_mean(   ,aixs=?)
    te.reduce_sum(   ,aixs=?)
    
    axis :  
    	           ,       axis   0 1       。  axis=0    (  ,down), axis=1    (  ,across)      axis,         。
    
  • tf.Variable():変数を「訓練可能」とマークし、マークされた変数は逆伝搬に勾配情報を記録します.ニューラルネットワーク訓練では,この関数が訓練対象パラメータをマークするのによく用いられる.w = tf.Variable(tf.random.normal([2, 2], mean=0, stddev=1))
  • 四則演算:tf.add(テンソル1,テンソル2),tf.subtract,tf.multiply,tf.divide
  • 平方、次方と開方:tf.square,tf.pow,tf.sqrt
  • 行列乗:tf.matmul
  • 2.データ処理
  • data=tf.data.Dataset.from_tensor_slices((入力フィーチャー、ラベル))
  • 勾配解:
    #   with          
    with tf.GradientTape( ) as tape:  
    	w = tf.Variable(tf.constant(3.0)) 
    	loss = tf.pow(w,2)
    grad = tape.gradient(loss,w)   #     
    
  • enumerate:要素
    for i, element in enumerate(seq): 
    	print(i, element)
    
    を巡回するための
  • np.random.seed(116) #      seed,     /       
    np.random.shuffle(x_data) 
    np.random.seed(116) 
    np.random.shuffle(y_data) 
    tf.random.set_seed(116)
    

  • 3.ネットワーク構築
  • tf.one_hot(変換対象データ、depth=数分類):ユニヒート符号化(one-hot encoding):分類問題では、一般的にユニヒート符号をラベルとして使用し、タグカテゴリ:1は、0は非を表す.(例えば、十分な問題は最終的に出力されるが、確率が最大の出力は1であり、その他の出力は0.
  • である.
  • tf.nn.softmax(x)出力を確率分布
  • に適合する.
  • w.assign_sub (x)  
    tf.argmax (   ,axis=   ) #         
    
  • tf.where(    ,   A,   B)  
    c=tf.where(tf.greater(a,b), a, b) #  a>b,  a       ,     b       
    
  • np.vstack(配列1,配列2):2つの配列を垂直方向に重ねる
  • np.mgrid[     :     :   ,    :     :    , … ]  #        
    x, y = np.mgrid [1:3:1, 2:4:0.5]
    x.ravel( )  x      
    np.c_[ ]            
    
  • 損失関数
        :loss=tf.reduce_mean(tf.square(y_-y))
       :tf.losses.categorical_crossentropy(y_,y)
    
  • 正規化はオーバーフィットを解決する:正規化は損失関数にモデル複雑度指標を導入し、W重み付け値を利用して訓練データのノイズ**(一般的に非正規化b)**
  • を弱めた.
  • オプティマイザ:SGD、SGDM、Adagrad、RMSProp、Adam
  • 4.Kerasの使用
  •  Tensorflow API:tf.keras      
       
    import
    train,test
    model=tf.keras.models.Sequential
    model.compile
    model.fit
    model.summary
    
  • model=tf.keras.models.Sequential([ネットワーク構造])#各レイヤネットワーク
  • を記述する
  • ストレート層:tf.keras.layers.Flatten()で、データを1次元配列
  • に押し出します.
  • 全接続層:tf.keras.layers.dense(ニューロン個数、activation="活性化関数",kernel_regularizer=どの正規化)
  • model.compile(optimizer =    , loss =      metrics = [“   ”] )
    loss  :'mse'、'sparse_categorical_crossentropy'
    Metrics  :‘accuracy’、‘categorical_accuracy’ 、‘sparse_categorical_accuracy’(  )
    
  • model.fit (        ,       , 			batch_size= ,	
    	epochs= ,
    	validation_data=(        ,      ), 	   validation_split=              ,       validation_freq =    epoch    )	
    
  • model.summary():パラメータ
  • を返します.
    5.自家製データ(前処理)、パッケージング
  • データ拡張
    image_gen_train=ImageDataGenerator(
    	rescale=1./1.#    ,   255 ,   0~1
    	rotation_range=45#  45   
    	width_shift_range=.15#    
    	height_shift_range=.15#    
    	horizontal_flip=False#    
    	zoom_range=0.5#         50%)
    image_gen_train.fit(x_train)
    
  • 断点続訓
    #     :      callbacks
    tf.keras.callbacks.ModelCheckpoint( 
        filepath=     , 
        save_weights_only=True,  #          
        save_best_only=True,   #         
    history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1, callbacks=[cp_callback])
    #     
    model.load_weights(     )
    
  • パラメータ抽出:
  • model.trainable_variablesモデルで訓練可能なパラメータ
  • np.set_printoptions(threshold=np.inf)

  • acc/loss可視化
    history=model.fit(     ,      , batch_size=, epochs=, validation_split=         ,validation_data=   , validation_freq=    )
    #     
    history: 
    loss:   loss 
    val_loss:    loss 
    sparse_categorical_accuracy:       
    val_sparse_categorical_accuracy:      
    #     
    acc=history.history['sparse_categorical_accuracy']  
    valacc=history.history['val_sparse_categorical accuracy']
    loss=history.history['loss']
    val loss=history. history['val_loss']
    
  • predict(入力フィーチャー、batch_size=整数):フォワード伝播計算結果
  • を返します.