1. 程式人生 > >壓縮神經網路 實驗記錄(剪枝 + rebirth + mobilenet)

壓縮神經網路 實驗記錄(剪枝 + rebirth + mobilenet)

本文轉自:https://blog.csdn.net/jason19966/article/details/79137666

實驗內容

  1. 設計一個原始的神經網路。記錄其執行時間和模型大小和準確率。
  2. 對該網路進行剪枝操作,記錄其執行時間和模型大小和準確率。
  3. 對剪枝後的幾個網路分別分部做rebirth,記錄其執行時間和模型大小和準確率。
  4. 對所有之前涉及的網路做mobile net 。記錄其執行時間和模型大小和準確率。

實驗基礎

  1. 原始神經網路的架構 (被捨棄)如下:
    神經網路結構圖
    經過測試,以上的神經網路的架構訓練網路很難,原始模型在訓練10epoch後準確率才超過50%,後面根據該模型改寫的mobile net 根本就沒法訓練了,正確率始終是10%左右,這根瞎猜沒區別了。

  2. 更換神經網路架構,後來還是採用了之前用的架構,如下:
    修改後的模型架構圖
    雖然直接跟換以後解決了上面模型的問題,但這個問題,我不是很能想明白,前面的網路的最前面的並行網路的設計,我是參照GoogLeNet設計的,為什麼效果這麼差。後來通過

  3. 資料集 :mnist手寫數字的資料集。

實驗詳情記錄

訓練難度對比

  1. 原始網路 的準確度曲線如下,我們可以看到不到2K 步,正確率就基本在0.9幾了。
    原始網路訓練過程
  2. mobile net 的準確度曲線如下,我們可以看到前期準確度一直很低很低,知道6k 步以後才慢慢上升,直到10k以後才達到90%多。
    mobilenet訓練過程
    對比以上兩個網路,我們可以明顯的看出,經過使用mobile net 改造後的網路訓練難度明顯很高,我的電腦跑得很慢,我前期差點就要放棄了。

訓練準確度,時間,模型大小

實驗步驟與資料
1. 先訓練具有上面架構的原始模型t001.pb,記錄模型的大小,預測速度,和準確度。
2. 對原始模型進行迭代剪枝7輪,得到t101.pb,t102.pb,…t107.pb七個模型。分別記錄模型的大小,預測速度,和準確度。(conv2,conv3,fc1,表示對應層的輸出)
剪枝
3.對原始模型和剪枝後的模型進行rebirth,由於rebirth後的模型大小不變,主要記錄其速度和準確度。(時間1和時間2,只是進行了分別跑了兩次,記錄下來的,可以證明預測速度受電腦狀態的影響還是有點大)
rebirth
4. 對前面rebirth 後的模型再進行剪枝。(使用前面rebirth後的結果進行剪枝,比如 t501 對應t401、t502對應t402)
這裡寫圖片描述


5. 根據原始模型t001.pb和剪枝一次後的模型t101.pb分別構建對應的mobile net 模型t201.pb(原始模型) 和 201.pb(剪枝一次後的模型)。
mobile資料
實驗結果分析
1. 剪枝對於縮小模型大小和加快速度都有很好的效果。
2. rebirth在加快速度方面有很好的效果。
3. 將 剪枝 和rebirth 結合起來,反覆迭代,可以得到一個很小,速度很快,準確度也不受影響的模型。
4. mobile net 可以加快速度,但訓練難度較大。
遺留問題
1, mobile net得到的模型的大小,反而比之前的大了。(這有可能是我的程式碼寫得有問題,我的mobile net 部分的程式碼如下,希望有人能幫忙指出問題。)
程式碼中主要是使用depseparable_conv3v3函式代替了之前的卷積函式,以實現mobilenet的結構。

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import tensorflow.contrib.slim as slim
import time

def depseparable_conv3v3(input_data,depthwise_filter, pointwise_filter,name):
    # input_data = tf.Variable(np.random.rand(1, 9, 9, 3), dtype=np.float32)
    # depthwise_filter = tf.Variable(np.random.rand(2, 2, 3, 4), dtype=np.float32)
    # pointwise_filter = tf.Variable(np.random.rand(1, 1, 12, 20), dtype=np.float32)
    y = tf.nn.separable_conv2d(input_data, depthwise_filter, pointwise_filter, strides=[1, 1, 1, 1], padding='SAME',name = name)
log_dir = 'F:'
modlefile = '201.pb'

