Caffe Tutorial : 4.Solver (Kor) - ys7yoo/BrainCaffe GitHub Wiki
Solver)
ํด๊ฒฐ์ฌ (ํด๊ฒฐ์ฌ(solver)๋ ์์ค์ ํฅ์์ํค๋ ค๋ ์๋๋ฅผ ํ๋ ํ๋ผ๋ฏธํฐ ์ ๋ฐ์ดํธ๋ฅผ ํ์ฑํ๊ธฐ์ํด ๋คํธ์ํฌ์ ์ ๋ฐฉํฅ ์ถ์ธก๊ณผ ์ญ๋ฐฉํฅ ๊ทธ๋๋์ธํธ๋ฅผ ์กฐ์งํด์ ๋ชจ๋ธ ์ต์ ํ๋ฅผ ์กฐ์ ํ๋ค. ํ์ต์ ํ์ ์ฌํญ๋ค์ ์ต์ ํ๋ฅผ ๊ฐ๋ ํ๊ณ ํ๋ผ๋ฏธํฐ ์ ๋ฐ์ดํธ๋ฅผ ์์ฑํ๊ธฐ์ํ ํด๊ฒฐ์ฌ์, ์์ค๊ณผ ๊ทธ๋๋์ธํธ๋ฅผ ์ฐ์ถํ๊ธฐ์ํ ๋ง์ผ๋ก ๋๋์ด์ง๋ค.
Caffe์ ํด๊ฒฐ์ฌ๋ ๋ค์๊ณผ ๊ฐ๋ค.
- Stochastic Gradient Descent ( type : "SGD" )
- AdaDelta ( type : "AdaDelta" )
- Adaptive Gradient (type: "AdaGrad"),
- Adam (type: "Adam"),
- Nesterovโs Accelerated Gradient (type: "Nesterov") and
- RMSprop (type: "RMSProp")
- The Caffe solvers are:
ํด๊ฒฐ์ฌ๋
- ์ต์ ํ ๊ณผ์ ๊ธฐ๋ก์ ๋ฐํ์ ๋ง๋ จํด์ฃผ๊ณ ํ์ต์ ์ํ ํ๋ จ ๋คํธ์ํฌ์ ํ๊ฐ๋ฅผ ์ํ ์คํ ๋คํธ์ํฌ๋ฅผ ์์ฑํด์ค๋ค.
- ๋ฐ๋ณต์ ์ผ๋ก ์ ๋ฐฉํฅ / ์ญ๋ฐฉํฅ์ ํธ์ถํ๊ณ ํ๋ผ๋ฏธํฐ๋ฅผ ์ ๋ฐ์ดํธํจ์ผ๋ก์จ ์ต์ ํ๋ฅผ ์งํํ๋ค.
- (์ฃผ๊ธฐ์ ์ผ๋ก) ํ ์คํธ ๋คํธ์ํฌ๋ค์ ํ๊ฐํ๋ค.
- ์ต์ ํ ๋ด๋ด ๋ชจ๋ธ๊ณผ ํด๊ฒฐ์ฌ ์ํ์ ์ค๋ ์ท์ ์ฐ๋๋ค.
๊ฐ ๋ฐ๋ณต๋ง๋ค ์ด๊ธฐํ๋ถํฐ ํ์ต๋ ๋ชจ๋ธ๊น์ง ๋ชจ๋ ๋ฐฉ๋ฒ์ ๊ฐ์ค์น๋ฅผ ์ทจํ๊ธฐ ์ํด
- ์์ค๊ณผ ์ถ๋ ฅ์ ๊ณ์ฐํ๊ธฐ์ํด ์ ๋ฐฉํฅ ๋คํธ์ํฌ๋ฅผ ํธ์ถํ๋ค.
- ๊ทธ๋๋์ธํธ๋ฅผ ๊ณ์ฐํ๊ธฐ ์ํด ์ญ๋ฐฉํฅ ๋คํธ์ํฌ๋ฅผ ํธ์ถํ๋ค.
- ํด๊ฒฐ์ฌ ๋ฉ์๋์ ๋ฐ๋ผ ํ๋ผ๋ฏธํฐ ์ ๋ฐ์ดํธ ์์ ๊ทธ๋๋์ธํธ๊ฐ ํฌํจ๋๋ค.
- ํ์ต๋ฅ , ๊ธฐ๋ก, ๊ทธ๋ฆฌ๊ณ ๋ฉ์๋์ ๋ฐ๋ผ์ ํด๊ฒฐ์ฌ ์ํ๊ฐ ์ ๋ฐ์ดํธ ๋๋ค.
Caffe ๋ชจ๋ธ๋ค๊ณผ ๊ฐ์ด, Caffe ํด๊ฒฐ์ฌ๋ CPU์ GPU ๋ชจ๋์์ ์๋ํ๋ค.
1.๋ฉ์๋ (Method)
ํด๊ฒฐ์ฌ ๋ฉ์๋๋ ์์ค ์ต์ํ์ ์ผ๋ฐ์ ์ต์ ํ ๋ฌธ์ ๋ฅผ ๋ค๋ฃฌ๋ค. ๋ฐ์ดํฐ์ธํธ D์ ๋ํ์ฌ , ์ต์ ํ ๋ชฉ์ ์ ๋ฐ์ดํฐ ์ ์ ๊ฑธ์ณ ๋ชจ๋ |D| ๋ฐ์ดํฐ ์ฌ๋ก์ ๋ํ ์ ์ฒด ํ๊ท ์์ค์ด๋ค.
L(W) = \frac{1}{|D|} \sum_i^{|D|} f_W\left(X^{(i)}\right) + \lambda r(W) <-- TeX
์ฌ๊ธฐ์ fW(X(i))๋ ๋ฐ์ดํฐ ๊ฒฝ์ฐ์ ์์ ๋ํ ์์ค์ด๊ณ r(W)๋ ๊ฐ์ค์น ฮป๋ฅผ ๊ฐ์ง ์กฐ์งํ ํญ(regularization term)์ด๋ค. |D|๋ ๋งค์ฐ ํด ์ ์์ง๋ง, ๊ทธ๋์ ์ค์ ๋ก๋, ์ฐ๋ฆฌ๊ฐ ์ด ๋ชฉํ์ ํ์จ์ ๊ทผ์ฌ์น๋ฅผ ์ฌ์ฉํ๋ ๊ฐ๊ฐ์ ํด๊ฒฐ์ฌ ๋ฐ๋ณต์ ์์ด, N<<|D| ๊ฒฝ์ฐ์ ์ต์ ์ผํ ์ฒ๋ฆฌ๋์ ๊ทธ๋ฆฐ๋ค.
L(W) \approx \frac{1}{N} \sum_i^N f_W\left(X^{(i)}\right) + \lambda r(W) <-- TeX
๋ชจ๋ธ์ ์ ๋ฐฉํฅ๊ณผ์ ์์๋ fw๋ฅผ ์ฐ์ฐํ๊ณ , ์ญ๋ฐฉํฅ ๊ณผ์ ์์๋ ๊ทธ๋๋์ธํธ โfw๋ฅผ ์ฐ์ฐํ๋ค. ํ๋ผ๋ฏธํฐ ์ ๋ฐ์ดํธ ฮW๋ ์๋ฌ ๊ทธ๋๋์ธํธ โfw, ์กฐ์งํ ๊ทธ๋๋์ธํธ(regularization gradient)โr(W), ๊ทธ๋ฆฌ๊ณ ๋ค๋ฅธ ํน์ ํ ๊ฐ๊ฐ์ ๋ฉ์๋ ๋ถํฐ์ ํด๊ฒฐ์ฌ์ ์ํด ์์ฑ๋๋ค.
1. ํ์จ๊ฒฝ์ฌํ๊ฐ SGD
ํ์จ๊ฒฝ์ฌํ๊ฐ("SGD" ๋ผ๊ณ ์น๋ค.)๋ ๋ค๊ฑฐํฐ๋ธ ๊ทธ๋๋์ธํธ โL(W)์ ์ด์ ์ ๊ฐ์ค์น ์ ๋ฐ์ดํธ Vt์ ์ ํ ํฉ์ฑ์ ์ํด ๊ฐ์ค์น W๋ฅผ ์ ๋ฐ์ดํธํ๋ค. ํ์ต์จ ฮฑ๋ ๋ค๊ฑฐํฐ๋ธ ๊ทธ๋๋์ธํธ์ ๊ฐ์ค์น์ด๋ฉฐ ๋ชจ๋ฉํ ฮผ์ ์ด์ ์ ๋ฐ์ดํธ์ ๊ฐ์ค์น์ด๋ค. ํ์์ ์ผ๋ก, ์ด์ ๊ฐ์ค์น ์ ๋ฐ์ดํธ Vt์ ํ์ฌ ๊ฐ์ค์น Wt๋ฅผ ๊ณ ๋ คํ์ฌ, ๋ฐ๋ณต t+1์์ ์ ๋ฐ์ดํธ ๋ ๊ฐ์ค์น Wt+1์ ์ ๋ฐ์ดํธ ๊ฐ Vt+1์ ์ฐ์ฐํ๊ธฐ์ํ ๋ค์๊ณผ ๊ฐ์ ๊ณต์์ด ์๋ค.
Vt+1=ฮผVtโฮฑโL(Wt)
Wt+1=Wt+Vt+1
"ํ์ดํผํ๋ผ๋ฏธํฐ" (ฮฑ ์ ฮผ)๋ฅผ ํ์ตํ๋๊ฒ์ ์ต๋์ ๊ฒฐ๊ณผ์ ๋ํ ์ฝ๊ฐ์ ์กฐ์จ์ด ์๊ตฌ๋ ์ง ๋ชจ๋ฅธ๋ค. ๋ง์ฝ ์ด๋์ ์์ํ ์ง์ ๋ํ ํ์ ์ด ์๋ค๋ฉด, ์๋ "์์ง์๊ฐ๋ฝ์ ๊ท์น"์ ๋ณด๊ณ ์ค๋ผ, ๊ทธ๋ฆฌ๊ณ ๋ ๋ง์ ์ ๋ณด๊ฐ ํ์ํ๋ค๋ฉด Leon Bottou ์ ์ ํ์จ์ ๊ธฐ์ธ๊ธฐ ๊ฐํ ์์์ (Stochastic Gradient Descent Tricks)๋ฅผ ์ฐธ๊ณ ํ ์ ์๋ค. #######[1] L. Bottou. Stochastic Gradient Descent Tricks. Neural Networks: Tricks of the Trade: Springer, 2012.
ํ์ต์จฮฑ ์ ๋ชจ๋ฉํ ฮผ ์ค์ ์ ์ํ ์์ง์๊ฐ๋ฝ ๊ท์น (Rules of thumb for setting the learning rate ฮฑ and momentum ฮผ)
SGD๋ก ํ๋ ์ฌ์ธตํ์ต์ ์ํ ์ข์ ์ ๋ต์ ์์ค์ด ํ์คํ "์์ ๊ธฐ"์ ๋ค๊ฐ๊ฐ๊ธฐ ์์ํ ๋ ํ์ต์ํค๋ ๋ด๋ด ์์ ์์ (10 ๊ฐ์)์ ์ํด ํ์ต์จ์ ๋ฎ์ถ๋ฉด์ ฮฑโ0.01=10^(โ2) ์ฃผ์ ๊ฐ์ ํ์ต์จฮฑ ์ ์ด๊ธฐํ ์ํค๋ ๊ฒ์ด๋ค. ์ผ๋ฐ์ ์ผ๋ก ์๋ง ๋ชจ๋ฉํ ฮผ=0.9์ด๋ ์ด์ ๋น์ทํ ๊ฐ์ ์ฌ์ฉํ๊ธธ ์ํ ์๋ ์๋ค. ๋ฐ๋ณต์ ํตํ ๊ฐ์ค์น ์ ๋ฐ์ดํธ๋ฅผ ๊ณ ๋ฃจ๊ฒ ํจ์ ์ํด, ๋ชจ๋ฉํ ์ ๋ ์์ ์ ์ด๊ณ ๋ ๋น ๋ฅธ SGD๋ก ํ๋ ์ฌ์ธตํ์ต์ ์ด๋ฃจ๋ ๊ฒฝํฅ์ด ์๋ค. ์ด๊ฒ์ Krizhevsky์ ๋ฑ๋ฑ์ ์ํด ์ฌ์ฉ๋ ์ ๋ต์ด๋ค. ILSVRC-2012๋ํ์์ CNN ์ํธ๋ฆฌ๋ก๋ถํฐ ์น๋ฆฌํ [1]์ Caffe๋ ์ด ์ ๋ต์ SolverParameter์์ ์ฝ๊ฒ ์ํํ๋ค. ์ด์ ๊ฐ์ ํ์ต์จ ์ ์ฑ ์ ์ฌ์ฉํ๊ธฐ ์ํด, solver prototxt ํ์ผ์์ ์ด๋ค ๊ณณ์ด๋ ๋ค์๊ณผ ๊ฐ์ ๋ผ์ธ์ ์ถ๊ฐํ ์ ์๋ค.
base_lr: 0.01 # 0.01 = 1e-2์ ํ์ต์จ๋ก ํ๋ จ์ ์์ํ๋ค.
lr_policy: "step" # ํ์ต์จ ๊ท์น : "๋จ๊ณ์ ์ผ๋ก" ํ์ต์จ์ ํ๋ฝ์ํจ๋ค.
# ๋ชจ๋ ๋จ๊ณ ํฌ๊ธฐ ๋ฐ๋ณต ๊ฐ๋ง์ ์์์ ์ํด
gamma: 0.1 # 10 ์์์ ์ํด ํ์ต์จ์ ํ๊ฐ์ํจ๋ค.
# (i.e., multiply it by a factor of gamma = 0.1)
stepsize: 100000 # ๋งค 10๋ง๋ฒ ๋ฐ๋ณตํ ๋๋ง๋ค ํ์ต์จ์ ํ๊ฐ์ํจ๋ค.
max_iter: 350000 # ์ ์ฒด 35๋ง๋ฒ ๋ฐ๋ณตํ์ฌ ํ๋ จํ๋ค.
momentum: 0.9
์์ ์ค์ ํ์, ์ฐ๋ฆฌ๋ ํญ์ ๋ชจ๋ฉํ
ฮผ=0.9์ ์ฌ์ฉํ ๊ฒ์ด๋ค. ์ฐ๋ฆฌ๋ ์ฒซ 10๋ง๋ฒ ๋ฐ๋ณต์ ๋ํด ฮฑ=0.01=10^(โ2)์ "base_lr"์์ ํ์ต์ ์์ํ ๊ฒ์ด๊ณ , ๊ทธ๋ฆฌ๊ณ ๋์ ๊ฐ๋ง(ฮณ)๋ฅผ ํ์ต์จ์ ๊ณฑ์
ํ๊ณ 10๋ง๋ฒ20๋ง๋ฒ ๋ฐ๋ณต์ ๋ํ์ฌ ฮฑโฒ=ฮฑฮณ=(0.01)(0.1)=0.001=10โ3์์ ํ์ต์ ํ๊ณ , 20๋ง๋ฒ30๋ง๋ฒ ๋ฐ๋ณต์ ๋ํด์๋ ฮฑโฒโฒ=10^(โ4)์์, ๊ทธ๋ฆฌ๊ณ ๋ง์ง๋ง์ผ๋ก 350๋ฒ์งธ ๋ฐ๋ณต๊น์ง๋ (์ฐ๋ฆฌ๊ฐ max_iter: 350000๋ก ์ค์ ํด ๋์๊ธฐ์) ฮฑโฒโฒโฒ=10^(โ5)์์ ํ์ตํ๋ค.
๋ชจ๋ฉํ ์ธํ ฮผ๊ฐ ์๋ง์ ํ์ต์ ๋ฐ๋ณต ํ์ 11โฮผ์ ์์์ ์ํด ์ ๋ฐ์ดํธ ์ฌ์ด์ฆ๋ฅผ ๊ณฑ์ ํ๋๋ฐ, ๊ทธ๋์ ๋ง์ฝ ฮผ๋ฅผ ์ฌ๋ฆฌ๊ธฐ๋ฅผ ์ํ๋ค๋ฉด, ฮฑ ์ ์์ํ์ฌ ๊ฐ์ํ๋ ๊ฒ์ ์ข์ ์๊ฐ์ด๋ค. (์ญ์ผ๋ก๋ ๊ฐ์) ์๋ฅผ๋ค๋ฉด, ฮผ=0.9๋ก, ์ฐ๋ฆฌ๋ 11โ0.9=10์ ํจ์จ์ ์ ๋ฐ์ดํธ ์ฌ์ด์ฆ ์น์๋ฅผ ๊ฐ์ง๋ค. ๋ง์ฝ ์ฐ๋ฆฌ๊ฐ ๋ชจ๋ฉํ ์ ฮผ=0.99๋ก ์ฌ๋ฆฐ๋ค๋ฉด, ์ฐ๋ฆฌ๋ ์ฐ๋ฆฌ์ ์ ๋ฐ์ดํธ ํฌ๊ธฐ ์น์๋ฅผ 100๊น์ง ์ฌ๋ฆฌ๋ฏ๋ก, ์ฐ๋ฆฌ๋ 10 ์์์ ์ํด (base_lr) ฮฑ๋ฅผ ํ๋ฝ์์ผ์ผ๋งํ๋ค.
๋ํ ์์ ์ค์ ์ ๋จ์ง ๊ฐ์ด๋๋ผ์ธ์ด๋ฉฐ, ๋ถ๋ช ํ ๋ชจ๋ ์ํฉ์์ ์ ์ค์ ์ด ์ต์ ์ด๋ผ๋ ๋ณด์ฅ์ด์๋ค. ๋ง์ฝ ํ์ตํ๋๊ฒ์ด ๋๋๋ฉด base_lr(์๋ฅผ๋ค๋ฉด base_lr: 0.001)๋ฅผ ๋ฎ์ถ๊ฑฐ๋ ์๋๋ฉด ์ฌ ํ๋ จ์ํค๋ ๊ฒ์ด๋ ์ ๋นํ base_lr ๊ฐ์ ์ฐพ์ ๋๊น์ง ๋ฐ๋ณตํด๋ณด์๋ผ.
#######[1] A. Krizhevsky, I. Sutskever, and G. Hinton. ImageNet Classification with Deep Convolutional Neural Networks. Advances in Neural Information Processing Systems, 2012.
2. AdaDelta
The AdaDelta ("AdaDelta"๋ผ๊ณ ์ ๋ ฅํ๋) ๋ฉ์๋ (M. Zeiler [1])๋ ํ๋ฐํ ํ์ต์จ ๋ฉ์๋(robust learning rate method)์ด๋ค. (SGD ๊ฐ์ด) ์ด๊ฒ์ ๊ทธ๋๋์ธํธ ๊ธฐ๋ฐ์ ์ต์ ํ ๋ฉ์๋์ด๋ค. ์ ๋ฐ์ดํธ ๊ณต์์ ๋ค์๊ณผ ๊ฐ๋ค.
% <![CDATA[
\begin{align}
(v_t)_i &= \frac{\operatorname{RMS}((v_{t-1})_i)}{\operatorname{RMS}\left( \nabla L(W_t) \right)_{i}} \left( \nabla L(W_{t'}) \right)_i
\\
\operatorname{RMS}\left( \nabla L(W_t) \right)_{i} &= \sqrt{E[g^2] + \varepsilon}
\\
E[g^2]_t &= \delta{E[g^2]_{t-1} } + (1-\delta)g_{t}^2
\end{align} %]]>
(W_{t+1})_i = (W_t)_i - \alpha (v_t)_i.
#######[1] M. Zeiler ADADELTA: AN ADAPTIVE LEARNING RATE METHOD. arXiv preprint, 2012.
3. AdaGrad
์กฐ์ ํ๋ ๊ทธ๋๋์ธํธ ๋ฉ์๋(adaptive gradient method) ("AdaGrad"๋ผ๊ณ ์น๋ค.)๋ Duchi์ ๊ทธ์ ๋๋ฃ๋ค์ ๋ง์ ๋ฐ๋ฅด๋ฉด "์์ธก์ด ๋งค์ฐ ๋ฐ์ด๋์ง๋ง ๊ฑฐ์ ํน์ง์ด ๋ณด์ด์ง ์๋ ํํ์์์ ๊ฑด์ด๋๋ฏธ์์ ๋ฐ๋๋ฅผ ์ฐพ๋ ๊ฒ"๊ณผ๊ฐ์ ์ ์๋ํ๋ (SGD์ ๊ฐ์) ๊ทธ๋๋์ธํธ ๊ธฐ๋ฐ์ ์ต์ ํ ๋ฉ์๋์ด๋ค. tโฒโ{1,2,...,t}์ ๋ํ (โL(W))tโฒ์ธ ์ ์ฒด ์ด์ ์ ๋ฐ๋ณต๋ค๋ก ๋ถํฐ ์ ๋ฐ์ดํธ ์ ๋ณด๋ฅผ ๊ณ ๋ คํ์๋ฉด, ๊ฐ๊ฐ์ ๊ฐ์ค์นW์ ์์ i์ ๋ํด ๋ช ์๋ [1]์ ์ํด ์ ์๋ ๊ณต์์ด ๋ค์๊ณผ ๊ฐ๋ค.
(W_{t+1})_i =
(W_t)_i - \alpha
\frac{\left( \nabla L(W_t) \right)_{i}}{
\sqrt{\sum_{t'=1}^{t} \left( \nabla L(W_{t'}) \right)_i^2}
}
์ค์ ๋ก๋, ๊ฐ์ค์น WโRd์ ๋ํ์ฌ, (Caffe์์ ์ ๊ณตํ๋ ๊ฒ๋ ํฌํจํด์) AdaGrad ์ํ๋ค์ ๊ธฐ๋ก๋ ๊ทธ๋๋์ธํธ ์ ๋ณด์ ์ฅ์ ๋ํ ์ถ๊ฐ์ ์ ์ฅ์ธ ์ค์ง O(d)๋ฅผ ์ฌ์ฉํ๋ค. (๊ฐ๊ฐ์ ๊ธฐ๋ก๋ ๊ทธ๋๋์ธํธ๋ฅผ ๊ฐ๋ณ์ ์ผ๋ก ์ ์ฅํด์ผ๋ง ํ๋ O(dt)๋ณด๋ค๋ ) #######[1] J. Duchi, E. Hazan, and Y. Singer. Adaptive Subgradient Methods for Online Learning and Stochastic Optimization. The Journal of Machine Learning Research, 2011.
4. Adam
kingma์ ๊ทธ์ ๋๋ฃ๋ค์ด ์ ์ํ [1], Adam ("Adam"์ด๋ผ๊ณ ์น๋ค) ์ SGD๊ฐ์ด ๊ทธ๋๋์ธํธ ๊ธฐ๋ฐ์ ์ต์ ํ ๋ฉ์๋์ด๋ค. ์ด ๋ฐฉ๋ฒ์ "์กฐ์ ํ๋ ๋ชจ๋ฉํธ ํ๊ฐ์น(adaptive moment estimation)" (mt,vtmt,vt)๋ฅผ ํฌํจํ๋ฉฐ AdaGrad์ ์ผ๋ฐํ๋ก์จ ๊ฐ์ฃผ๋ ์ ์๋ค. ์ ๋ฐ์ดํธ ๊ณต์์ ๋ค์๊ณผ ๊ฐ๋ค.
(m_t)_i = \beta_1 (m_{t-1})_i + (1-\beta_1)(\nabla L(W_t))_i,\\
(v_t)_i = \beta_2 (v_{t-1})_i + (1-\beta_2)(\nabla L(W_t))_i^2
(W_{t+1})_i =
(W_t)_i - \alpha \frac{\sqrt{1-(\beta_2)_i^t}}{1-(\beta_1)_i^t}\frac{(m_t)_i}{\sqrt{(v_t)_i}+\varepsilon}.
Kingma ์ ๊ทธ์ ๋๋ฃ๋ค์ด ์ ์ํ [1]์์๋ ฮฒ1=0.9,ฮฒ2=0.999,ฮต=10โ8 ๋ฅผ ๋ํดํธ ๊ฐ์ผ๋ก ์ฌ์ฉํ๋ผ๊ณ ์ ์ํ๋ค. Caffe๋ ๊ฐ๊ฐ ฮฒ1,ฮฒ2,ฮตฮฒ1,ฮฒ2,ฮต์ ๋ํ์ฌ ๋ชจ๋ฉํ , ๋ชจ๋ฉํ 2 ๋ธํ๋ฅผ ์ฌ์ฉํ๋ค.
#######[1] D. Kingma, J. Ba. Adam: A Method for Stochastic Optimization. International Conference for Learning Representations, 2015.
4. NAG (Nesterovโs accelerated gradient)
๋ค์คํธ๋ก๋ธ์ ๊ฐ์๋ ๊ทธ๋๋์ธํธ ("Nesterov"๋ผ๊ณ ์น๋ค.)๋ O(1/t)๋ณด๋ค O(1/(t^2))์ ์๋ ด๋ฅ ์ ๋ฌ์ฑํ๋ฉด์ ๋ณผ๋กํ ์ต์ ํ(convex optimization)์ "์ต์ ์" ๋ฐฉ๋ฒ์ผ๋ก์จ Nesterov๋ [1]์ ์ ์ํ๋ค. ๋น๋ก ์๋ ด O(1/t2)๋ฅผ ๋ฌ์ฑํ๊ธฐ์ํด ํ์๋กํ๋ ์๋น๊ฐ ์ผ๋ฐ์ ์ผ๋ก Caffe๋ก ํ๋ จ์ํจ ์ฌ์ธต ๋คํธ์ํฌ๋ค์ ์๋ฆฌ์ก์ง๋ ์๋๋ผ๋, Sutskever์ ๊ทธ์ ๋๋ฃ๋ค์ด deep MNIST autoencoders [2]๋ฅผ ๋ฌ์ฌํ๋ ๊ฒ์ฒ๋ผ, ์ค์ NAG๋ ์ฌ์ธตํ์ต ๊ตฌ์กฐ์ ํน์ ํ ํ์ ๋ค์ ์ต์ ํํ๋๋ฐ ๋งค์ฐ ํจ๊ณผ์ ์ธ ๋ฐฉ๋ฒ์ด๋ค. ๊ฐ์ค์น ์ ๋ฐ์ดํธ ๊ณต์์ ์์ SGD ์ ๋ฐ์ดํธ์์ ๋ณด์ธ ๊ฒ๊ณผ ๋งค์ฐ ์ ์ฌํ๋ค.
V_{t+1} = \mu V_t - \alpha \nabla L(W_t + \mu V_t)
W_{t+1} = W_t + V_{t+1}
SGD ๋ฉ์๋์ ๊ตฌ๋ณ๋๋ ์ด ๋ฐฉ๋ฒ์ ์ฐ๋ฆฌ๊ฐ ๊ฐ๋จํ ํ์ฌ ๊ฐ์ค์น ๊ทธ๋ค ์์ฒด์์ ๊ทธ๋๋์ธํธโL(Wt)๋ฅผ ์ทจํ๋ SGD์์, ํน์ ์ฐ๋ฆฌ๊ฐ ์ถ๊ฐ๋ ๋ชจ๋ฉํ โL(Wt+ฮผVt)์ผ๋ก ๊ฐ์ค์น์ ๋ํ ๊ทธ๋๋์ธํธ๋ฅผ ์ทจํ๋ NAG์์, ์ฐ๋ฆฌ๊ฐ ์๋ฌ ๊ทธ๋๋์ธํธ โL(W)๋ฅผ ๊ณ์ฐํ ๊ฒ์ ๋ํ W๋ฅผ ์ค์ ํ๋ ๊ฐ์ค์น์ด๋ค.
######[1] Y. Nesterov. A Method of Solving a Convex Programming Problem with Convergence Rate O(1/kโโโ)O(1/k). Soviet Mathematics Doklady, 1983.
######[2] I. Sutskever, J. Martens, G. Dahl, and G. Hinton. On the Importance of Initialization and Momentum in Deep Learning. Proceedings of the 30th International Conference on Machine Learning, 2013.
5. RMSprop
RMSprop("RMSProp"๋ผ๊ณ ์น๋ค)๋ ์ฝ์ธ๋ผ ๊ณผ์ ๊ฐ์(Coursera course lecture)์์ Tieleman ์ด ์ ์ํ ๊ฒ์ด๋ฉฐ ์ด๋ SGD ์ฒ๋ผ ๊ทธ๋๋์ธํธ ๊ธฐ๋ฐ์ ์ต์ ํ์ด๋ฉฐ ์ ๋ฐ์ดํธ ๊ณต์์ ๋ค์๊ณผ ๊ฐ๋ค.
\operatorname{MS}((W_t)_i)= \delta\operatorname{MS}((W_{t-1})_i)+ (1-\delta)(\nabla L(W_t))_i^2 \\
(W_{t+1})_i= (W_{t})_i -\alpha\frac{(\nabla L(W_t))_i}{\sqrt{\operatorname{MS}((W_t)_i)}}
(rms_decay) ฮด์ ๋ํดํธ ๊ฐ์ ฮด=0.99๋ก ์ค์ ๋์ด ์๋ค.
[1] T. Tieleman, and G. Hinton. RMSProp: Divide the gradient by a running average of its recent magnitude. COURSERA: Neural Networks for Machine Learning.Technical report, 2012.
2. ๋ฐํ ๋ง๋ จํ๊ธฐ (Scaffolding)
๋ฐํ์ ๋ง๋ จํ๋ ํด๊ฒฐ์ฌ๋ "Solver::Presolve()"์์ ํ์ต๋์ด์ง๊ธฐ ์ํ ๋ชจ๋ธ์ ์ด๊ธฐํํ๊ณ ๋ฉ์๋ ์ต์ ํ๋ฅผ ์ค๋นํ๋ค.
> caffe train -solver examples/mnist/lenet_solver.prototxt
I0902 13:35:56.474978 16020 caffe.cpp:90] Starting Optimization
I0902 13:35:56.475190 16020 solver.cpp:32] Initializing solver from parameters:
test_iter: 100
test_interval: 500
base_lr: 0.01
display: 100
max_iter: 10000
lr_policy: "inv"
gamma: 0.0001
power: 0.75
momentum: 0.9
weight_decay: 0.0005
snapshot: 5000
snapshot_prefix: "examples/mnist/lenet"
solver_mode: GPU
net: "examples/mnist/lenet_train_test.prototxt"
- ๋ง ์ด๊ธฐํ (Net initialization)
I0902 13:35:56.655681 16020 solver.cpp:72] Creating training net from net file: examples/mnist/lenet_train_test.prototxt
[...]
I0902 13:35:56.656740 16020 net.cpp:56] Memory required for data: 0
I0902 13:35:56.656791 16020 net.cpp:67] Creating Layer mnist
I0902 13:35:56.656811 16020 net.cpp:356] mnist -> data
I0902 13:35:56.656846 16020 net.cpp:356] mnist -> label
I0902 13:35:56.656874 16020 net.cpp:96] Setting up mnist
I0902 13:35:56.694052 16020 data_layer.cpp:135] Opening lmdb examples/mnist/mnist_train_lmdb
I0902 13:35:56.701062 16020 data_layer.cpp:195] output data size: 64,1,28,28
I0902 13:35:56.701146 16020 data_layer.cpp:236] Initializing prefetch
I0902 13:35:56.701196 16020 data_layer.cpp:238] Prefetch initialized.
I0902 13:35:56.701212 16020 net.cpp:103] Top shape: 64 1 28 28 (50176)
I0902 13:35:56.701230 16020 net.cpp:103] Top shape: 64 1 1 1 (64)
[...]
I0902 13:35:56.703737 16020 net.cpp:67] Creating Layer ip1
I0902 13:35:56.703753 16020 net.cpp:394] ip1 <- pool2
I0902 13:35:56.703778 16020 net.cpp:356] ip1 -> ip1
I0902 13:35:56.703797 16020 net.cpp:96] Setting up ip1
I0902 13:35:56.728127 16020 net.cpp:103] Top shape: 64 500 1 1 (32000)
I0902 13:35:56.728142 16020 net.cpp:113] Memory required for data: 5039360
I0902 13:35:56.728175 16020 net.cpp:67] Creating Layer relu1
I0902 13:35:56.728194 16020 net.cpp:394] relu1 <- ip1
I0902 13:35:56.728219 16020 net.cpp:345] relu1 -> ip1 (in-place)
I0902 13:35:56.728240 16020 net.cpp:96] Setting up relu1
I0902 13:35:56.728256 16020 net.cpp:103] Top shape: 64 500 1 1 (32000)
I0902 13:35:56.728270 16020 net.cpp:113] Memory required for data: 5167360
I0902 13:35:56.728287 16020 net.cpp:67] Creating Layer ip2
I0902 13:35:56.728304 16020 net.cpp:394] ip2 <- ip1
I0902 13:35:56.728333 16020 net.cpp:356] ip2 -> ip2
I0902 13:35:56.728356 16020 net.cpp:96] Setting up ip2
I0902 13:35:56.728690 16020 net.cpp:103] Top shape: 64 10 1 1 (640)
I0902 13:35:56.728705 16020 net.cpp:113] Memory required for data: 5169920
I0902 13:35:56.728734 16020 net.cpp:67] Creating Layer loss
I0902 13:35:56.728747 16020 net.cpp:394] loss <- ip2
I0902 13:35:56.728767 16020 net.cpp:394] loss <- label
I0902 13:35:56.728786 16020 net.cpp:356] loss -> loss
I0902 13:35:56.728811 16020 net.cpp:96] Setting up loss
I0902 13:35:56.728837 16020 net.cpp:103] Top shape: 1 1 1 1 (1)
I0902 13:35:56.728849 16020 net.cpp:109] with loss weight 1
I0902 13:35:56.728878 16020 net.cpp:113] Memory required for data: 5169924
- ์์ค (Loss)
I0902 13:35:56.728893 16020 net.cpp:170] loss needs backward computation.
I0902 13:35:56.728909 16020 net.cpp:170] ip2 needs backward computation.
I0902 13:35:56.728924 16020 net.cpp:170] relu1 needs backward computation.
I0902 13:35:56.728938 16020 net.cpp:170] ip1 needs backward computation.
I0902 13:35:56.728953 16020 net.cpp:170] pool2 needs backward computation.
I0902 13:35:56.728970 16020 net.cpp:170] conv2 needs backward computation.
I0902 13:35:56.728984 16020 net.cpp:170] pool1 needs backward computation.
I0902 13:35:56.728998 16020 net.cpp:170] conv1 needs backward computation.
I0902 13:35:56.729014 16020 net.cpp:172] mnist does not need backward computation.
I0902 13:35:56.729027 16020 net.cpp:208] This network produces output loss
I0902 13:35:56.729053 16020 net.cpp:467] Collecting Learning Rate and Weight Decay.
I0902 13:35:56.729071 16020 net.cpp:219] Network initialization done.
I0902 13:35:56.729085 16020 net.cpp:220] Memory required for data: 5169924
I0902 13:35:56.729277 16020 solver.cpp:156] Creating test net (#0) specified by net file: examples/mnist/lenet_train_test.prototxt
- Completion
I0902 13:35:56.806970 16020 solver.cpp:46] Solver scaffolding done.
I0902 13:35:56.806984 16020 solver.cpp:165] Solving LeNet
3. ํ๋ผ๋ฏธํฐ ์ ๋ฐ์ดํธํ๊ธฐ (Updating Parameters)
์ค์ ๊ฐ์ค์น ์ ๋ฐ์ดํธ๋ ํด๊ฒฐ์ฌ์ ์ํด ๋ง๋ค์ด์ง ๋ค, "Solver::ComputeUpdateValue()"์์ ๋ง ํ๋ผ๋ฏธํฐ๊ฐ ์ ์ฉ๋๋ค. "ComputeUpdateValue" ๋ฉ์๋๋ ๊ฐ๊ฐ์ ๋คํธ์ํฌ ๊ฐ์ค์น์ ๋ํ์ฌ ์ต์ข ์ ๊ทธ๋๋์ธํธ๋ฅผ ์ทจํ๋ (ํ์ฌ ์๋ฌ ๊ทธ๋๋์ธํธ๋ฅผ ํฌํจํ๊ณ ์๋ )๊ฐ์ค์น ๊ทธ๋๋์ธํธ์์ ์ด๋ค ์ค๋์น ๊ฐ์ r(W)๋ฅผ ํฌํจํ๋ค. ๊ทธ๋ฆฌ๊ณ ๋์ ์ด๋ฌํ ๊ทธ๋๋์ธํธ๋ ๊ฐ๊ฐ์ Bolb์ diff ํ๋ ํ๋ผ๋ฏธํฐ์์ ์ ์ฅ๋ ๋บ์ ์ ๋ฐ์ดํธ์ ํ์ต์จ ฮฑ์ ์ํด ์์น๋์ด์ง๋ค. ์ต์ข ์ ์ผ๋ก "Blob::Update" ๋ฉ์๋๋ ๊ฐ๊ฐ์ blob ํ๋ผ๋ฏธํฐ์ ํธ์ถ๋๋ฉฐ, ์ด๋ ์ต์ข ์ ๋ฐ์ดํธ๋ฅผ ์ํํ๋ค. (๋ฐ์ดํฐ๋ก ๋ถํฐ Blob์ diff๋ฅผ ๋นผ๋ฉด์)
4. Snapshotting and Resuming
ํด๊ฒฐ์ฌ๋ "Solver::Snapshot()"์ "Solver::SnapshotSolverState()"์์ ํ์ตํ๋ ๋์ ๊ฐ์ค์น์ ๊ฐ์ค์น์ ์ํ๋ฅผ ์ค๋ ์ท์ผ๋ก ์ฐ๋๋ค. ํด๊ฒฐ์ฌ ์ค๋ ์ท์ด ์ฃผ์ด์ง ์ง์ ์ผ๋ก๋ถํฐ ์ฌํ์ตํ๊ธฐ ์ํ ํ๋ จ์ ๊ฐ๋ฅํ๊ฒํ๋ ๋์์ ๊ฐ์ค์น ์ค๋ ์ท์ ํ์ต๋ ๋ชจ๋ธ์ ๋ด๋ณด๋ธ๋ค. ํ๋ จ์ "Solver::Restore()"์ "Solver::RestoreSolverState()"์ ์ํด ์ฌํ์ต๋์ด์ง๋ค. ํด๊ฒฐ์ฌ ์ํ๊ฐ ".solverstate" ํ์ฅ์ ์ ์ฅ๋๋ ๋์ ๊ฐ์ค์น๋ค์ ํ์ฅ์์ด ์ ์ฅ๋๋ค. ์์ชฝ ํ์ผ ๋ชจ๋ ์ค๋ ์ท ๋ฐ๋ณต ์์ ๋ํ์ฌ ์ ๋ฏธ์ฌ "_iter_N"๋ฅผ ๊ฐ์ง๋ค. ์ค๋ ์ท์ ๋ค์๊ณผ ๊ฐ์ด ์ค์ ๋์ด์ง๋ค.
# ๋ฐ๋ณต์์ ์ค๋
์ท ๊ฐ๊ฒฉ
snapshot: 5000
# ๋ชจ๋ธ ๊ฐ์ค์น์ ํด๊ฒฐ์ฌ ์ํ๋ฅผ ์ค๋
์ท์ผ๋ก ์ฐ์ด๋์ ๊ฒ์ ๋ํ ํ์ผ ๊ฒฝ๋ก ์ ๋ฏธ์ฌ
# ์ด๋ 'Caffe' ๋๊ตฌ๊ฐ ๋์ํ๋ ๊ฒ๊ณผ ๊ด๋ จ์์ผ๋ฉฐ ํด๊ฒฐ์ฌ ์ ์ ํ์ผ๊ณผ๋ ๋ฌด๊ดํ๋ค.
snapshot_prefix: "/path/to/model"
# ๊ฐ์ค์น์๋ฐ๋ผ diff๋ฅผ ์ค๋
์ท์ผ๋ก ์ฐ์ผ๋ฉฐ ์ด๋ ํ์ต์ ๋๋ฒ๊น
์ ๋์์ ์ฃผ์ง๋ง ์ ์ฅ์ฉ๋์ด ์ฆ๊ฐํ๋ค.
# ์ต์ข
์ค๋
์ท์ ์ด ํ๋๊ทธ๊ฐ false๋ผ๊ณ ์ค์ ํ์ง ์๋ํ ํ์ต์ ๋์ ์ ์ฅ๋ ๊ฒ์ด๋ค. ๋ํดํธ๊ฐ์๋ true๋ค.
snapshot_after_train: true