1. 程式人生 > >對抗神經網路學習(六)——BEGAN實現不同人臉的生成(tensorflow實現)



BEGAN,即邊界平衡GAN(Boundary Equilibrium GAN),是DavidBerthelot等人[1]於2017年03月提出的一種方法。傳統的GAN是利用判別器去評估生成器生成的圖片和真實圖片的資料分佈是否一致,而BEGAN則代替了這種概率估計的方法,作者認為只要分佈之間的誤差分佈相近的話,那就可以認為這些分佈相近。同時作者又對網路結構進行了改進,並取得了比較好的實驗效果。










In this paper, we make the following contributions: 

• A GAN with a simple yet robust architecture, standard training procedure with fast and stable convergence.(更為魯棒的GAN,更快速穩定收斂)

• An equilibrium concept that balances the power of the discriminator against the generator. (一種判別器於生成器的平衡概念)

• A new way to control the trade-off between image diversity and visual quality. (影象多樣性和生成質量的控制)

• An approximate measure of convergence. To our knowledge the only other published measure is from Wasserstein GAN [1] (WGAN), which will be discussed in the next section.(關於收斂的近似評估的討論,當然作者的靈感來自於WGAN)


We use an auto-encoder as a discriminator as was first proposed in EBGAN [21]. While typical GANs try to match data distributions directly, our method aims to match auto-encoder loss distributions using a loss derived from the Wasserstein distance. This is done using a typical GAN objective with the addition of an equilibrium term to balance the discriminator and the generator. Our method has an easier training procedure and uses a simpler neural network architecture compared to typical GAN techniques.










-- main.py                          (主要執行檔案)
-- model.py                         (BEGAN模型檔案)
-- utils.py                         (相關函式檔案)
-- data                             (訓練資料資料夾)
    |------ img_align_celeba_png
                |------ image01.png
                |------ image02.png
                |------ ...




CelebFaces Attributes Dataset (CelebA) is a large-scale face attributes dataset with more than 200K celebrity images, each with 40 attribute annotations. The images in this dataset cover large pose variations and background clutter. CelebA has large diversities, large quantities, and rich annotations, including

  • 10,177 number of identities,

  • 202,599 number of face images, and

  • 5 landmark locations, 40 binary attributes annotations per image.

可以得知,該資料集採集了10177個樣本,共有資料202599張人臉影像,這裡我們需要下載相關人臉資料,從官網地址向下拉,看到下面的介面,選擇Align&Cropped Images進行下載:





import math
import os

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

# 根據影象路徑來讀取影象,並對影象進行裁剪,由於人臉基本都是處在影象的正中央,因此直接裁剪中心部分即可
def get_image(image_path, width, height, mode):
    Read image from image_path
    :param image_path: Path of image
    :param width: Width of image
    :param height: Height of image
    :param mode: Mode of image
    :return: Image data
    image = Image.open(image_path)

    if image.size != (width, height):  # HACK - Check if image is from the CELEBA dataset
        # Remove most pixels that aren't part of a face
        face_width = face_height = 108
        j = (image.size[0] - face_width) // 2
        i = (image.size[1] - face_height) // 2
        image = image.crop([j, i, j + face_width, i + face_height])
        image = image.resize([width, height], Image.BILINEAR)

    return np.array(image.convert(mode))

# 將讀取的影象分批
def get_batch(image_files, width, height, mode):
    data_batch = np.array(
        [get_image(sample_file, width, height, mode) for sample_file in image_files]).astype(np.float32)

    # Make sure the images are in 4 dimensions
    if len(data_batch.shape) < 4:
        data_batch = data_batch.reshape(data_batch.shape + (1,))

    return data_batch

# 構建資料集類,需要用到上述兩個函式
class Dataset(object):

    def __init__(self, data_files):
        param data_files: List of files in the database
        IMAGE_WIDTH = 64
        IMAGE_HEIGHT = 64

        self.image_mode = 'RGB'
        image_channels = 3

        self.data_files = data_files
        self.shape = len(data_files), IMAGE_WIDTH, IMAGE_HEIGHT, image_channels

    def get_batches(self, batch_size):
        Generate batches
        :param batch_size: Batch Size
        :return: Batches of data
        IMAGE_MAX_VALUE = 255

        current_index = 0
        while current_index + batch_size <= self.shape[0]:
            data_batch = get_batch(
                self.data_files[current_index:current_index + batch_size],

            current_index += batch_size

            yield data_batch / IMAGE_MAX_VALUE - 0.5

# 將生成的影象放置在一起
def images_square_grid(images, mode):
    Save images as a square grid
    :param images: Images to be used for the grid
    :param mode: The mode to use for images
    :return: Image of images in a square grid
    # Get maximum size for square grid of images
    save_size = math.floor(np.sqrt(images.shape[0]))

    # Scale to 0-255
    images = (((images - images.min()) * 255) /
              (images.max() - images.min())).astype(np.uint8)

    # Put images in a square arrangement
    images_in_square = np.reshape(
        images[:save_size * save_size],
        (save_size, save_size, images.shape[1], images.shape[2], images.shape[3]))
    if mode == 'L':
        images_in_square = np.squeeze(images_in_square, 4)

    # Combine images to grid image
    new_im = Image.new(
        mode, (images.shape[1] * save_size, images.shape[2] * save_size))
    for col_i, col_images in enumerate(images_in_square):
        for image_i, image in enumerate(col_images):
            im = Image.fromarray(image, mode)
                im, (col_i * images.shape[1], image_i * images.shape[2]))

    return new_im