def variable_summaries(var):
    """Attach a lot of summaries to a Tensor (for TensorBoard visualization)."""
    with tf.name_scope('summaries'):
      # 計算引數的均值,並使用tf.summary.scaler記錄
      mean = tf.reduce_mean(var)
      tf.summary.scalar('mean', mean)
      # 計算引數的標準差
      with tf.name_scope('stddev'):
        stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)))
      # 使用tf.summary.scaler記錄記錄下標準差,最大值,最小值
      tf.summary.scalar('stddev', stddev)
      tf.summary.scalar('max', tf.reduce_max(var))
      tf.summary.scalar('min', tf.reduce_min(var))
      # 用直方圖記錄引數的分佈
      tf.summary.histogram('histogram', var)

def load_data():
    from tensorflow.examples.tutorials.mnist import input_data
    mnist = input_data.read_data_sets('/tmp/data/', one_hot=True)
    return mnist


def weight_variable(shape,name):
    initial=tf.truncated_normal(shape,stddev=0.1)
    return tf.Variable(initial,name =name)
def bias_variable(shape,name):
    initial=tf.constant(0.1,shape=shape)
    return tf.Variable(initial,name=name)
def conv2d(x,W,name):
    return tf.nn.conv2d(x,W,strides=[1,1,1,1],padding='VALID',name=name)
def conv2d_same(x,W,name):
    return tf.nn.conv2d(x,W,strides=[1,1,1,1],padding='SAME',name=name)
def depseparable_conv(input_data,depthwise_filter, pointwise_filter,name):
    # input_data = tf.Variable(np.random.rand(1, 9, 9, 3), dtype=np.float32)
    # depthwise_filter = tf.Variable(np.random.rand(2, 2, 3, 1), dtype=np.float32)
    # pointwise_filter = tf.Variable(np.random.rand(1, 1, 3, 20), dtype=np.float32)
    y = tf.nn.separable_conv2d(input_data, depthwise_filter, pointwise_filter, strides=[1, 1, 1, 1], padding='SAME',name = name+"_separa")
    return y
def max_pool_2x2(x,name):
    return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding='VALID',name = name)
def max_pool_2x2_same(x,name):
    return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,1,1,1],padding='SAME',name = name)
def variable_weight_loss(shape,stddev,w1):
    var=tf.Variable(tf.truncated_normal(shape,stddev=stddev))
    if w1 is not None:
        weight_loss=tf.multiply(tf.nn.l2_loss(var),w1,name="weight_loss")
        tf.add_to_collection("losses",weight_loss)
    return var

