DeepLearning chainer DCGAN - eiichiromomma/CVMLAB GitHub Wiki

(DeepLearning) chainer-DCGAN

データセット

imagesフォルダを作って画像を入れておく。但し、ファイルサイズが96x96決め打ちなのでImagemagickを入れておいて

mogrify -resize 96x96! *.jpg

で変更できる。"!"がポイント。

DCGAN.py

ファイルが古いので3.0用に雑に変更した。あとは表示部分をmatplotlibへ変換して常時更新表示できるようにした。

--- DCGAN_org.py	2017-06-26 08:34:04.944146151 +0900
+++ DCGAN.py	2017-06-26 12:39:15.849476620 +0900
@@ -1,10 +1,10 @@
+import pickle
 import numpy as np
 from PIL import Image
 import os
-from io import StringIO
 import io
 import math
-import pylab
+import matplotlib.pyplot as plt
 
 
 import chainer
@@ -33,18 +33,19 @@
 n_epoch=10000
 n_train=200000
 image_save_interval = 5000
+image_show_interval = 1000
 
 # read all images
 
 fs = os.listdir(image_dir)
-print( len(fs))
+print (len(fs))
 dataset = []
 for fn in fs:
     f = open('%s/%s'%(image_dir,fn), 'rb')
     img_bin = f.read()
     dataset.append(img_bin)
     f.close()
-print( len(dataset))
+print (len(dataset))
 
 class ELU(function.Function):
 
@@ -160,6 +161,13 @@
     o_dis.add_hook(chainer.optimizer.WeightDecay(0.00001))
 
     zvis = (xp.random.uniform(-1, 1, (100, nz), dtype=np.float32))
+    fig = plt.figure()
+    imw = 96
+    tmp = np.zeros((imw*10, imw*10, 3), dtype=np.float32)
+    ax = fig.add_axes([0, 0, 1, 1])
+    im = ax.imshow(tmp)
+    plt.axis('off')
+    plt.pause(0.01)
     
     for epoch in range(epoch0,n_epoch):
         perm = np.random.permutation(n_train)
@@ -171,14 +179,14 @@
             # 0: from dataset
             # 1: from noise
 
-            print( "load image start ", i)
+            print ("load image start ", i)
             x2 = np.zeros((batchsize, 3, 96, 96), dtype=np.float32)
             for j in range(batchsize):
                 try:
                     rnd = np.random.randint(len(dataset))
                     rnd2 = np.random.randint(2)
 
-                    img = np.asarray(Image.open(io.BytesIO(dataset[rnd]))).astype(np.float32).transpose(2,0,1)
+                    img = np.asarray(Image.open(io.BytesIO(dataset[rnd]))).astype(np.float32).transpose(2, 0, 1)
                     if rnd2==0:
                         x2[j,:,:,:] = (img[:,:,::-1]-128.0)/128.0
                     else:
@@ -200,7 +208,7 @@
             yl2 = dis(x2)
             L_dis += F.softmax_cross_entropy(yl2, Variable(xp.zeros(batchsize, dtype=np.int32)))
             
-            print( "forward done")
+            print ("forward done")
 
             o_gen.zero_grads()
             L_gen.backward()
@@ -214,28 +222,33 @@
             sum_l_dis += L_dis.data.get()
             
             print ("backward done")
+            z = zvis
+            z[50:,:] = (xp.random.uniform(-1, 1, (50, nz), dtype=np.float32))
+            z = Variable(z)
+            x = gen(z, test=True)
+            x = x.data.get()
 
-            if i%image_save_interval==0:
-                pylab.rcParams['figure.figsize'] = (16.0,16.0)
-                pylab.clf()
-                vissize = 100
+            if i%image_show_interval==0:
                 z = zvis
                 z[50:,:] = (xp.random.uniform(-1, 1, (50, nz), dtype=np.float32))
                 z = Variable(z)
                 x = gen(z, test=True)
                 x = x.data.get()
                 for i_ in range(100):
-                    tmp = ((np.vectorize(clip_img)(x[i_,:,:,:])+1)/2).transpose(1,2,0)
-                    pylab.subplot(10,10,i_+1)
-                    pylab.imshow(tmp)
-                    pylab.axis('off')
-                pylab.savefig('%s/vis_%d_%d.png'%(out_image_dir, epoch,i))
+                    rr = int(i_/10)*imw
+                    cc = (i_%10)*imw
+                    tmp[rr:rr+imw ,cc:cc+imw, :] =  ((np.vectorize(clip_img)(x[i_,:,:,:])+1)/2).transpose(1,2,0) 
+                im.set_array(tmp)
+                plt.pause(0.001)
+            if i%image_save_interval == 0:
+                plt.savefig('%s/vis_%d_%d.png'%(out_image_dir, epoch,i))
+                print('image saved')
                 
         serializers.save_hdf5("%s/dcgan_model_dis_%d.h5"%(out_model_dir, epoch),dis)
         serializers.save_hdf5("%s/dcgan_model_gen_%d.h5"%(out_model_dir, epoch),gen)
         serializers.save_hdf5("%s/dcgan_state_dis_%d.h5"%(out_model_dir, epoch),o_dis)
         serializers.save_hdf5("%s/dcgan_state_gen_%d.h5"%(out_model_dir, epoch),o_gen)
-        print( 'epoch end', epoch, sum_l_gen/n_train, sum_l_dis/n_train)
+        print ('epoch end', epoch, sum_l_gen/n_train, sum_l_dis/n_train)

学習初期はただのノイズ

9epochくらい回した後