Tensorflow中tf.layers.dense()在RNN網路搭建中的使用
阿新 • • 發佈:2018-11-20
一般來說,tf.layers.dense()多用於CNN網路搭建時搭建全連線層使用。
但在RNN網路搭建,特別是涉及到用RNN做分類時,tf.layers.dense()可以用作RNN分類預測輸出。
分類預測輸出一般形式:
其中n_hidden為RNN cell的units個數,n_classes為labels的個數。
import tensorflow as tf weights = { 'out': tf.Variable(tf.random_normal([n_hidden, n_classes], mean=1.0)) } biases = { 'out': tf.Variable(tf.random_normal([n_classes])) } output = tf.matmul(lstm_last_output, weights['out']) + biases['out'] predication = tf.nn.softmax(output)
使用tf.layers.dense():
import tensorflow as tf
prediction = tf.layers.dense(
inputs=lstm_last_output,
units=n_classes,
activation=tf.nn.softmax
)