Neural-Networks

RNN 或 LSTM 中的可變重要性

  • January 21, 2016

已經設計了幾種方法來訪問或量化 MLP 神經網絡模型中的變量重要性(即使只是相對於彼此):

  • 連接權重
  • 加森算法
  • 偏導數
  • 輸入擾動
  • 敏感性分析
  • 正向逐步加法
  • 向後逐步消除
  • 改進的逐步選擇 1
  • 改進的逐步選擇 2

(這些在http://dx.doi.org/10.1016/j.ecolmodel.2004.03.013中有描述)

是否有任何方法可以應用於 RNN 或 LSTM 神經網絡?

簡而言之,的,您可以對基於 RNN 的模型的變量重要性進行一些度量。我不會遍歷問題中列出的所有建議,但我將深入介紹敏感性分析的示例。

數據

我的 RNN 的輸入數據將由具有三個特徵的時間序列組成, $ x_1 $ , $ x_2 $ , $ x_3 $ . 每個特徵都將從隨機均勻分佈中抽取。我的 RNN 的目標變量將是一個時間序列(我輸入中每個時間步長的一個預測):

$$ y = \left{\begin{array}{lr} 0, & \text{if } x_1 x_2 \geq 0.25\ 1, & \text{if } x_1 x_2 < 0.25 \end{array}\right. $$

正如我們所見,目標僅依賴於前兩個特徵。因此,一個好的變量重要性度量應該顯示前兩個變量很重要,而第三個變量不重要。

該模型

該模型是一個簡單的三層 LSTM,在最後一層有一個 sigmoid 激活。該模型將在 5 個 epoch 中進行訓練,每個 epoch 有 1000 個批次。

可變重要性

為了衡量變量的重要性,我們將對數據進行大樣本(250 個時間序列) $ \hat{x} $ 併計算模型的預測 $ \hat{y} $ . 然後,對於每個變量 $ x_i $ 我們將通過以 0 為中心、尺度為 0.2 的隨機正態分佈擾動該變量(並且僅該變量)併計算預測 $ \hat{y_i} $ . 我們將通過計算原始數據之間的均方根差來測量這種擾動的影響 $ \hat{y} $ 和不安的 $ \hat{y_i} $ . 較大的均方根差異意味著變量“更重要”。

顯然,用於擾動數據的確切機制,以及如何測量擾動和未擾動輸出之間的差異,將高度依賴於您的特定數據集。

結果

完成上述所有操作後,我們看到以下重要性:

Variable 1, perturbation effect: 0.1162
Variable 2, perturbation effect: 0.1185
Variable 3, perturbation effect: 0.0077

正如我們所料,發現變量 1 和 2 比變量 3 更重要(大約 15 倍)!

重現的 Python 代碼

from tensorflow import keras  # tensorflow v1.14.0 was used
import numpy as np            # numpy v1.17.1 was used

np.random.seed(2019)

def make_model():
   inp = keras.layers.Input(shape=(10, 3))
   x = keras.layers.LSTM(10, activation='relu', return_sequences=True)(inp)
   x = keras.layers.LSTM(5, activation='relu', return_sequences=True)(x)
   x = keras.layers.LSTM(1, activation='sigmoid', return_sequences=True)(x)
   out = keras.layers.Flatten()(x)
   return keras.models.Model(inp, out)

def data_gen():
   while True:
       x = np.random.rand(5, 10, 3)  # batch x time x features
       yield x, x[:, :, 0] * x[:, :, 1] < 0.25

def var_importance(model):
   g = data_gen()
   x = np.concatenate([next(g)[0] for _ in range(50)]) # Get a sample of data
   orig_out = model.predict(x)
   for i in range(3):  # iterate over the three features
       new_x = x.copy()
       perturbation = np.random.normal(0.0, 0.2, size=new_x.shape[:2])
       new_x[:, :, i] = new_x[:, :, i] + perturbation
       perturbed_out = model.predict(new_x)
       effect = ((orig_out - perturbed_out) ** 2).mean() ** 0.5
       print(f'Variable {i+1}, perturbation effect: {effect:.4f}')

def main():
   model = make_model()
   model.compile('adam', 'binary_crossentropy')
   print(model.summary())
   model.fit_generator(data_gen(), steps_per_epoch=500, epochs=10)
   var_importance(model)

if __name__ == "__main__":
   main()

代碼的輸出:

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_1 (InputLayer)         [(None, 10, 3)]           0
_________________________________________________________________
lstm (LSTM)                  (None, 10, 10)            560
_________________________________________________________________
lstm_1 (LSTM)                (None, 10, 5)             320
_________________________________________________________________
lstm_2 (LSTM)                (None, 10, 1)             28
_________________________________________________________________
flatten (Flatten)            (None, 10)                0
=================================================================
Total params: 908
Trainable params: 908
Non-trainable params: 0
_________________________________________________________________
Epoch 1/5
1000/1000 [==============================] - 14s 14ms/step - loss: 0.6261
Epoch 2/5
1000/1000 [==============================] - 12s 12ms/step - loss: 0.4901
Epoch 3/5
1000/1000 [==============================] - 13s 13ms/step - loss: 0.4631
Epoch 4/5
1000/1000 [==============================] - 14s 14ms/step - loss: 0.4480
Epoch 5/5
1000/1000 [==============================] - 14s 14ms/step - loss: 0.4440
Variable 1, perturbation effect: 0.1162
Variable 2, perturbation effect: 0.1185
Variable 3, perturbation effect: 0.0077

引用自:https://stats.stackexchange.com/questions/191855

comments powered by Disqus