1. 程式人生 > >基於Keras的attention實戰

基於Keras的attention實戰

要點:
該教程為基於Kears的Attention實戰,環境配置:
Wn10+CPU i7-6700
Pycharm 2018
python 3.6
numpy 1.14.5
Keras 2.0.2
Matplotlib 2.2.2
強調:各種庫的版本型號一定要配置對,因為Keras以及Tensorflow升級更新比較頻繁,很多函式更新後要麼更換了名字,要麼沒有這個函數了,所以大家務必重視。
相關程式碼我放在了我的程式碼倉庫裡哈,歡迎大家下載,這裡附上地址:基於Kears的Attention實戰
筆者資訊:Next_Legend QQ:1219154092 人工智慧 自然語言處理 影象處理 神經網路
——2018.8.21於天津大學

一、導讀

最近兩年,尤其在今年,注意力機制(Attention)及其變種Attention逐漸熱了起來,在很多頂會Paper中都或多或少的用到了attention,所以小編出於好奇,整理了這篇基於Kears的Attention實戰,本教程僅從程式碼的角度來看Attention。通過一個簡單的例子,探索Attention機制是如何在模型中起到特徵選擇作用的。

二、程式碼實戰(一)

1、匯入相關庫檔案

import numpy as np
from attention_utils import get_activations, get_data

np.random.seed(1337
) # for reproducibility from keras.models import * from keras.layers import Input, Dense, merge import tensorflow as tf

2、資料生成函式

def get_data(n, input_dim, attention_column=1):
    """
    Data generation. x is purely random except that it's first value equals the target y.
    In practice, the network should learn that the target = x[attention_column].
    Therefore, most of its attention should be focused on the value addressed by attention_column.
    :param n: the number of samples to retrieve.
    :param input_dim: the number of dimensions of each element in the series.
    :param attention_column: the column linked to the target. Everything else is purely random.
    :return: x: model inputs, y: model targets
    """
x = np.random.standard_normal(size=(n, input_dim)) y = np.random.randint(low=0, high=2, size=(n, 1)) x[:, attention_column] = y[:, 0] return x, y

3、模型定義函式

將輸入進行一次變換後,計算出Attention權重,將輸入乘上Attention權重,獲得新的特徵。

def build_model():
    inputs = Input(shape=(input_dim,))

    # ATTENTION PART STARTS HERE
    attention_probs = Dense(input_dim, activation='softmax', name='attention_vec')(inputs)
    attention_mul =merge([inputs, attention_probs], output_shape=32, name='attention_mul', mode='mul')
    # ATTENTION PART FINISHES HERE

    attention_mul = Dense(64)(attention_mul)
    output = Dense(1, activation='sigmoid')(attention_mul)
    model = Model(input=[inputs], output=output)
    return model

4、主函式

if __name__ == '__main__':
    N = 10000
    inputs_1, outputs = get_data(N, input_dim)

    m = build_model()
    m.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    print(m.summary())

    m.fit([inputs_1], outputs, epochs=20, batch_size=64, validation_split=0.5)

    testing_inputs_1, testing_outputs = get_data(1, input_dim)

    # Attention vector corresponds to the second matrix.
    # The first one is the Inputs output.
    attention_vector = get_activations(m, testing_inputs_1,
                                       print_shape_only=True,
                                       layer_name='attention_vec')[0].flatten()
    print('attention =', attention_vector)

    # plot part.
    import matplotlib.pyplot as plt
    import pandas as pd

    pd.DataFrame(attention_vector, columns=['attention (%)']).plot(kind='bar',
                                                                   title='Attention Mechanism as '
                                                                         'a function of input'
                                                                         ' dimensions.')
    plt.show()

5、執行結果

程式碼中,attention_column為1,也就是說,label只與資料的第1個特徵相關。從執行結果中可以看出,Attention權重成功地獲取了這個資訊。

三、程式碼實戰(二)

1、匯入相關庫檔案

from keras.layers import merge
from keras.layers.core import *
from keras.layers.recurrent import LSTM
from keras.models import *

from attention_utils import get_activations, get_data_recurrent
INPUT_DIM = 2
TIME_STEPS = 20
# if True, the attention vector is shared across the input_dimensions where the attention is applied.
SINGLE_ATTENTION_VECTOR = False
APPLY_ATTENTION_BEFORE_LSTM = False

2、資料生成函式

def attention_3d_block(inputs):
    # inputs.shape = (batch_size, time_steps, input_dim)
    input_dim = int(inputs.shape[2])
    a = Permute((2, 1))(inputs)
    a = Reshape((input_dim, TIME_STEPS))(a) # this line is not useful. It's just to know which dimension is what.
    a = Dense(TIME_STEPS, activation='softmax')(a)
    if SINGLE_ATTENTION_VECTOR:
        a = Lambda(lambda x: K.mean(x, axis=1), name='dim_reduction')(a)
        a = RepeatVector(input_dim)(a)
    a_probs = Permute((2, 1), name='attention_vec')(a)
    output_attention_mul = merge([inputs, a_probs], name='attention_mul', mode='mul')
    return output_attention_mul