# 對最終的結果進行繪製並儲存
def save_plot(data, title, image_mode=None, isImage=False):
    Save images or plot to file on the out folder
    Can also save stacked plots
    if not os.path.exists('out/'):

    fig = plt.figure()

    if isImage:
        cmap = None if image_mode == 'RGB' else 'gray'
        plt.imshow(data, cmap=cmap)


        if type(data) == list:
            for i in data:

    fig.savefig('out/' + title)

# 對生成器的結果進行繪製,需要用到上述兩個函式
def show_generator_output(sess, generator, input_z, example_z, out_channel_dim, image_mode, num):
    Show example output for the generator
    :param sess: TensorFlow session
    :param n_images: Number of Images to display
    :param input_z: Input Z Tensor
    :param out_channel_dim: The number of channels in the output image
    :param image_mode: The mode to use for images ("RGB" or "L")

    samples = sess.run(
        generator(input_z, out_channel_dim, False),
        feed_dict={input_z: example_z})

    images_grid = images_square_grid(samples, image_mode)
    save_plot(images_grid, '{}.png'.format(num), image_mode, True)

# 進行高斯平滑
def smooth(list, degree=5):
    By Scott W Harden from www.swharden.com
    window = degree * 2 - 1
    weight = np.array([1.0] * window)
    weightGauss = []

    for i in range(window):
        i = i - degree + 1
        frac = i / float(window)
        gauss = 1 / (np.exp((4 * (frac)) ** 2))

    weight = np.array(weightGauss) * weight
    smoothed = [0.0] * (len(list) - window)

    for i in range(len(smoothed)):
        smoothed[i] = sum(np.array(list[i:i + window]) * weight) / sum(weight)

    return smoothed



import tensorflow as tf
from tensorflow.python.ops import math_ops
from tensorflow.python.framework import ops
import numpy as np

