1. 程式人生 > >tensorflow正則化新增方法整理

tensorflow正則化新增方法整理

一、基礎正則化函式

tf.contrib.layers.l1_regularizer(scale, scope=None)

返回一個用來執行L1正則化的函式,函式的簽名是func(weights)
引數:

  • scale: 正則項的係數.
  • scope: 可選的scope name

tf.contrib.layers.l2_regularizer(scale, scope=None)

先看看tf.contrib.layers.l2_regularizer(weight_decay)都執行了什麼:

?

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

import tensorflow as tf

sess=tf.Session()

weight_decay=0.1

tmp=tf.constant([0,1,2,3],dtype=tf.float32)

"""

l2_reg=tf.contrib.layers.l2_regularizer(weight_decay)

a=tf.get_variable("I_am_a",regularizer=l2_reg,initializer=tmp)

"""

#**上面程式碼的等價程式碼

a=tf.get_variable(

"I_am_a",initializer=tmp)

a2=tf.reduce_sum(a*a)*weight_decay/2;

a3=tf.get_variable(a.name.split(":")[0]+"/Regularizer/l2_regularizer",initializer=a2)

tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES,a2)

#**

sess.run(tf.global_variables_initializer())

keys = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)

for key in keys:

  print("%s : %s" %(key.name,sess.run(key)))

我們很容易可以模擬出tf.contrib.layers.l2_regularizer都做了什麼,不過會讓程式碼變醜。

以下比較完整實現L2 正則化。

?

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

import tensorflow as tf

sess=tf.Session()

weight_decay=0.1                                                #(1)定義weight_decay

l2_reg=tf.contrib.layers.l2_regularizer(weight_decay)           #(2)定義l2_regularizer()

tmp=tf.constant([0,1,2,3],dtype=tf.float32)

a=tf.get_variable("I_am_a",regularizer=l2_reg,initializer=tmp)  #(3)建立variable,l2_regularizer複製給regularizer引數。

                                                                #目測REXXX_LOSSES集合

#regularizer定義會將a加入REGULARIZATION_LOSSES集合

print("Global Set:")

keys = tf.get_collection("variables")

for key in keys:

  print(key.name)

print("Regular Set:")

keys = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)

for key in keys:

  print(key.name)

print("--------------------")

sess.run(tf.global_variables_initializer())

print(sess.run(a))

reg_set=tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)   #(4)則REGULARIAZTION_LOSSES集合會包含所有被weight_decay後的引數和,將其相加

l2_loss=tf.add_n(reg_set)

print("loss=%s" %(sess.run(l2_loss)))

"""

此處輸出0.7,即:

   weight_decay*sigmal(w*2)/2=0.1*(0*0+1*1+2*2+3*3)/2=0.7

其實程式碼自己寫也很方便,用API看著比較正規。

在網路模型中,直接將l2_loss加入loss就好了。(loss變大,執行train自然會decay)

"""

回到頂部

二、新增正則化方法

a、原始辦法

正則化常用到集合,下面是最原始的新增正則辦法(直接在變數聲明後將之新增進'losses'集合或tf.GraphKeys.LOESSES也行):

?

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

import tensorflow as tf

import numpy as np

 

def get_weights(shape, lambd):

 

    var = tf.Variable(tf.random_normal(shape), dtype=tf.float32)

    tf.add_to_collection('losses', tf.contrib.layers.l2_regularizer(lambd)(var))

    return var

 

 

x = tf.placeholder(tf.float32, shape=(None, 2))

y_ = tf.placeholder(tf.float32, shape=(None, 1))

batch_size = 8

layer_dimension = [2, 10, 10, 10, 1]

n_layers = len(layer_dimension)

cur_lay = x

in_dimension = layer_dimension[0]

 

for i in range(1, n_layers):

    out_dimension = layer_dimension[i]

    weights = get_weights([in_dimension, out_dimension], 0.001)

    bias = tf.Variable(tf.constant(0.1, shape=[out_dimension]))

    cur_lay = tf.nn.relu(tf.matmul(cur_lay, weights)+bias)

    in_dimension = layer_dimension[i]

 

mess_loss = tf.reduce_mean(tf.square(y_-cur_lay))

tf.add_to_collection('losses', mess_loss)

loss = tf.add_n(tf.get_collection('losses'))

b、tf.contrib.layers.apply_regularization(regularizer, weights_list=None)

先看引數

  • regularizer:就是我們上一步建立的正則化方法
  • weights_list: 想要執行正則化方法的引數列表,如果為None的話,就取GraphKeys.WEIGHTS中的weights.

函式返回一個標量Tensor,同時,這個標量Tensor也會儲存到GraphKeys.REGULARIZATION_LOSSES中.這個Tensor儲存了計算正則項損失的方法.

tensorflow中的Tensor是儲存了計算這個值的路徑(方法),當我們run的時候,tensorflow後端就通過路徑計算出Tensor對應的值

現在,我們只需將這個正則項損失加到我們的損失函式上就可以了.

如果是自己手動定義weight的話,需要手動將weight儲存到GraphKeys.WEIGHTS中,但是如果使用layer的話,就不用這麼麻煩了,別人已經幫你考慮好了.(最好自己驗證一下tf.GraphKeys.WEIGHTS中是否包含了所有的weights,防止被坑)

c、使用slim

使用slim會簡單很多:

?

1

2

3

4

with slim.arg_scope([slim.conv2d, slim.fully_connected],

                           activation_fn=tf.nn.relu,

                           weights_regularizer=slim.l2_regularizer(weight_decay)):

   pass

此時新增集合為tf.GraphKeys.REGULARIZATION_LOSSES。