Logistic

一個簡單的邏輯回歸模型如何在 MNIST 上實現 92% 的分類準確率?

  • September 11, 2019

儘管 MNIST 數據集中的所有圖像都是居中的,具有相似的比例,並且沒有旋轉,但它們具有顯著的筆跡變化,這讓我感到困惑,線性模型如何實現如此高的分類精度。

據我所知,考慮到明顯的手寫變化,數字在 784 維空間中應該是線性不可分的,即,應該有一點複雜(雖然不是很複雜)的非線性邊界來分隔不同的數字,類似於被廣泛引用的 $ XOR $ 任何線性分類器都不能區分正類和負類的例子。多類邏輯回歸如何在完全線性特徵(沒有多項式特徵)的情況下產生如此高的準確度,這讓我感到莫名其妙。

例如,給定圖像中的任何像素,數字的不同手寫變化 $ 2 $ 和 $ 3 $ 可以使該像素發光或不發光。因此,通過一組學習的權重,每個像素可以使一個數字看起來像 $ 2 $ 以及ASA $ 3 $ . 只有通過像素值的組合才能判斷一個數字是否是 $ 2 $ 或一個 $ 3 $ . 大多數數字對都是如此。那麼,邏輯回歸如何盲目地將其決策獨立地基於所有像素值(根本不考慮任何像素間依賴關係),能夠實現如此高的精度。

我知道我在某個地方錯了,或者只是高估了圖像的變化。但是,如果有人可以幫助我直觀地了解數字如何“幾乎”線性分離,那就太好了。

tl;dr即使這是一個圖像分類數據集,它仍然是一項非常簡單的任務,人們可以輕鬆地找到從輸入到預測的直接映射。


回答:

這是一個非常有趣的問題,並且由於邏輯回歸的簡單性,您實際上可以找到答案。

邏輯回歸所做的是對每個圖像接受 $ 784 $ 輸入並將它們與權重相乘以生成其預測。有趣的是,由於輸入和輸出之間的直接映射(即沒有隱藏層),每個權重的值對應於每個權重的多少 $ 784 $ 在計算每個類別的概率時會考慮輸入。現在,通過獲取每個類的權重並將它們重塑為 $ 28 \times 28 $ (即圖像分辨率),我們可以知道哪些像素對每個類的計算最重要

再次注意,這些是權重

現在看一下上面的圖像並關注前兩位數(即零和一)。藍色權重意味著該像素的強度對該類的貢獻很大,紅色值意味著它的貢獻是負面的。

現在想像一下,一個人如何畫一個 $ 0 $ ? 他畫了一個圓形,中間是空的。這正是重量增加的原因。事實上,如果有人畫了圖像的中間,它就會被視為負數為零。因此,要識別零,您不需要一些複雜的過濾器和高級功能。您可以只看繪製的像素位置並據此進行判斷。

同樣的事情 $ 1 $ . 它總是在圖像中間有一條垂直的直線。其他的都是負數。

其餘的數字稍微複雜一些,但你可以想像一下 $ 2 $ , 這 $ 3 $ , 這 $ 7 $ 和 $ 8 $ . 其餘的數字有點困難,這實際上限制了邏輯回歸達到 90 年代的高位。

通過這一點,您可以看到邏輯回歸很有可能獲得很多正確的圖像,這就是它得分如此之高的原因。


重現上圖的代碼有點過時了,但在這裡:

import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data

# Load MNIST:
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

# Create model
x = tf.placeholder(tf.float32, shape=(None, 784))
y = tf.placeholder(tf.float32, shape=(None, 10))

W = tf.Variable(tf.zeros((784,10)))
b = tf.Variable(tf.zeros((10)))
z = tf.matmul(x, W) + b

y_hat = tf.nn.softmax(z)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y * tf.log(y_hat), reduction_indices=[1]))
optimizer = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) # 

correct_pred = tf.equal(tf.argmax(y_hat, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

# Train model
batch_size = 64
with tf.Session() as sess:

   loss_tr, acc_tr, loss_ts, acc_ts = [], [], [], []

   sess.run(tf.global_variables_initializer()) 

   for step in range(1, 1001):

       x_batch, y_batch = mnist.train.next_batch(batch_size) 
       sess.run(optimizer, feed_dict={x: x_batch, y: y_batch})

       l_tr, a_tr = sess.run([cross_entropy, accuracy], feed_dict={x: x_batch, y: y_batch})
       l_ts, a_ts = sess.run([cross_entropy, accuracy], feed_dict={x: mnist.test.images, y: mnist.test.labels})
       loss_tr.append(l_tr)
       acc_tr.append(a_tr)
       loss_ts.append(l_ts)
       acc_ts.append(a_ts)

   weights = sess.run(W)      
   print('Test Accuracy =', sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})) 

# Plotting:
for i in range(10):
   plt.subplot(2, 5, i+1)
   weight = weights[:,i].reshape([28,28])
   plt.title(i)
   plt.imshow(weight, cmap='RdBu')  # as noted by @Eric Duminil, cmap='gray' makes the numbers stand out more
   frame1 = plt.gca()
   frame1.axes.get_xaxis().set_visible(False)
   frame1.axes.get_yaxis().set_visible(False)

引用自:https://stats.stackexchange.com/questions/426873

comments powered by Disqus