Time-Series

注意力機制究竟是什麼?

  • May 4, 2018

在過去的幾年中,注意力機制已被用於各種深度學習論文中。Open AI 研究負責人 Ilya Sutskever 熱情地稱讚了他們: https ://towardsdatascience.com/the-fall-of-rnn-lstm-2d1594c74ce0

Purdue 大學的 Eugenio Culurciello 聲稱應該放棄 RNN 和 LSTM,取而代之的是純粹的基於注意力的神經網絡:

https://towardsdatascience.com/the-fall-of-rnn-lstm-2d1594c74ce0

這似乎有些誇張,但不可否認的是,純粹基於注意力的模型在序列建模任務中做得相當不錯:我們都知道谷歌的那篇名副其實的論文,Attention is all you need

然而,究竟什麼基於注意力的模型?我還沒有找到對這些模型的明確解釋。假設我想在給定歷史值的情況下預測多元時間序列的新值。很清楚如何使用具有 LSTM 單元的 RNN 來做到這一點。我如何對基於注意力的模型做同樣的事情?

Attention 是一種聚合一組向量的方法 $ v_i $ 進入一個向量,通常通過一個查找向量 $ u $ . 通常, $ v_i $ 要么是模型的輸入,要么是先前時間步的隱藏狀態,要么是向下一級的隱藏狀態(在堆疊 LSTM 的情況下)。

結果通常稱為上下文向量 $ c $ ,因為它包含與當前時間步相關的上下文。

這個額外的上下文向量 $ c $ 然後也被輸入到 RNN/LSTM 中(它可以簡單地與原始輸入連接)。因此,上下文可用於幫助預測。

最簡單的方法是計算概率向量 $ p = \text{softmax}(V^Tu) $ 和 $ c = \sum_i p_i v_i $ 在哪裡 $ V $ 是所有先前的串聯 $ v_i $ . 一個常見的查找向量 $ u $ 是當前隱藏狀態 $ h_t $ .

這有很多變化,您可以根據需要使事情變得複雜。例如,改為使用 $ v_i^T u $ 作為logits,可以選擇 $ f(v_i, u) $ 相反,在哪裡 $ f $ 是一個任意的神經網絡。

序列到序列模型的常見註意機制使用 $ p = \text{softmax}(q^T \tanh(W_1 v_i + W_2 h_t)) $ , 在哪裡 $ v $ 是編碼器的隱藏狀態,並且 $ h_t $ 是解碼器的當前隱藏狀態。 $ q $ 和兩者 $ W $ s 是參數。

一些論文展示了注意力概念的不同變化:

指針網絡使用對參考輸入的注意力來解決組合優化問題。

循環實體網絡在閱讀文本時為不同的實體(人/對象)維護單獨的記憶狀態,並使用注意力更新正確的記憶狀態。

變壓器模型也廣泛使用注意力。他們對注意力的表述稍微更籠統,並且還涉及關鍵向量 $ k_i $ :注意力權重 $ p $ 實際上是在鍵和查找之間計算的,然後使用 $ v_i $ .


這是一種注意力形式的快速實現,儘管除了它通過了一些簡單的測試之外,我不能保證正確性。

基本 RNN:

def rnn(inputs_split):
   bias = tf.get_variable('bias', shape = [hidden_dim, 1])
   weight_hidden = tf.tile(tf.get_variable('hidden', shape = [1, hidden_dim, hidden_dim]), [batch, 1, 1])
   weight_input = tf.tile(tf.get_variable('input', shape = [1, hidden_dim, in_dim]), [batch, 1, 1])

   hidden_states = [tf.zeros((batch, hidden_dim, 1), tf.float32)]
   for i, input in enumerate(inputs_split):
       input = tf.reshape(input, (batch, in_dim, 1))
       last_state = hidden_states[-1]
       hidden = tf.nn.tanh( tf.matmul(weight_input, input) + tf.matmul(weight_hidden, last_state) + bias )
       hidden_states.append(hidden)
   return hidden_states[-1]

注意,我們在計算新的隱藏狀態之前只添加了幾行代碼:

       if len(hidden_states) > 1:
           logits = tf.transpose(tf.reduce_mean(last_state * hidden_states[:-1], axis = [2, 3]))
           probs = tf.nn.softmax(logits)
           probs = tf.reshape(probs, (batch, -1, 1, 1))
           context = tf.add_n([v * prob for (v, prob) in zip(hidden_states[:-1], tf.unstack(probs, axis = 1))])
       else:
           context = tf.zeros_like(last_state)

       last_state = tf.concat([last_state, context], axis = 1)

       hidden = tf.nn.tanh( tf.matmul(weight_input, input) + tf.matmul(weight_hidden, last_state) + bias )

完整代碼

引用自:https://stats.stackexchange.com/questions/344508

comments powered by Disqus