1. 程式人生 > >TensorFlow——學習率衰減的使用方法

TensorFlow——學習率衰減的使用方法

在TensorFlow的優化器中, 都要設定學習率。學習率是在精度和速度之間找到一個平衡:

學習率太大,訓練的速度會有提升,但是結果的精度不夠,而且還可能導致不能收斂出現震盪的情況。

學習率太小,精度會有所提升,但是訓練的速度慢,耗費較多的時間。

因而我們可以使用退化學習率,又稱為衰減學習率。它的作用是在訓練的過程中,對學習率的值進行衰減,訓練到達一定程度後,使用小的學習率來提高精度。

在TensorFlow中的方法如下:tf.train.exponential_decay(),該方法的引數如下:

learning_rate, 初始的學習率的值

global_step, 迭代步數變數

decay_steps, 帶迭代多少次進行衰減

decay_rate, 迭代decay_steps次衰減的值

staircase=False, 預設為False,為True則不衰減

例如

tf.train.exponential_decay(initial_learning_rate, global_step=global_step, decay_steps=1000, decay_rate=0.9)表示沒經過1000次的迭代,學習率變為原來的0.9。

增大批次處理樣本的數量也可以起到退化學習率的作用。

下面我們寫了一個例子,每迭代10次,則較小為原來的0.5,程式碼如下:

import tensorflow as tf
import numpy as np

global_step = tf.Variable(0, trainable=False)
initial_learning_rate = 0.1

learning_rate = tf.train.exponential_decay(initial_learning_rate,
                                           global_step=global_step,
                                           decay_steps=10,
                                           decay_rate=0.5)

opt = tf.train.GradientDescentOptimizer(learning_rate)
add_global = global_step.assign_add(1)

with tf.Session() as sess:
    tf.global_variables_initializer().run()
    print(sess.run(learning_rate))

    for i in range(50):
        g, rate = sess.run([add_global, learning_rate])
        print(g, rate)

下面是程式的結果,我們發現沒10次就變為原來的一般:

隨後,又在MNIST上面進行了測試,發現使用學習率衰減使得準確率有較好的提升。程式碼如下:

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

tf.reset_default_graph()

x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])

w = tf.Variable(tf.random_normal([784, 10]))
b = tf.Variable(tf.zeros([10]))

pred = tf.matmul(x, w) + b
pred = tf.nn.softmax(pred)

cost = tf.reduce_mean(-tf.reduce_sum(y * tf.log(pred), reduction_indices=1))

global_step = tf.Variable(0, trainable=False)
initial_learning_rate = 0.1

learning_rate = tf.train.exponential_decay(initial_learning_rate,
                                           global_step=global_step,
                                           decay_steps=1000,
                                           decay_rate=0.9)


opt = tf.train.GradientDescentOptimizer(learning_rate)
add_global = global_step.assign_add(1)

optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

training_epochs = 50
batch_size = 100

display_step = 1

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    for epoch in range(training_epochs):
        avg_cost = 0
        total_batch = int(mnist.train.num_examples/batch_size)
        for i in range(total_batch):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            _, c, add, rate = sess.run([optimizer, cost, add_global, learning_rate], feed_dict={x:batch_xs, y:batch_ys})
            avg_cost += c / total_batch

        if (epoch + 1) % display_step == 0:
            print('epoch= ', epoch+1, ' cost= ', avg_cost, 'add_global=', add, 'rate=', rate)
    print('finished')

    correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    print('accuracy: ', accuracy.eval({x:mnist.test.images, y:mnist.test.labels}))

在使用衰減學習率我們最後的精度達到0.8897,在使用固定的學習率時,精度只有0.8586。

&n