详解如何使用Keras实现Wassertein GAN

开发 开发工具
在本文中,作者将用自己的在 Keras 上的代码来向大家简要介绍一下WGAN。

在阅读论文 Wassertein GAN 时,作者发现理解它最好的办法就是用代码来实现其内容。于是在本文中,作者将用自己的在 Keras 上的代码来向大家简要介绍一下WGAN。

[[205674]]

何为 GAN?

GAN,亦称为生成对抗网络(Generative Adversarial Network),它是生成模型中的一类——即一种能够通过观察来自特定分布的训练数据,进而尝试对这个分布进行预测的模型。这个模型新获取的样本「看起来」会和最初的训练样本类似。有些生成模型只会去学习训练数据分布的参数,有一些模型则只能从训练数据分布中提取样本,而有一些则可以二者兼顾。

目前,已经存在了很多种类的生成模型:全可见信念网络(Fully Visible Belief Network)、变分自编码器(Variational Autoencoder)、玻尔兹曼机(Boltzmann Machine),生成随机网络(Generative Stochastic Network),像素循环神经网络(PixelRNN)等等。以上的模型都因其所表征或接近的训练数据密度而有所区别。一些模型会去精细的表征训练数据,另一些则会以某种方式去和训练数据进行互动——比如说生成模型。GAN 就是这里所说的后者。大部分生成模型的学习原则都可被概括为「最大化相似度预测」——即让模型的参数能够尽可能地与训练数据相似。

GAN 的工作方式可以看成一个由两部分构成的游戏:生成器(Generator/G)和判别器(Discriminator/D)(一般而言,这两者都由神经网络构成)。生成器随机将一个噪声作为自己的输入,然后尝试去生成一个样本,目的是让判别器无法判断这个样本是来自训练数据还是来自生成器的。在判别器这里,我们让它以监督学习方式来工作,具体而言就是让它观察真实样本和生成器生成的样本,并且同时用标签告诉它这些样本分别来自哪里。在某种意义上,判别器可以代替固定的损失函数,并且尝试学习与训练数据分布相关的模式。

何为 Wasserstein GAN?

就其本质而言,任何生成模型的目标都是让模型(习得地)的分布与真实数据之间的差异达到最小。然而,传统 GAN 中的判别器 D 并不会当模型与真实的分布重叠度不够时去提供足够的信息来估计这个差异度——这导致生成器得不到一个强有力的反馈信息(特别是在训练之初),此外生成器的稳定性也普遍不足。

Wasserstein GAN 在原来的基础之上添加了一些新的方法,让判别器 D 去拟合模型与真实分布之间的 Wasserstein 距离。Wassersterin 距离会大致估计出「调整一个分布去匹配另一个分布还需要多少工作」。此外,其定义的方式十分值得注意,它甚至可以适用于非重叠的分布。

为了让判别器 D 可以有效地拟合 Wasserstein 距离:

  • 其权重必须在紧致空间(compact space)之内。为了达到这个目的,其权重需要在每步训练之后,被调整到-0.01 到+0.01 的闭区间上。然而,论文作者承认,虽然这对于裁剪间距的选择并不是理想且高敏感的(highly sensitive),但是它在实践中却是有效的。更多信息可参见论文 6 到 7 页。
  • 由于判别器被训练到了更好的状态上,所以它可以为生成器提供一个有用的梯度。
  • 判别器顶层需要有线性激活。
  • 它需要一个本质上不会修改判别器输出的价值函数。
    1. K.mean(y_true * y_pred) 

以 keras 这段损失函数为例:

  • 这里采用 mean 来适应不同的批大小以及乘积。
  • 预测的值通过乘上 element(可使用的真值)来最大化输出结果(优化器通常会将损失函数的值最小化)。

