變分自動編碼器的反向傳播
再次,在線教程深入描述了變分自編碼器 (VAE) 的統計解釋;然而,我發現這個算法的實現有很大的不同,並且與常規的神經網絡相似。
網上典型的 vae 圖片是這樣的:
作為一個愛好者,我發現這個解釋非常混亂,尤其是在主題介紹在線帖子中。
無論如何,首先讓我嘗試解釋一下我如何理解常規前饋神經網絡上的反向傳播。
例如,導數的鍊式法則 E (總誤差)相對於重量 w1 如下:
∂E∂W1=∂E∂HA1…∂HA1∂H1∂H1∂w1
現在讓我們看看VAE等價物併計算鍊式法則 E (總誤差)相對於重量 W16 (只是編碼器端的任意權重 - 它們都是相同的)。
請注意,編碼器端的每個權重,包括 w16 ,取決於解碼器端的所有連接;因此,突出顯示的連接。鍊式規則如下所示:
$$ \frac{\partial E}{\partial w_{16}} = \frac{\partial E}{\partial OA_1} \frac{\partial OA_1}{\partial O_1} \frac{\partial O_1}{\partial HA_4} \frac{\partial HA_4}{\partial H_4} \color{red}{\frac{\partial H_4}{\partial Z} \frac{\partial Z}{\partial \mu} \frac{\partial \mu}{\partial w_{16}}} \
- \frac{\partial E}{\partial OA_2}… \
- \frac{\partial E}{\partial OA_3}… \
- \frac{\partial E}{\partial OA_4}… \ $$
請注意,紅色部分是我不打算在這裡介紹的重新參數化技巧。
但是,這還不是全部——假設對於常規神經網絡,批次等於 1——算法如下所示:
- 傳遞輸入並執行前饋傳遞。
- 計算總誤差並為網絡中的每個權重取導數
- 更新網絡權重並重複…
然而,在 VAE 中,算法有點不同:
- 傳遞輸入並為編碼器執行前饋並停止。
- 對潛在空間進行採樣 ( Z ) 說 n -times 並使用採樣的隨機變量執行前饋步驟 n -次
- 計算所有輸出和样本的總誤差,並對網絡中的每個權重求導
- 更新網絡權重並重複…
好的,好的,是的,我的問題是什麼!
問題 1
我對 VAE 的描述是否正確?
問題2
我將嘗試一步一步地走過潛在空間的採樣 (Z) 和象徵性的反向傳播。
讓我們假設 VAE 輸入是一個一維數組(所以即使它是一個圖像 - 它已經被展平了)。此外,潛在空間 (Z) 是一維的;因此,它包含一個單一的平均值 (μ) 和 std.var (σ) 假設正態分佈。
- 為簡單起見,讓單個輸入的誤差 xi 是 ei=(xi−¯xi) 在哪裡 ¯xi 是等效的 vae 輸出。
- 另外,讓我們假設有 m 此 vae 示例中的輸入和輸出。
- 最後讓我們假設 mini-batch 是一個,所以我們在 wach backprop 之後更新權重;因此,我們不會看到小批量 b 梯度公式中的索引。
在常規的前饋神經網絡中,給定上述設置,總誤差如下所示:
E=1mm∑i=1ei
因此,從上面的例子中,
∂E∂w1=∂(1m∑mi=1ei)∂w1
並使用梯度下降輕鬆更新權重。非常直截了當。請注意,我們有每個偏導數的單個值,即: ∂HA1∂H1 - 這是一個重要的區別。
選項1
現在對於 VAE,如在線帖子中所述,我們必須採樣 n 來自潛在空間的時間,以獲得良好的期望表示。
所以給定上面的例子和假設,總誤差 n 樣品和 m 輸出是:
E=1n1mn∑i=im∑j=1eij
如果我理解正確 - 我們必須至少有 n 樣本以求導數 ∂E∂w16 . 在一個樣本中取導數(反向傳播)是沒有意義的。
因此,在 VAE 中,導數將如下所示:
∂E∂w16=∂(1n1m∑ni=i∑mj=1eij)∂w16
這意味著在導數鏈中,我們必須計算並添加變量或函數的導數 n 時間即:
…∂Z1∂μ+…+∂Z2∂μ+…∂Zn∂μ
最後,我們使用梯度體面更新權重:
wk+116=wk16−η∂E∂w16
選項 2
我們保持總誤差公式與常規神經網絡中的相同,只是現在我們必須索引,因為我們最終會得到 n 其中:
Ei=1mm∑j=1ej
並在每個潛在空間樣本之後進行反向傳播 Z 但不要更新權重:
∂Ei∂w16=∂(1m∑mj=1ej)∂w16
即:現在我們只有一個 z - 鏈中的導數不像 n 在選項 1
…∂Z∂μ+…
最後通過平均梯度來更新權重:
wk+116=wk16−ηnn∑i=1∂Ei∂w16
那麼在問題 2 中 -選項 1或選項 2是否正確?我錯過了什麼嗎?
太感謝了!
Q1:您的描述似乎非常正確。
Q2:這兩個選項是相等的:
∂E∂w=∂1n∑ni=1Ei∂w=1nn∑i=1∂Ei∂w
另外,請注意 n=1 是一個有效的選擇:
在我們的實驗中,我們發現樣本數 L 只要小批量大小,每個數據點都可以設置為 1 M 足夠大,例如 M=100 .
Kingma、Diederik P. 和 Max Welling。“自動編碼變分貝葉斯。” arXiv 預印本 arXiv:1312.6114 (2013)。