1. 程式人生 > >[Tensorflow]L2正則化和collection【tf.GraphKeys】

[Tensorflow]L2正則化和collection【tf.GraphKeys】

L2-Regularization 實現的話,需要把所有的引數放在一個集合內,最後計算loss時,再減去加權值。

相比自己亂搞,程式碼一團糟,Tensorflow 提供了更優美的實現方法。

一、tf.GraphKeys : 多個包含Variables(Tensor)集合

 (1)GLOBAL_VARIABLES:使用tf.get_variable()時,預設會將vairable放入這個集合。

   我們熟悉的tf.global_variables_initializer()就是初始化這個集合內的Variables。

import tensorflow as tf
sess=tf.Session()
a=tf.get_variable("a",[3,3,32,64],initializer=tf.random_normal_initializer())
b=tf.get_variable("b",[64],initializer=tf.random_normal_initializer())
#collections=None等價於 collection=[tf.GraphKeys.GLOBAL_VARIABLES]

gv= tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)          #tf.get_collection(collection_name)返回某個collection的列表
for var in gv: 
  print(var is a)
  print(var.get_shape())
   Tips: tf.GraphKeys.GLOBAL_VARIABLES == "variable"。即其儲存的是一個字串。

(2)自定義集合

   想個集合的名字,然後在tf.get_variable時,把集合名字傳給 collection 就好了。

import tensorflow as tf
sess=tf.Session()
a=tf.get_variable("a",shape=[10],collections=["mycollection"])  #不把GLOBAL_VARIABLES加進去,那麼就不在那個集合裡了。
keys=tf.get_collection("mycollection")
for key in keys:
  print(key.name)

二、L2正則化

先看看tf.contrib.layers.l2_regularizer(weight_decay)都執行了什麼:
import tensorflow as tf
sess=tf.Session()
weight_decay=0.1
tmp=tf.constant([0,1,2,3],dtype=tf.float32)
"""
l2_reg=tf.contrib.layers.l2_regularizer(weight_decay)
a=tf.get_variable("I_am_a",regularizer=l2_reg,initializer=tmp) 
"""
#**上面程式碼的等價程式碼
a=tf.get_variable("I_am_a",initializer=tmp)
a2=tf.reduce_sum(a*a)*weight_decay/2;
a3=tf.get_variable(a.name.split(":")[0]+"/Regularizer/l2_regularizer",initializer=a2)
tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES,a2)
#**
sess.run(tf.global_variables_initializer())
keys = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
for key in keys:
  print("%s : %s" %(key.name,sess.run(key)))

我們很容易可以模擬出tf.contrib.layers.l2_regularizer都做了什麼,不過會讓程式碼變醜。 以下比較完整實現L2 正則化。
import tensorflow as tf
sess=tf.Session()
weight_decay=0.1                                                #(1)定義weight_decay
l2_reg=tf.contrib.layers.l2_regularizer(weight_decay)           #(2)定義l2_regularizer()
tmp=tf.constant([0,1,2,3],dtype=tf.float32)
a=tf.get_variable("I_am_a",regularizer=l2_reg,initializer=tmp)  #(3)建立variable,l2_regularizer複製給regularizer引數。
                                                                #目測REXXX_LOSSES集合
#regularizer定義會將a加入REGULARIZATION_LOSSES集合
print("Global Set:")
keys = tf.get_collection("variables")
for key in keys:
  print(key.name)
print("Regular Set:")
keys = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
for key in keys:
  print(key.name)
print("--------------------")
sess.run(tf.global_variables_initializer())
print(sess.run(a))
reg_set=tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)   #(4)則REGULARIAZTION_LOSSES集合會包含所有被weight_decay後的引數和,將其相加
l2_loss=tf.add_n(reg_set)
print("loss=%s" %(sess.run(l2_loss)))
"""
此處輸出0.7,即:
   weight_decay*sigmal(w*2)/2=0.1*(0*0+1*1+2*2+3*3)/2=0.7
其實程式碼自己寫也很方便,用API看著比較正規。
在網路模型中,直接將l2_loss加入loss就好了。(loss變大,執行train自然會decay)
"""