1. 程式人生 > >關於tensorflow中的softmax_cross_entropy_with_logits_v2函式的區別

關於tensorflow中的softmax_cross_entropy_with_logits_v2函式的區別

tf.nn.softmax_cross_entropy_with_logits(記為f1) 和
tf.nn.sparse_softmax_cross_entropy_with_logits(記為f3),以及
tf.nn.softmax_cross_entropy_with_logits_v2(記為f2)
之間的區別。

f1和f3對於引數logits的要求都是一樣的,即未經處理的,直接由神經網路輸出的數值, 比如 [3.5,2.1,7.89,4.4]。兩個函式不一樣的地方在於labels格式的要求,f1的要求labels的格式和logits類似,比如[0,0,1,0]。而f3的要求labels是一個數值,這個數值記錄著ground truth所在的索引。以[0,0,1,0]

為例,這裡真值1的索引為2。所以f3要求labels的輸入為數字2(tensor)。一般可以用tf.argmax()來從[0,0,1,0]中取得真值的索引。

f1和f2之間很像,實際上官方文件已經標記出f1已經是deprecated 狀態,推薦使用f2。兩者唯一的區別在於f1在進行反向傳播的時候,只對logits進行反向傳播,labels保持不變。而f2在進行反向傳播的時候,同時對logits和labels都進行反向傳播,如果將labels傳入的tensor設定為stop_gradients,就和f1一樣了。
那麼問題來了,一般我們在進行監督學習的時候,labels都是標記好的真值,什麼時候會需要改變label?f2存在的意義是什麼?實際上在應用中labels並不一定都是人工手動標註的,有的時候還可能是神經網路生成的,一個實際的例子就是對抗生成網路(GAN)。

測試用程式碼:

import tensorflow as tf
import numpy as np

Truth = np.array([0,0,1,0])
Pred_logits = np.array([3.5,2.1,7.89,4.4])

loss = tf.nn.softmax_cross_entropy_with_logits(labels=Truth,logits=Pred_logits)
loss2 = tf.nn.softmax_cross_entropy_with_logits_v2(labels=Truth,logits=Pred_logits)
loss3 = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.argmax(Truth),logits=Pred_logits)

with
tf.Session() as sess: print(sess.run(loss)) print(sess.run(loss2)) print(sess.run(loss3))

參考: