Optimization

反向模式自動微分的逐步示例

  • July 16, 2016

不確定這個問題是否屬於這裡,但它與優化中的梯度方法密切相關,這似乎是這裡的主題。無論如何,如果您認為其他社區在該主題上有更好的專業知識,請隨時遷移。

簡而言之,我正在尋找反向模式自動微分的逐步示例。沒有太多關於該主題的文獻,現有的實現(如TensorFlow中的實現)在不了解其背後的理論的情況下很難理解。因此,如果有人能詳細說明我們傳入了什麼、我們如何處理它以及我們計算圖中得到了什麼,我將非常感激。

我最難解決的幾個問題:

  • 種子——我們為什麼需要它們?
  • 反向微分規則- 我知道如何進行正向微分,但我們如何倒退?例如,在本節的示例中,我們如何知道?
  • 我們是只使用符號還是傳遞實際?例如,在同一個例子中,是和符號或值?

假設我們有表達 z=x1x2+sin(x1) 並想找到衍生品 dzdx1dzdx2 . 反向模式 AD 將此任務分為兩部分,即正向和反向傳遞。

前傳

首先,我們將復雜表達式分解為一組原始表達式,即最多包含單個函數調用的表達式。請注意,我還重命名了輸入和輸出變量以保持一致性,儘管這不是必需的:

w1=x1

w2=x2
w3=w1w2
w4=sin(w1)
w5=w3+w4
z=w5

這種表示的優點是每個單獨的表達式的微分規則是已知的。例如,我們知道 sincos , 所以 dw4dw1=cos(w1) . 我們將在下面的反向傳遞中使用這個事實。

本質上,前向傳遞包括評估每個表達式並保存結果。比如說,我們的輸入是: x1=2x2=3 . 然後我們有:

w1=x1=2

w2=x2=3
w3=w1w2=6
w4=sin(w1) =0.9
w5=w3+w4=6.9
z=w5=6.9

反向傳球

這是神奇的開始,它從鍊式法則開始。鍊式法則的基本形式是,如果你有變量 t(u(v)) 這取決於 u 而這又取決於 v , 然後:

dtdv=dtdududv

或者如果 t 依賴於取決於 v 通過多個路徑/變量 ui ,例如:

u1=f(v)

u2=g(v)
t=h(u1,u2)

然後(參見此處的證明):

dtdv=idtduiduidv

就表達式圖而言,如果我們有一個最終節點 z 和輸入節點 wi , 和路徑 zwi 通過中間節點 wp (IE z=g(wp) 在哪裡 wp=f(wi) ),我們可以求導 dzdwi 作為

dzdwi=pparents(i)dzdwpdwpdwi

換句話說,計算輸出變量的導數 z wrt任何中間或輸入變量 wi ,我們只需要知道它的雙親的導數和計算原始表達式導數的公式 wp=f(wi) .

反向傳球從最後開始(即 dzdz ) 並向後傳播到所有依賴項。這裡我們有(“種子”的表達):

dzdz=1

這可以理解為“改變 z 導致完全相同的變化 z “,這是很明顯的。

然後我們知道 z=w5 所以:

dzdw5=1

w5 線性依賴於 w3w4 , 所以 dw5dw3=1dw5dw4=1 . 使用鍊式法則我們發現:

dzdw3=dzdw5dw5dw3=1×1=1

dzdw4=dzdw5dw5dw4=1×1=1

從定義 w3=w1w2 和偏導數規則,我們發現 dw3dw2=w1 . 因此:

dzdw2=dzdw3dw3dw2=1×w1=w1

正如我們從前向傳球中已經知道的那樣,它是:

dzdw2=w1=2

最後, w1 有助於 z 通過 w3w4 . 再一次,從偏導數的規則我們知道 dw3dw1=w2dw4dw1=cos(w1) . 因此:

dzdw1=dzdw3dw3dw1+dzdw4dw4dw1=w2+cos(w1)

同樣,給定已知的輸入,我們可以計算它:

dzdw1=w2+cos(w1)=3+cos(2) =2.58

自從 w1w2 只是別名 x1x2 ,我們得到答案:

dzdx1=2.58

dzdx2=2

就是這樣!


此描述僅涉及標量輸入,即數字,但實際上它也可以應用於多維數組,例如向量和矩陣。在用此類對象區分錶達式時應牢記兩件事:

  1. 導數可能比輸入或輸出具有更高的維度,例如向量 wrt 的導數是一個矩陣,而矩陣 wrt 的導數是一個 4 維數組(有時稱為張量)。在許多情況下,此類導數非常稀疏。
  2. 輸出數組中的每個分量都是輸入數組的一個或多個分量的獨立函數。例如,如果 y=f(x) 和兩者 xy 是向量, yi 從不依賴 yj ,但僅在 xk . 特別是,這意味著找到導數 dyidxj 歸結為跟踪如何 yi 依賴於取決於 xj .

自動微分的強大之處在於它可以處理來自編程語言的複雜結構,例如條件和循環。但是,如果您只需要代數表達式,並且您有足夠好的框架來處理符號表示,那麼構建完全符號表達式是可能的。事實上,在這個例子中,我們可以產生表達式 dzdw1=w2+cos(w1)=x2+cos(x1) 並為我們想要的任何輸入計算這個導數。

引用自:https://stats.stackexchange.com/questions/224140