神經網路優化演算法二(正則化、滑動平均模型)
1、神經網路進一步優化——過擬合與正則化
過擬合,指的是當一個模型過為複雜後,它可以很好的“記憶”每一個訓練資料中隨機噪音的部分而忘了要去“學習”訓練資料中通用的趨勢。舉一個極端的例子,如果一個模型中的引數比訓練資料的總數還多,那麼只要訓練資料不衝突,這個模型完全可以記住所有訓練資料的結果從而使得損失函式為0。
為了避免過擬合問題,一個非常常用的方法是正則化。
正則化:就是在損失函式中給每個引數加上權重,引入模型複雜度指標,從而抑制模型的噪聲,減少過擬合。使用正則化後,損失函式loss變為兩項之和,假設用於刻畫模型在訓練資料上表現的損失函式為 J(θ),那麼在優化時不是直接優化J(θ),
而是優化
其中 R(w) 刻畫的是模型的複雜度,而 λ 表示模型複雜損失在總損失中的比例。一般來說模型複雜度只由權重 w 決定。常用的刻畫模型複雜度的函式 R(w) 有兩種。
1、一種是 L1 正則化,計算公式是
2、另一種是 L2 正則化,計算公式是
無論是哪一種正則化方式,基本思想都是希望通過限制權重的大小,使得模型不能任意擬合訓練資料中的隨機噪音。但這兩種正則化方式有很大的區別:
- L1 正則化會讓引數變得更稀疏,而 L2 正則化不會。所謂引數更稀疏就是會有更多的引數變為0。
- L1 正則化不可導,L2 正則化可導。所以優化 L2 正則化損失函式更簡潔,優化 L1 正則化損失函式更復雜。
w = tf.Variable(tf.random_normal([2,1],stddev = 1,seed = 1))
y = tf.matmul(x,w)
loss = tf.reduce_mean(tf.square(y_ - y)) +
tf.contrib.layers.l2_regularizer(lambda)(w)
在上述程式碼中,loss 為定義的損失函式,它由兩部分組成。第一部分是前面介紹的均方差函式,它刻畫了模型在訓練資料上的表現。第二部分就是 L2 正則化。
weights = tf.constant([[1.0,-2.0],[-3.0,4.0]])
with tf.Session() as sess:
#輸出為 (|1|+|-2|+|-3|+|4|) * 0.5 = 5 其中 0.5 為正則化項的權重
print sess.run(tf.contrib.layers.l1_regularizer(0.5)(weights))
#輸出為 (1^2 + (-2)^2 + (-3)^2 + (4)^2) /2 * 0.5 = 7.5
print sess.run(tf.contrib.layers.l2_regularizer(0.5)(weights))
以上程式碼顯示了 L1 正則化和 L2 正則化的計算差別。但當神經網路的引數增多後,這樣的方式首先會導致損失函式 loss 的定義很長,可讀性差且容易出錯。但更為主要的是,當網路結構複雜化之後定義網路結構的部分和計算損失函式的部分可能不在一個函式中,這樣通過變數這種方式計算損失函式就不方便了。為了解決這個問題,可以利用TensorFlow中提供的集合,以下程式碼給出了通過集合計算一個 5 層神經網路帶 L2 正則化的損失函式的計算方法。
import tensorflow as tf
#獲取一層神經網路邊上的權重,並將這個權重的 L2 正則化損失加入名稱為 'losses' 的集合中
def get_weight(shape,lambda):
#生成一個變數
var = tf.Variable(tf.random_normal(shape),dtype = tf.float32)
# add_to_collection 函式將這個新生成變數的 L2 正則化損失加入集合
# 這個函式的第一個引數 'losses' 是集合的名字,第二個引數是要加入集合的內容
tf.add_to_collection('losses',tf.contrib.layers.l2_regularizer(lambda)(var))
return var
x = tf.placeholder(tf.float32,shape = (None,2))
y_ = tf.placeholder(tf.float32,shape = (None,1))
batch_size = 8
#定義了每一層網路節點中的個數
layer_dimension = [2,10,10,10,1]
#神經網路的層數
n_layers = len(layer_dimension)
#這個變數維護前向傳播時最深層的節點,開始的時候是輸入層
cur_layer = x
#當前層的節點個數
in_dimension = layer_dimension[0]
#通過 for 迴圈來生成 5 層全連線神經網路
for i in range(1,n_layers):
out_dimension = layer_dimension[i] #下一層節點個數
#生成當前層中權重的變數,並將這個變數的 L2 正則化損失加入計算圖上的集合
weight = get_weight([in_dimension,out_dimension],0.001)
bias = tf.Variable(tf.constant(0.1,shape = [out_dimension]))
#使用relu 啟用函式
cur_layer = tf.nn.relu(tf.matmul(cur_layer,weight) + bias)
#進入下一層之前將下一層的節點個數更新為當前節點個數
in_dimension = layer_dimension[i]
#定義神經網路前向傳播的同時已經將所有的 L2 正則化損失加入了圖上的集合
#這裡只需要計算刻畫模型在資料上表現的損失函式
mse_loss = tf.reduce_mean(tf.square(y_ - cur_layer))
#將均方差損失函式加入集合
tf.add_to_collection('losses',mse_loss)
# get_collection 返回一個列表,這個列表是所有的這個集合中的元素。
# 在這個樣例中,這些元素就是損失函式的不同部分,將它們加起來就可以得到最終的損失函式
loss = tf.add_n(tf.get_collection('losses'))
第一步:
#coding:utf-8
#匯入模組,生成資料集
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
BATCH_SIZE=30 #一次喂入神經網路的30組資料
seed=2
#基於seed產生隨機數
rdm=np.random.RandomState(seed)
#隨機數返回300行2列的矩陣,表示300組座標點(x0,x1)作為輸入資料集
X=rdm.randn(300,2)
#從X這個300行2列的矩陣中取出1行,判斷如果2個座標的平方和小於2,給Y賦值為1,其中賦值為0
#作為資料集的正確答案(標籤)
Y_=[int(x0*x0+x1*x1<2) for (x0,x1) in X ]
#遍歷Y_中的每個元素,1賦值為red,2賦值為blue,這樣視覺化顯示時人可以直觀的區分
Y_c=[['red' if y else 'blue'] for y in Y_]
#對資料集X和標籤Y進行shape整理,第一元素為-1表示,隨第二個引數計算得到,第二個元素表示多少列,把X整理為n行2列,把Y整理為n行1列
X=np.vstack(X).reshape(-1,2)
Y=np.vstack(Y_).reshape(-1,1)
print X
print Y
print Y_c
#用plt.scatter畫出資料集X和各行中的第0列元素和第1列元素的點即各行的(x0,x1),用各行Y_c對應的值表示顏色(c是color的縮寫)
plt.scatter(X[:,0],X[:,1],c=np.squeeze(Y_c))
plt.show()
執行結果:
(tf1.5) [email protected]:~/tf/tf4$ python opt4.py
[[ -4.16757847e-01 -5.62668272e-02]
[ -2.13619610e+00 1.64027081e+00]
[ -1.79343559e+00 -8.41747366e-01]
[ 5.02881417e-01 -1.24528809e+00]
[ -1.05795222e+00 -9.09007615e-01]
[ 5.51454045e-01 2.29220801e+00]
[ 4.15393930e-02 -1.11792545e+00]
[ 5.39058321e-01 -5.96159700e-01]
[ -1.91304965e-02 1.17500122e+00]
[ -7.47870949e-01 9.02525097e-03]
[ -8.78107893e-01 -1.56434170e-01]
[ 2.56570452e-01 -9.88779049e-01]
[ -3.38821966e-01 -2.36184031e-01]
[ -6.37655012e-01 -1.18761229e+00]
[ -1.42121723e+00 -1.53495196e-01]
[ -2.69056960e-01 2.23136679e+00]
[ -2.43476758e+00 1.12726505e-01]
[ 3.70444537e-01 1.35963386e+00]
[ 5.01857207e-01 -8.44213704e-01]
[ 9.76147160e-06 5.42352572e-01]
[ -3.13508197e-01 7.71011738e-01]
[ -1.86809065e+00 1.73118467e+00]
[ 1.46767801e+00 -3.35677339e-01]
[ 6.11340780e-01 4.79705919e-02]
[ -8.29135289e-01 8.77102184e-02]
[ 1.00036589e+00 -3.81092518e-01]
[ -3.75669423e-01 -7.44707629e-02]
[ 4.33496330e-01 1.27837923e+00]
[ -6.34679305e-01 5.08396243e-01]
[ 2.16116006e-01 -1.85861239e+00]
[ -4.19316482e-01 -1.32328898e-01]
[ -3.95702397e-02 3.26003433e-01]
[ -2.04032305e+00 4.62555231e-02]
[ -6.77675577e-01 -1.43943903e+00]
[ 5.24296430e-01 7.35279576e-01]
[ -6.53250268e-01 8.42456282e-01]
[ -3.81516482e-01 6.64890091e-02]
[ -1.09873895e+00 1.58448706e+00]
[ -2.65944946e+00 -9.14526229e-02]
[ 6.95119605e-01 -2.03346655e+00]
[ -1.89469265e-01 -7.72186654e-02]
[ 8.24703005e-01 1.24821292e+00]
[ -4.03892269e-01 -1.38451867e+00]
[ 1.36723542e+00 1.21788563e+00]
[ -4.62005348e-01 3.50888494e-01]
[ 3.81866234e-01 5.66275441e-01]
[ 2.04207979e-01 1.40669624e+00]
[ -1.73795950e+00 1.04082395e+00]
[ 3.80471970e-01 -2.17135269e-01]
[ 1.17353150e+00 -2.34360319e+00]
[ 1.16152149e+00 3.86078048e-01]
[ -1.13313327e+00 4.33092555e-01]
[ -3.04086439e-01 2.58529487e+00]
[ 1.83533272e+00 4.40689872e-01]
[ -7.19253841e-01 -5.83414595e-01]
[ -3.25049628e-01 -5.60234506e-01]
[ -9.02246068e-01 -5.90972275e-01]
[ -2.76179492e-01 -5.16883894e-01]
[ -6.98589950e-01 -9.28891925e-01]
[ 2.55043824e+00 -1.47317325e+00]
[ -1.02141473e+00 4.32395701e-01]
[ -3.23580070e-01 4.23824708e-01]
[ 7.99179995e-01 1.26261366e+00]
[ 7.51964849e-01 -9.93760983e-01]
[ 1.10914328e+00 -1.76491773e+00]
[ -1.14421297e-01 -4.98174194e-01]
[ -1.06079904e+00 5.91666521e-01]
[ -1.83256574e-01 1.01985473e+00]
[ -1.48246548e+00 8.46311892e-01]
[ 4.97940148e-01 1.26504175e-01]
[ -1.41881055e+00 -2.51774118e-01]
[ -1.54667461e+00 -2.08265194e+00]
[ 3.27974540e+00 9.70861320e-01]
[ 1.79259285e+00 -4.29013319e-01]
[ 6.96197980e-01 6.97416272e-01]
[ 6.01515814e-01 3.65949071e-03]
[ -2.28247558e-01 -2.06961226e+00]
[ 6.10144086e-01 4.23496900e-01]
[ 1.11788673e+00 -2.74242089e-01]
[ 1.74181219e+00 -4.47500876e-01]
[ -1.25542722e+00 9.38163671e-01]
[ -4.68346260e-01 -1.25472031e+00]
[ 1.24823646e-01 7.56502143e-01]
[ 2.41439629e-01 4.97425649e-01]
[ 4.10869262e+00 8.21120877e-01]
[ 1.53176032e+00 -1.98584577e+00]
[ 3.65053516e-01 7.74082033e-01]
[ -3.64479092e-01 -8.75979478e-01]
[ 3.96520159e-01 -3.14617436e-01]
[ -5.93755583e-01 1.14950057e+00]
[ 1.33556617e+00 3.02629336e-01]
[ -4.54227855e-01 5.14370717e-01]
[ 8.29458431e-01 6.30621967e-01]
[ -1.45336435e+00 -3.38017777e-01]
[ 3.59133332e-01 6.22220414e-01]
[ 9.60781945e-01 7.58370347e-01]
[ -1.13431848e+00 -7.07420888e-01]
[ -1.22142917e+00 1.80447664e+00]
[ 1.80409807e-01 5.53164274e-01]
[ 1.03302907e+00 -3.29002435e-01]
[ -1.15100294e+00 -4.26522471e-01]
[ -1.48147191e-01 1.50143692e+00]
[ 8.69598198e-01 -1.08709057e+00]
[ 6.64221413e-01 7.34884668e-01]
[ -1.06136574e+00 -1.08516824e-01]
[ -1.85040397e+00 3.30488064e-01]
[ -3.15693210e-01 -1.35000210e+00]
[ -6.98170998e-01 2.39951198e-01]
[ -5.52949440e-01 2.99526813e-01]
[ 5.52663696e-01 -8.40443012e-01]
[ -3.12270670e-01 2.14467809e+00]
[ 1.21105582e-01 -8.46828752e-01]
[ 6.04624490e-02 -1.33858888e+00]
[ 1.13274608e+00 3.70304843e-01]
[ 1.08580640e+00 9.02179395e-01]
[ 3.90296450e-01 9.75509412e-01]
[ 1.91573647e-01 -6.62209012e-01]
[ -1.02351498e+00 -4.48174823e-01]
[ -2.50545813e+00 1.82599446e+00]
[ -1.71406741e+00 -7.66395640e-02]
[ -1.31756727e+00 -2.02559359e+00]
[ -8.22453750e-02 -3.04666585e-01]
[ -1.59724130e-01 5.48946560e-01]
[ -6.18375485e-01 3.78794466e-01]
[ 5.13251444e-01 -3.34844125e-01]
[ -2.83519516e-01 5.38424263e-01]
[ 5.72509465e-02 1.59088487e-01]
[ -2.37440268e+00 5.85199353e-02]
[ 3.76545911e-01 -1.35479764e-01]
[ 3.35908395e-01 1.90437591e+00]
[ 8.53644334e-02 6.65334278e-01]
[ -8.49995503e-01 -8.52341797e-01]
[ -4.79985112e-01 -1.01964910e+00]
[ -7.60113841e-03 -9.33830661e-01]
[ -1.74996844e-01 -1.43714343e+00]
[ -1.65220029e+00 -6.75661789e-01]
[ -1.06706712e+00 -6.52931145e-01]
[ -6.12094750e-01 -3.51262461e-01]
[ 1.04547799e+00 1.36901602e+00]
[ 7.25353259e-01 -3.59474459e-01]
[ 1.49695179e+00 -1.53111111e+00]
[ -2.02336394e+00 2.67972576e-01]
[ -2.20644541e-03 -1.39291883e-01]
[ 3.25654693e-02 -1.64056022e+00]
[ -1.15669917e+00 1.23403468e+00]
[ 1.02818490e+00 -7.21879726e-01]
[ 1.93315697e+00 -1.07079633e+00]
[ -5.71381608e-01 2.92432067e-01]
[ -1.19499989e+00 -4.87930544e-01]
[ -1.73071165e-01 -3.95346401e-01]
[ 8.70840765e-01 5.92806797e-01]
[ -1.09929731e+00 -6.81530644e-01]
[ 1.80066685e-01 -6.69310440e-02]
[ -7.87749540e-01 4.24753672e-01]
[ 8.19885117e-01 -6.31118683e-01]
[ 7.89059649e-01 -1.62167380e+00]
[ -1.61049926e+00 4.99939764e-01]
[ -8.34515207e-01 -9.96959687e-01]
[ -2.63388077e-01 -6.77360492e-01]
[ 3.27067038e-01 -1.45535944e+00]
[ -3.71519124e-01 3.16096597e+00]
[ 1.09951013e-01 -1.91352322e+00]
[ 5.99820429e-01 5.49384465e-01]
[ 1.38378103e+00 1.48349243e-01]
[ -6.53541444e-01 1.40883398e+00]
[ 7.12061227e-01 -1.80071604e+00]
[ 7.47598942e-01 -2.32897001e-01]
[ 1.11064528e+00 -3.73338813e-01]
[ 7.86146070e-01 1.94168696e-01]
[ 5.86204098e-01 -2.03872918e-02]
[ -4.14408598e-01 6.73134124e-02]
[ 6.31798924e-01 4.17592731e-01]
[ 1.61517627e+00 4.25606211e-01]
[ 6.35363758e-01 2.10222927e+00]
[ 6.61264168e-02 5.35558351e-01]
[ -6.03140792e-01 4.19576292e-02]
[ 1.64191464e+00 3.11697707e-01]
[ 1.45116990e+00 -1.06492788e+00]
[ -1.40084545e+00 3.07525527e-01]
[ -1.36963867e+00 2.67033724e+00]
[ 1.24845030e+00 -1.24572655e+00]
[ -1.67168774e-01 -5.76610930e-01]
[ 4.16021749e-01 -5.78472626e-02]
[ 9.31887358e-01 1.46833213e+00]
[ -2.21320943e-01 -1.17315562e+00]
[ 5.62669078e-01 -1.64515057e-01]
[ 1.14485538e+00 -1.52117687e-01]
[ 8.29789046e-01 3.36065952e-01]
[ -1.89044051e-01 -4.49328601e-01]
[ 7.13524448e-01 2.52973487e+00]
[ 8.37615794e-01 -1.31682403e-01]
[ 7.07592866e-01 1.14053878e-01]
[ -1.28089518e+00 3.09846277e-01]
[ 1.54829069e+00 -3.15828043e-01]
[ -1.12590378e+00 4.88496666e-01]
[ 1.83094666e+00 9.40175993e-01]
[ 1.01871705e+00 2.30237829e+00]
[ 1.62109298e+00 7.12683273e-01]
[ -2.08703629e-01 1.37617991e-01]
[ -1.03352168e-01 8.48350567e-01]
[ -8.83125561e-01 1.54538683e+00]
[ 1.45840073e-01 -4.00106056e-01]
[ 8.15206041e-01 -2.07492237e+00]
[ -8.34437391e-01 -6.57718447e-01]
[ 8.20564332e-01 -4.89157001e-01]
[ 1.42496703e+00 -4.46857897e-01]
[ 5.21109431e-01 -7.08194380e-01]
[ 1.15553059e+00 -2.54530459e-01]
[ 5.18924924e-01 -4.92994911e-01]
[ -1.08654815e+00 -2.30917497e-01]
[ 1.09801004e+00 -1.01787805e+00]
[ -1.52939136e+00 -3.07987737e-01]
[ 7.80754356e-01 -1.05583964e+00]
[ -5.43883381e-01 1.84301739e-01]
[ -3.30675843e-01 2.87208202e-01]
[ 1.18952814e+00 2.12015479e-02]
[ -6.54096803e-02 7.66115904e-01]
[ -6.16350846e-02 -9.52897152e-01]
[ -1.01446306e+00 -1.11526396e+00]
[ 1.91260068e+00 -4.52632031e-02]
[ 5.76909718e-01 7.17805695e-01]
[ -9.38998998e-01 6.28775807e-01]
[ -5.64493432e-01 -2.08780746e+00]
[ -2.15050132e-01 -1.07502856e+00]
[ -3.37972149e-01 3.43212732e-01]
[ 2.28253964e+00 -4.95778848e-01]
[ -1.63962832e-01 3.71622161e-01]
[ 1.86521520e-01 -1.58429224e-01]
[ -1.08292956e+00 -9.56625520e-01]
[ -1.83376735e-01 -1.15980690e+00]
[ -6.57768362e-01 -1.25144841e+00]
[ 1.12448286e+00 -1.49783981e+00]
[ 1.90201722e+00 -5.80383038e-01]
[ -1.05491567e+00 -1.18275720e+00]
[ 7.79480054e-01 1.02659795e+00]
[ -8.48666001e-01 3.31539648e-01]
[ -1.49591353e-01 -2.42440600e-01]
[ 1.51197175e-01 7.65069481e-01]
[ -1.91663052e+00 -2.22734129e+00]
[ 2.06689897e-01 -7.08763560e-02]
[ 6.84759969e-01 -1.70753905e+00]
[ -9.86569665e-01 1.54353634e+00]
[ -1.31027053e+00 3.63433972e-01]
[ -7.94872445e-01 -4.05286267e-01]
[ -1.37775793e+00 1.18604868e+00]
[ -1.90382114e+00 -1.19814038e+00]
[ -9.10065643e-01 1.17645419e+00]
[ 2.99210670e-01 6.79267178e-01]
[ -1.76606800e-02 2.36040923e-01]
[ 4.94035871e-01 1.54627765e+00]
[ 2.46857508e-01 -1.46877580e+00]
[ 1.14709994e+00 9.55569845e-02]
[ -1.10743873e+00 -1.76286141e-01]
[ -9.82755667e-01 2.08668273e+00]
[ -3.44623671e-01 -2.00207923e+00]
[ 3.03234433e-01 -8.29874845e-01]
[ 1.28876941e+00 1.34925462e-01]
[ -1.77860064e+00 -5.00791490e-01]
[ -1.08816157e+00 -7.57855553e-01]
[ -6.43744900e-01 -2.00878453e+00]
[ 1.96262894e-01 -8.75896370e-01]
[ -8.93609209e-01 7.51902355e-01]
[ 1.89693224e+00 -6.29079151e-01]
[ 1.81208553e+00 -2.05626574e+00]
[ 5.62704887e-01 -5.82070757e-01]
[ -7.40029749e-02 -9.86496364e-01]
[ -5.94722499e-01 -3.14811843e-01]
[ -3.46940532e-01 4.11443516e-01]
[ 2.32639090e+00 -6.34053128e-01]
[ -1.54409962e-01 -1.74928880e+00]
[ -2.51957930e+00 1.39116243e+00]
[ -1.32934644e+00 -7.45596414e-01]
[ 2.12608498e-02 9.10917515e-01]
[ 3.15276082e-01 1.86620821e+00]
[ -1.82497623e-01 -1.82826634e+00]
[ 1.38955717e-01 1.19450165e-01]
[ -8.18899200e-01 -3.32639265e-01]
[ -5.86387955e-01 1.73451634e+00]
[ -6.12751558e-01 -1.39344202e+00]
[ 2.79433757e-01 -1.82223127e+00]
[ 4.27017458e-01 4.06987749e-01]
[ -8.44308241e-01 -5.59820113e-01]
[ -6.00520405e-01 1.61487324e+00]
[ 3.94953220e-01 -1.20381347e+00]
[ -1.24747243e+00 -7.75462496e-02]
[ -1.33397514e-02 -7.68323250e-01]
[ 2.91234010e-01 -1.97330948e-01]
[ 1.07682965e+00 4.37410232e-01]
[ -9.31978663e-02 1.35631416e-01]
[ -8.82708822e-01 8.84744194e-01]
[ 3.83204463e-01 -4.16994149e-01]
[ 1.17796550e-01 -5.36685309e-01]
[ 2.48718458e+00 -4.51361054e-01]
[ 5.18836127e-01 3.64448005e-01]
[ -7.98348729e-01 5.65779713e-03]
[ -3.20934708e-01 2.49513550e-01]
[ 2.56308392e-01 7.67625083e-01]
[ 7.83020087e-01 -4.07063047e-01]
[ -5.24891667e-01 -5.89808683e-01]
[ -8.62531086e-01 -1.74287290e+00]]
[[1]
[0]
[0]
[1]
[1]
[0]
[1]
[1]
[1]
[1]
[1]
[1]
[1]
[1]
[0]
[0]
[0]
[1]
[1]
[1]
[1]
[0]
[0]
[1]
[1]
[1]
[1]
[1]
[1]
[0]
[1]
[1]
[0]
[0]
[1]
[1]
[1]
[0]
[0]
[0]
[1]
[0]
[0]
[0]
[1]
[1]
[0]
[0]
[1]
[0]
[1]
[1]
[0]
[0]
[1]
[1]
[1]
[1]
[1]
[0]
[1]
[1]
[0]
[1]
[0]
[1]
[1]
[1]
[0]
[1]
[0]
[0]
[0]
[0]
[1]
[1]
[0]
[1]
[1]
[0]
[0]
[1]
[1]
[1]
[0]
[0]
[1]
[1]
[1]
[1]
[1]
[1]
[1]
[0]
[1]
[1]
[1]
[0]
[1]
[1]
[1]
[0]
[1]
[1]
[1]
[0]
[1]
[1]
[1]
[1]
[0]
[1]
[1]
[1]
[1]
[1]
[1]
[1]
[0]
[0]
[0]
[1]
[1]
[1]
[1]
[1]
[1]
[0]
[1]
[0]
[1]
[1]
[1]
[1]
[0]
[0]
[1]
[1]
[0]
[1]
[0]
[0]
[1]
[0]
[0]
[1]
[0]
[1]
[1]
[1]
[1]
[1]
[1]
[1]
[1]
[0]
[0]
[1]
[1]
[0]
[0]
[0]
[1]
[1]
[0]
[0]
[1]
[1]
[1]
[1]
[1]
[1]
[0]
[0]
[1]
[1]
[0]
[0]
[0]
[0]
[0]
[1]
[1]
[0]
[1]
[1]
[1]
[1]
[1]
[0]
[1]
[1]
[1]
[0]
[1]
[0]
[0]
[0]
[1]
[1]
[0]
[1]
[0]
[1]
[1]
[0]
[1]
[1]
[1]
[1]
[0]
[0]
[1]
[1]
[1]
[1]
[1]
[1]
[0]
[0]
[1]
[1]
[0]
[1]
[1]
[0]
[1]
[1]
[0]
[1]
[1]
[0]
[0]
[0]
[1]
[1]
[1]
[1]
[0]
[1]
[0]
[0]
[1]
[1]
[0]
[0]
[0]
[1]
[1]
[0]
[0]
[1]
[1]
[0]
[0]
[1]
[1]
[0]
[1]
[0]
[1]
[1]
[0]
[0]
[1]
[1]
[1]
[1]
[0]
[0]
[0]
[0]
[1]
[0]
[0]
[1]
[1]
[0]
[0]
[0]
[1]
[1]
[0]
[1]
[1]
[1]
[1]
[1]
[1]
[1]
[1]
[1]
[0]
[1]
[1]
[1]
[1]
[1]
[1]
[0]]
[['red'], ['blue'], ['blue'], ['red'], ['red'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['blue'], ['blue'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['blue'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['blue'], ['red'], ['red'], ['blue'], ['blue'], ['red'], ['red'], ['red'], ['blue'], ['blue'], ['blue'], ['red'], ['blue'], ['blue'], ['blue'], ['red'], ['red'], ['blue'], ['blue'], ['red'], ['blue'], ['red'], ['red'], ['blue'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['red'], ['blue'], ['red'], ['red'], ['blue'], ['red'], ['blue'], ['red'], ['red'], ['red'], ['blue'], ['red'], ['blue'], ['blue'], ['blue'], ['blue'], ['red'], ['red'], ['blue'], ['red'], ['red'], ['blue'], ['blue'], ['red'], ['red'], ['red'], ['blue'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['blue'], ['red'], ['red'], ['red'], ['blue'], ['red'], ['red'], ['red'], ['blue'], ['red'], ['red'], ['red'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['blue'], ['blue'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['blue'], ['red'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['blue'], ['blue'], ['red'], ['red'], ['blue'], ['red'], ['blue'], ['blue'], ['red'], ['blue'], ['blue'], ['red'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['blue'], ['blue'], ['red'], ['red'], ['blue'], ['blue'], ['blue'], ['red'], ['red'], ['blue'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['blue'], ['blue'], ['red'], ['red'], ['blue'], ['blue'], ['blue'], ['blue'], ['blue'], ['red'], ['red'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['red'], ['blue'], ['red'], ['red'], ['red'], ['blue'], ['red'], ['blue'], ['blue'], ['blue'], ['red'], ['red'], ['blue'], ['red'], ['blue'], ['red'], ['red'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['blue'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['blue'], ['blue'], ['red'], ['red'], ['blue'], ['red'], ['red'], ['blue'], ['red'], ['red'], ['blue'], ['red'], ['red'], ['blue'], ['blue'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['blue'], ['red'], ['blue'], ['blue'], ['red'], ['red'], ['blue'], ['blue'], ['blue'], ['red'], ['red'], ['blue'], ['blue'], ['red'], ['red'], ['blue'], ['blue'], ['red'], ['red'], ['blue'], ['red'], ['blue'], ['red'], ['red'], ['blue'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['blue'], ['blue'], ['blue'], ['blue'], ['red'], ['blue'], ['blue'], ['red'], ['red'], ['blue'], ['blue'], ['blue'], ['red'], ['red'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['blue']]
圖片顯示:
解釋:
plt.scatter():利用指定顏色實現點(x,y)的視覺化
plt.scatter(x座標,y座標,c="顏色")其中c是color的縮寫。
第二步:完整程式碼
#coding:utf-8
#0匯入模組 ,生成模擬資料集
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
BATCH_SIZE = 30
seed = 2
#基於seed產生隨機數
rdm = np.random.RandomState(seed)
#隨機數返回300行2列的矩陣,表示300組座標點(x0,x1)作為輸入資料集
X = rdm.randn(300,2)
#從X這個300行2列的矩陣中取出一行,判斷如果兩個座標的平方和小於2,給Y賦值1,其餘賦值0
#作為輸入資料集的標籤(正確答案)
Y_ = [int(x0*x0 + x1*x1 <2) for (x0,x1) in X]
#遍歷Y中的每個元素,1賦值'red'其餘賦值'blue',這樣視覺化顯示時人可以直觀區分
Y_c = [['red' if y else 'blue'] for y in Y_]
#對資料集X和標籤Y進行shape整理,第一個元素為-1表示,隨第二個引數計算得到,第二個元素表示多少列,把X整理為n行2列,把Y整理為n行1列
X = np.vstack(X).reshape(-1,2)
Y_ = np.vstack(Y_).reshape(-1,1)
print X
print Y_
print Y_c
#用plt.scatter畫出資料集X各行中第0列元素和第1列元素的點即各行的(x0,x1),用各行Y_c對應的值表示顏色(c是color的縮寫)
plt.scatter(X[:,0], X[:,1], c=np.squeeze(Y_c))
plt.show()
#定義神經網路的輸入、引數和輸出,定義前向傳播過程
def get_weight(shape, regularizer):
w = tf.Variable(tf.random_normal(shape), dtype=tf.float32)
tf.add_to_collection('losses', tf.contrib.layers.l2_regularizer(regularizer)(w))
return w
def get_bias(shape):
b = tf.Variable(tf.constant(0.01, shape=shape))
return b
x = tf.placeholder(tf.float32, shape=(None, 2))
y_ = tf.placeholder(tf.float32, shape=(None, 1))
w1 = get_weight([2,11], 0.01)
b1 = get_bias([11])
y1 = tf.nn.relu(tf.matmul(x, w1)+b1)
w2 = get_weight([11,1], 0.01)
b2 = get_bias([1])
y = tf.matmul(y1, w2)+b2
#定義損失函式
loss_mse = tf.reduce_mean(tf.square(y-y_))
loss_total = loss_mse + tf.add_n(tf.get_collection('losses'))
#定義反向傳播方法:不含正則化
train_step = tf.train.AdamOptimizer(0.0001).minimize(loss_mse)
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
STEPS = 40000
for i in range(STEPS):
start = (i*BATCH_SIZE) % 300
end = start + BATCH_SIZE
sess.run(train_step, feed_dict={x:X[start:end], y_:Y_[start:end]})
if i % 2000 == 0:
loss_mse_v = sess.run(loss_mse, feed_dict={x:X, y_:Y_})
print("After %d steps, loss is: %f" %(i, loss_mse_v))
#xx在-3到3之間以步長為0.01,yy在-3到3之間以步長0.01,生成二維網格座標點
xx, yy = np.mgrid[-3:3:.01, -3:3:.01]
#將xx , yy拉直,併合併成一個2列的矩陣,得到一個網格座標點的集合
grid = np.c_[xx.ravel(), yy.ravel()]
#將網格座標點喂入神經網路 ,probs為輸出
probs = sess.run(y, feed_dict={x:grid})
#probs的shape調整成xx的樣子
probs = probs.reshape(xx.shape)
print "w1:\n",sess.run(w1)
print "b1:\n",sess.run(b1)
print "w2:\n",sess.run(w2)
print "b2:\n",sess.run(b2)
plt.scatter(X[:,0], X[:,1], c=np.squeeze(Y_c))
plt.contour(xx, yy, probs, levels=[.5])
plt.show()
#定義反向傳播方法:包含正則化
train_step = tf.train.AdamOptimizer(0.0001).minimize(loss_total)
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
STEPS = 40000
for i in range(STEPS):
start = (i*BATCH_SIZE) % 300
end = start + BATCH_SIZE
sess.run(train_step, feed_dict={x: X[start:end], y_:Y_[start:end]})
if i % 2000 == 0:
loss_v = sess.run(loss_total, feed_dict={x:X,y_:Y_})
print("After %d steps, loss is: %f" %(i, loss_v))
xx, yy = np.mgrid[-3:3:.01, -3:3:.01]
grid = np.c_[xx.ravel(), yy.ravel()]
probs = sess.run(y, feed_dict={x:grid})
probs = probs.reshape(xx.shape)
print "w1:\n",sess.run(w1)
print "b1:\n",sess.run(b1)
print "w2:\n",sess.run(w2)
print "b2:\n",sess.run(b2)
plt.scatter(X[:,0], X[:,1], c=np.squeeze(Y_c))
plt.contour(xx, yy, probs, levels=[.5])
plt.show()
執行結果1:無正則化的
執行結果2:有正則化的
注意:
tf.add_to_collection(name, value) 用來把一個value放入名稱是‘name’的集合,組成一個列表;
f.get_collection(key, scope=None) 用來獲取一個名稱是‘key’的集合中的所有元素,
返回的是一個列表,列表的順序是按照變數放入集合中的先後; scope引數可選,表示的是
名稱空間(名稱域),如果指定,就返回名稱域中所有放入‘key’的變數的列表,不指定則返回所有變數。
2、神經網路進一步優化 —— 滑動平均模型
另一個可以使模型在測試資料上更健壯的方法 -- 滑動平均模型。在採用隨機梯度下降演算法訓練神經網路時,使用滑動平均模型在很多應用中都可以在一定程度上提高最終模型在測試資料上的表現。
在TensorFlow中提供了 tf.train.ExponentialMovingAverage 來實現滑動平均模型。在初始化時,需要提供一個衰減率(decay)。這個衰減率將用於控制模型更新的速度。滑動平均對每一個變數會維護一個影子變數,這個影子變數的初始值就是相應變數的初始值,而每次更新變數時,影子變數的值會更新成:
(decay 為衰減率)
從公式可以看到,decay 決定了模型更新的速度,decay 越大則模型越趨於穩定。在實際應用中,decay 一般取非常接近 1 的數,比如 0.99 或 0.999。 為了使得模型在訓練前期可以更新的更快,滑動平均還提供了 num_updates 引數來動態設定 decay 的大小。如果在滑動平均初始化時提供了 num_updates 引數,那麼每次使用的衰減率將是
import tensorflow as tf
#定義一個變數用於計算滑動平均,這個變數的初始值為 0
v1 = tf.Variable(0,dtype = tf.float32)
#這裡 step 變數模擬神經網路中迭代的輪數,可以用於動態控制衰減率
step = tf.Variable(0,trainable = False)
#定義一個滑動平均類。初始化時給定了衰減率 0.99 和控制衰減率的變數 step
ema = tf.train.ExponentialMovingAverage(0.99)(step)
#定義一個更新變數滑動平均的操作,這裡需要給定一個列表,每次執行這個操作時,這個列表中的變數都會更新
maintain_averages_op = ema.apply([v1])
with tf.Session() as sess:
#初始化所有變數
tf.initialize_all_variables().run()
#通過 ema.average(v1) 獲取滑動平均之後的變數的取值。在初始化之後變數 v1 的值和 v1的滑動平均都是 0
print sess.run([v1,ema.average(v1)])
#更新 v1 的值到 5
sess.run(tf.assign(v1,5))
#更新 v1 的滑動平均值,衰減率為 min{0.99,(1+step)/(10+step) = 0.1} = 0.1
sess.run(maintain_averages_op)
print sess.run([v1,ema.average(v1)]) #輸出 [5.0,4.5]
#更新 step 的值為 10000
sess.run(tf.assign(step,10000))
#更新 v1 的值為 10
sess.run(tf.assign(v1,10))
#更新 v1 的滑動平均值,衰減率為 min{0.99,(1+step)/(10+step)} = 0.99
#所以 v1 的滑動平均會被更新為 0.99*4.5 + 0.01*10 = 4.555
sess.run(maintain_averages_op)
print sess.run([v1,ema.average(v1)]) #輸出 [10.0,4.5549998]
#再次更新滑動平均,得到的新滑動平均值為 0.99*4.555 + 0.01*10 = 4.60945
sess.run(maintain_averages_op)
print sess.run([v1,ema.average(v1)]) #輸出 [10.0,4.6094499]
通過上述程式碼可知,滑動平均模型是一個使得訓練在基於後期時趨於穩定的一個模型。
部分參考了作者:https://blog.csdn.net/qq_32023541/article/details/79607000