閱讀理解任務中的Attention-over-Attention神經網路模型原理及實現
本文是“Attention-over-Attention Neural Networks for Reading Comprehension”的閱讀筆記。這篇論文所處理的任務是閱讀理解裡面的完形填空問題。其模型架構是建立在“Text Understanding with the Attention Sum Reader Network”這篇論文至上。該論文首先提出了將Attention用於完形填空任務,本篇論文則在其基礎之上添加了一個額外的Attention層,可以免去啟發式的演算法和一些超引數調整等問題。我們接下來結合兩篇論文進行介紹。
資料集
首先介紹一下資料集,目前常用的大規模資料集主要包括
CBT資料集是從兒童讀物中獲取,由於其沒有摘要,所以採用,前面連續的21句話作為Document,第22句話作為Query等方式構建。然後其還根據答案的詞性分為四個子集:命名實體(NE)、公共名詞(CN)、動詞、介詞。但是由於後面兩種答案與文字並沒有十分緊密的關係,比如人們常常不需要讀文字就可以判斷出介詞填空等,所以常用的是前面兩種。
最終每條資料被構建為如下三元組:
<D, Q, A>
模型
首先我們可以看一下“Text Understanding with the Attention Sum Reader Network”這篇論文所提出的模型架構,如下圖所示:
從上圖可以看出,模型首先通過嵌入矩陣V得到Document和Query中每個單詞的詞向量e(w)。接下來分別使用兩個encoder網路獲得文字中每個單詞的向量contextual embedding
和Query的表示向量。這裡的encoder使用的是雙向GRU迴圈神經網路。然後使用點積的方式將Query向量和每一個單詞的contextual embedding
相乘,得到的結果可以視為每個單詞對於該查詢的權重,亦可理解為attention。最後使用softmax函式將權重轉化為歸一化的概率,將概率最大的結果視為該query的答案。
接下來我們再看一下本文提出的模型架構,如下圖所示:
模型的前半部分與上面完全一樣,差別在於本文提出了一種“Attention over Attention”的機制,也就是獲得Document和Query的向量之後,不將Query的所有單詞合為一個向量,而是直接以矩陣的形式與Document矩陣相乘,然後分別從行和列兩個維度對相乘後的矩陣進行softmax操作得到document的注意力矩陣和query的注意力矩陣。在對query矩陣每一列的元素進行求和當做權重,對document的attention矩陣進行點積即可。
模型的程式碼實現
其實模型使用tensorflow實現的時候十分簡單,直接呼叫tf.contrib.rnn下面的GRUCell即可,難點在於資料的處理和讀取操作。這裡我們可以參考github上面的兩個實現方案:OlavHN,marshmelloX。第一個使用了TF內建的讀取資料的API,程式碼十分簡潔明瞭,我有時間需要研究一下其實現原理整理出一份部落格來。第二個使用的是傳統的資料處理方式,也可以參考,此外在github上面應該可以找到CNN等資料集的處理程式碼結合著一起學習。但是上面兩個程式碼實現都用的是比較老的版本,如果用的是tf1.0及以上的版本可能會出現一些函式的不相容問題,我參照第一份程式碼實現進行了一定的修改,可以再1。0的版本上執行。程式碼後續會放到我的github上面,歡迎檢視。在伺服器上跑需要四五天的樣子,現在還沒跑完==下圖是結果截圖:
四個引數分別代表步數,錯誤率,準確度,時間。可以看到準確度不是十分穩定,但是基本上達到了論文裡面提到的效果。可以看一下我修改過之後的model的程式碼,特別是模型構建部分還是比較簡單的,只用了幾行命令就實現了:
import os
import time
import random
import numpy as np
import tensorflow as tf
from tensorflow.python.ops import sparse_ops
from util import softmax, orthogonal_initializer
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_integer('vocab_size', 119662, 'Vocabulary size')
flags.DEFINE_integer('embedding_size', 384, 'Embedding dimension')
flags.DEFINE_integer('hidden_size', 256, 'Hidden units')
flags.DEFINE_integer('batch_size', 32, 'Batch size')
flags.DEFINE_integer('epochs', 2, 'Number of epochs to train/test')
flags.DEFINE_boolean('training', True, 'Training or testing a model')
flags.DEFINE_string('name', 'lc_model', 'Model name (used for statistics and model path')
flags.DEFINE_float('dropout_keep_prob', 0.9, 'Keep prob for embedding dropout')
flags.DEFINE_float('l2_reg', 0.0001, 'l2 regularization for embeddings')
model_path = 'models/' + FLAGS.name
if not os.path.exists(model_path):
os.makedirs(model_path)
def read_records(index=0):
train_queue = tf.train.string_input_producer(['training.tfrecords'], num_epochs=FLAGS.epochs)
validation_queue = tf.train.string_input_producer(['validation.tfrecords'], num_epochs=FLAGS.epochs)
test_queue = tf.train.string_input_producer(['test.tfrecords'], num_epochs=FLAGS.epochs)
queue = tf.QueueBase.from_list(index, [train_queue, validation_queue, test_queue])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(queue)
features = tf.parse_single_example(
serialized_example,
features={
'document': tf.VarLenFeature(tf.int64),
'query': tf.VarLenFeature(tf.int64),
'answer': tf.FixedLenFeature([], tf.int64)
})
document = sparse_ops.serialize_sparse(features['document'])
query = sparse_ops.serialize_sparse(features['query'])
answer = features['answer']
document_batch_serialized, query_batch_serialized, answer_batch = tf.train.shuffle_batch(
[document, query, answer], batch_size=FLAGS.batch_size,
capacity=2000,
min_after_dequeue=1000)
sparse_document_batch = sparse_ops.deserialize_many_sparse(document_batch_serialized, dtype=tf.int64)
sparse_query_batch = sparse_ops.deserialize_many_sparse(query_batch_serialized, dtype=tf.int64)
document_batch = tf.sparse_tensor_to_dense(sparse_document_batch)
document_weights = tf.sparse_to_dense(sparse_document_batch.indices, sparse_document_batch.dense_shape, 1)
query_batch = tf.sparse_tensor_to_dense(sparse_query_batch)
query_weights = tf.sparse_to_dense(sparse_query_batch.indices, sparse_query_batch.dense_shape, 1)
return document_batch, document_weights, query_batch, query_weights, answer_batch
def inference(documents, doc_mask, query, query_mask):
embedding = tf.get_variable('embedding',
[FLAGS.vocab_size, FLAGS.embedding_size],
initializer=tf.random_uniform_initializer(minval=-0.05, maxval=0.05))
regularizer = tf.nn.l2_loss(embedding)
doc_emb = tf.nn.dropout(tf.nn.embedding_lookup(embedding, documents), FLAGS.dropout_keep_prob)
doc_emb.set_shape([None, None, FLAGS.embedding_size])
query_emb = tf.nn.dropout(tf.nn.embedding_lookup(embedding, query), FLAGS.dropout_keep_prob)
query_emb.set_shape([None, None, FLAGS.embedding_size])
with tf.variable_scope('document', initializer=orthogonal_initializer()):
fwd_cell = tf.contrib.rnn.GRUCell(FLAGS.hidden_size)
back_cell = tf.contrib.rnn.GRUCell(FLAGS.hidden_size)
doc_len = tf.reduce_sum(doc_mask, reduction_indices=1)
h, _ = tf.nn.bidirectional_dynamic_rnn(
fwd_cell, back_cell, doc_emb, sequence_length=tf.to_int64(doc_len), dtype=tf.float32)
#h_doc = tf.nn.dropout(tf.concat(2, h), FLAGS.dropout_keep_prob)
h_doc = tf.concat(h, 2)
with tf.variable_scope('query', initializer=orthogonal_initializer()):
fwd_cell = tf.contrib.rnn.GRUCell(FLAGS.hidden_size)
back_cell = tf.contrib.rnn.GRUCell(FLAGS.hidden_size)
query_len = tf.reduce_sum(query_mask, reduction_indices=1)
h, _ = tf.nn.bidirectional_dynamic_rnn(
fwd_cell, back_cell, query_emb, sequence_length=tf.to_int64(query_len), dtype=tf.float32)
#h_query = tf.nn.dropout(tf.concat(2, h), FLAGS.dropout_keep_prob)
h_query = tf.concat(h, 2)
M = tf.matmul(h_doc, h_query, adjoint_b=True)
M_mask = tf.to_float(tf.matmul(tf.expand_dims(doc_mask, -1), tf.expand_dims(query_mask, 1)))
alpha = softmax(M, 1, M_mask)
beta = softmax(M, 2, M_mask)
#query_importance = tf.expand_dims(tf.reduce_mean(beta, reduction_indices=1), -1)
query_importance = tf.expand_dims(tf.reduce_sum(beta, 1) / tf.to_float(tf.expand_dims(doc_len, -1)), -1)
s = tf.squeeze(tf.matmul(alpha, query_importance), [2])
unpacked_s = zip(tf.unstack(s, FLAGS.batch_size), tf.unstack(documents, FLAGS.batch_size))
y_hat = tf.stack([tf.unsorted_segment_sum(attentions, sentence_ids, FLAGS.vocab_size) for (attentions, sentence_ids) in unpacked_s])
return y_hat, regularizer
def train(y_hat, regularizer, document, doc_weight, answer):
# Trick while we wait for tf.gather_nd - https://github.com/tensorflow/tensorflow/issues/206
# This unfortunately causes us to expand a sparse tensor into the full vocabulary
index = tf.range(0, FLAGS.batch_size) * FLAGS.vocab_size + tf.to_int32(answer)
flat = tf.reshape(y_hat, [-1])
relevant = tf.gather(flat, index)
# mean cause reg is independent of batch size
loss = -tf.reduce_mean(tf.log(relevant)) + FLAGS.l2_reg * regularizer
global_step = tf.Variable(0, name="global_step", trainable=False)
accuracy = tf.reduce_mean(tf.to_float(tf.equal(tf.argmax(y_hat, 1), answer)))
optimizer = tf.train.AdamOptimizer()
grads_and_vars = optimizer.compute_gradients(loss)
capped_grads_and_vars = [(tf.clip_by_value(grad, -5, 5), var) for (grad, var) in grads_and_vars]
train_op = optimizer.apply_gradients(capped_grads_and_vars, global_step=global_step)
tf.summary.scalar('loss', loss)
tf.summary.scalar('accuracy', accuracy)
return loss, train_op, global_step, accuracy
def main():
dataset = tf.placeholder_with_default(0, [])
document_batch, document_weights, query_batch, query_weights, answer_batch = read_records(dataset)
y_hat, reg = inference(document_batch, document_weights, query_batch, query_weights)
loss, train_op, global_step, accuracy = train(y_hat, reg, document_batch, document_weights, answer_batch)
summary_op = tf.summary.merge_all()
with tf.Session() as sess:
summary_writer = tf.summary.FileWriter(model_path, sess.graph)
saver_variables = tf.all_variables()
if not FLAGS.training:
saver_variables = filter(lambda var: var.name != 'input_producer/limit_epochs/epochs:0', saver_variables)
saver_variables = filter(lambda var: var.name != 'smooth_acc:0', saver_variables)
saver_variables = filter(lambda var: var.name != 'avg_acc:0', saver_variables)
saver = tf.train.Saver(saver_variables)
sess.run([
tf.initialize_all_variables(),
tf.initialize_local_variables()])
model = tf.train.latest_checkpoint(model_path)
if model:
print('Restoring ' + model)
saver.restore(sess, model)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
start_time = time.time()
accumulated_accuracy = 0
try:
if FLAGS.training:
while not coord.should_stop():
loss_t, _, step, acc = sess.run([loss, train_op, global_step, accuracy], feed_dict={dataset: 0})
elapsed_time, start_time = time.time() - start_time, time.time()
print(step, loss_t, acc, elapsed_time)
if step % 100 == 0:
summary_str = sess.run(summary_op)
summary_writer.add_summary(summary_str, step)
if step % 1000 == 0:
saver.save(sess, model_path + '/aoa', global_step=step)
else:
step = 0
while not coord.should_stop():
acc = sess.run(accuracy, feed_dict={dataset: 2})
step += 1
accumulated_accuracy += (acc - accumulated_accuracy) / step
elapsed_time, start_time = time.time() - start_time, time.time()
print(accumulated_accuracy, acc, elapsed_time)
except tf.errors.OutOfRangeError:
print('Done!')
finally:
coord.request_stop()
coord.join(threads)
'''
import pickle
with open('counter.pickle', 'r') as f:
counter = pickle.load(f)
word, _ = zip(*counter.most_common())
'''
if __name__ == "__main__":
main()
相關推薦
閱讀理解任務中的Attention-over-Attention神經網路模型原理及實現
本文是“Attention-over-Attention Neural Networks for Reading Comprehension”的閱讀筆記。這篇論文所處理的任務是閱讀理解裡面的完形填空問題。其模型架構是建立在“Text Understanding
MLP多層感知機(人工神經網路)原理及程式碼實現
一、多層感知機(MLP)原理簡介多層感知機(MLP,Multilayer Perceptron)也叫人工神經網路(ANN,Artificial Neural Network),除了輸入輸出層,它中間可以有多個隱層,最簡單的MLP只含一個隱層,即三層的結構,如下圖:從上圖可以看
深入理解Java中的底層阻塞原理及實現
更多 安全 posix pla static events time() 方便 原理 談到阻塞,相信大家都不會陌生了。阻塞的應用場景真的多得不要不要的,比如 生產-消費模式,限流統計等等。什麽 ArrayBlockingQueue、 LinkedBlockingQueue、
《TensorFlow實戰》中AlexNet卷積神經網路的訓練中
TensorFlow實戰中AlexNet卷積神經網路的訓練 01 出錯 TypeError: as_default() missing 1 required positional argument: 'self' 經過百度、谷歌的雙重查詢,沒找到就具體原因。後面去TensorFlow官方文件中發現,tf
介紹一個快速確定神經網路模型中各層矩陣維度的方法
深度學習在設計神經網路結構時需要指定每個隱藏層權重矩陣、偏置項的維度,有的時候容易搞混淆,導致執行時產生維度不對的錯誤,特別是對於隱藏層既深又寬的網路,更容易出錯。下面以一個例項來說明怎麼簡單快速確定每一層各個矩陣的維度。 假設需要擬合的函式為:y=f(x)=WX+b。 損失函式:J(W,b) 其中 X
影象風格轉移中神經網路模型的使用
Neural Style 中使用其他的神經網路模型 影象風格轉移中VGG-19和NIN模型都可以用,還有其他的模型也可以使用。這裡就把它們列出來,說明它們能做什麼,什麼時候使用它們,怎麼樣使用,它們的效能怎麼樣,去哪下載。 神經網路模型 神經網路模型比
【論文閱讀】利用深度自編碼器神經網路預測藥物相似度
論文來源: Prediction of Drug-Likeness Using Deep Autoencoder Neural Networks 摘要 由於各種各樣的原因,大多數候選藥物最終不能成為上市藥物。建立可靠的預測候選化合物是否具有藥物相似度對提高藥物的發現
深入理解Oracle中的shared pool與library cache元件及相關等待事件
傳統的’library cache pin’在10.2.0.2之後預設被取代, 此處PIN被Mutex及其ref count取代。 當程序執行遊標語句時或者需要PIN,或者需要hard parse一個子遊標heap。在版本10.2.0.1中, 使用mutex部分程式碼替代PIN的功能預設是不啟用的,
深度學習【1】ubuntu中利用h5py儲存訓練好的keras 神經網路模型
利用h5py儲存的模型所佔的空間非常小。在利用h5py儲存keras訓練好的模型之前需要先安裝h5py,具體的安裝過程可參考我關於h5py安裝的博文:http://blog.csdn.net/linmingan/article/details/50736300 利用h5py
【Unity3D Shader程式設計】之十 深入理解Unity5中的Standard Shader(二)&螢幕油畫特效的實現
本系列文章由出品,轉載請註明出處。 本文工程使用的Unity3D版本:5.2.1 概要:本文講解了Unity中著色器編譯多樣化的思路,並對Standard Shader中正向基礎渲染通道的原始碼進行了分析,以及對螢幕油畫特效進行了實現。眾所周知,Unity官方文件對Shade
卷積神經網路模型概念與理解
博文的編寫,主要參考網上資料並結合個人見解僅供學習交流使用,如有侵權請聯絡博主刪除,原創文章轉載請註明出處。 本文簡單介紹 AlexNet、GoogLeNet、VGG、ResNet四個模型 AlexNet AlexNet (2012 Ima
怎麼理解tf.nn,conv2d(卷積神經網路)的影象通道數
其實"影象通道數"就是影象的xx,呵呵..,其實這裡的影象通道數其實是叫做“影象的色彩通道數”,還是來看例子理解吧! 如每一張圖片的大小為28*28*1,則表示圖片的大小為28*28的畫素,*1表示是黑白的,即為一個色彩通道 同理,28*28*3,則表示圖片的大小為28*2
Python中的卷積神經網路
本教程將 主要面向程式碼, 旨在幫助您 深入學習和卷積神經網路。由於這個意圖,我 不會花很多時間討論啟用功能,池層或密集/完全連線的層 - 將來會有 很多教程在PyImageSearch部落格上將覆蓋 每個層型別/概念 在很多細節。 再次,本
深度學習在資料中的應用---迴圈神經網路
第一次接觸迴圈神經網路。。。。。。 迴圈神經網路是針對序列資料建模的深度學習模型,在自然語言處理、語音識別等領域中應用較為廣泛。(以前接觸的基本上都是影象方面的東西) 先總結一下文章的內容,下次主要介紹一下迴圈神經網路的工作機理: 看完文章得到的結果就是:CNN
理解RabbitMQ中的AMQP-0-9-1模型
前提 之前有個打算在學習RabbitMQ之前,把AMQP詳細閱讀一次,挑出裡面的重點內容。後來找了下RabbitMQ的官方文件,發現了有一篇文件專門介紹了RabbitMQ中實現的AMQP模型部分,於是直接基於此文件和個人理解寫下這篇文章。 AMQP協議 AMQP全稱是Advanced Message Queu
深度學習中神經網路模型的量化
深度學習神經網路模型中的量化是指浮點數用定點數來表示,也就是在DSP技術中常說的Q格式。我在以前的文章(Android手機上Audio DSP頻率低 memory小的應對措施 )中簡單講過Q格式,網上也有很多講Q格式的,這裡就不細講了。神經網路模型在訓練時都是浮點運算的,得到的模型引數也是浮點的。通常模型引數
網站統計中的數據收集原理及實現
fun 美的 置配 客戶 etc 分析 獲取 固定 open 網站統計中的數據收集原理及實現 網站數據統計分析工具是網站站長和運營人員經常使用的一種工具,比較常用的有谷歌分析、百度統計和騰訊分析等等。所有這些統計分析工具的第一步都是網站訪問數據的收集。目前主流的數據收
請教一下linux中程序重定向輸出到文...[模式及實現]
zcl inux log www. tar xxd html lin scm 56e3xa陌緣唾詠爍蹤http://blog.sina.com.cn/s/blog_17da0698f0102xdu0.html8vbt4y轎匣彩凸莆庇http://blog.sina.com.
【火爐煉AI】深度學習002-構建並訓練單層神經網路模型
【火爐煉AI】深度學習002-構建並訓練單層神經網路模型 (本文所使用的Python庫和版本號: Python 3.6, Numpy 1.14, scikit-learn 0.19, matplotlib 2.2 ) 前面我們介紹了神經網路的基本結構單元-感知器,現在我們再升一級,看看神經網路的基本結構和
【火爐煉AI】深度學習003-構建並訓練深度神經網路模型
【火爐煉AI】深度學習003-構建並訓練深度神經網路模型 (本文所使用的Python庫和版本號: Python 3.6, Numpy 1.14, scikit-learn 0.19, matplotlib 2.2 ) 前面我們講解過單層神經網路模型,發現它結構簡單,難以解決一些實際的比較複雜的問題,故而現