def model_attention_applied_after_lstm():
    inputs = Input(shape=(TIME_STEPS, INPUT_DIM,))
    lstm_units = 32
    lstm_out = LSTM(lstm_units, return_sequences=True)(inputs)
    attention_mul = attention_3d_block(lstm_out)
    attention_mul = Flatten()(attention_mul)
    output = Dense(1, activation='sigmoid')(attention_mul)
    model = Model(input=[inputs], output=output)
    return model

def model_attention_applied_before_lstm():
    inputs = Input(shape=(TIME_STEPS, INPUT_DIM,))
    attention_mul = attention_3d_block(inputs)
    lstm_units = 32
    attention_mul = LSTM(lstm_units, return_sequences=False)(attention_mul)
    output = Dense(1, activation='sigmoid')(attention_mul)
    model = Model(input=[inputs], output=output)
    return model

4、主函式

 if __name__ == '__main__':

    N = 300000
    # N = 300 -> too few = no training
    inputs_1, outputs = get_data_recurrent(N, TIME_STEPS, INPUT_DIM)

    if APPLY_ATTENTION_BEFORE_LSTM:
        m = model_attention_applied_before_lstm()
    else:
        m = model_attention_applied_after_lstm()

    m.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    print(m.summary())

    m.fit([inputs_1], outputs, epochs=1, batch_size=64, validation_split=0.1)

    attention_vectors = []
    for i in range(300):
        testing_inputs_1, testing_outputs = get_data_recurrent(1, TIME_STEPS, INPUT_DIM)
        attention_vector = np.mean(get_activations(m,
                                                   testing_inputs_1,
                                                   print_shape_only=True,
                                                   layer_name='attention_vec')[0], axis=2).squeeze()
        print('attention =', attention_vector)
        assert (np.sum(attention_vector) - 1.0) < 1e-5
        attention_vectors.append(attention_vector)

    attention_vector_final = np.mean(np.array(attention_vectors), axis=0)
    # plot part.
    import matplotlib.pyplot as plt
    import pandas as pd

    pd.DataFrame(attention_vector_final, columns=['attention (%)']).plot(kind='bar',
                                                                         title='Attention Mechanism as '
                                                                               'a function of input'
                                                                               ' dimensions.')
    plt.show()

相關推薦

AngularJS進階(三十九)基於專案實戰解析ng啟動載入過程

分享一下我老師大神的人工智慧教程!零基礎,通俗易懂!http://blog.csdn.net/jiangjunshow 也歡迎大家轉載本篇文章。分享知識,造福人民,實現我們中華民族偉大復興!        

AngularJS進階 三十九 基於專案實戰解析ng啟動載入過程

基於專案實戰解析ng啟動載入過程 前言       在AngularJS專案開發過程中,自己將遇到的問題進行了整理。回過頭來總結一下angular的啟動過程。       下面以實際專案為例進行簡要講解。 1.載入ng庫 &

基於OpenLayers實戰地理資訊系統(離線地圖,通過基站轉經緯度,Quartz深入,軌跡實戰

我這裡有套課程想和大家分享,需要的朋友可以加我qq和我聯絡。QQ2059055336. 一、本課程是怎麼樣的一門課程(全面介紹)    1.1、課程的背景        OpenLayers是一個用於開發WebGIS客戶端的JavaScript包。        地理地圖眾多方案實現的對比:        

基於OpenLayers實戰地理資訊系統視訊

看到大家都在找尋關於基於Openlayers實戰地理資訊系統的視訊,小編在此共享,但是由於可能會涉及版權的問題,請勿廣泛傳播,謝謝!我將視訊上傳到了360雲盤上,需要的朋友請留言...  第一講:概述      第二講:龐雜的GIS體系概覽      第三講:專案快速實戰

電子書 flaskweb開發:基於Python的Web應用開發實戰.pdf

商業 機器 免費 影評 而且 視頻軟件 python程序 規範 初級 作為PythonWeb開發的微框架,Flask獨樹一幟。它不會強迫開發者遵循預置的開發規範,為開發者提供了自由度和創意空間。   《圖靈程序設計叢書·Flask Web開發:基於Python的Web應用開

Android實戰簡易教程-第二十六槍(基於ViewPager實現微信頁面切換效果)

stat addview data android tid des viewpage 聊天 == 1.頭部布局文件top.xml:<?xml version="1.0" encoding="utf-8"?> <LinearLayout xmlns:and

【推薦系統實戰】:C++實現基於用戶的協同過濾(UserCollaborativeFilter)

color style popu ted std 相似度 abi ear result 好早的時候就打算寫這篇文章,可是還是參加阿裏大數據競賽的第一季三月份的時候實驗就完畢了。硬生生是拖到了十一假期。自己也是醉了。。。找工作不是非常順利,希望寫點東西回想一下知識。然後再

Linux實戰第五篇:RHEL7.3下Nginx虛擬主機配置實戰基於別名)

