一個簡單的邏輯回歸模型如何在 MNIST 上實現 92% 的分類準確率?
儘管 MNIST 數據集中的所有圖像都是居中的,具有相似的比例,並且沒有旋轉,但它們具有顯著的筆跡變化,這讓我感到困惑,線性模型如何實現如此高的分類精度。
據我所知,考慮到明顯的手寫變化,數字在 784 維空間中應該是線性不可分的,即,應該有一點複雜(雖然不是很複雜)的非線性邊界來分隔不同的數字,類似於被廣泛引用的 $ XOR $ 任何線性分類器都不能區分正類和負類的例子。多類邏輯回歸如何在完全線性特徵(沒有多項式特徵)的情況下產生如此高的準確度,這讓我感到莫名其妙。
例如,給定圖像中的任何像素,數字的不同手寫變化 $ 2 $ 和 $ 3 $ 可以使該像素發光或不發光。因此,通過一組學習的權重,每個像素可以使一個數字看起來像 $ 2 $ 以及ASA $ 3 $ . 只有通過像素值的組合才能判斷一個數字是否是 $ 2 $ 或一個 $ 3 $ . 大多數數字對都是如此。那麼,邏輯回歸如何盲目地將其決策獨立地基於所有像素值(根本不考慮任何像素間依賴關係),能夠實現如此高的精度。
我知道我在某個地方錯了,或者只是高估了圖像的變化。但是,如果有人可以幫助我直觀地了解數字如何“幾乎”線性分離,那就太好了。
tl;dr即使這是一個圖像分類數據集,它仍然是一項非常簡單的任務,人們可以輕鬆地找到從輸入到預測的直接映射。
回答:
這是一個非常有趣的問題,並且由於邏輯回歸的簡單性,您實際上可以找到答案。
邏輯回歸所做的是對每個圖像接受 $ 784 $ 輸入並將它們與權重相乘以生成其預測。有趣的是,由於輸入和輸出之間的直接映射(即沒有隱藏層),每個權重的值對應於每個權重的多少 $ 784 $ 在計算每個類別的概率時會考慮輸入。現在,通過獲取每個類的權重並將它們重塑為 $ 28 \times 28 $ (即圖像分辨率),我們可以知道哪些像素對每個類的計算最重要。
再次注意,這些是權重。
現在看一下上面的圖像並關注前兩位數(即零和一)。藍色權重意味著該像素的強度對該類的貢獻很大,紅色值意味著它的貢獻是負面的。
現在想像一下,一個人如何畫一個 $ 0 $ ? 他畫了一個圓形,中間是空的。這正是重量增加的原因。事實上,如果有人畫了圖像的中間,它就會被視為負數為零。因此,要識別零,您不需要一些複雜的過濾器和高級功能。您可以只看繪製的像素位置並據此進行判斷。
同樣的事情 $ 1 $ . 它總是在圖像中間有一條垂直的直線。其他的都是負數。
其餘的數字稍微複雜一些,但你可以想像一下 $ 2 $ , 這 $ 3 $ , 這 $ 7 $ 和 $ 8 $ . 其餘的數字有點困難,這實際上限制了邏輯回歸達到 90 年代的高位。
通過這一點,您可以看到邏輯回歸很有可能獲得很多正確的圖像,這就是它得分如此之高的原因。
重現上圖的代碼有點過時了,但在這裡:
import tensorflow as tf import matplotlib.pyplot as plt from tensorflow.examples.tutorials.mnist import input_data # Load MNIST: mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) # Create model x = tf.placeholder(tf.float32, shape=(None, 784)) y = tf.placeholder(tf.float32, shape=(None, 10)) W = tf.Variable(tf.zeros((784,10))) b = tf.Variable(tf.zeros((10))) z = tf.matmul(x, W) + b y_hat = tf.nn.softmax(z) cross_entropy = tf.reduce_mean(-tf.reduce_sum(y * tf.log(y_hat), reduction_indices=[1])) optimizer = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) # correct_pred = tf.equal(tf.argmax(y_hat, 1), tf.argmax(y, 1)) accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) # Train model batch_size = 64 with tf.Session() as sess: loss_tr, acc_tr, loss_ts, acc_ts = [], [], [], [] sess.run(tf.global_variables_initializer()) for step in range(1, 1001): x_batch, y_batch = mnist.train.next_batch(batch_size) sess.run(optimizer, feed_dict={x: x_batch, y: y_batch}) l_tr, a_tr = sess.run([cross_entropy, accuracy], feed_dict={x: x_batch, y: y_batch}) l_ts, a_ts = sess.run([cross_entropy, accuracy], feed_dict={x: mnist.test.images, y: mnist.test.labels}) loss_tr.append(l_tr) acc_tr.append(a_tr) loss_ts.append(l_ts) acc_ts.append(a_ts) weights = sess.run(W) print('Test Accuracy =', sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})) # Plotting: for i in range(10): plt.subplot(2, 5, i+1) weight = weights[:,i].reshape([28,28]) plt.title(i) plt.imshow(weight, cmap='RdBu') # as noted by @Eric Duminil, cmap='gray' makes the numbers stand out more frame1 = plt.gca() frame1.axes.get_xaxis().set_visible(False) frame1.axes.get_yaxis().set_visible(False)