Loss-Functions

為什麼在非二進制數據的自動編碼器中使用二進制交叉熵(或對數損失)

  • February 26, 2019

我正在研究用於非二進制數據的自動編碼器,在[0,1]探索現有解決方案時,我注意到很多人(例如,關於自動編碼器的keras 教程這個人)在這種情況下使用二進制交叉熵作為損失函數。雖然自動編碼器工作,但它會產生略微模糊的重建,其中有很多原因可能是因為非二進制數據的二進制交叉熵對 0 和 1 的誤差比對 0.5 的誤差更懲罰(正如這裡很好解釋的那樣)。

例如,給定真值為 0.2,自動編碼器 A 預測為 0.1,而自動編碼器 2 預測為 0.3。A 的損失為

−(0.2 * log(0.1) + (1−0.2) * log(1−0.2)) = .27752801

而 B 的損失為

−(0.2 * log(0.3) + (1−0.2) * log(1−0.3)) = .228497317

因此,B 被認為是比 A 更好的重建;如果我一切都正確。但這對我來說並不完全有意義,因為我不確定為什麼非對稱比其他對稱損失函數(如 MSE)更受歡迎。

這段視頻中,Hugo Larochelle 認為最小值仍將處於完美重建點,但損失永遠不會為零(這是有道理的)。這在這個優秀的答案中得到了進一步的解釋,這證明了為什麼[0,1]當預測等於真實值時給出了非二進制值的二進制交叉熵的最小值。

所以,我的問題是:為什麼二進制交叉熵用於非二進制值[0,1],為什麼與其他對稱損失函數(如 MSE、MAE 等)相比,非對稱損失是可以接受的?它是否有更好的損失情況,即它是凸的而其他不是,還是有其他原因?

你的問題啟發了我從數學分析的角度來看待損失函數。這是一個免責聲明——我的背景是物理學,而不是統計學。

讓我們重寫 $ -loss $ 作為 NN 輸出的函數 $ x $ 並找到它的導數:

$ f(x) = a \ln x + (1-a) \ln (1-x) $

$ f'(x) = \frac{a-x}{x(1-x)} $

在哪裡 $ a $ 是目標值。現在我們把 $ x = a + \delta $ 並假設 $ \delta $ 很小,我們可以忽略條款 $ \delta^2 $ 為了清楚起見:

$ f'(\delta) = \frac{\delta}{a(a-1) + \delta(2a-1)} $

這個等式讓我們直觀地了解損失的行為方式。當目標值 $ a $ 是(接近)零或一,導數是常數 $ -1 $ 或者 $ +1 $ . 為了 $ a $ 大約 0.5 導數是線性的 $ \delta $ .

換句話說,在反向傳播過程中,這種損失更關心非常亮和非常暗的像素,但在優化中間色調上投入的精力較少。

關於不對稱性 - 當 NN 遠非最佳時,這可能並不重要,因為您會更快或更慢地收斂。當NN接近最優時( $ \delta $ 小)不對稱消失。

引用自:https://stats.stackexchange.com/questions/394582

comments powered by Disqus