def evaluate_pictures(n_epochs=50,batch_size=50):
    def loss(logits, labels):
        labels = tf.cast(labels, tf.int64) # 型別轉換
        cross_entropy = tf.reduce_mean(
            tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits), name='cross_entropy') # 內部執行了softmax。
        tf.add_to_collection('losses', cross_entropy)
        return tf.add_n(tf.get_collection('losses'), name='total_loss')

    mnist = load_data()
    train_set_x  = mnist.train.images
    train_set_y = mnist.train.labels
    test_set_x = mnist.test.images
    test_set_y = mnist.test.labels



    # 計算各資料集的batch個數
    n_train_batches = train_set_x.shape[0]
    n_test_batches  = test_set_x.shape[0]
    n_train_batches = int(n_train_batches / batch_size)
    n_test_batches  = int(n_test_batches / batch_size)
    print("... building the model")

    # 搭建神經網路
    x = tf.placeholder(tf.float32, shape=[None, 784], name = 'input_x')
    y = tf.placeholder(tf.float32, shape=[None, 10], name = 'label_y')
    keep_prob = tf.placeholder(tf.float32,name = 'keep_prob')
    x_images = tf.reshape(x, [-1, 28, 28, 1], name = 'x_tensor')
    tf.summary.image('input', x_images, 10)


    with tf.name_scope("conv2"):
        wd_cov2 = weight_variable([5, 5, 1, 1],name = 'wd_conv2')
        wp_cov2 = weight_variable([1, 1, 1, 10],name = 'wp_conv2')
        b_cov2 = bias_variable([10],name = 'b_conv2')
        h_cov2 = tf.nn.relu(depseparable_conv(x_images, wd_cov2,wp_cov2, name = 'conv2') + b_cov2,name = 'relu_conv2')
        h_pool2 = max_pool_2x2(h_cov2,name = 'maxpooling_conv2')
    with tf.name_scope("conv3"):
        wd_cov3 = weight_variable([5, 5, 10, 1],name = 'wd_conv3')
        wp_cov3 = weight_variable([1, 1, 10, 27],name = 'wp_conv3')
        b_cov3 = bias_variable([27],name = 'b_conv3')
        h_cov3 = tf.nn.relu(depseparable_conv(h_pool2, wd_cov3,wp_cov3, name = 'conv3') + b_cov3,name = 'relu_conv3')
        h_pool3 = max_pool_2x2(h_cov3,name = 'maxpooling_conv3')
    with tf.name_scope("fc1"):
        h_pool2_reshape = tf.reshape(h_pool3, [-1, 7*7*27], name='cnn_fc_convert')
        w_fc1 = weight_variable([7*7*27,496],name = 'w_fc1')
        b_fc1 = bias_variable([496],name = 'b_fc1')
        h_fc1 = tf.nn.relu(tf.matmul(h_pool2_reshape, w_fc1) + b_fc1,name= 'relu_fc1')
    with tf.name_scope("dropout"):
        h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob,name= 'dropout')
    with tf.name_scope("fc2"):
        w_fc2 = weight_variable([496, 10],name = 'w_fc2')
        b_fc2 = bias_variable([10],name = 'b_dc2')
        y_conv = tf.nn.bias_add(tf.matmul(h_fc1_drop, w_fc2), b_fc2,name = 'y')
    with tf.name_scope("loss"):
        loss=loss(labels=y, logits=y_conv)
    train_step = tf.train.GradientDescentOptimizer(0.001).minimize(loss)
    with tf.name_scope("accuracy"):
        correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32),name='accuracy')
        tf.summary.scalar('accuracy', accuracy)
    # 啟動session
    sess=tf.Session()
    sess.run(tf.global_variables_initializer()) #初始化graph的引數

    best_validation_acc = 0
    epoch = 0

    print("... training")

    print(tf.get_default_graph().get_collection(tf.GraphKeys.GLOBAL_VARIABLES))

    # summaries合併
    merged = tf.summary.merge_all()
    # 寫到指定的磁碟路徑中
    train_writer = tf.summary.FileWriter('F:/sum/train_mo', sess.graph)
    test_writer = tf.summary.FileWriter(log_dir + '/test')
    while (epoch < n_epochs):
        epoch = epoch + 1
        for minibatch_index in range(n_train_batches):
            iter = (epoch - 1) * n_train_batches + minibatch_index
            summary,acc,_=sess.run([merged, accuracy,train_step],feed_dict={x: train_set_x[minibatch_index * batch_size: (minibatch_index + 1) * batch_size],
                y: train_set_y[minibatch_index * batch_size: (minibatch_index + 1) * batch_size], keep_prob: 0.5})
            print('epoch %i, step %d,minibatch %i / %i, train acc %f' % (epoch, iter, minibatch_index + 1, n_train_batches,acc))
            run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
            run_metadata = tf.RunMetadata()
            train_writer.add_run_metadata(run_metadata, 'step%03d' % iter)
            train_writer.add_summary(summary, iter)
            # train_writer.add_summary(summary, iter)

            if (iter + 1) % 100 == 0:
                valid_acc=0
                for i in range(n_test_batches):
                    acc=sess.run([accuracy],feed_dict={x: test_set_x[i*batch_size:(i+1)*batch_size], y: test_set_y[i*batch_size:(i+1)*batch_size], keep_prob:1})
                    valid_acc =valid_acc+ acc[0]
                valid_acc=valid_acc/n_test_batches
                print('                         validation acc %g' %(valid_acc ))
                if valid_acc>best_validation_acc:
                    best_validation_acc=valid_acc
                    output_graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node_names=["accuracy/accuracy"])
                    with tf.gfile.FastGFile(modlefile, mode = 'wb') as f:
                        f.write(output_graph_def.SerializeToString())

    train_writer.close()
    print('Optimization complete.')
    test_acc=0;
    start_time=time.time()
    valid_acc=0
    print(start_time)
    for i in range(n_test_batches):
        valid_acc =valid_acc+ sess.run(accuracy,feed_dict={x: test_set_x[i*batch_size:(i+1)*batch_size], y: test_set_y[i*batch_size:(i+1)*batch_size], keep_prob:1})
    end_time=time.time()
    test_acc=valid_acc/n_test_batches
    print("test accuracy %g" % test_acc)
    print((end_time - start_time)*1000/60)

if __name__ == '__main__':
    evaluate_pictures()

相關論文連結:
rebirth:論文連結
mobile net :論文連結
剪枝:沒找到。


更多案例請關注“思享會Club”公眾號或者關注思享會部落格:http://gkhelp.cn/

在這裡插入圖片描述