Loss Function - BD-SEARCH/MLtutorial GitHub Wiki
1. Loss Function
- ๋ชจ๋ธ์์ ์์ฑ๋ ๊ฐ๊ณผ ์ค์ ๋ฐ์ดํฐ์ ๊ฐ์ด ์ฐจ์ด๋๋ ์ ๋๋ฅผ ๋ํ๋ด๋ ํจ์
- loss function์ ๊ฐ์ ์ต์ํํ๋ ๋ฐฉ์์ผ๋ก ๋ฅ๋ฌ๋ ๋ชจ๋ธ ํ์ต
- neural network์์์ ์ต์ ํ : output๊ณผ label ์ฐจ์ด๋ฅผ Error๋ก ์ ์ํ ํ, ์ด ๊ฐ์ ์ค์ด๋๋ก parameter๋ฅผ ๋ฐ๊พธ์ด ๋๊ฐ๋ ๊ฒ
2. ์ข ๋ฅ
2-1. Linear regression (MSE, Mean Square Error)
(1) ์ ์
(2) ํน์ง
- ์ฃผ๋ก regression ๋ฌธ์ ์ ์ฌ์ฉํ๋ค.
(3) regression์ MSE๋ฅผ ์ฐ๋ ์ด์
Classification ๊ณผ ๊ฐ์ ๊ฒฝ์ฐ, ๋ง๋ค/์๋๋ค๊ฐ ํ๋ณ์ด ๊ฐ๋ฅํ์ง๋ง, ์ฃผ์ ๊ฐ๊ฒฉ ์์ธก๊ณผ ๊ฐ์ ์์น ํ๋จ์ ์ ๋งคํ ๊ฒฝ์ฐ๊ฐ ๋ง๋ค.
-
ex. ์ฃผ์
- GT: 100,000์, output: 95,000์
- output์ด GT์ ๋์ผํ์ง ์์ง๋ง, ์ด ๋ชจ๋ธ์ด ์ผ๋ง๋ ์ ํ๋จํ ๊ฒ์ธ์ง ์ ๋งคํ๋ค.
- ๋ฐ๋ผ์ ์ค์ ๊ฐ๊ณผ ์์ธก๊ฐ์ ์ฐจ์ด๋ฅผ ๊ธฐ์ค์ผ๋ก ์ค์ฐจ๋ฅผ ํ๋จํด์ผ ํ๋ค.
-
MSE ๊ฐ์ด ์์ ๋ฐ๋์งํ ์ถ์ ๋์ด๋, ๋ถํธ์ฑ๊ณผ ํจ์จ์ฑ์ ๋ง์กฑํ๋ ๊ฐ์ ์๋ฏธํ๋ค.
- ๋ถํธ์ฑ(unbiasedness): ์ถ์ ๋์ ํ๊ท ์ด ๊ฐ๋ฅํ ํ ๋ชจ์์ ํ๊ท ์ ๊ทผ์
- ํจ์จ์ฑ(efficiency): ์ถ์ ๋์ ๋ถ์ฐ์ด ๋์์ ์์์ผ ํจ
-
์ฆ MSE๋ ์ ๋ต์ ๊ณผ์ ํ๊ท ์ ์ธ ์ฐจ์ด ๋ฟ๋ง ์๋๋ผ ๊ฐ๊ฐ์ ์ถ๋ ฅ๊ฐ์ด ์ ๋ต๊ณผ ์ผ๋ง๋ ์ฐจ์ด๊ฐ ํฌ๊ฒ ๋๋์ง๋ ๋ฐ์ํ๋ค.
2-2. Cross-Entropy Error (CEE)
(1) ์ ์
์ฃผ์ด์ง ํ๋ฅ ๋ณ์ X์ ๋ํด, ํ๋ฅ ๋ถํฌ p๋ฅผ ์ฐพ์๋ณด์. ํ๋ฅ ๋ถํฌ p๋ฅผ ์ ์ ์๊ธฐ ๋๋ฌธ์ p๋ฅผ ์์ธกํ ๊ทผ์ฌ ๋ถํฌ q๋ฅผ ์๊ฐํ๋ค. ์ ํํ ํ๋ฅ ๋ถํฌ๋ฅผ ์ป๊ธฐ ์ํด q์ parameter๋ค์ updateํ๋ฉด์ q๋ฅผ p์ ๊ทผ์ฌํ ๊ฒ์ด๋ค. ์ฆ ๋ ๋ถํฌ์ ์ฐจ์ด๋ฅผ ์ธก์ ํ๋ KL(p|q)๊ฐ ์ต์๊ฐ ๋๋ q๋ฅผ ์ฐพ๋ ๋ฌธ์ ๊ฐ ๋๋ค.
๋๋ฒ์งธ ํญ์ ๊ทผ์ฌ ๋ถํฌ q์ ๋ฌด๊ดํ ํญ์ด๋ฏ๋ก KL Divergence๋ฅผ ์ต์ํํ๋ ๊ฒ์ ๊ฒฐ๊ตญ ์ฒซ๋ฒ์งธ ํญ์ด๋ค. ๊ทธ๋ฌ๋ฏ๋ก ์ฒซ๋ฒ์งธ ํญ์ ์ต์ํํ๋ q๋ฅผ ์ฐพ์์ผ ํ๋ค.
p_i
: ์ค์ ํ๋ฅ ๋ถํฌq_i
: p๋ฅผ ๊ทผ์ฌํ ๋ถํฌ
(2) ํน์ง
- classification(๋ถ๋ฅ๋ฌธ์ )์๋ ACE(Average cross-entropy)๋ฅผ ์ฌ์ฉํ๋ค.
(3) classification์ ACE(Average cross-entropy)๋ฅผ ์ฌ์ฉํ๋ ์ด์
Model X, Y๊ฐ ์๊ณ , class๋ A,B,C 3๊ฐ๊ฐ ์๋ค.
Model X์ output
output | label | A | B | C | correct? | |||
---|---|---|---|---|---|---|---|---|
0.3 | 0.3 | 0.4 | 0 | 0 | 1 | Y | ||
0.3 | 0.4 | 0.3 | 0 | 1 | 0 | Y | ||
0.1 | 0.2 | 0.7 | 1 | 0 | 0 | N |
- 1,2๋ ๊ฒจ์ฐ ๋ง์ท๊ณ 3์ ์์ ํ ํ๋ ธ๋ค.
Model Y์ output
output | label | A | B | C | correct? | |||
---|---|---|---|---|---|---|---|---|
0.1 | 0.2 | 0.7 | 0 | 0 | 1 | Y | ||
0.1 | 0.7 | 0.2 | 0 | 1 | 0 | Y | ||
0.3 | 0.4 | 0.3 | 1 | 0 | 0 | N |
- 1,2๋ ํ์คํ ๋ง์ท์ผ๋ 3์ ์์ฝ๊ฒ ํ๋ ธ๋ค.
[๋จ์ ๋ถ๋ฅ ์ค์ฐจ]
* model X : 1/3 = 0.33
* model Y : 1/3 = 0.33
[๋ถ๋ฅ ์ ํ๋]
* model X : 2/3 = 0.67
* model Y : 2/3 = 0.67
- ๋จ์ ๋ถ๋ฅ ์ค์ฐจ ๊ณ์ฐ์ ํ๋ฆฐ ๊ฐ์์ ๋ํ ๊ฒฐ๊ณผ๋ง ์์ ๋ฟ, label๊ณผ ๋น๊ตํ์ฌ ์ผ๋ง๋ ๋ง์ด ํ๋ ธ๋ ์ง๋ ์ ๊ณตํ์ง ์๋๋ค.
cross entropy๋ก ๊ณ์ฐํ ๊ฒฝ์ฐ
model X
* ์ฒซ๋ฒ์งธ sample : -( (ln(0.3)*0) + (ln(0.3)*0) + (ln(0.4)*1) ) = -ln(0.4)
* 3๊ฐ sample ๋ชจ๋์ ๋ํ ๊ณ์ฐ ๋ฐ ACE (Average cross-entropy)
* -(ln(0.4) + ln(0.4) + ln(0.1)) / 3 = 1.38
model Y
* 3๊ฐ sample ๋ชจ๋์ ๋ํ ๊ณ์ฐ ๋ฐ ACE (Average cross-entropy)
* -(ln(0.7) + ln(0.7) + ln(0.3)) / 3 = 0.64
Model X๋ณด๋ค Y๊ฐ ์ค์ฐจ๊ฐ ๋ ์๋ค. ์ฆ ์ด๋ค model์ด ๋ ์ ํ์ต ๋์๋ ์ง๋ฅผ ์ ์ ์๋ค.
MSE๋ก ๊ณ์ฐํ ๊ฒฝ์ฐ
model X
* ์ฒซ๋ฒ์งธ sample : (0.3 - 0)^2 + (0.3 - 0)^2 + (0.4 - 1)^2 = 0.09 + 0.09 + 0.36 = 0.54
* 3๊ฐ sample ๋ชจ๋์ ๋ํ ๊ณ์ฐ ๋ฐ MSE(Mean squared error)
* (0.54 + 0.54 + 1.34) / 3 = 0.81
model Y
* (0.14 + 0.14 + 0.74) / 3 = 0.34
MSE๋ ํ๋ฆฐ sample์ ๋ํด ๋ ์ง์คํ๋ค. ๋ง์ ๊ฒ๊ณผ ํ๋ฆฐ ๊ฒ ๋ชจ๋์ ๋๊ฐ์ด focusํด์ผ ํ๋๋ฐ ์ฌ๊ธฐ์๋ ๊ทธ๋ ์ง ์๋ค.
ACE์ MSE ๋น๊ต (activation์ softmax๋ก ํ์ ๊ฒฝ์ฐ)
- background : backpropagation ์ค์ label์ ๋ฐ๋ผ output์ 1.0 ๋๋ 0.0์ผ๋ก ์ค์ ํ๋ ค๊ณ ํ๋ค.
- MSE๋ฅผ ์ฌ์ฉํ ๊ฒฝ์ฐ
- ๊ฐ์ค์น ๊ณ์ฐ์์ ๊ธฐ์ธ๊ธฐ ๊ฐ์ (output) * (1 - output)์ด๋ผ๋ ์กฐ์ ์์๊ฐ ํฌํจ๋๋ค.
- ๊ณ์ฐ ๋ ์ถ๋ ฅ์ด 0.0 ๋๋ 1.0์ ๊ฐ๊น๊ฑฐ๋ ๊ฐ๊น์์ง์ ๋ฐ๋ผ (output) * (1 - output)์ ๊ฐ์ ์ ์ ์์์ง๋ค.
- ex) output = 0.6์ด๋ผ๋ฉด (output) * (1 - output) = 0.24์ด์ง๋ง ์ถ๋ ฅ์ด 0.95์ด๋ฉด (output) * (1 - output) = 0.0475์ด๋ค.
- ๊ทธ๋ ๊ฒ ๋๋ฉด ์กฐ์ ์์๊ฐ ์ ์ ์์์ง๋ฉด์ ๊ฐ์ค์น ๋ณํ๋ ์ ์ ์์์ง๊ณ ํ์ต ์งํ์ด ๋ฉ์ถ ์๋ ์๋ค.
- ACE๋ฅผ ์ฌ์ฉํ ๊ฒฝ์ฐ
- (output) * (1 - output) ํญ์ด ์ฌ๋ผ์ง๋ค.
- ๋ฐ๋ผ์ ๊ฐ์ค์น ๋ณํ๋ ์ ์ ์์์ง์ง ์์ผ๋ฏ๋ก ํ์ต์ด ๋ฉ์ถ๋ ์ผ์ด ๋ฐ์ํ์ง ์๋๋ค.
2-3. Logistic regression (binary cross-entropy)
2-4. Hinge Loss
SVM ๋ฑ์์ Maximum-margin classification์ ์ฌ์ฉํ๊ธฐ ์ํ Loss function