论文作者表示,与 vanlillaGAN 相比,WGAN 有一下优点:

  • 有意义的损失指标。判别器 D 的损失可以与生成样本(这些样本使得可以更少地监控训练过程)的质量很好地关联起来。
  • 稳定性得到改进。当判别器 D 的训练达到了最佳,它便可以为生成器 G 的训练提供一个有用的损失。这意味着,对判别器 D 和生成器 G 的训练不必在样本数量上保持平衡(相反,在 Vanilla GAN 方法中而这是平衡的)。此外,作者也表示,在实验中,他们的 WGAN 模型没有发生过一次崩溃的情况。

开始编程!

我们会在 Keras 上实现 ACGAN 的 Wasserstein variety。在 ACGAN 这种生成对抗网络中,其判别器 D 不仅可以预测样本的真实与否,同时还可以将其进行归类。

下方代码附有部分解释。

[1] 导入库文件:

  1. import os 
  2.  
  3. import matplotlib.pyplot as plt 
  4. %matplotlib inline 
  5. %config InlineBackend.figure_format = 'retina' # enable hi-res output 
  6.  
  7. import numpy as np 
  8. import tensorflow as tf 
  9.  
  10. import keras.backend as K 
  11. from keras.datasets import mnist 
  12. from keras.layers import * 
  13. from keras.models import * 
  14. from keras.optimizers import * 
  15. from keras.initializers import * 
  16. from keras.callbacks import * 
  17. from keras.utils.generic_utils import Progbar 

[2].Runtime 配置

  1. # random seed 
  2. RND = 777 
  3.  
  4. # output settings 
  5. RUN = 'B' 
  6. OUT_DIR = 'out/' + RUN 
  7. TENSORBOARD_DIR = '/tensorboard/wgans/' + RUN 
  8. SAVE_SAMPLE_IMAGES = False 
  9.  
  10. # GPU # to run on 
  11. GPU = "0" 
  12.  
  13. BATCH_SIZE = 100 
  14. ITERATIONS = 20000 
  15.  
  16. # size of the random vector used to initialize G 
  17. Z_SIZE = 100 

[3]生成器 G 每进行一次迭代,判别器 D 都需要进行 D_ITERS 次迭代。

  • 由于在 WGAN 中让判别器质量能够优化这件事更加重要,所以判别器 D 与生成器 G 在训练次数上呈非对称比例。
  • 在论文的 v2 版本中,判别器 D 在生成器 G 每 1000 次迭代的前 25 次都会训练 100 次,此外,判别器也会当生成器每进行了 500 次迭代以后训练 100 次。
    1. D_ITERS = 5 

[4]其它准备:

  1. # create output dirif not os.path.isdir(OUT_DIR): os.makedirs(OUT_DIR) 
  2.  
  3. # make only specific GPU to be utilized 
  4. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"os.environ["CUDA_VISIBLE_DEVICES"] = GPU 
  5.  
  6. # seed random generator for repeatability 
  7. np.random.seed(RND) 
  8.  
  9. # force Keras to use last dimension for image channels 
  10. K.set_image_dim_ordering('tf') 

[5]判别器的损失函数:

  • 由于判别器 D 一方面力图在训练数据分布与生成器 G 生成的数据之间习得一个 Wasserstein 距离的拟合,另一方面判别器 D 又是线性激活的,所以这里我们不需要去修改它的输出结果。
  • 由于已经使用了损失函数 Mean,所以我们可以在不同的批大小之间比较输出结果。
  • 预测结果等于真值(true value)与元素的点乘(element-wise multiplication),为了让判别器 D 的输出能够最大化(通常,优化器都力图去让损失函数的值达到最小),真值需要取-1。
    1. def d_loss(y_true, y_pred):    return K.mean(y_true * y_pred) 

[6].创建判别器 D

