1. 程式人生 > >官網例項詳解4.6(cifar10_cnn_capsule.py)-keras學習筆記四

官網例項詳解4.6(cifar10_cnn_capsule.py)-keras學習筆記四

基於CIFAR10(小批量圖片)資料集訓練簡單的膠囊(組神經元)深度卷積神經網路

程式碼註釋

"""Train a simple CNN-Capsule Network on the CIFAR10 small images dataset.
基於CIFAR10(小批量圖片)資料集訓練簡單的膠囊(組神經元)深度卷積神經網路
Without Data Augmentation:
It gets to 75% validation accuracy in 10 epochs,
and 79% after 15 epochs, and overfitting after 20 epochs.
不擴充套件資料集,10個週期後達到7%%的準確率,15個週期後達到79%的準確率,20個週期後過擬合。

With Data Augmentation:
It gets to 75% validation accuracy in 10 epochs,
and 79% after 15 epochs, and 83% after 30 epcohs.
In my test, highest validation accuracy is 83.79% after 50 epcohs.
不擴充套件資料集,10個週期後達到7%%的準確率,15個週期後達到79%的準確率,30個週期達到83%的準確率。
測試中,最好成績是50個週期達到83.79%的準確率

This is a fast Implement, just 20s/epcoh with a gtx 1070 gpu.
快速實現,基於一個gtx 1070 GPU(影象處理器,顯示卡)每個週期執行需要20秒
"""

from __future__ import print_function
from keras import backend as K
from keras.engine.topology import Layer
from keras import activations
from keras import utils
from keras.datasets import cifar10
from keras.models import Model
from keras.layers import *
from keras.preprocessing.image import ImageDataGenerator


# the squashing function.
# 壓平函式
# we use 0.5 in stead of 1 in hinton's paper.
# 在hinton的論文中,使用0.5代替1
# if 1, the norm of vector will be zoomed out.
# 如果是1,向量範數將縮小
# if 0.5, the norm will be zoomed in while original norm is less than 0.5
# 如果是1,原範數小於0.5,向量範數將放大
# and be zoomed out while original norm is greater than 0.5.
# 原範數大於0.5,向量範數將縮小
def squash(x, axis=-1):
    s_squared_norm = K.sum(K.square(x), axis, keepdims=True) + K.epsilon()
    scale = K.sqrt(s_squared_norm) / (0.5 + s_squared_norm)
    return scale * x


# define our own softmax function instead of K.softmax
# because K.softmax can not specify axis.
# 自定義softmax函式,替換K.softmax函式,因為K.softmax函式不能指定軸
def softmax(x, axis=-1):
    ex = K.exp(x - K.max(x, axis=axis, keepdims=True))
    return ex / K.sum(ex, axis=axis, keepdims=True)


# define the margin loss like hinge loss
# 定義利潤邊緣損失,如hinge損失
def margin_loss(y_true, y_pred):
    lamb, margin = 0.5, 0.1
    return y_true * K.square(K.relu(1 - margin - y_pred)) + lamb * (
        1 - y_true) * K.square(K.relu(y_pred - margin))


