GAN - newlife-js/Wiki GitHub Wiki
GAN
1. Deep Learning ๊ฐ์
-
์ธ๊ณต์ง๋ฅ (Artificial Intelligence) ์ง์(์ ์ ) : ์ด๋ค ๊ฒ์ ๋ํ ์ธ์, ์ดํด
์ง๋ฅ(๋์ ) : ์ง์์ ์์๋ด๊ณ ์ํฉ์ ๋์ํ๋ ๋ฅ๋ ฅ(ํ์ต๊ณผ ์ถ๋ก )
์ง์ฑ(๋ชฉ์ /๋ชฉํ์งํฅ) : ์ง๋ฅ์ ํ์ฉํ๋ ์ฑํฅ/์ฑ์ง
์ธ๊ณต์ง๋ฅ : ๊ธฐ๊ณ์ ์ผ๋ก ์ง์์ ์์๋ด๊ณ , ๋์ํ๋ ๋ฅ๋ ฅ -
๊ธฐ๊ณํ์ต : ๋ฐ์ดํฐ๋ก๋ถํฐ ์ค์ค๋ก ํ์ต
๋ฐ์ดํฐ์์ ๋ด์ฌ๋ ํจํด, ๊ท์น ๋ฑ์์ ํ์ต์ ํตํ์ฌ ํน์ง์ ๊ตฌ๋ถํ๊ณ ๊ธฐ์ตํ๋ ๊ฒ.
ํ์ต: ์ฃผ์ด์ง ๋ฐ์ดํฐ์ ํ๊ฐ๋ฅผ ํตํด ํน์ง์ ์ฐพ๊ณ , ํน์ง๋ค์ ์กฐํฉ์ผ๋ก ๊ณ ์ฐจ์์ ํน์ง์ ์ฐพ๋ ๊ฒ -
Deep Neural Network : ์ธ๊ณต์ ๊ฒฝ๋ง ๊ณ์ธตํ MLP -> Deep NN
-
๊ธฐ๊ณํ์ต ๋ถ๋ฅ
Supervised Learning - classification / regression
Unsupervised Learning - clustering / generative model
Reinforcement Learning -
Regression : ์ด๋ค ์ ๋ ฅ ๋ฐ์ดํฐ์ ๋ํด ์ถ๋ ฅ์ ์ํฅ์ ์ฃผ๋ ์กฐ๊ฑด์ ๊ณ ๋ คํ ํ๊ท ๊ฐ ๊ตฌํ๊ธฐ
Linear regression : ํน์ง์ ์ ํ๊ฒฐํฉ์ผ๋ก value์ถ์
- y(์ข ์๋ณ์) ๊ฐฏ์ -> univariate / multivariate regression
Logistic regression : ๋ถ๋ฅ(-โพ๏ธ ~ +โพ๏ธ ๋ฅผ 0~1์ ํ๋ฅ ๋ก ๋ณํ)
- category ๊ฐฏ์ -> binomial / multinomial logistic regression
๊ฐ๊ฐ์ ๋ด๋ฐ node๊ฐ regression model๋ก ๊ตฌ์ฑ๋๋ค๊ณ ์๊ฐํ๋ฉด ๋จ.
activation function์ ์ข
๋ฅ๊ฐ linear / nonlinear / logistic ๋ง๋ฆ
multinomial logistic regression
- ๊ฐ ํด๋์ค์ ํด๋นํ ํ๋ฅ ์ ํฉ์ด 1์ด ๋๋๋ก(softmax) y ๋ฒกํฐ๋ฅผ ์ถ๋ ฅ
ํ์ต ์ฉ์ด
Epoch: ๋ชจ๋ data๋ฅผ ํ ๋ฒ ํ์ต
Batch size: ํ์ต ํ ๋ฒ์ ์ฌ์ฉ๋๋ dataset ํฌ๊ธฐ
Iteration: batch_size๋ก ํ์ต ์ํํ๋ ๊ฒ
1 Epoch = batch_size * #of_iteration(= #of_data / batch_size)
-
Cost function / Loss function
Cost function : ํ์ต์ ํตํด ์ต์ํํ๋ ค๋ function Loss function : cost function์ ๊ตฌ์ฑํ๋ subset -
Optimization Algorithm
Loss function์ ๊ฐ์ ์ต์ํํ๋ ์๊ณ ๋ฆฌ์ฆ
์) SGD, momentom, NAG, Adagrad ๋ฑ -
Learning rate
ํ์ต ์๋์ ์ฑ๋ฅ์ ์ํฅ์ ์ฃผ๋ hyper parameter -
Overfitting
Training Data์ ์ต์ ํ๋์ด ์ฑ๋ฅ์ด ์ ํ๋๋ ๋ฌธ์
๋ชจ๋ธ์ผ ใ ฃ๋ณต์ก๋๊ฐ ์ปค์ง์๋ก Training loss๋ ์ง์์ ์ผ๋ก ๊ฐ์ / Validation loss๋ ๋ค์ ์ฆ๊ฐ
2. Keras
๋ชจ๋ธ ์ค๊ณ
- Sequential model
model = keras.models.Sequential()
model.add(keras.~)
model.add(layers.~)
~~~
- Functional API
input_x = keras.Input(shape=(28,28))
x0 = layers.Flatten()(input_x)
x1 = layers.Dense(~)(x0)
output_x = layer.Dense(~)(x1)
model = keras.Model(inputs=input_x, outputs=output_x)
- Subclassing API
class MyModel(keras.Model):
~~~~~
- visualize
model.summary()
keras.utils.plot_model(model, 'model.png', hsow_shapes=True)
Compile method
- ๋ด์ฅ ํจ์์ 'name'์ผ๋ก ์ง์
model.compile(loss='categorical_crossentropy',
optimizer='Adam',
metrics=['accuracy', 'mse'])
- ๋ด์ฅํจ์๋ฅผ ์ง์
model.compile(optimizer=keras.optimizers.RMSprop(learning_rate=0.01, rho=0.9),
loss=keras.losses.CategoriclCrossentropy(),
metrics=[keras.metrics.CategoricalAccuracy()])
- ๋ด์ฅํจ์์ ์ธ์คํด์ค๋ก ์ง์
opt = keras.optimizers.Adam(learning_rate=0.01(
loss = keras.losses.SparseCategoricalCrossentropy()
metric = keras.metrics.CategoricalAccuracy()
model.compile(loss=loss, optimizer=opt, metrics=[metric]
๋ชจ๋ธ ์ ์ฅ
model.save('save_model') // SavedModel ํฌ๋งท ์ ์ฅ
keras.models.save_model(model, 'save_model') // SavedModel ํฌ๋งท ์ ์ฅ
model.save("model_save.h5") // HDF5 ํ์ผ๋ก ์ ์ฅ
๋ชจ๋ธ ๋ณต์
model = keras.models.load_model('save_model')
layers API
- Input
- Flatten
- Dense
- Activation(sigmoid, relu ๋ฑ)
- Dropout
- Batch Normalization : ๋ฐฐ์น ๋จ์๋ก ํต๊ณ์ ํน์ฑ์ด ๋ค๋ฅด๊ธฐ ๋๋ฌธ์, ๋ ์ด์ด ์ถ๋ ฅ์ ํต๊ณ์ ํน์ฑ์ด ํ๋ค๋ฆผ -> ํ์ต์๋ ๋ํ => Batch ๋จ์๋ก normalizationํ์ฌ ํ์ต์๋๋ฅผ ํฅ์์ํด
CNN ๊ด๋ จ layer
- Conv2d : CNN Convolution layer
- MaxPooling2D : downsampling(์ฐ์ฐ โฌ๏ธ)
- Conv2DTranspose : upsampling(์ฑ๋ฅ ํฅ์)
3. Generative Model
- ํ์ต ๋ฐ์ดํฐ ๋ถํฌ์ ์ ์ฌํ ๋ถํฌ๋ฅผ ๊ฐ๋ ๋ฐ์ดํฐ๋ฅผ ์์ฑํ๋ ๋ชจ๋ธ
๋ถ๋ฅ
- Explicit density : ํ์ต๋ฐ์ดํฐ๋ก๋ถํฐ ์์ฑ
- PixelRNN, PixelCNN, VAE
- Implicity density : random๊ฐ์ผ๋ก๋ถํฐ ์์ฑ
- GAN, GSN
VAE(Variational AutoEncoder)
AutoEncoder: ์์ถ๋ ํํ์ ์ฐพ๊ธฐ ์ํด ๋ฐ์ดํฐ๋ฅผ ์์ถํ๊ณ ์ ๋ ฅ์ ์ฌ๊ตฌ์ฑํ๋ ๋น์ง๋ ํ์ต (์ ๋ ฅ๋ฐ์ดํฐ๋ฅผ fixed vector์ mapping) VAE: ์์ถ๋ ํํ์ ๋ํํ๋ latent vector๋ฅผ ์ฐพ๊ณ , sampling๋ latent vector์์ ์ถ๋ ฅ์ ์์ฑ (์ ๋ ฅ๋ฐ์ดํฐ๋ฅผ distribution์ mapping)
GAN(Generative Adversarial Network)
discriminator๋ฅผ ํตํ ๊ฐ์ ์ ๊ต์ก์ ๋ฐ๋ generator๋ก ๊ตฌ์ฑ
NN์ผ๋ก ๊ตฌ์ฑ๋ ๋ ๊ฐ์ ๋ชจ๋ธ์ด ๊ฒฝ์์ ํ์ต์ ํตํด ์ฑ๋ฅ ๊ฐ์
generator์ discriminator๋ฅผ ์์ด๋๋ก ํ์ต(์ข์ ๋ถํฌ๋ฅผ ํ์ต)
discriminator๋ generator์ ๊ฒฐ๊ณผ๋ฅผ ํ๋ณํ๋๋ก ํ์ต(์ข์ ๊ฒฝ๊ณ๋ฅผ ํ์ต)
- Image-to-Image Translation (pixel2pixel) : ๋ค๋ฅธ ์ฌ์ง์ ํน์ง์ ๊ฐ์ ธ์์ ์ ํ์ค.
- Semantic-Image-to-Photo-Translation : semantic segmentation ์ด๋ฏธ์ง๋ฅผ ๊ฐ์ง๊ณ ์ค์ฌํ ์ด๋ฏธ์ง๋ฅผ ์์ฑ
- Super Resolution : ํด์๋ ๋์ด๊ธฐ (Conditional GAN ์ ์ฉ)
- Photo Inpainting : ์ผ๋ถ๊ฐ ์ง์์ง ์ฌ์ง ๋ถ๋ถ์ ์ฑ์๋ฃ๋ ๊ธฐ์
4. AutoEncoder / Denoise AutoEncoder
Encoder: ์ ๋ ฅ ๊ฐ์ ๊ตฌ์กฐํ๋ ๊ฐ(์ ์ฌ๊ณต๊ฐใ )์ ๋งตํํ๋ ํจ์
Decoder:์ ์ฌ ๊ณต๊ฐ์ ๊ฐ์ ๋ค๋ฅธ ๋๋ฉ์ธ์ผ๋ก ๋งตํํ๋ ํจ์
Code: ์ ์ฌ ๊ณต๊ฐ์ ๋ฒกํฐ๋ฅผ ๋งํจ.
AE๋ ์ํ์ ์ผ๋ก PCA์ ์ ์ฌํ์ง๋ง ์ค์ค๋ก ์ต์ ํ
Latent space: ์์ถ๋ ์ ์ฐจ์ ๊ณต๊ฐ
Latent variables: ์ ์ฅ๋ ๋ณ์
์ ์ฉ ๋ถ์ผ: Denoising, Super-resolution, Semantic Segmentation
- ํต์ฌ ๊ธฐ๋ฅ:
๊ณ ์ฐจ์ -> ์ ์ฐจ์ ํน์ง ๋ฐ๊ฒฌ
ํต์ฌ ์์ฑ ๋ณด์กด(์์๋ ์ด๋ฏธ์ง ๋ณต๊ตฌ)
์ฃผ์ ๋ณ๋ ์์ธ์ ์๊ฐํ
๋น์ ํ ์ฐจ์ ์ถ์(ํน๋ณํ ๊ณ ์ฐจ์ ๋ฐ์ดํฐ ์ฒ๋ฆฌ์ ๊ฐ๋ ฅํ ๋๊ตฌ)
- ํ๊ณ:
Decoding ๊ฒฐ๊ณผ์ quality ๋ฎ์(latent attribute๊ฐ discreteํ๊ฒ ํํ๋๋ฉด overfitting โฌ๏ธ)
-> latent attribute๋ฅผ ๋ถํฌ๋ก ํํํ์
Latent space์ ๋น๋์นญ mapping(decoding range ๋ถ๊ท ํ)
-> encoder๊ฐ ๋ง๋๋ ๋ถํฌ๋ฅผ ์ ๊ท๋ถํฌ๋ก ์ ํ(encoder์ ์ถ๋ ฅ์ด ์ ๊ท๋ถํฌ์์ ๋ฒ์ด๋๋ฉด loss๊ฐ ์ปค์ง๋๋ก ์ค๊ณ)
๊ฐ ๊ธ์ ๋ถํฌ์ ๋ถ๊ท ํ, ๋์ ๋ถํฌ์ ์ข์ ๋ถํฌ๊ฐ ํผ์ฌ
VAE(Variational Auto Encoder)
Latent variables๊ฐ ๊ฐ๊ณ ์๋ ํน์ง์ ์ ํํํ๋ decoder์ data-latent variable ๊ฐ์ mapping์ ์ ํ๋ encoder์ ๊ฒฐํฉ
์ข์ latent variables(P(z))๋ฅผ ๋ฝ๋ ๊ฒ์ด ์ค์, ํ์ง๋ง ๋๋ฌด ์ด๋ ค์ฐ๋ฏ๋ก ์ข์ encoder q(z|x)๋ฅผ ์ฐพ์.. -> Variational Inference ์ฌ์ฉ
Variational Inference
decoder๋ p(x|z)๋ฅผ ํ์ตํด์ผ ํ๋๋ฐ, prior P(z)๋ฅผ ์ ์ ์์ผ๋ฏ๋ก, ํ์ต์ด ๋ถ๊ฐ๋ฅํ๊ธฐ ๋๋ฌธ์ q(z|x)๋ก P(z)๋ฅผ ๊ทผ์ฌํจ
๋ณต์กํ distribution ์ ๋ ๊ฐ๋จํ encoder distribution q(z|x)์ ์ด์ฉํด์ ๊ทผ์ฌํ๋ ๊ฒ..
KL divergence๋ฅผ ์ด์ฉ(p(z)์ q(z|x) ์ฌ์ด์ KL Divergence๋ฅผ ๊ณ์ฐํ๊ณ , D_KL์ด ์ค์ด๋๋ ์ชฝ์ผ๋ก q_ํ์ด(z|x)์ ํ์ด๋ฅผ ์กฐ๊ธ์ฉ ์
๋ฐ์ดํธํด์ ์ต์ ์ P(z)์ ์ ์ฌํ ๋ถํฌ๋ฅผ ์ป์
-> p(x|z)๋ฅผ maximize ํ๋๋ก ํ์ตํ๋ ๊ฒ์ q(z|x)๋ฅผ ํ์ตํ๋ ๊ฒ์ผ๋ก ๋์ฒดํจ.
VAE_loss = decoder_loss + encoder_loss
-
Entropy: ์ ๋ณด๋์ ๊ธฐ๋๊ฐ(ํ๊ท ์ ๋ณด๋)
์ ๋ณด๋์ ๋ฐ์ ํ๋ฅ ๊ณผ ๋ฐ๋น๋ก(1/p) -> -log(p)
๊ธฐ๋๊ฐ -> -plog(p) ์ ๋ณด๋์ ํฉ -> -plog(p)์ ํฉ -
KL divergence: ์ ๋ณด ์์ค๋์ ๊ธฐ๋๊ฐ
์ ๋ณด์ ์์ค๋: ํ๋ฅ ๋ถํฌ p์ q ์ฌ์ด์ ์ ๋ณด๋์ ์ฐจ์ด -> -log(q) + log(p)
๊ธฐ๋๊ฐ -> -plog(q) + plog(p)์ ํฉ (= D_KL(p||q) )
D_KL์ด ์ต์๊ฐ ๋๋๋ก q๋ฅผ ์์ -
Cross entropy
D_KL์ ๋ทํญ์ q์ ๋ฌด๊ดํ๋ฏ๋ก, ์ ํญ(p*log(q)์ ํฉ)๋ง ์ต์ํ -
Maximum Likelihood Estimation
likelihood: ๊ด์ฐฐ๋ก๋ถํฐ ๋ชจ์๋ฅผ ์์ธกํ๋ ๊ฒ
ํ์ต: ํ๋ฅ ๊ด์ ์์ ๋ณด๋ฉด Maximum Likelihood ์ฐพ๋ ๊ฒ
5. GAN(Generative Adversarial Network)
์ด๋ ํ ๋ถํฌ์ ๋ฐ์ดํฐ๋ ๋ชจ๋ฐฉ / ์์ฑ ๋ชจ๋ธ๊ณผ ํ๋ณ ๋ชจ๋ธ์ด ๊ฒฝ์ํ๋ ๊ตฌ์กฐ
์์ฑ ๋ชจ๋ธ์ data class์ ๋ถํฌ๋ฅผ ๋ชจ๋ธ๋ง
ํ๋ณ ๋ณด๋ธ์ data class์ ๊ฒฝ๊ณ๋ฅผ ๋ชจ๋ธ๋ง
- noise z๋ก๋ถํฐ Generator๊ฐ G(z)๋ผ๋ fake ๋ฐ์ดํฐ๋ฅผ ์์ฑ
- Discriminator๊ฐ real data์ธ p(x)์ G(z)๋ฅผ ๋น๊ตํ์ฌ Real์ผ ํ๋ฅ (D(x))์ ์ถ๋ ฅ
Discriminator๋ gradient ascent: max(log(D(x))
Real์ ๋ํด ํ๋ฅ ์ด 1 -> ๊ธฐ๋๊ฐ 0(์ต๋๊ฐ) Fake์ ๋ํด ํ๋ฅ ์ด 0 -> ๊ธฐ๋๊ฐ 0(์ต๋๊ฐ)
Generator๋ gradient descent: min(1-log(D(G(z))) -> max(log(D(Gz))
Fake์ ๋ํด D์ ํ๋ฅ ์ด 1 -> ๊ธฐ๋๊ฐ -โพ๏ธ
๊ฐ ๋ชจ๋ธ์ loss function์ ๋ฐ๋ก ๋์ด์ ๊ฐ์ ํ์ต
loss๋ฅผ ์ต์ํํ๋ ๋ฐ์๋ D_JS(Jesen-Shannon Divergence)๋ฅผ ์ฌ์ฉ(๋์นญ์ ์ธ D_KL)
GAN ํ์ต์ด ์ด๋ ค์ด ์ด์
-
๋ถ๊ดด(์ถ์): Mode collapsing
๋ชจ๋ธ์ด multi-modal(์๋ด) ๋ฐ์ดํฐ ๋ถํฌ๋ฅผ ๋ชจ๋ ์ปค๋ฒํ์ง ๋ชปํ๊ณ ๋ค์์ฑ์ ์์ด๋ฒ๋ฆผ
loss๋ง์ ์ค์ด๋ ค๊ณ ํ๊ธฐ ๋๋ฌธ์ ํ์ชฝ ๋ด์ผ๋ก๋ง bias๋จ -
Oscillation ์๋ด์ ํ์ชฝ ๋ด์ผ๋ก bias๋ ํํ๋ฅผ ์๋ค๋ฆฌ๊ฐ๋ค๋ฆฌ ํจ
์๋ก์ ๋ฐ๋๋ฐฉํฅ์ผ๋ก ํ์ต์ด ์งํ๋์ด ์คํจ๋ฅผ ๋ฐ๋ณต
ํด๊ฒฐ์ฑ : Lossํจ์ ๊ฐ์
- Wasserstein GAN
- LS_GAN
CGAN(Conditinal GAN)
DCGAN์ ์ํด ์์ฑ๋ ์ด๋ฏธ์ง๋ ๋๋ค์ด๋ฏ๋ก, ํน์ ์ด๋ฏธ์ง๋ฅผ ์ ์ดํ ์ ์๋๋ก condition์ ๋ถ์ฌํจ
Condition: label์ one-hot code
ํ๋ณ๊ธฐ: condition์ ์ด๋ฏธ์ง์ ๊ฐ์ ํํ๋ก ๋ณํํ์ฌ ์ด๋ฏธ์ง์ concatenateํ์ฌ input์ผ๋ก ๊ณต๊ธ
์์ฑ๊ธฐ: latent vector์ label์ ๊ฒฐํฉํ์ฌ input์ผ๋ก ๊ณต๊ธ
์์ฉ: text๋ฅผ condition์ผ๋ก ๋ณํํ์ฌ, text๋ฅผ ์ด๋ฏธ์ง๋ก ๋ณํํ๋๋ก ํ ์๋ ์์..
ACGAN(Auxiliary classifier GAN)
์์ฑ๊ธฐ๋ ๋์ผํ๋, ํ๋ณ๊ธฐ๋ฅผ 2๊ฐ์ ๋ชจ๋ธ๋ก ๊ตฌ์ฑ
์ฐธ/๊ฑฐ์ง ๊ตฌ๋ถ(binary) + ์ด๋ฏธ์ง ๋ผ๋ฒจ ํ๋จ(categorical)
label์ ์ด๋ฏธ์ง์ concatenateํ์ง ์๊ณ , ์ด๋ฏธ์ง์์์ ์ถ๋ ฅ์ด sigmoid๋ก ๋ค์ด๊ฐ๊ธฐ ์ ์ ๋ฐ๋ก ๋ถ๊ธฐํ์ฌ softmax๋ก ์ถ๋ ฅํ์ฌ reak label๊ณผ ๋น๊ต
- CGAN, ACGAN์ผ๋ก๋ ์ํ๋ ์ ๋๋ก ๊ธฐ์ธ์ด์ง๊ณ , ๊ตต์ด์ง ์ซ์๋ฅผ ์์ฑํ ์๋ ์์
์ ์ฌ๊ณต๊ฐ์ ์ ๋ณด๋ค์ด ์ฝํ์๊ธฐ ๋๋ฌธ์...
InfoGAN
์ ์ฌ๊ณต๊ฐ์ ์ฝ๋๋ฅผ ํ์ด์ ์ ๋ฆฌํด ํด์ ๊ฐ๋ฅํ z-vector๋ฅผ ์ถ๊ฐ ๊ตฌ์ฑ
Z=(z,c) z: noise vector(์ฝํ ์ฝ๋), c: latent code(ํด์ ๊ฐ๋ฅ)
์์ฑ๊ธฐ ์
์ฅ์์๋ z์ c๋ฅผ ๊ตฌ๋ถํ์ง ์์
loss ํจ์์ ์ํธ์ ๋ณด๋ term ์ถ๊ฐ
์ํธ์ ๋ณด๋(mutual information): ๋ ํ๋ฅ ๋ณ์์ ์์กด์ฑ, ๊ณต์ entropy, I(X;Y) = D_KL(p(x,y)||p(x)p(y))
์ฌ๊ธฐ์๋ I(c'; G(z,c))๋ฅผ ์ฌ์ฉํจ, z,c๋ก๋ถํฐ ์์ฑ๋ ์ด๋ฏธ์ง์ ํ๋ณ๊ธฐ ์ ์ฌ์ฝ๋ c'์ ์ํธ์ ๋ณด๋
Pix2Pix
Image-to-Image Translation with CGAN
input ์ด๋ฏธ์ง๋ฅผ ์๋ก์ด domain์ผ๋ก translation(์ค์ผ์น -> real object ๋ฑ)