判别器将图像作为输入,然后给出两个输出:

  • 用线性激活来评价生成图像的「虚假度」(最大化以用于生成图像)。
  • 用 softmax 激活来对图像种类进行预测。
  • 由于权重是从标准差为 0.02 的正态分布中初始化出来的,所以最初的剪裁不会去掉所有的权重。
    1. def create_D(): 
    2.   
    3.     # weights are initlaized from normal distribution with below params 
    4.     weight_init = RandomNormal(mean=0., stddev=0.02) 
    5.  
    6.     input_image = Input(shape=(28, 28, 1), name='input_image'
    7.  
    8.     x = Conv2D(        32, (3, 3), 
    9.         padding='same'
    10.         name='conv_1'
    11.         kernel_initializer=weight_init)(input_image) 
    12.     x = LeakyReLU()(x) 
    13.     x = MaxPool2D(pool_size=2)(x) 
    14.     x = Dropout(0.3)(x) 
    15.  
    16.     x = Conv2D(        64, (3, 3), 
    17.         padding='same'
    18.         name='conv_2'
    19.         kernel_initializer=weight_init)(x) 
    20.     x = MaxPool2D(pool_size=1)(x) 
    21.     x = LeakyReLU()(x) 
    22.     x = Dropout(0.3)(x) 
    23.  
    24.     x = Conv2D(        128, (3, 3), 
    25.         padding='same'
    26.         name='conv_3'
    27.         kernel_initializer=weight_init)(x) 
    28.     x = MaxPool2D(pool_size=2)(x) 
    29.     x = LeakyReLU()(x) 
    30.     x = Dropout(0.3)(x) 
    31.  
    32.     x = Conv2D(        256, (3, 3), 
    33.         padding='same'
    34.         name='coonv_4'
    35.         kernel_initializer=weight_init)(x) 
    36.     x = MaxPool2D(pool_size=1)(x) 
    37.     x = LeakyReLU()(x) 
    38.     x = Dropout(0.3)(x) 
    39.  
    40.     features = Flatten()(x) 
    41.  
    42.     output_is_fake = Dense(        1, activation='linear'name='output_is_fake')(features) 
    43.  
    44.     output_class = Dense(        10, activation='softmax'name='output_class')(features)    return Model( 
    45.         inputs=[input_image], outputs=[output_is_fake,  

[7].创建生成器

生成器有两个输入:

  • 一个尺寸为Z_SIZE的潜在随机变量。
  • 我们希望生成的数字类型(integer o 到 9)。

为了加入这些输入(input),integer 类型会在内部转换成一个1 x DICT_LEN(在本例中DICT_LEN = 10)的稀疏向量,然后乘上嵌入的维度为 DICT_LEN x Z_SIZE的矩阵,结果得到一个维度为1 x Z_SIZE的密集向量。然后该向量乘上(点 乘)可能的输入(input),经过多个上菜样和卷积层,最后其维度就可以和训练图像的维度匹配了。

  1. def create_G(Z_SIZEZ_SIZE=Z_SIZE): 
  2.     DICT_LEN = 10 
  3.     EMBEDDING_LEN = Z_SIZE 
  4.  
  5.     # weights are initialized from normal distribution with below params 
  6.     weight_init = RandomNormal(mean=0., stddev=0.02) 
  7.  
  8.     # class#    input_class = Input(shape=(1, ), dtype='int32'name='input_class'
  9.     # encode class# to the same size as Z to use hadamard multiplication later on 
  10.     e = Embedding
  11.         DICT_LEN, EMBEDDING_LEN, 
  12.         embeddings_initializer='glorot_uniform')(input_class) 
  13.     embedded_class = Flatten(name='embedded_class')(e) 
  14.  
  15.     # latent var 
  16.     input_z = Input(shape=(Z_SIZE, ), name='input_z'
  17.  
  18.     # hadamard product 
  19.     h = multiply([input_z, embedded_class], name='h'
  20.  
  21.     # cnn part 
  22.     x = Dense(1024)(h) 
  23.     x = LeakyReLU()(x) 
  24.  
  25.     x = Dense(128 * 7 * 7)(x) 
  26.     x = LeakyReLU()(x) 
  27.     x = Reshape((7, 7, 128))(x) 
  28.  
  29.     x = UpSampling2D(size=(2, 2))(x) 
  30.     x = Conv2D(256, (5, 5), padding='same'kernel_initializer=weight_init)(x) 
  31.     x = LeakyReLU()(x) 
  32.  
  33.     x = UpSampling2D(size=(2, 2))(x) 
  34.     x = Conv2D(128, (5, 5), padding='same'kernel_initializer=weight_init)(x) 
  35.     x = LeakyReLU()(x) 
  36.  
  37.     x = Conv2D(        1, (2, 2), 
  38.         padding='same'
  39.         activation='tanh'
  40.         name='output_generated_image'
  41.         kernel_initializer=weight_init)(x)    return Mode 

[8].将判别器 D 和生成器 G 整合到一个模型中:

  1. D = create_D() 
  2.  
  3. D.compile( 
  4.     optimizer=RMSprop(lr=0.00005), 
  5.     loss=[d_loss, 'sparse_categorical_crossentropy']) 
  6.  
  7. input_z = Input(shape=(Z_SIZE, ), name='input_z_'
  8. input_class = Input(shape=(1, ),name='input_class_'dtype='int32'
  9.  
  10. G = create_G() 
  11.  
  12. # create combined D(G) model 
  13. output_is_fake, output_class = D(G(inputs=[input_z, input_class])) 
  14. DG = Model(inputs=[input_z, input_class], outputs=[output_is_fake, output_class]) 
  15.  
  16. DG.compile( 
  17.     optimizer=RMSprop(lr=0.00005), 
  18.     loss=[d_loss, 'sparse_categorical_crossentropy'] 

[9].加载 MNIST 数据集:

  1. # load mnist data 
  2. (X_train, y_train), (X_test, y_test) = mnist.load_data() 
  3.  
  4. # use all available 70k samples from both train and test sets 
  5. X_train = np.concatenate((X_train, X_test)) 
  6. y_train = np.concatenate((y_train, y_test)) 
  7.  
  8. # convert to -1..1 range, reshape to (sample_i, 28, 28, 1) 
  9. X_train = (X_train.astype(np.float32) - 127.5) / 127.5X_train = np.expand_dims(X_train, axis=3

[10].生成样本以及将指标和图像发送到 TensorBorad 的实用工具:

  1. # save 10x10 sample of generated images 
  2. def generate_samples(n=0save=True): 
  3.  
  4.     zz = np.random.normal(0., 1., (100, Z_SIZE)) 
  5.     generated_classes = np.array(list(range(0, 10)) * 10) 
  6.     generated_images = G.predict([zz, generated_classes.reshape(-1, 1)]) 
  7.  
  8.     rr = []    for c in range(10): 
  9.         rr.append( 
  10.             np.concatenate(generated_images[c * 10:(1 + c) * 10]).reshape(                280, 28)) 
  11.     img = np.hstack(rr)    if save: 
  12.         plt.imsave(OUT_DIR + '/samples_%07d.png' % n, img, cmap=plt.cm.gray)    return img 
  13.  
  14. # write tensorboard summaries 
  15. sw = tf.summary.FileWriter(TENSORBOARD_DIR) 
  16. def update_tb_summary(step, write_sample_images=True): 
  17.  
  18.     s = tf.Summary() 
  19.  
  20.     # losses as is    for names, vals in zip((('D_real_is_fake', 'D_real_class'), 
  21.                             ('D_fake_is_fake', 'D_fake_class'), ('DG_is_fake',                                                                 'DG_class')), 
  22.                            (D_true_losses, D_fake_losses, DG_losses)): 
  23.  
  24.         v = s.value.add() 
  25.         v.simple_value = vals[-1][1] 
  26.         v.tag = names[0] 
  27.  
  28.         v = s.value.add() 
  29.         v.simple_value = vals[-1][2] 
  30.         v.tag = names[1] 
  31.  
  32.     # D loss: -1*D_true_is_fake - D_fake_is_fake 
  33.     v = s.value.add() 
  34.     v.simple_value = -D_true_losses[-1][1] - D_fake_losses[-1][1] 
  35.     v.tag = 'D loss (-1*D_real_is_fake - D_fake_is_fake)' 
  36.  
  37.     # generated image    if write_sample_images: 
  38.         img = generate_samples(step, save=True
  39.         s.MergeFromString(tf.Session().run( 
  40.             tf.summary.image('samples_%07d' % step, 
  41.                              img.reshape([1, *img.shape, 1])))) 
  42.  
  43.     sw.add_summary(s, step) 
  44.     sw.flush() 

[11].训练

训练过程包含了以下步骤:

  1. 解除对判别器 D 权重的控制,让它们变得可学习。
  2. 调整判别器的权重(调整到-0.01 到+0.01 闭区间上)。
  3. 向判别器 D 提供真实的样本,通过在损失函数中将其乘上-1 来尽可能最大化它的输出,最小化它的值。
  4. 向判别器 D 提供假的样本试图最小化其输出。
  5. 按照上文讲述的判别器迭代训练方法重复步骤 3 和 4。
  6. 固定判别器 D 的权重。
  7. 训练一对判别器和生成器,尽力去最小化其输出。由于这种手段优化了生成器 G 的权重,所以前面已经训练好了的权重固定的判别器才会将生成的假样本判断为真图像。
    1. progress_bar = Progbar(target=ITERATIONS
    2.  
    3. DG_losses = [] 
    4. D_true_losses = [] 
    5. D_fake_losses = []for it in range(ITERATIONS):    if len(D_true_losses) > 0: 
    6.         progress_bar.update( 
    7.             it, 
    8.             values=[ # avg of 5 most recent 
    9.                     ('D_real_is_fake', np.mean(D_true_losses[-5:], axis=0)[1]), 
    10.                     ('D_real_class', np.mean(D_true_losses[-5:], axis=0)[2]), 
    11.                     ('D_fake_is_fake', np.mean(D_fake_losses[-5:], axis=0)[1]), 
    12.                     ('D_fake_class', np.mean(D_fake_losses[-5:], axis=0)[2]), 
    13.                     ('D(G)_is_fake', np.mean(DG_losses[-5:],axis=0)[1]), 
    14.                     ('D(G)_class', np.mean(DG_losses[-5:],axis=0)[2]) 
    15.             ] 
    16.         )         
    17.     else: 
    18.         progress_bar.update(it) 
    19.  
    20.     # 1: train D on real+generated images    if (it % 1000) < 25 or it % 500 == 0: # 25 times in 1000, every 500th 
    21.         d_iters = 100 
    22.     else: 
    23.         d_iters = D_ITERS    for d_it in range(d_iters): 
    24.  
    25.         # unfreeze D 
    26.         D.trainable = True        for l in D.layers: l.trainable = True 
    27.  
    28.         # clip D weights        for l in D.layers: 
    29.             weights = l.get_weights() 
    30.             weights = [np.clip(w, -0.01, 0.01) for w in weights] 
    31.             l.set_weights(weights) 
    32.  
    33.         # 1.1: maximize D output on reals === minimize -1*(D(real)) 
    34.  
    35.         # draw random samples from real images 
    36.         index = np.random.choice(len(X_train), BATCH_SIZE, replace=False
    37.         real_images = X_train[index] 
    38.         real_images_classes = y_train[index] 
    39.  
    40.         DD_loss = D.train_on_batch(real_images, [-np.ones(BATCH_SIZE),  
    41.           real_images_classes]) 
    42.         D_true_losses.append(D_loss) 
    43.  
    44.         # 1.2: minimize D output on fakes  
    45.  
    46.         zz = np.random.normal(0., 1., (BATCH_SIZE, Z_SIZE)) 
    47.         generated_classes = np.random.randint(0, 10, BATCH_SIZE) 
    48.         generated_images = G.predict([zz, generated_classes.reshape(-1, 1)]) 
    49.  
    50.         DD_loss = D.train_on_batch(generated_images, [np.ones(BATCH_SIZE), 
    51.           generated_classes]) 
    52.         D_fake_losses.append(D_loss) 
    53.  
    54.     # 2: train D(G) (D is frozen) 
    55.     # minimize D output while supplying it with fakes,  
    56.     # telling it that they are reals (-1) 
    57.  
    58.     # freeze D 
    59.     D.trainable = False    for l in D.layers: l.trainable = False 
    60.  
    61.     zz = np.random.normal(0., 1., (BATCH_SIZE, Z_SIZE))  
    62.     generated_classes = np.random.randint(0, 10, BATCH_SIZE) 
    63.  
    64.     DGDG_loss = DG.train_on_batch( 
    65.         [zz, generated_classes.reshape((-1, 1))], 
    66.         [-np.ones(BATCH_SIZE), generated_classes]) 
    67.  
    68.     DG_losses.append(DG_loss)    if it % 10 == 0: 
    69.         update_tb_summary(it, write_sample_images=(it  

结论

视频的每一秒都是 250 次训练迭代。使用 Wasserstein GAN 的一个好处就是它有着损失与样本质量之间的关系。

附论文地址:https://arxiv.org/pdf/1701.07875.pdf

参考文献

1. Wasserstein GAN paper (https://arxiv.org/pdf/1701.07875.pdf) – Martin Arjovsky, Soumith Chintala, Léon Bottou

2. NIPS 2016 Tutorial: Generative Adversarial Networks (https://arxiv.org/pdf/1701.00160.pdf) – Ian Goodfellow

3. Original PyTorch code for the Wasserstein GAN paper (https://github.com/martinarjovsky/WassersteinGAN)

4. Conditional Image Synthesis with Auxiliary Classifier GANs (https://arxiv.org/pdf/1610.09585v3.pdf) – Augustus Odena, Christopher Olah, Jonathon Shlens

5. Keras ACGAN implementation (https://github.com/lukedeo/keras-acgan) – Luke de Oliveira

6. Code for the article (https://gist.github.com/myurasov/6ecf449b32eb263e7d9a7f6e9aed5dc2)

原文:https://myurasov.github.io/2017/09/24/wasserstein-gan-keras.html?r

【本文是51CTO专栏机构“机器之心”的原创译文,微信公众号“机器之心( id: almosthuman2014)”】

 

戳这里,看该作者更多好文

责任编辑:赵宁宁 来源: 51CTO专栏
相关推荐

2021-08-25 17:03:09

模型人工智能PyTorch

2011-08-15 14:27:51

CocoaRunLoop

2021-11-08 22:59:04

机器学习

2020-09-25 08:49:42

HashMap

2023-04-18 08:00:35

DexKubernetes身份验证

2010-02-26 11:22:16

LitwareHR使用

2010-02-01 09:19:32

WF 4.0

2011-08-24 16:41:38

lua调试器

2017-04-26 09:30:53

卷积神经网络实战

2020-03-04 10:51:35

Python算法脚本语言

2009-11-03 17:08:38

Oracle修改用户权

2022-06-29 09:00:00

前端图像分类模型SQL

2010-12-12 21:01:00

Android控件

2011-08-23 09:56:52

UnicodeLua

2011-03-16 09:05:29

iptablesNAT

2009-11-23 10:31:25

PHP使用JSON

2021-04-09 20:04:34

区块链Go加密

2024-03-22 12:10:39

Redis消息队列数据库

2011-08-25 10:13:32

对leveldb的访问LLServer编译安

2019-01-29 10:27:27

量子计算机芯片超算
点赞
收藏

51CTO技术栈公众号