class Capsule(Layer):
    """A Capsule Implement with Pure Keras
    基於純Keras的膠囊(組神經元)實現
    There are two vesions of Capsule.
    以下是2個版本的膠囊(組神經元)
    One is like dense layer (for the fixed-shape input),
    一種類似與全連線層(對於固定形狀輸入),
    and the other is like timedistributed dense (for various length input).
    一種類似時間分佈的全連線層(對於變長輸入),

    The input shape of Capsule must be (batch_size,
                                        input_num_capsule,
                                        input_dim_capsule
                                       )
    膠囊(組神經元)輸入(資料)形狀為:

    and the output shape is (batch_size,
                             num_capsule,
                             dim_capsule
                            )
     輸出(資料)形狀為:

    Capsule Implement is from https://github.com/bojone/Capsule/
    膠囊(組神經元)實現見:https://github.com/bojone/Capsule/
    Capsule Paper: https://arxiv.org/abs/1710.09829
    膠囊(組神經元)論文:https://arxiv.org/abs/1710.09829
    """

    def __init__(self,
                 num_capsule,
                 dim_capsule,
                 routings=3,
                 share_weights=True,
                 activation='squash',
                 **kwargs):
        super(Capsule, self).__init__(**kwargs)
        self.num_capsule = num_capsule
        self.dim_capsule = dim_capsule
        self.routings = routings
        self.share_weights = share_weights
        if activation == 'squash':
            self.activation = squash
        else:
            self.activation = activations.get(activation)

    def build(self, input_shape):
        input_dim_capsule = input_shape[-1]
        if self.share_weights:
            self.kernel = self.add_weight(
                name='capsule_kernel',
                shape=(1, input_dim_capsule,
                       self.num_capsule * self.dim_capsule),
                initializer='glorot_uniform',
                trainable=True)
        else:
            input_num_capsule = input_shape[-2]
            self.kernel = self.add_weight(
                name='capsule_kernel',
                shape=(input_num_capsule, input_dim_capsule,
                       self.num_capsule * self.dim_capsule),
                initializer='glorot_uniform',
                trainable=True)

    def call(self, inputs):
        """Following the routing algorithm from Hinton's paper,
        根據Hinton論文的路由演算法,
        but replace b = b + <u,v> with b = <u,v>.
        但是 用b = <u,v>替換b = b + <u,v>

        This change can improve the feature representation of Capsule.
        這種改變可以改善膠囊(組神經元)的特徵表示。

        However, you can replace
        而且,可以把替換
            b = K.batch_dot(outputs, hat_inputs, [2, 3])
        with
        使用
            b += K.batch_dot(outputs, hat_inputs, [2, 3])
        to realize a standard routing.
        實現標準路由。
        """

        if self.share_weights:
            hat_inputs = K.conv1d(inputs, self.kernel)
        else:
            hat_inputs = K.local_conv1d(inputs, self.kernel, [1], [1])

        batch_size = K.shape(inputs)[0]
        input_num_capsule = K.shape(inputs)[1]
        hat_inputs = K.reshape(hat_inputs,
                               (batch_size, input_num_capsule,
                                self.num_capsule, self.dim_capsule))
        hat_inputs = K.permute_dimensions(hat_inputs, (0, 2, 1, 3))

        b = K.zeros_like(hat_inputs[:, :, :, 0])
        for i in range(self.routings):
            c = softmax(b, 1)
            if K.backend() == 'theano':
                o = K.sum(o, axis=1)
            o = self.activation(K.batch_dot(c, hat_inputs, [2, 2]))
            if i < self.routings - 1:
                b = K.batch_dot(o, hat_inputs, [2, 3])
                if K.backend() == 'theano':
                    o = K.sum(o, axis=1)

        return o

    def compute_output_shape(self, input_shape):
        return (None, self.num_capsule, self.dim_capsule)


batch_size = 128
num_classes = 10
epochs = 100
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
y_train = utils.to_categorical(y_train, num_classes)
y_test = utils.to_categorical(y_test, num_classes)

# A common Conv2D model
# 2維卷積模型
input_image = Input(shape=(None, None, 3))
x = Conv2D(64, (3, 3), activation='relu')(input_image)
x = Conv2D(64, (3, 3), activation='relu')(x)
x = AveragePooling2D((2, 2))(x)
x = Conv2D(128, (3, 3), activation='relu')(x)
x = Conv2D(128, (3, 3), activation='relu')(x)


"""now we reshape it as (batch_size, input_num_capsule, input_dim_capsule)
then connect a Capsule layer.
重塑(資料)形狀為 (batch_size, input_num_capsule, input_dim_capsule)
然後,連結一個膠囊(組神經元)層

the output of final model is the lengths of 10 Capsule, whose dim=16.
最終模型輸出為16維的10個膠囊(組神經元)

the length of Capsule is the proba,
膠囊(組神經元)的長度是proba
so the problem becomes a 10 two-classification problem.
因此問題成為10個兩分類問題。
"""

x = Reshape((-1, 128))(x)
capsule = Capsule(10, 16, 3, True)(x)
output = Lambda(lambda z: K.sqrt(K.sum(K.square(z), 2)))(capsule)
model = Model(inputs=input_image, outputs=output)

# we use a margin loss
# 使用邊緣損失(函式)
model.compile(loss=margin_loss, optimizer='adam', metrics=['accuracy'])
model.summary()

# we can compare the performance with or without data augmentation
# 比較效果,有和沒有資料集(擴大)
data_augmentation = True

