softmax - cccbook/py2gpt GitHub Wiki
分類問題以 CrossEntropy + Softmax 作為損失函數,其反傳遞公式如何計算
神經網路的反傳遞算法,通常用在有標準答案的問題上面,這種問題有兩類:
- Regression (回歸問題)
- Classification (分類問題)
第一種 Regression 回歸問題,輸出是一個實數 (浮點數),像是:
- 明天台積電的收盤股價會是多少錢?
- 明天的平均氣溫會是攝氏幾度?
第二種的 Classification 分類問題,輸出通常是個整數,是從 n 類中選 1 類的問題,像是:
- GPT 預測下一個詞,是從所有可能的詞當中選一個 (一萬個詞就是一萬類,選其中一類)
- MNIST 手寫數字辨識,是從十種數字當中選一個 (十類選一類)
而對於 Regression (回歸問題),您可以參考下列文章
對於分類問題,我們通常會用 Softmax + CrossEntropy 兩個函數串接,作為最後的損失函數。
假如果 Softmax 的輸入是 x, 而正確答案是 y ,那麼整個損失函數可以寫成下列算式
def loss(y, x):
s = softmax(x)
return cross_entropy(y, s)
問題是,這樣的函數,其 x 的反傳遞公式是什麼呢?
答案出奇的簡單,就是 s - y ,其中的 s = softmax(x)
為何是這樣呢?
請看下列文章
但是這些關於 Jocobian 的數學我看得似懂非懂,為了驗證這個公式是對的,於是我寫了一個程式來驗證,網址如下:
該驗證程式的輸出為
$ python test1.py
x = [0.3 0.5 0.2]
y = [0. 1. 0.]
s = softmax(x) = [0.31987306 0.39069383 0.28943311]
jacobian_softmax(s)=
[[ 0.21755428 -0.12497243 -0.09258185]
[-0.12497243 0.23805216 -0.11307973]
[-0.09258185 -0.11307973 0.20566159]]
cross_entropy(y, s)= [0.93983106]
gradient_cross_entropy(y, s)= [-0. -2.55954897 -0. ]
num_gradient_cross_entropy(y, s)= [ 0. -2.55627891 0. ]
error_softmax_input(y, s)= [ 0.31987306 -0.60930617 0.28943311]
num_error_softmax_input(y, x)= [ 0.31998185 -0.60918713 0.28953596]
其中最後的 error_softmax_input(y, s) 就是 s - y
def error_softmax_input(y, s):
return s - y
而 num_error_softmax_input(y,x) 則是我們用 numgrad 套件的 grad 函數,用數值方法計算出來的梯度。
def loss(y, x):
s = so.softmax(x)
return so.cross_entropy(y, s)
def num_error_softmax_input(y, x):
return ngd.grad(lambda x:loss(y, x), x)
你可以看到該測試程式的最後,輸出的 error_softmax_input(y, s) 與 num_error_softmax_input(y, x) 非常接近,因此我們用程式驗證 x 的梯度就是 s-y 這個數學推導出來的公式。
$ python test1.py
x = [0.3 0.5 0.2]
y = [0. 1. 0.]
s = softmax(x) = [0.31987306 0.39069383 0.28943311]
jacobian_softmax(s)=
[[ 0.21755428 -0.12497243 -0.09258185]
[-0.12497243 0.23805216 -0.11307973]
[-0.09258185 -0.11307973 0.20566159]]
cross_entropy(y, s)= [0.93983106]
gradient_cross_entropy(y, s)= [-0. -2.55954897 -0. ]
num_gradient_cross_entropy(y, s)= [ 0. -2.55627891 0. ]
error_softmax_input(y, s)= [ 0.31987306 -0.60930617 0.28943311]
num_error_softmax_input(y, x)= [ 0.31998185 -0.60918713 0.28953596]