Deep-Learning

tf.nn.dynamic_rnn() 的輸出是什麼?

  • February 23, 2018

我不確定我從官方文檔中了解的內容,其中說:

返回:一對(輸出,狀態),其中:

outputs:RNN 輸出張量。

如果time_major == False(默認),這將是一個張量形狀: [batch_size, max_time, cell.output_size]

如果time_major == True,這將是一個張量形狀:[max_time, batch_size, cell.output_size]

請注意,如果cell.output_size是整數或 TensorShape 對象的(可能嵌套的)元組,則輸出將是與 cell.output_size 具有相同結構的元組,其中包含具有與 中的形狀數據對應的形狀的張量cell.output_size

state: 最終狀態。如果 cell.state_size 是一個 int,這將是 shape [batch_size, cell.state_size]。如果它是一個 TensorShape,這將是 shape [batch_size] + cell.state_size。如果它是整數或 TensorShape 的(可能是嵌套的)元組,這將是一個具有相應形狀的元組。如果單元格是 LSTMCells,則狀態將是一個元組,其中包含每個單元格的 LSTMStateTuple。

]是否output[-1總是(在所有三種單元類型中,即 RNN、GRU、LSTM)等於狀態(返回元組的第二個元素)?我想各地的文獻在使用“隱藏狀態”一詞時都過於自由了。所有三個單元格中的隱藏狀態是否得分出來(為什麼它被稱為隱藏超出了我的範圍,它會出現 LSTM 中的單元格狀態應該稱為隱藏狀態,因為它沒有暴露)?

是的,單元輸出等於隱藏狀態。在 LSTM 的情況下,它是元組的短期部分( 的第二個元素LSTMStateTuple),如下圖所示:

長短期記憶體

但是對於tf.nn.dynamic_rnn,當序列較短(參數)時,返回的狀態可能不同。sequence_length看看這個例子:

n_steps = 2
n_inputs = 3
n_neurons = 5

X = tf.placeholder(dtype=tf.float32, shape=[None, n_steps, n_inputs])
seq_length = tf.placeholder(tf.int32, [None])

basic_cell = tf.nn.rnn_cell.BasicRNNCell(num_units=n_neurons)
outputs, states = tf.nn.dynamic_rnn(basic_cell, X, sequence_length=seq_length, dtype=tf.float32)

X_batch = np.array([
 # t = 0 t = 1
 [[0, 1, 2], [9, 8, 7]], # instance 0
 [[3, 4, 5], [0, 0, 0]], # instance 1
 [[6, 7, 8], [6, 5, 4]], # instance 2
 [[9, 0, 1], [3, 2, 1]], # instance 3
])
seq_length_batch = np.array([2, 1, 2, 2])

with tf.Session() as sess:
 sess.run(tf.global_variables_initializer())
 outputs_val, states_val = sess.run([outputs, states], 
                                    feed_dict={X: X_batch, seq_length: seq_length_batch})

 print(outputs_val)
 print()
 print(states_val)

這裡輸入批次包含 4 個序列,其中一個很短並用零填充。運行後你應該是這樣的:

[[[ 0.2315362  -0.37939444 -0.625332   -0.80235624  0.2288385 ]
 [ 0.9999524   0.99987394  0.33580178 -0.9981791   0.99975705]]

[[ 0.97374666  0.8373545  -0.7455188  -0.98751736  0.9658986 ]
 [ 0.          0.          0.          0.          0.        ]]

[[ 0.9994331   0.9929737  -0.8311569  -0.99928087  0.9990415 ]
 [ 0.9984355   0.9936006   0.3662448  -0.87244385  0.993848  ]]

[[ 0.9962312   0.99659646  0.98880637  0.99548346  0.9997809 ]
 [ 0.9915743   0.9936939   0.4348318   0.8798458   0.95265496]]]

[[ 0.9999524   0.99987394  0.33580178 -0.9981791   0.99975705]
[ 0.97374666  0.8373545  -0.7455188  -0.98751736  0.9658986 ]
[ 0.9984355   0.9936006   0.3662448  -0.87244385  0.993848  ]
[ 0.9915743   0.9936939   0.4348318   0.8798458   0.95265496]]

…這確實表明state == output[1]對於完整序列和state == output[0]短序列。也是output[1]這個序列的零向量。LSTM 和 GRU 單元也是如此。

所以state是一個方便的張量,它保存最後一個實際的RNN 狀態,忽略零。output張量包含所有單元格的輸出,因此它不會忽略零。這就是他們兩個都退貨的原因。

引用自:https://stats.stackexchange.com/questions/330176

comments powered by Disqus