當我在邏輯回歸設置中使用平方損失時,這裡發生了什麼?
我正在嘗試使用平方損失對玩具數據集進行二元分類。
我正在使用
mtcars
數據集,使用每加侖英里數和重量來預測傳輸類型。下圖顯示了兩種不同顏色的傳輸類型數據,以及不同損失函數生成的決策邊界。平方損失為 $ \sum_i (y_i-p_i)^2 $ 在哪裡 $ y_i $ 是地面實況標籤(0 或 1)和 $ p_i $ 是預測概率 $ p_i=\text{Logit}^{-1}(\beta^Tx_i) $ . 換句話說,我在分類設置中將邏輯損失替換為平方損失,其他部分相同。對於一個帶有
mtcars
數據的玩具示例,在許多情況下,我得到了一個與邏輯回歸“相似”的模型(參見下圖,隨機種子為 0)。但在某些情況下(如果我們這樣做
set.seed(1)
),平方損失似乎效果不佳。 這裡發生了什麼?優化不收斂?與平方損失相比,邏輯損失更容易優化?任何幫助,將不勝感激。
代碼
d=mtcars[,c("am","mpg","wt")] plot(d$mpg,d$wt,col=factor(d$am)) lg_fit=glm(am~.,d, family = binomial()) abline(-lg_fit$coefficients[1]/lg_fit$coefficients[3], -lg_fit$coefficients[2]/lg_fit$coefficients[3]) grid() # sq loss lossSqOnBinary<-function(x,y,w){ p=plogis(x %*% w) return(sum((y-p)^2)) } # this random seed is important for reproducing the problem set.seed(0) x0=runif(3) x=as.matrix(cbind(1,d[,2:3])) y=d$am opt=optim(x0, lossSqOnBinary, method="BFGS", x=x,y=y) abline(-opt$par[1]/opt$par[3], -opt$par[2]/opt$par[3], lty=2) legend(25,5,c("logisitc loss","squared loss"), lty=c(1,2))
似乎您已經在特定示例中解決了這個問題,但我認為仍然值得更仔細地研究最小二乘和最大似然邏輯回歸之間的差異。
讓我們得到一些符號。讓 $ L_S(y_i, \hat y_i) = \frac 12(y_i - \hat y_i)^2 $ 和 $ L_L(y_i, \hat y_i) = y_i \log \hat y_i + (1 - y_i) \log(1 - \hat y_i) $ . 如果我們正在做最大似然(或我在這裡所做的最小負對數似然),我們有 $$ \hat \beta_L := \text{argmin}{b \in \mathbb R^p} -\sum{i=1}^n y_i \log g^{-1}(x_i^T b) + (1-y_i)\log(1 - g^{-1}(x_i^T b)) $$ 和 $ g $ 作為我們的鏈接功能。
或者我們有 $$ \hat \beta_S := \text{argmin}{b \in \mathbb R^p} \frac 12 \sum{i=1}^n (y_i - g^{-1}(x_i^T b))^2 $$ 作為最小二乘解。因此 $ \hat \beta_S $ 最小化 $ L_S $ 同樣對於 $ L_L $ .
讓 $ f_S $ 和 $ f_L $ 是對應於最小化的目標函數 $ L_S $ 和 $ L_L $ 分別為 $ \hat \beta_S $ 和 $ \hat \beta_L $ . 最後,讓 $ h = g^{-1} $ 所以 $ \hat y_i = h(x_i^T b) $ . 請注意,如果我們使用的是我們得到的規範鏈接 $$ h(z) = \frac{1}{1+e^{-z}} \implies h'(z) = h(z) (1 - h(z)). $$
對於常規邏輯回歸,我們有 $$ \frac{\partial f_L}{\partial b_j} = -\sum_{i=1}^n h'(x_i^T b)x_{ij} \left( \frac{y_i}{h(x_i^T b)} - \frac{1-y_i}{1 - h(x_i^T b)}\right). $$ 使用 $ h' = h \cdot (1 - h) $ 我們可以將其簡化為 $$ \frac{\partial f_L}{\partial b_j} = -\sum_{i=1}^n x_{ij} \left( y_i(1 - \hat y_i) - (1-y_i)\hat y_i\right) = -\sum_{i=1}^n x_{ij}(y_i - \hat y_i) $$ 所以 $$ \nabla f_L(b) = -X^T (Y - \hat Y). $$
接下來讓我們做二階導數。黑森州
$$ H_L:= \frac{\partial^2 f_L}{\partial b_j \partial b_k} = \sum_{i=1}^n x_{ij} x_{ik} \hat y_i (1 - \hat y_i). $$ 這意味著 $ H_L = X^T A X $ 在哪裡 $ A = \text{diag} \left(\hat Y (1 - \hat Y)\right) $ . $ H_L $ 確實取決於當前的擬合值 $ \hat Y $ 但 $ Y $ 已經退出,並且 $ H_L $ 是PSD。因此我們的優化問題是凸的 $ b $ .
讓我們將其與最小二乘法進行比較。
$$ \frac{\partial f_S}{\partial b_j} = - \sum_{i=1}^n (y_i - \hat y_i) h'(x^T_i b)x_{ij}. $$
這意味著我們有 $$ \nabla f_S(b) = -X^T A (Y - \hat Y). $$ 這是一個關鍵點:梯度幾乎相同,除了所有 $ i $ $ \hat y_i (1 - \hat y_i) \in (0,1) $ 所以基本上我們將漸變相對於 $ \nabla f_L $ . 這會使收斂速度變慢。
對於 Hessian,我們可以先寫 $$ \frac{\partial f_S}{\partial b_j} = - \sum_{i=1}^n x_{ij}(y_i - \hat y_i) \hat y_i (1 - \hat y_i) = - \sum_{i=1}^n x_{ij}\left( y_i \hat y_i - (1+y_i)\hat y_i^2 + \hat y_i^3\right). $$
這導致我們 $$ H_S:=\frac{\partial^2 f_S}{\partial b_j \partial b_k} = - \sum_{i=1}^n x_{ij} x_{ik} h'(x_i^T b) \left( y_i - 2(1+y_i)\hat y_i + 3 \hat y_i^2 \right). $$
讓 $ B = \text{diag} \left( y_i - 2(1+y_i)\hat y_i + 3 \hat y_i ^2 \right) $ . 我們現在有 $$ H_S = -X^T A B X. $$
對我們來說不幸的是,權重 $ B $ 不保證為非負數:如果 $ y_i = 0 $ 然後 $ y_i - 2(1+y_i)\hat y_i + 3 \hat y_i ^2 = \hat y_i (3 \hat y_i - 2) $ 這是正的iff $ \hat y_i > \frac 23 $ . 同樣,如果 $ y_i = 1 $ 然後 $ y_i - 2(1+y_i)\hat y_i + 3 \hat y_i ^2 = 1-4 \hat y_i + 3 \hat y_i^2 $ 當 $ \hat y_i < \frac 13 $ (這也是積極的 $ \hat y_i > 1 $ 但這是不可能的)。這意味著 $ H_S $ 不一定是 PSD,所以我們不僅壓縮了梯度,這會使學習更加困難,而且我們還搞砸了問題的凸性。
總而言之,最小二乘邏輯回歸有時會遇到困難也就不足為奇了,在您的示例中,您有足夠的擬合值接近 $ 0 $ 或者 $ 1 $ 以便 $ \hat y_i (1 - \hat y_i) $ 可以非常小,因此漸變非常平坦。
將其連接到神經網絡,儘管這只是一個不起眼的邏輯回歸,我認為平方損失你正在經歷 Goodfellow、Bengio 和 Courville 在他們的深度學習書中所指的內容,當時他們寫了以下內容:
整個神經網絡設計中反復出現的一個主題是成本函數的梯度必須足夠大且可預測,以作為學習算法的良好指南。飽和(變得非常平坦)的函數破壞了這個目標,因為它們使梯度變得非常小。在許多情況下,這是因為用於產生隱藏單元或輸出單元輸出的激活函數飽和。對於許多模型,負對數似然有助於避免這個問題。許多輸出單元包含一個 exp 函數,當它的參數非常負時,它會飽和。負對數似然成本函數中的對數函數撤消了一些輸出單元的 exp。我們將在第二節討論成本函數和輸出單元選擇之間的相互作用。6.2.2.
並且,在 6.2.2 中,
不幸的是,當與基於梯度的優化一起使用時,均方誤差和平均絕對誤差通常會導致較差的結果。一些飽和的輸出單元在與這些成本函數結合時會產生非常小的梯度。這是交叉熵成本函數比均方誤差或平均絕對誤差更受歡迎的原因之一,即使不需要估計整個分佈 $ p(y|x) $ .
(兩個節選都來自第 6 章)。