DeepLearning chainer DCGAN - eiichiromomma/CVMLAB GitHub Wiki
DeepLearning) chainer-DCGAN
(データセット
imagesフォルダを作って画像を入れておく。但し、ファイルサイズが96x96決め打ちなのでImagemagickを入れておいて
mogrify -resize 96x96! *.jpg
で変更できる。"!"がポイント。
DCGAN.py
ファイルが古いので3.0用に雑に変更した。あとは表示部分をmatplotlibへ変換して常時更新表示できるようにした。
--- DCGAN_org.py 2017-06-26 08:34:04.944146151 +0900
+++ DCGAN.py 2017-06-26 12:39:15.849476620 +0900
@@ -1,10 +1,10 @@
+import pickle
import numpy as np
from PIL import Image
import os
-from io import StringIO
import io
import math
-import pylab
+import matplotlib.pyplot as plt
import chainer
@@ -33,18 +33,19 @@
n_epoch=10000
n_train=200000
image_save_interval = 5000
+image_show_interval = 1000
# read all images
fs = os.listdir(image_dir)
-print( len(fs))
+print (len(fs))
dataset = []
for fn in fs:
f = open('%s/%s'%(image_dir,fn), 'rb')
img_bin = f.read()
dataset.append(img_bin)
f.close()
-print( len(dataset))
+print (len(dataset))
class ELU(function.Function):
@@ -160,6 +161,13 @@
o_dis.add_hook(chainer.optimizer.WeightDecay(0.00001))
zvis = (xp.random.uniform(-1, 1, (100, nz), dtype=np.float32))
+ fig = plt.figure()
+ imw = 96
+ tmp = np.zeros((imw*10, imw*10, 3), dtype=np.float32)
+ ax = fig.add_axes([0, 0, 1, 1])
+ im = ax.imshow(tmp)
+ plt.axis('off')
+ plt.pause(0.01)
for epoch in range(epoch0,n_epoch):
perm = np.random.permutation(n_train)
@@ -171,14 +179,14 @@
# 0: from dataset
# 1: from noise
- print( "load image start ", i)
+ print ("load image start ", i)
x2 = np.zeros((batchsize, 3, 96, 96), dtype=np.float32)
for j in range(batchsize):
try:
rnd = np.random.randint(len(dataset))
rnd2 = np.random.randint(2)
- img = np.asarray(Image.open(io.BytesIO(dataset[rnd]))).astype(np.float32).transpose(2,0,1)
+ img = np.asarray(Image.open(io.BytesIO(dataset[rnd]))).astype(np.float32).transpose(2, 0, 1)
if rnd2==0:
x2[j,:,:,:] = (img[:,:,::-1]-128.0)/128.0
else:
@@ -200,7 +208,7 @@
yl2 = dis(x2)
L_dis += F.softmax_cross_entropy(yl2, Variable(xp.zeros(batchsize, dtype=np.int32)))
- print( "forward done")
+ print ("forward done")
o_gen.zero_grads()
L_gen.backward()
@@ -214,28 +222,33 @@
sum_l_dis += L_dis.data.get()
print ("backward done")
+ z = zvis
+ z[50:,:] = (xp.random.uniform(-1, 1, (50, nz), dtype=np.float32))
+ z = Variable(z)
+ x = gen(z, test=True)
+ x = x.data.get()
- if i%image_save_interval==0:
- pylab.rcParams['figure.figsize'] = (16.0,16.0)
- pylab.clf()
- vissize = 100
+ if i%image_show_interval==0:
z = zvis
z[50:,:] = (xp.random.uniform(-1, 1, (50, nz), dtype=np.float32))
z = Variable(z)
x = gen(z, test=True)
x = x.data.get()
for i_ in range(100):
- tmp = ((np.vectorize(clip_img)(x[i_,:,:,:])+1)/2).transpose(1,2,0)
- pylab.subplot(10,10,i_+1)
- pylab.imshow(tmp)
- pylab.axis('off')
- pylab.savefig('%s/vis_%d_%d.png'%(out_image_dir, epoch,i))
+ rr = int(i_/10)*imw
+ cc = (i_%10)*imw
+ tmp[rr:rr+imw ,cc:cc+imw, :] = ((np.vectorize(clip_img)(x[i_,:,:,:])+1)/2).transpose(1,2,0)
+ im.set_array(tmp)
+ plt.pause(0.001)
+ if i%image_save_interval == 0:
+ plt.savefig('%s/vis_%d_%d.png'%(out_image_dir, epoch,i))
+ print('image saved')
serializers.save_hdf5("%s/dcgan_model_dis_%d.h5"%(out_model_dir, epoch),dis)
serializers.save_hdf5("%s/dcgan_model_gen_%d.h5"%(out_model_dir, epoch),gen)
serializers.save_hdf5("%s/dcgan_state_dis_%d.h5"%(out_model_dir, epoch),o_dis)
serializers.save_hdf5("%s/dcgan_state_gen_%d.h5"%(out_model_dir, epoch),o_gen)
- print( 'epoch end', epoch, sum_l_gen/n_train, sum_l_dis/n_train)
+ print ('epoch end', epoch, sum_l_gen/n_train, sum_l_dis/n_train)
学習初期はただのノイズ
9epochくらい回した後