虛擬主機 nginx個人筆記分享(在線閱讀):http://note.youdao.com/noteshare?id=05daf711c28922e50792c4b09cf63c58PDF版本下載http://down.51cto.com/data/2323313本文出自 “人才雞雞” 博客,請務必保留此出處

機器學習之分類問題實戰(基於UCI Bank Marketing Dataset)

表示 般的 機構 文件 cnblogs opened csv文件 mas htm 導讀: 分類問題是機器學習應用中的常見問題,而二分類問題是其中的典型,例如垃圾郵件的識別。本文基於UCI機器學習數據庫中的銀行營銷數據集,從對數據集進行探索,數據預處理和特征工程,到學習

selenium自動化實戰-基於python語言(二: 編寫腳本)

獲取 pat 打開 border 命令 需要 框架 attribute 一個 上一篇文章說到顯示等待和隱式等待語句,我們繼續學習下面的命令方法。 8. 定位一組元素 這裏書上是自己寫了一個頁面代碼,通過訪問本地這個頁面來舉例。但我覺得找一個現有的頁面自己琢磨更有意思,而且

下載基於大數據技術推薦系統實戰教程(Spark ML Spark Streaming Kafka Hadoop Mahout Flume Sqoop Redis)

大數據技術推薦系統 推薦系統實戰 地址:http://pan.baidu.com/s/1c2tOtwc 密碼:yn2r82課高清完整版,轉一播放碼。互聯網行業是大數據應用最前沿的陣地,目前主流的大數據技術,包括 hadoop,spark等,全部來自於一線互聯網公司。從應用角度講,大數據在互聯網領域主

基於ASP.NET WebAPI OWIN實現Self-Host項目實戰

hosting 知識 工作 develop plist 簡單 eba 直接 sock 引用 寄宿ASP.NET Web API 不一定需要IIS 的支持,我們可以采用Self Host 的方式使用任意類型的應用程序(控制臺、Windows Forms 應用、WPF 應

nginx基於域名的虛擬主機配置實戰

linux背景: 在www虛擬主機站點基礎上新增一個bbs虛擬主機站點。1 備份配置文件[[email protected]/* */ conf]# pwd /application/nginx/conf [[email protected]/* */ conf]#

基於centos7.3安裝部署jewel版本ceph集群實戰演練

集群 ceph 一、環境準備安裝centos7.3虛擬機三臺由於官網源與網盤下載速度都非常的慢,所以給大家提供了國內的搜狐鏡像源:http://mirrors.sohu.com/centos/7.3.1611/isos/x86_64/CentOS-7-x86_64-DVD-1611.iso在三臺裝好的

企業實戰-實現基於LVS負載均衡集群的電商網站架構

企業實戰 lvs lnmp 實現LVS-DR工作模式:環境準備:一臺centos系統做DR、兩臺實現過基於LNMP的電子商務網站機器名稱IP配置服務角色備註lvs-serverVIP:172.17.252.110DIP:172.17.250.223負載均衡器開啟路由功能(VIP橋接)rs01RIP

企業實戰(4)-實現基於Haproxy負載均衡集群的電子商務網站架構

haproxy keepalived 企業實戰:逐步實現企業各種情景下的需求企業情景四:隨著公司業務的發展,公司負載均衡服務已經實現四層負載均衡,但業務的復雜程度提升,公司要求把mobile手機站點作為單獨的服務提供,不在和pc站點一起提供服務,此時需要做7層規則負載均衡,運維總監要求,能否用一種服務

iKcamp|基於Koa2搭建Node.js實戰(含視頻)? 代碼分層

如果 讓我 span module input 數據 listen else nod 視頻地址:https://www.cctalk.com/v/15114923889408 文章 在前面幾節中,我們已經實現了項目中的幾個常見操作:啟動服務器、路由中間件、Get 和 Po

Linux實戰第八篇:CentOS7.3下Nginx虛擬主機配置實戰基於端口)

基於 sub 主機配置 centos7.3 entos ada .com 版本 fad 個人筆記分享(在線閱讀): http://note.youdao.com/noteshare?id=9a8b56ec54800ccf197eb6c23de55a85&sub=2E3048

PK2153-BAT大牛親授 基於ElasticSearch的搜房網實戰

height ear 希望 data- arch http package support nta PK2153-BAT大牛親授 基於ElasticSearch的搜房網實戰 新年伊始,學習要趁早,點滴記錄,學習就是進步! 隨筆背景:在很多時候,很多入門不久的朋友都會問

基於Storm構建實時熱力分布項目實戰

解析 cat django ron 優化 Redis分布式 java並發編程 body code 詳情請交流 QQ 709639943 01、基於Storm構建實時熱力分布項目實戰 02、以慕課網日誌分析為例 進入大數據 Spark SQL 的世界 03、Spri