1. 程式人生 > >Tensorflow中tf.layers.dense()在RNN網路搭建中的使用

Tensorflow中tf.layers.dense()在RNN網路搭建中的使用

一般來說,tf.layers.dense()多用於CNN網路搭建時搭建全連線層使用。

但在RNN網路搭建,特別是涉及到用RNN做分類時,tf.layers.dense()可以用作RNN分類預測輸出

 

分類預測輸出一般形式:

其中n_hidden為RNN cell的units個數,n_classes為labels的個數。

import tensorflow as tf
weights = {
    'out': tf.Variable(tf.random_normal([n_hidden, n_classes], mean=1.0))
}
biases = {
    'out': tf.Variable(tf.random_normal([n_classes]))
}
output = tf.matmul(lstm_last_output, weights['out']) + biases['out']
predication = tf.nn.softmax(output)

使用tf.layers.dense():

import tensorflow as tf
prediction = tf.layers.dense(
    inputs=lstm_last_output,
    units=n_classes,
    activation=tf.nn.softmax
)