if not data_augmentation:
    print('Not using data augmentation.')
    model.fit(
        x_train,
        y_train,
        batch_size=batch_size,
        epochs=epochs,
        validation_data=(x_test, y_test),
        shuffle=True)
else:
    print('Using real-time data augmentation.')
    # This will do preprocessing and realtime data augmentation:
    #  預處理和實時資料擴大(通過平移、翻轉等影象變換增加影象樣本數量)。
    datagen = ImageDataGenerator(
        featurewise_center=False,  # set input mean to 0 over the dataset  # 基於資料集,使輸入資料平均值為0
        samplewise_center=False,  # set each sample mean to 0 # 使樣本平均值為0
        featurewise_std_normalization=False,  # divide inputs by std of the dataset # 通過資料標準化劃分輸入資料
        samplewise_std_normalization=False,  # divide each input by its std # 通過標準化劃分輸入資料
        zca_whitening=False,  # apply ZCA(Zero-phase Component Analysis) whitening # 對輸入資料施加ZCA白化
        rotation_range=0,  # randomly rotate images in the range (degrees, 0 to 180) # 旋轉影象0-180度
        width_shift_range=0.1,  # randomly shift images horizontally (fraction of total width) # 水平平移影象(基於影象寬度比例)
        height_shift_range=0.1,  # randomly shift images vertically (fraction of total height) # 垂直平移影象(基於影象高度比例)
        horizontal_flip=True,  # randomly flip images # 水平翻轉影象
        vertical_flip=False)  # randomly flip images # 垂直翻轉影象

    # Compute quantities required for feature-wise normalization
    # 特徵歸一化的計算量
    # (std, mean, and principal components if ZCA whitening is applied).
    # (如果ZCA白化(一種降維方法)會使用標準化、均值和主成分方法)
    datagen.fit(x_train)

    # Fit the model on the batches generated by datagen.flow().
    # 使用datagen.flow()生成的批次資料在模型訓練
    model.fit_generator(
        datagen.flow(x_train, y_train, batch_size=batch_size),
        epochs=epochs,
        validation_data=(x_test, y_test),
        workers=4)

程式碼執行

Keras詳細介紹

例項下載

完整專案下載

方便沒積分童鞋,請加企鵝452205574,共享資料夾。

包括:程式碼、資料集合(圖片)、已生成model、安裝庫檔案等。

 

相關推薦

例項4.6cifar10_cnn_capsule.py-keras學習筆記

