tf.nn.dynamic_rnn() 的輸出是什麼?
我不確定我從官方文檔中了解的內容,其中說:
返回:一對(輸出,狀態),其中:
outputs
:RNN 輸出張量。如果
time_major == False
(默認),這將是一個張量形狀:[batch_size, max_time, cell.output_size]
。如果
time_major == True
,這將是一個張量形狀:[max_time, batch_size, cell.output_size]
。請注意,如果
cell.output_size
是整數或 TensorShape 對象的(可能嵌套的)元組,則輸出將是與 cell.output_size 具有相同結構的元組,其中包含具有與 中的形狀數據對應的形狀的張量cell.output_size
。
state
: 最終狀態。如果 cell.state_size 是一個 int,這將是 shape[batch_size, cell.state_size]
。如果它是一個 TensorShape,這將是 shape[batch_size] + cell.state_size
。如果它是整數或 TensorShape 的(可能是嵌套的)元組,這將是一個具有相應形狀的元組。如果單元格是 LSTMCells,則狀態將是一個元組,其中包含每個單元格的 LSTMStateTuple。]是否
output[-1
總是(在所有三種單元類型中,即 RNN、GRU、LSTM)等於狀態(返回元組的第二個元素)?我想各地的文獻在使用“隱藏狀態”一詞時都過於自由了。所有三個單元格中的隱藏狀態是否得分出來(為什麼它被稱為隱藏超出了我的範圍,它會出現 LSTM 中的單元格狀態應該稱為隱藏狀態,因為它沒有暴露)?
是的,單元輸出等於隱藏狀態。在 LSTM 的情況下,它是元組的短期部分( 的第二個元素
LSTMStateTuple
),如下圖所示:但是對於
tf.nn.dynamic_rnn
,當序列較短(參數)時,返回的狀態可能不同。sequence_length
看看這個例子:n_steps = 2 n_inputs = 3 n_neurons = 5 X = tf.placeholder(dtype=tf.float32, shape=[None, n_steps, n_inputs]) seq_length = tf.placeholder(tf.int32, [None]) basic_cell = tf.nn.rnn_cell.BasicRNNCell(num_units=n_neurons) outputs, states = tf.nn.dynamic_rnn(basic_cell, X, sequence_length=seq_length, dtype=tf.float32) X_batch = np.array([ # t = 0 t = 1 [[0, 1, 2], [9, 8, 7]], # instance 0 [[3, 4, 5], [0, 0, 0]], # instance 1 [[6, 7, 8], [6, 5, 4]], # instance 2 [[9, 0, 1], [3, 2, 1]], # instance 3 ]) seq_length_batch = np.array([2, 1, 2, 2]) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) outputs_val, states_val = sess.run([outputs, states], feed_dict={X: X_batch, seq_length: seq_length_batch}) print(outputs_val) print() print(states_val)
這裡輸入批次包含 4 個序列,其中一個很短並用零填充。運行後你應該是這樣的:
[[[ 0.2315362 -0.37939444 -0.625332 -0.80235624 0.2288385 ] [ 0.9999524 0.99987394 0.33580178 -0.9981791 0.99975705]] [[ 0.97374666 0.8373545 -0.7455188 -0.98751736 0.9658986 ] [ 0. 0. 0. 0. 0. ]] [[ 0.9994331 0.9929737 -0.8311569 -0.99928087 0.9990415 ] [ 0.9984355 0.9936006 0.3662448 -0.87244385 0.993848 ]] [[ 0.9962312 0.99659646 0.98880637 0.99548346 0.9997809 ] [ 0.9915743 0.9936939 0.4348318 0.8798458 0.95265496]]] [[ 0.9999524 0.99987394 0.33580178 -0.9981791 0.99975705] [ 0.97374666 0.8373545 -0.7455188 -0.98751736 0.9658986 ] [ 0.9984355 0.9936006 0.3662448 -0.87244385 0.993848 ] [ 0.9915743 0.9936939 0.4348318 0.8798458 0.95265496]]
…這確實表明
state == output[1]
對於完整序列和state == output[0]
短序列。也是output[1]
這個序列的零向量。LSTM 和 GRU 單元也是如此。所以
state
是一個方便的張量,它保存最後一個實際的RNN 狀態,忽略零。output
張量包含所有單元格的輸出,因此它不會忽略零。這就是他們兩個都退貨的原因。