Caffe Example : 8.Siamese Network Training with Caffe (Kor) - ys7yoo/BrainCaffe GitHub Wiki
(Siamese Network Training with Caffe)
Caffe๋ฅผ ์ด์ฉํ ์ด ๋คํธ์ํฌ ํ์ตํ๊ธฐ์ด ์์๋ Caffe์์ ์ด ๋คํธ์ํฌ๋ฅผ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ์ ํ์ตํ๊ธฐ์ํด ์ฐ๋ฆฌ๊ฐ ๊ฐ์ค์น ๊ณต์ ์ ๋์กฐํ๋ ์์คํจ์์ ์ด๋ป๊ฒ ์ฌ์ฉํ ์ ์๋์ง๋ฅผ ๋ณด์ฌ์ค๋ค.
์ฐ๋ฆฌ๋ Caffe๋ฅผ ์ฑ๊ณต์ ์ผ๋ก ์ปดํ์ผ ํ๋ค๊ณ ๊ฐ์ ํ ๊ฒ์ด๋ค. ์๋๋ผ๋ฉด, ์ค์นํ์ด์งใน๋ฅผ ์ฐธ๊ณ ํ๊ธธ ๋ฐ๋๋ค. ์ด ์์๋ MNIST ์ ๋ฌธ์์ ๊ธฐ๋ฐ์ผ๋ก ๊ตฌ์ถํ๋ค. ๊ทธ๋์ ๊ณ์ํ๊ธฐ์ ์, ์ด๊ฑธ ๋จผ์ ์ฝ๊ณ ์ค๋๊ฒ ์ข์๊ฒ์ด๋ค.
์ด ๊ฐ์ด๋๋ ๋ชจ๋ ๊ฒฝ๋ก๋ค์ ๋ช ์ํ๊ณ ๋ชจ๋ ๋ช ๋ น์ด๋ค์ Caffe root ๋๋ ํ ๋ฆฌ์์ ์คํ๋์ด์ง๋ค๊ณ ๊ฐ์ ํ๋ค.
๋ฐ์ดํฐ์ธํธ ์ค๋นํ๊ธฐ(Prepare Datasets)
๋น์ ์ ์ฒ์์ MNIST ์น์ฌ์ดํธ์์ ๋ฐ์ดํฐ๋ฅผ ๋ค์ด๋ก๋๋ฐ๊ณ ์ ํํ ํ์๊ฐ ์๋ค. ์ด๋ฅผ์ํด์๋, ๊ฐ๋จํ ๋ค์๊ณผ ๊ฐ์ ๋ช ๋ น์ด๋ฅผ ์คํํ๋ผ:
./data/mnist/get_mnist.sh
./examples/siamese/create_mnist_siamese.sh
์ด ๋ช ๋ น์ด๋ฅผ ์คํํํ ๋๊ฐ์ ๋ฐ์ดํฐ์ธํธ ./examples/siamese/mnist_siamese_train_leveldb์ ./examples/siamese/mnist_siamese_test_leveldb๊ฐ ์์ ๊ฒ์ด๋ค.
๋ชจ๋ธ(The Model)
์ฒ์์, ์ฐ๋ฆฌ๋ ์ด ๋คํธ์ํฌ๋ฅผ์ฌ์ฉํด ํ์ตํ๊ธฐ๋ฅผ ์ํ๋ ๋ชจ๋ธ์ ์ ์ํ ๊ฒ์ด๋ค. ์ฐ๋ฆฌ๋ ./examples/siamese/mnist_siamese.prototxt์ ์ ์๋ ์ปจ๋ณผ๋ฃจ์ ๋ง์ ์ฌ์ฉํ ๊ฒ์ด๋ค. ์ด ๋ชจ๋ธ์ ๊ฑฐ์ ์ ํํ๊ฒ LeNet ๋ชจ๋ธ๊ณผ ์ผ์นํ๋๋ฐ, ์ค์ง ํ๊ฐ์ง ๋ค๋ฅธ์ ์ 2์ฐจ์ ๋ฒกํฐ๋ฅผ ์์ฑํ๋ ์ ํ "ํน์ง" ๊ณ์ธต์ผ๋ก 10๊ฐ์ ์ซ์ ํด๋์ค ์ด์์ ๊ฐ๋ฅ์ฑ์ ์์ฐํ๋ ์์ ๊ณ์ธต์ ๋์ฒดํด์จ๊ฒ์ด๋ค.
layer {
name: "feat"
type: "InnerProduct"
bottom: "ip2"
top: "feat"
param {
name: "feat_w"
lr_mult: 1
}
param {
name: "feat_b"
lr_mult: 2
}
inner_product_param {
num_output: 2
}
}
์ด ๋คํธ์ํฌ ์ ์ (Define the Siamese Network)
์ด ์น์ ์์๋ ์ฐ๋ฆฐ ํ์ต์ ์ฌ์ฉ๋ ์ด ๋คํธ์ํฌ๋ฅผ ์ ์ํ ๊ฒ์ด๋ค. ๋์ถ๋๋ ๋คํธ์ํฌ๋ ./examples/siamese/mnist_siamese_train_test.prototxt์ ์ ์๋์ด์ ธ์๋ค. In this section we will define the siamese network used for training. The resulting network is defined in ./examples/siamese/mnist_siamese_train_test.prototxt.
๋ฐ์ดํฐ ์ ์ฝ์ด๋ค์ด๊ธฐ(Reading in the Pair Data)
์ฐ๋ฆฌ๋ ์ด๊ธฐ์ ์ฐ๋ฆฌ๊ฐ ์์ฑํ LevelDB ๋ฐ์ดํฐ ๋ฒ ์ด์ค๋ก๋ถํฐ ์ฝ์ ๋ฐ์ดํฐ๊ณ์ธต์ ๊ฐ์ง๊ณ ์์ํ๋ค. ์ด ๋ฐ์ดํฐ ๋ฒ ์ด์ค์ ๊ฐ๊ฐ ๋ด์ฉ์ ํ ์์ ์ด๋ฏธ์ง๋ค (pair_data)์ ๋ํ ์ด๋ฏธ์ง์ ์ด๋ฏธ์ง๋ค์ด ๊ฐ์ ํด๋์ค์์ธ์ง ๋ค๋ฅธ ํด๋์ค๋ค์ ๊ฑธ๋ ค์๋์ง ๋งํด์ฃผ๋ ์ด์ง์ ๋ผ๋ฒจ(sim)์ ํฌํจํ๊ณ ์๋ค.
layer {
name: "pair_data"
type: "Data"
top: "pair_data"
top: "sim"
include { phase: TRAIN }
transform_param {
scale: 0.00390625
}
data_param {
source: "examples/siamese/mnist_siamese_train_leveldb"
batch_size: 64
}
}
๋ฐ์ดํฐ๋ฒ ์ด์ค์ ๊ฐ์ blob์์๋ค ํ ์์ ์ด๋ฏธ์ง๋ฅผ ํฌ์ฅํ๊ธฐ์ํด ์ฐ๋ฆฌ๋ ํ ์ฑ๋๋น ํ๋์ ์ด๋ฏธ์ง๋ก ํฌ์ฅํ๋ค. ์ฐ๋ฆฌ๋ ์ด๋ฌํ ๋๊ฐ์ ์ด๋ฏธ์ง๋ฅผ ๋๋์ด์ ์์ ํ ์ ์๋๊ฒ์ ์ํ๊ธฐ ๋๋ฌธ์, ์ฐ๋ฆฌ๋ ๋ฐ์ดํฐ ๊ณ์ธต ๋ค์์ ์๋ฅด๊ธฐ ๊ณ์ธต(slice layer)์ ์ถ๊ฐํ๋ค. ์ด๋ pair_data๋ฅผ ๊ฐ์ง๊ณ ์์ ์ฑ๋์ ์ฐจ์์์ ๋ฐ๋ผ pair_data๋ฅผ ์๋ผ์ ์ฐ๋ฆฌ๊ฐ ๋ฐ์ดํฐ๋ด ๋จ์ผ ์ด๋ฏธ์ง์ data_p์ ๋ฐ์ดํฐ์ ์์ผ๋ก ๋ ์ด๋ฏธ์ง๋ฅผ ๊ฐ์ง๋ค
layer {
name: "slice_pair"
type: "Slice"
bottom: "pair_data"
top: "data"
top: "data_p"
slice_param {
slice_dim: 1
slice_point: 1
}
}
์ด ๋ง์ ๋๋ฒ์งธ ์ฌ์ด๋ ๊ตฌ์ถ(Building the Second Side of the Siamese Net)
์ด์ ์ฐ๋ฆฌ๋ data_p์์ ์๋ํ๊ณ feat_p๋ฅผ ์ฒ๋ฆฌํ๋ ๋๋ฒ์งธ ๊ฒฐ๋ก๋ฅผ ์์ฑํ ํ์๊ฐ์๋ค. ์ด ๊ฒฐ๋ก๋ ์ ํํ ์ฒ์๊ณผ ๋ค๋ฅด์ง ์๋ค. ๊ทธ๋์ ์ฐ๋ฆฌ๋ ๊ทธ๋ฅ ์ด๋ฅผ ๋ณต์ฌํ๊ณ ๋ถ์ฌ๋ฃ๊ธฐํ ์ ์๋ค. ๊ทธ๋ฌ๊ณ ๋๋ฉด ๊ธฐ์กด์ ๊ฒ๊ณผ ์์ผ๋ก ๋(paired) ๊ณ์ธต๋ค์ ์ฐจ๋ณํ๊ธฐ์ํด _p๋ฅผ ๋ถ์ฌ์ค์ผ๋ก์จ ์ฐ๋ฆฌ๋ ๊ฐ๊ฐ์ ๊ณ์ธต, ์ ๋ ฅ, ๊ทธ๋ฆฌ๊ณ ์ถ๋ ฅ์ ์ด๋ฆ์ ๋ฐ๊ฟ์ผํ๋ค.
๋์กฐํ๋ ์์ค ํจ์ ์ถ๊ฐํ๊ธฐ(Adding the Contrastive Loss Function)
๋คํธ์ํฌ๋ฅผ ํ์ตํ๊ธฐ์ํด์ ์ฐ๋ฆฌ๋ Raia Hadsell, Sumit Chopra, ๊ทธ๋ฆฌ๊ณ Yann LeCun์ด ์ ์ ํ "๋ณํ์ง์๋ ๋งตํ์ ํ์ตํจ์ผ๋ก ์ฐจ์๊ฐ๋ฅ์ฑ ์ ๊ฑฐํ๊ธฐ(Dimensionality Reduction by Learning an Invariant Mapping)"์์ ์ ์ํ๋ ๋์กฐํ๋ ์์คํจ์๋ฅผ ์ต์ ํํ ๊ฒ์ด๋ค. ์ด ์์คํจ์๋ ๋งค์น๋์ง ๋ชปํ๋ ์๋ค์ ๋ฐ์ด๋๊ณผ ๋์์ ํน์ง ๊ณต๊ฐ์์ ์๋ก๋ฅผ ๊ฐ๊น๊ฒํ๊ธฐ์ํด ์์ ๋งค์นญํ๋๋ก ๋๋๋ค. ์ด๋ฌํ ๋น์ฉ ํจ์๋ CONTRASTIVE_LOSS๊ณ์ธต์์ ์ํ๋๋ค.
layer {
name: "loss"
type: "ContrastiveLoss"
contrastive_loss_param {
margin: 1.0
}
bottom: "feat"
bottom: "feat_p"
bottom: "sim"
top: "loss"
}
ํด๊ฒฐ์ฌ ์ ์ํ๊ธฐ(Define the Solver)
์๋ง๋ ๋ชจ๋ธ ํ์ผ์ ์ด๋ฅผ ์ง์ ํ๊ธฐ๋ง ํด์ฃผ๋ฉด ํด๊ฒฐ์ฌ๋ฅผ ์ ์ํ๊ธฐ์ํด ํน๋ณํ ํด์ค์ผ์ ์๋ค. ํด๊ฒฐ์ฌ๋ ./examples/siamese/mnist_siamese_solver.prototxt์ ์ ์๋์ด์๋ค.
๋ชจ๋ธ ํ์ต ๋ฐ ์คํ(Training and Testing the Model)
๋ชจ๋ธ ํ์ต์ ๋น์ ์ด ๋คํธ์ํฌ ์ ์ protobuf์ ํด๊ฒฐ์ฌ protobuf ํ์ผ๋ค์ ์์ฑํด์จ ํ๋ผ๋ฉด ๊ฐ๋จํ๋ค. ๊ฐ๋จํ ./examples/siamese/train_mnist_siamese.sh๋ฅผ ์คํ์์ผ์ฃผ๋ฉด ๋๋ค.
./examples/siamese/train_mnist_siamese.sh
๊ฒฐ๊ณผ ํ๋ผํ (Plotting the results)
์ฒ์์ ์ฐ๋ฆฌ๋ .prototxt ํ์ผ๋ค์ ์ ์๋ DAGs๋ฅผ ๊ทธ๋ ค์ฃผ๋ ๋ค์์ ๋ช ๋ น์ด๋ค์ ์คํ์์ผ์ค์ผ๋ก์จ ๋ชจ๋ธ๊ณผ ์ด ๋คํธ์ํฌ๋ฅผ ๊ทธ๋ฆด์์๋ค.
./python/draw_net.py \
./examples/siamese/mnist_siamese.prototxt \
./examples/siamese/mnist_siamese.png
./python/draw_net.py \
./examples/siamese/mnist_siamese_train_test.prototxt \
./examples/siamese/mnist_siamese_train_test.png
๋ค์์ผ๋ก, ์ฐ๋ฆฌ๋ ํ์ต๋ ๋ชจ๋ธ์ ๋ถ๋ฌ์ iPython ๋ ธํธ๋ถ์ ์ฌ์ฉํ ํน์ง๋ค์ ํ๋ผํ ํ ์ ์๋ค.
ipython notebook ./examples/siamese/mnist_siamese.ipynb