基於CIFAR10(小批量圖片)資料集訓練簡單的膠囊(組神經元)深度卷積神經網路程式碼註釋"""Train a simple CNN-Capsule Network on the CIFAR10 small images dataset. 基於CIFAR10(小批量圖片)資料

例項4.15imdb_cnn_lstm.py-keras學習筆記

程式碼註釋'''Train a recurrent convolutional network on the IMDB sentiment classification task. 為IMDB(資料集)情感分類任務訓練迴圈卷積網路 Gets to 0.8498 test ac

例項4.30mnist_siamese.py-keras學習筆記

基於MNIST資料集上從一對數字中訓練一個 Siamese MLP。Siamese ,連體的,相似的。Siamese Net,孿生網路、連體網路MLP,多層感知機,(多個隱藏層的全連線的神經網路)詳解程式碼註釋'''Trains a Siamese MLP on pairs

keras\preprocessing目錄檔案5.2sequence.py-keras學習筆記

功能:用於預處理序列(例如一篇文章,句子)資料的實用工具。 keras-master\keras\preprocessing\sequence.py 建立詞向量嵌入層,把輸入文字轉為可以進一步處理的資料格式(例如,矩陣) 程式碼註釋 # -*- coding:

免費的HTML5連載來了《HTML5網頁開發例項》連載

    最近新浪、百度、騰訊、京東、大眾點評、淘寶等流行的網站都加大了招聘HTML5的力度,HTML5開發人員成了搶手貨,本次連載的是由大眾點評前端工程師和一淘網前端工程師基情奉獻的《HTML5網頁開發例項詳解》,喜歡本書的人可以關注連載,後續會更精彩! 2.1.1  最

免費的HTML5連載來了《HTML5網頁開發例項》連載媒體查詢

響應式設計的另一個重要技術手段是媒體查詢。如果只是簡單的設計一個流式佈局系統,那麼可以保證每個網格按比例的放大和縮小,但有可能會使得在小螢幕下(如手機裝置)網格太小而嚴重影響閱讀,這樣的設計稱不上響應式設計。媒體查詢可以來解決這一問題。媒體查詢可以為特定的瀏覽器和裝置提供特

機器學習中的概率模型和概率密度估計方法及VAE生成式模型之二作者簡介

AR aca rtu href beijing cert school start ica Brief Introduction of the AuthorChief Architect at 2Wave Technology Inc. (a startup company

Hibernate--Query物件 分頁

package com.itheima.b_api; import java.util.Arrays; import java.util.List; import org.hibernate.Criteria; import org.hibernate.Query; import org.h

分散式系統--基礎知識執行緒

                     分散式系統詳解--基礎知識(執行緒) 一、導讀           前面跟大家講了一下&n

分散式系統--架構簡介微服務

                       分散式系統詳解--架構簡介(微服務)           前面的一個集合我們

Java_51_組合_內部類_字串String類_equals和==的區別

組合 使用組合,可以獲得更多的靈活性,你甚至可以在執行的時候才決定哪幾個類組合在一起。 使用繼承,他是一種高度耦合,派生類和基類被緊緊的綁在一起,靈活性大大降低,而且,濫用繼承,也會使繼承樹變得又大又複雜,很難理解和維護。 如果是is-a關係,用繼承。【是一個[物件]】 如果是h

遞迴基礎篇———函式棧、階乘、Fibonacci數列

一、遞迴的基本概念 遞迴函式:在定義的時候,自己呼叫了自己的函式。 注意:遞迴函式定義的時候一定要明確結束這個函式的條件! 二、函式棧 棧:一種資料結構,它僅允許棧頂進,棧頂出,先進後出,後進先出。我們可以簡單的理解為棧就是一個杯子,這個杯子裡面有很多隔層,每一層都可以放東西,第一個放入的東西就在杯子

Elasticsearch深入-知識圖譜每週更新

1、題記 Elasticsearch技術已經燃爆到飛的感覺。 為了方便訂閱Elasticsearch深入詳解的博友們第一時間獲取最新經驗分享,和大家一起成長,特將本專欄內容製作為Elasticsearch深入詳解知識圖譜。 並承諾【銘毅天下】微信公眾號每

FreeRTOS之全配置項、裁剪FreeRTOSConfig.h

簡介   首先,我們需要明確一個問題,FreeRTOSConfig.h是一個使用者級別的檔案,不屬於核心檔案。每個使用者可以有不同的FreeRTOSConfig.h。   FreeRTOS作為一個可高度配置的實時核心,其絕大多數配置選項都體現在FreeRTO

《Python程式設計從入門到實踐》學習筆記-專案篇下載資料

上兩篇文章分別介紹了《Python程式設計從入門到實踐》的語法篇和專案篇(資料視覺化),這篇文 #專案二 下載資料 #訪問並可視化csv和json這兩種常見格式儲存的資料 #csv #提取並讀取資料 import csv filename='filename.c

PyQt5基本控制元件之QDialog十二

QDialog 前言 為了更好的實現人機互動,比如window和linux等系統均會提供一系列的標準對話方塊來完成特定場景下的功能,比如選擇字號大小。字型顏色等,在PyQt5中定義了一系列的標準對話方塊類,讓使用者能夠方便快捷地通過各個類完成字號大

資料結構之折半插入排序圖文及程式碼C++實現

問題:對待排序的陣列r[1..n]中的元素進行直接插入排序,得到一個有序的(從小到大)的陣列r[1..n]。演算法思想:1、設待排序的記錄存放在陣列r[1..n]中,r[1]是一個有序序列。2、迴圈n-1次,每次使用折半查詢法,查詢r[i](i=2,..,n)在已排好的序列r

java Comparable 和Comparator及 區別附程式碼

java中,對集合物件或者陣列物件排序,有兩種實現方式。 即:(1)物件實現Comparable 介面         (2)定義比較器,實現Comparator介面。 下面會簡要介紹這兩種方法的區別,並附上實現程式碼,供大家參考。 Comparable介紹 Compar

git rebase變基—— Git 學習筆記 19

git rebase(變基) 認識 git rebase 假設你現在基於遠端分支"origin",建立一個叫"mywork"的分支。 $ git checkout -b mywork origin 現在我們在 mywork 分支做一些修改,生成兩個提交(C5和C6