在TensorFlow中自定義梯度的兩種方法
前言
在深度學習中,有時候我們需要對某些節點的梯度進行一些定製,特別是該節點操作不可導(比如階梯除法如
),如果實在需要對這個節點進行操作,而且希望其可以反向傳播,那麼就需要對其進行自定義反向傳播時的梯度。在有些場景,如[2]中介紹到的梯度反轉(gradient inverse)中,就必須在某層節點對反向傳播的梯度進行反轉,也就是需要更改正常的梯度傳播過程,如下圖的
所示。
在tensorflow中有若干可以實現定製梯度的方法,這裡介紹兩種。
1. 重寫梯度法
重寫梯度法指的是通過tensorflow自帶的機制,將某個節點的梯度重寫(override),這種方法的適用性最廣。我們這裡舉個例子[3].
符號函式的前向傳播採用的是階躍函式
,如下圖所示,我們知道階躍函式不是連續可導的,因此我們在反向傳播時,將其替代為一個可以連續求導的函式
,於是梯度就是大於1和小於-1時為0,在-1和1之間時是1。
使用重寫梯度的方法如下,主要是涉及到tf.RegisterGradient()
和tf.get_default_graph().gradient_override_map()
,前者註冊新的梯度,後者重寫圖中具有名字name='Sign'
的操作節點的梯度,用在新註冊的QuantizeGrad
替代。
#使用修飾器,建立梯度反向傳播函式。其中op.input包含輸入值、輸出值,grad包含上層傳來的梯度
@tf.RegisterGradient("QuantizeGrad")
def sign_grad(op, grad):
input = op.inputs[0] # 取出當前的輸入
cond = (input>=-1)&(input<=1) # 大於1或者小於-1的值的位置
zeros = tf.zeros_like(grad) # 定義出0矩陣用於掩膜
return tf.where(cond, grad, zeros)
# 將大於1或者小於-1的上一層的梯度置為0
#使用with上下文管理器覆蓋原始的sign梯度函式
def binary(input):
x = input
with tf.get_default_graph().gradient_override_map({"Sign":'QuantizeGrad'}):
#重寫梯度
x = tf.sign(x)
return x
#使用
x = binary(x)
其中的def sign_grad(op, grad):
是註冊新的梯度的套路,其中的op
是當前操作的輸入值/張量等,而grad
指的是從反向而言的上一層的梯度。
通常來說,在tensorflow中自定義梯度,函式tf.identity()
是很重要的,其API手冊如下:
tf.identity(
input,
name=None
)
其會返回一個形狀和內容都和輸入完全一樣的輸出,但是你可以自定義其反向傳播時的梯度,因此在梯度反轉等操作中特別有用。
這裡再舉個反向梯度[2]的例子,也就是梯度為
而不是
。
import tensorflow as tf
x1 = tf.Variable(1)
x2 = tf.Variable(3)
x3 = tf.Variable(6)
@tf.RegisterGradient('CustomGrad')
def CustomGrad(op, grad):
# tf.Print(grad)
return -grad
g = tf.get_default_graph()
oo = x1+x2
with g.gradient_override_map({"Identity": "CustomGrad"}):
output = tf.identity(oo)
grad_1 = tf.gradients(output, oo)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(grad_1))
因為-grad
,所以這裡的梯度輸出是[-1]而不是[1]。有一個我們需要注意的是,在自定義函式def CustomGrad()
中,返回的值得是一個張量,而不能返回一個引數,比如return 0
,這樣會報錯,如:
AttributeError: 'int' object has no attribute 'name'
顯然,這是因為tensorflow的內部操作需要取返回值的名字而int
型別沒有名字。
PS:def CustomGrad()
這個函式簽名是隨便你取的。
2. stop_gradient法
對於自定義梯度,還有一種比較簡潔的操作,就是利用tf.stop_gradient()
函式,我們看下例子[1]:
t = g(x)
y = t + tf.stop_gradient(f(x) - t)
這裡,我們本來的前向傳遞函式是f(x),但是想要在反向時傳遞的函式是g(x),因為在前向過程中,tf.stop_gradient()
不起作用,因此+t
和-t
抵消掉了,只剩下f(x)前向傳遞;而在反向過程中,因為tf.stop_gradient()
的作用,使得f(x)-t的梯度變為了0,從而只剩下g(x)在反向傳遞。
我們看下完整的例子:
import tensorflow as tf
x1 = tf.Variable(1)
x2 = tf.Variable(3)
x3 = tf.Variable(6)
f = x1+x2*x3
t = -f
y1 = t + tf.stop_gradient(f-t)
y2 = f
grad_1 = tf.gradients(y1, x1)
grad_2 = tf.gradients(y2, x1)
with tf.Session(config=config) as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(grad_1))
print(sess.run(grad_2))
第一個輸出為[-1],第二個輸出為[1],顯然也實現了梯度的反轉。
Reference
[1]. How Can I Define Only the Gradient for a Tensorflow Subgraph?
[2]. Ganin Y, Ustinova E, Ajakan H, et al. Domain-adversarial training of neural networks[J]. Journal of Machine Learning Research, 2017, 17(1):2096-2030.
[3]. tensorflow 實現自定義梯度反向傳播
[4]. Custom Gradients in TensorFlow