class BEGAN(object):
    def __init__(self, place_holder=''):
        self.place_holder = place_holder
        # pass

    def model_inputs(self, image_width, image_height, image_channels, z_dim):
        Create the model inputs/tensors
        inputs_real = tf.placeholder(
            tf.float32, (None, image_width, image_height, image_channels), name='input_real')
        inputs_z = tf.placeholder(tf.float32, (None, z_dim), name='input_z')
        learning_rate = tf.placeholder(tf.float32, [], name='learning_rate')
        k_t = tf.placeholder(tf.float32, name='k_t')

        return inputs_real, inputs_z, learning_rate, k_t

    # default aplha is 0.2, 0.01 works best for this example
    # Function from TensorFlow v1.4 for backwards compatability
    def leaky_relu(self, features, alpha=0.01, name=None):
        with ops.name_scope(name, "LeakyRelu", [features, alpha]):
            features = ops.convert_to_tensor(features, name="features")
            alpha = ops.convert_to_tensor(alpha, name="alpha")

            return math_ops.maximum(alpha * features, features)

    def fully_connected(self, x, output_shape):
        # flatten and dense
        shape = x.get_shape().as_list()
        dim = np.prod(shape[1:])

        x = tf.reshape(x, [-1, dim])
        x = tf.layers.dense(x, output_shape, activation=None)

        return x

    def decoder(self, h, n, img_dim, channel_dim):
        Reconstruction network
        h = tf.layers.dense(h, img_dim * img_dim * n, activation=None)
        h = tf.reshape(h, (-1, img_dim, img_dim, n))

        conv1 = tf.layers.conv2d(
            h, n, 3, padding="same", activation=self.leaky_relu)
        conv1 = tf.layers.conv2d(
            conv1, n, 3, padding="same", activation=self.leaky_relu)

        upsample1 = tf.image.resize_nearest_neighbor(
            conv1, size=(img_dim * 2, img_dim * 2))

        conv2 = tf.layers.conv2d(
            upsample1, n, 3, padding="same", activation=self.leaky_relu)
        conv2 = tf.layers.conv2d(
            conv2, n, 3, padding="same", activation=self.leaky_relu)

        upsample2 = tf.image.resize_nearest_neighbor(
            conv2, size=(img_dim * 4, img_dim * 4))

        conv3 = tf.layers.conv2d(
            upsample2, n, 3, padding="same", activation=self.leaky_relu)
        conv3 = tf.layers.conv2d(
            conv3, n, 3, padding="same", activation=self.leaky_relu)

        conv4 = tf.layers.conv2d(conv3, channel_dim, 3,
                                 padding="same", activation=None)

        return conv4

    def encoder(self, images, n, z_dim, channel_dim):
        Feature extraction network
        conv1 = tf.layers.conv2d(
            images, n, 3, padding="same", activation=self.leaky_relu)

        conv2 = tf.layers.conv2d(
            conv1, n, 3, padding="same", activation=self.leaky_relu)
        conv2 = tf.layers.conv2d(
            conv2, n * 2, 3, padding="same", activation=self.leaky_relu)

        subsample1 = tf.layers.conv2d(
            conv2, n * 2, 3, strides=2, padding='same')

        conv3 = tf.layers.conv2d(subsample1, n * 2, 3,
                                 padding="same", activation=self.leaky_relu)
        conv3 = tf.layers.conv2d(
            conv3, n * 3, 3, padding="same", activation=self.leaky_relu)

        subsample2 = tf.layers.conv2d(
            conv3, n * 3, 3, strides=2, padding='same')

        conv4 = tf.layers.conv2d(subsample2, n * 3, 3,
                                 padding="same", activation=self.leaky_relu)
        conv4 = tf.layers.conv2d(
            conv4, n * 3, 3, padding="same", activation=self.leaky_relu)

        h = self.fully_connected(conv4, z_dim)

        return h

    def discriminator(self, images, z_dim, channel_dim, reuse=False):
        Create the discriminator network: The autoencoder
        with tf.variable_scope('discriminator', reuse=reuse):
            x = self.encoder(images, 64, z_dim, channel_dim)
            x = self.decoder(x, 64, 64 // 4, channel_dim)

            return x

    def generator(self, z, channel_dim, is_train=True):
        Create the generator network: Only the encoder part
        reuse = False if is_train else True
        with tf.variable_scope('generator', reuse=reuse):
            x = self.decoder(z, 64, 64 // 4, channel_dim)

            return x

    def model_loss(self, input_real, input_z, channel_dim, z_dim, k_t):
        Get the loss for the discriminator and generator
        g_model_fake = self.generator(input_z, channel_dim, is_train=True)
        d_model_real = self.discriminator(input_real, z_dim, channel_dim)
        d_model_fake = self.discriminator(
            g_model_fake, z_dim, channel_dim, reuse=True)

        # l1 loss
        d_real = tf.reduce_mean(tf.abs(input_real - d_model_real))
        d_fake = tf.reduce_mean(tf.abs(g_model_fake - d_model_fake))

        d_loss = d_real - k_t * d_fake
        g_loss = d_fake

        return d_loss, g_loss, d_real, d_fake

    def model_opt(self, d_loss, g_loss, learning_rate, beta1, beta2=0.999):
        Get optimization operations
        # Get variables
        g_vars = tf.get_collection(
            tf.GraphKeys.GLOBAL_VARIABLES, "generator")
        d_vars = tf.get_collection(
            tf.GraphKeys.GLOBAL_VARIABLES, "discriminator")

        # Optimize
        d_train_opt = tf.train.AdamOptimizer(
            learning_rate, beta1=beta1, beta2=beta2).minimize(d_loss, var_list=d_vars)
        g_train_opt = tf.train.AdamOptimizer(
            learning_rate, beta1=beta1, beta2=beta2).minimize(g_loss, var_list=g_vars)

        return d_train_opt, g_train_opt



from models import BEGAN
import tensorflow as tf
from glob import glob
import numpy as np
import utils
import math
import os

def train(model, epoch_count, batch_size, z_dim, star_learning_rate, beta1, beta2, get_batches, data_shape, image_mode):

    input_real, input_z, lrate, k_t = model.model_inputs(*(data_shape[1:]), z_dim)

    d_loss, g_loss, d_real, d_fake = model.model_loss(
        input_real, input_z, data_shape[3], z_dim, k_t)

    d_opt, g_opt = model.model_opt(d_loss, g_loss, lrate, beta1, beta2)

    losses = []
    iter = 0

    epoch_drop = 3

    lam = 1e-3
    gamma = 0.5
    k_curr = 0.0

    test_z = np.random.uniform(-1, 1, size=(16, z_dim))

    with tf.Session() as sess:

        for epoch_i in range(epoch_count):

            learning_rate = star_learning_rate * \
                math.pow(0.2, math.floor((epoch_i + 1) / epoch_drop))
            for batch_images in get_batches(batch_size):
                iter += 1
                batch_images *= 2

                batch_z = np.random.uniform(-1, 1, size=(batch_size, z_dim))

                _, d_real_curr = sess.run([d_opt, d_real], feed_dict={
                                          input_z: batch_z, input_real: batch_images, lrate: learning_rate, k_t: k_curr})

                _, d_fake_curr = sess.run([g_opt, d_fake], feed_dict={
                                          input_z: batch_z, input_real: batch_images, lrate: learning_rate, k_t: k_curr})

                k_curr = k_curr + lam * (gamma * d_real_curr - d_fake_curr)

                # save convergence measure
                if iter % 100 == 0:
                    measure = d_real_curr + \
                        np.abs(gamma * d_real_curr - d_fake_curr)

                    print("Epoch {}/{}, batch {}...".format(epoch_i + 1, epoch_count, iter),
                          'Convergence measure: {:.4}'.format(measure))

                # save test and batch images
                if iter % 700 == 0:
                        sess, model.generator, input_z, batch_z, data_shape[3], image_mode, 'batch-' + str(iter))

                        sess, model.generator, input_z, test_z, data_shape[3], image_mode, 'test-' + str(iter))

        print('Training steps: ', iter)

        losses = np.array(losses)

        utils.save_plot([losses, utils.smooth(losses)],

if __name__ == '__main__':
    batch_size = 16
    z_dim = 64  
    learning_rate = 0.0001
    beta1 = 0.5
    beta2 = 0.999
    epochs = 20

    data_dir = './data/'

    model = BEGAN()

    celeba_dataset = utils.Dataset(glob(os.path.join(data_dir, 'img_align_celeba_png/*.png')))

    with tf.Graph().as_default():






當然,我目前並沒有訓練完,不過已經可以看到,大約在訓練8000左右個batch 的時候,生成的結果已經非常好了。不過隨著生成結果的質量越來越好,噪聲的問題也越來越明顯。



人臉生成的影象為(上面一行為隨機生成的人臉,下面一行為固定生成的人臉,選擇batch的值分別為~100000, ~150000, ~200000, ~250000):





import os
import hashlib
from urllib.request import urlretrieve
import zipfile
import shutil
from tqdm import tqdm

def download_extract(database_name, data_path):
    Download and extract database
    :param database_name: Database name
    url = 'https://s3-us-west-1.amazonaws.com/udacity-dlnfd/datasets/celeba.zip'
    hash_code = '00d2c5bc6d35e252742224ab0c1e8fcb'
    extract_path = os.path.join(data_path, 'img_align_celeba')
    save_path = os.path.join(data_path, 'celeba.zip')
    extract_fn = _unzip

    if os.path.exists(extract_path):
        print('Found {} Data'.format(database_name))

    if not os.path.exists(data_path):

    if not os.path.exists(save_path):
        with DLProgress(unit='B', unit_scale=True, miniters=1, desc='Downloading {}'.format(database_name)) as pbar:

    assert hashlib.md5(open(save_path, 'rb').read()).hexdigest() == hash_code, \
        '{} file is corrupted.  Remove the file and try again.'.format(

        extract_fn(save_path, extract_path, database_name, data_path)
    except Exception as err:
        # Remove extraction folder if there is an error
        raise err

    # Remove compressed data

def _unzip(save_path, _, database_name, data_path):
    Unzip wrapper with the same interface as _ungzip
    :param save_path: The path of the gzip files
    :param database_name: Name of database
    :param data_path: Path to extract to
    :param _: HACK - Used to have to same interface as _ungzip
    print('Extracting {}...'.format(database_name))
    with zipfile.ZipFile(save_path) as zf:

class DLProgress(tqdm):
    Handle Progress Bar while Downloading
    last_block = 0

    def hook(self, block_num=1, block_size=1, total_size=None):
        A hook function that will be called once on establishment of the network connection and
        once after each block read thereafter.
        :param block_num: A count of blocks transferred so far
        :param block_size: Block size in bytes
        :param total_size: The total size of the file. This may be -1 on older FTP servers which do not return
                            a file size in response to a retrieval request.
        self.total = total_size
        self.update((block_num - self.last_block) * block_size)
        self.last_block = block_num

if __name__ == '__main__':
    download_extract('celeba', data_dir)