Home > MachineLearning > Main text

GAN


Tag: python, machine learning

目录


概述

生成对抗网络(Generative Adversarial Network, GAN),属于一种生成模型,通过生成样本,以训练模型,直到模型无法区分样本来源(真实 vs 生成)为止。


概念

1、网络结构


2、概念说明

组成部分:

  • 生成器(Generator):生成样本,特征大小和训练集一样。比如这里通过增加随机噪音(这个噪音就是输入),生成与数字图片大小相同的像素图片(这就是输出)。
  • 判别器(Discriminator):判别输入的来源。比如这里是根据输入的图片的像素特征(输入),判断其是来源于真实的数据集还是生成的数据集(输出)。

训练过程:

  • 1)初始化生成器(G)和判别器(D)的参数
  • 2)从真实样本中随机抽取n个样本,根据随机噪音使用生成器生成n个样本。保持生成器(参数)不变,训练判别器,使其能够很好的区分真实样本和生成样本。【现在判别器训练得很好了】
  • 3)更新生成器,使用较小的学习速率。训练生成器,使得生成样本和真实样本之间的误差尽可能的小(尽可能使判别器区分不出来)。
  • 4)多次迭代,最终判别器不能区分输入来源(输出真实或生成概率为50%)。

损失:

  • 交叉熵(cross entropy):在信息论中,基于相同事件测度的两个概率分布pq的交叉熵是指,当基于一个“非自然”(相对于“真实”分布p而言)的概率分布q进行编码时,在事件集合中唯一标识一个事件所需要的平均比特数(bit)[from wiki]。交叉熵越小,则两个分布越接近。cross_entropy.png
  • 相对熵衡量两个分布的相似性,这里在讨论autoencoder的时候也说了,两个分布越相近时,熵值越小。这里是固定其中的一个值p=0.2,当q取不同的值时,相对熵(KL散度)的大小,从图中可以看出,当q取和p相等的值(分布一样)时,熵值最小。

具体在GAN中,损失函数如下:GAN_loss.png


示例:MNIST手写识别

模型框架:

keras实现版本:

# https://github.com/eriklindernoren/Keras-GAN/blob/master/gan/gan.py

from __future__ import print_function, division

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam

import matplotlib.pyplot as plt

import sys

import numpy as np

class GAN():
    def __init__(self):
        self.img_rows = 28
        self.img_cols = 28
        self.channels = 1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.latent_dim = 100

        optimizer = Adam(0.0002, 0.5)

        # Build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy',
            optimizer=optimizer,
            metrics=['accuracy'])

        # Build the generator
        self.generator = self.build_generator()

        # The generator takes noise as input and generates imgs
        z = Input(shape=(self.latent_dim,))
        img = self.generator(z)

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # The discriminator takes generated images as input and determines validity
        validity = self.discriminator(img)

        # The combined model  (stacked generator and discriminator)
        # Trains the generator to fool the discriminator
        self.combined = Model(z, validity)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)


    def build_generator(self):

        model = Sequential()

        model.add(Dense(256, input_dim=self.latent_dim))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(np.prod(self.img_shape), activation='tanh'))
        model.add(Reshape(self.img_shape))

        model.summary()

        noise = Input(shape=(self.latent_dim,))
        img = model(noise)

        return Model(noise, img)

    def build_discriminator(self):

        model = Sequential()

        model.add(Flatten(input_shape=self.img_shape))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(256))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(1, activation='sigmoid'))
        model.summary()

        img = Input(shape=self.img_shape)
        validity = model(img)

        return Model(img, validity)

    def train(self, epochs, batch_size=128, sample_interval=50):

        # Load the dataset
        (X_train, _), (_, _) = mnist.load_data()

        # Rescale -1 to 1
        X_train = X_train / 127.5 - 1.
        X_train = np.expand_dims(X_train, axis=3)

        # Adversarial ground truths
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        for epoch in range(epochs):

            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Select a random batch of images
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs = X_train[idx]

            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

            # Generate a batch of new images
            gen_imgs = self.generator.predict(noise)

            # Train the discriminator
            d_loss_real = self.discriminator.train_on_batch(imgs, valid)
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # ---------------------
            #  Train Generator
            # ---------------------

            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

            # Train the generator (to have the discriminator label samples as valid)
            g_loss = self.combined.train_on_batch(noise, valid)

            # Plot the progress
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

            # If at save interval => save generated image samples
            if epoch % sample_interval == 0:
                self.sample_images(epoch)

    def sample_images(self, epoch):
        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        gen_imgs = self.generator.predict(noise)

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/%d.png" % epoch)
        plt.close()


if __name__ == '__main__':
    gan = GAN()
    gan.train(epochs=30000, batch_size=32, sample_interval=200)

pytorch实现版本:

# https://github.com/BeierZhu/GAN-MNIST-Pytorch/blob/master/main.py

参考



If you link this blog, please refer to this page, thanks!
Post link:https://tsinghua-gongjing.github.io/posts/GAN.html

Previous: 深度学习调参技巧