cartoonGAN是Yang Chen等人於2018年2月提出的一種模型。該模型針對漫畫風格影象生成做了進一步研究,提出了新的GAN網路結構和兩種損失函式,相較於之前的漫畫風格生成的GAN模型,cartoonGAN的生成漫畫風格的影象質量有了明顯提高。







However, existing methods do not produce satisfactory results for cartoonization, due to the fact that (1) cartoon styles have unique characteristics with high level simplification and abstraction, and (2) cartoon images tend to have clear edges, smooth color shading and relatively simple textures, which exhibit significant challenges for texture-descriptor-based loss functions used in existing methods. 




The main contributions of this paper are: 

(1) We propose a dedicated GAN-based approach that effectively learns the mapping from real-world photos to cartoon images using unpaired image sets for training. Our method is able to generate high-quality stylized cartoons, which are substantially better than state-of-the-art methods. When cartoon images from individual artists are used for training, our method is able to reproduce their styles. (提出新的GAN結構,利用非成對資料,實現現實影像到卡通影像的轉換)

(2) We propose two simple yet effective loss functions in GAN-based architecture. In the generative network, to cope with substantial style variation between photos and cartoons, we introduce a semantic loss defined as an ℓ1 sparse regularization in the high-level feature maps of the VGG network. In the discriminator network, we propose an edge-promoting adversarial loss for preserving clear edges. (提出了兩種loss函式。在生成器中引入semantic loss,在判別器中引入邊緣推進對抗loss函式)

(3) We further introduce an initialization phase to improve the convergence of the network to the target manifold. Our method is much more efficient to train than existing methods.(引入初始化階段改善網路到目標流的收斂)



作者還提到,隨機初始化使得傳統的GAN模型高度非線性化,其優化過程容易陷入次優區域性極小值。注意到生成器的生成影像都是具有一定語義內容(semantic content)的,因此在預訓練生成器時只用content loss。









1. 檔案結構


-- utils.py
-- layer.py
-- vgg19.py
-- vgg19.npy                            # 這個檔案需要自己下載,後面會講到
-- edge_smooth.py
-- cartoonGAN.py
-- main.py
-- dataset                              # 資料集檔案  
    |------ trainA                      # cartoon影像資料 
                |------ image1.png
                |------ image2.png
                |------ ......
    |------ trainB                      # 現實影像 
                |------ image1.png
                |------ image2.png
                |------ ......
    |------ trainB_smooth               # 利用edge_smooth可以製作  
                |------ image1.png
                |------ image2.png
                |------ ......
    |------ testA                       # 最終的測試影像 
                |------ image1.png
                |------ image2.png
                |------ ......

2. 資料準備












3. 載入資料的相關檔案utils.py


import tensorflow as tf
from tensorflow.contrib import slim
from scipy import misc
import os
import numpy as np

class ImageData:

    def __init__(self, load_size, channels):
        self.load_size = load_size
        self.channels = channels

    def image_processing(self, filename):
        x = tf.read_file(filename)
        x_decode = tf.image.decode_jpeg(x, channels=self.channels)
        img = tf.image.resize_images(x_decode, [self.load_size, self.load_size])
        img = tf.cast(img, tf.float32) / 127.5 - 1

        return img

def load_test_data(image_path, size=256):
    img = misc.imread(image_path, mode='RGB')
    img = misc.imresize(img, [size, size])
    img = np.expand_dims(img, axis=0)
    img = preprocessing(img)

    return img

def preprocessing(x):
    x = x/127.5 - 1                     # -1 ~ 1
    return x

def save_images(images, size, image_path):
    return imsave(inverse_transform(images), size, image_path)

def inverse_transform(images):
    return (images+1.) / 2

def imsave(images, size, path):
    return misc.imsave(path, merge(images, size))

def merge(images, size):
    h, w = images.shape[1], images.shape[2]
    img = np.zeros((h * size[0], w * size[1], 3))
    for idx, image in enumerate(images):
        i = idx % size[1]
        j = idx // size[1]
        img[h*j:h*(j+1), w*i:w*(i+1), :] = image

    return img

def show_all_variables():
    model_vars = tf.trainable_variables()
    slim.model_analyzer.analyze_vars(model_vars, print_info=True)

def check_folder(log_dir):
    if not os.path.exists(log_dir):
    return log_dir

def str2bool(x):
    return x.lower() in 'true'

4. 模型圖層定義檔案layer.py


import tensorflow as tf
import tensorflow.contrib as tf_contrib
from vgg19 import Vgg19

weight_init = tf.random_normal_initializer(mean=0.0, stddev=0.02)
weight_regularizer = None

# Layer

def conv(x, channels, kernel=4, stride=2, pad=0, pad_type='zero', use_bias=True, sn=False, scope='conv_0'):
    with tf.variable_scope(scope):
        if (kernel - stride) % 2 == 0:
            pad_top = pad
            pad_bottom = pad
            pad_left = pad
            pad_right = pad

        else :
            pad_top = pad
            pad_bottom = kernel - stride - pad_top
            pad_left = pad
            pad_right = kernel - stride - pad_left

        if pad_type == 'zero':
            x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]])
        if pad_type == 'reflect':
            x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]], mode='REFLECT')

        if sn:
            w = tf.get_variable("kernel", shape=[kernel, kernel, x.get_shape()[-1], channels], initializer=weight_init,
            x = tf.nn.conv2d(input=x, filter=spectral_norm(w),
                             strides=[1, stride, stride, 1], padding='VALID')
            if use_bias:
                bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0))
                x = tf.nn.bias_add(x, bias)

            x = tf.layers.conv2d(inputs=x, filters=channels,
                                 kernel_size=kernel, kernel_initializer=weight_init,
                                 strides=stride, use_bias=use_bias)

        return x

def deconv(x, channels, kernel=4, stride=2, use_bias=True, sn=False, scope='deconv_0'):
    with tf.variable_scope(scope):
        x_shape = x.get_shape().as_list()
        output_shape = [x_shape[0], x_shape[1]*stride, x_shape[2]*stride, channels]
        if sn:
            w = tf.get_variable("kernel", shape=[kernel, kernel, channels, x.get_shape()[-1]],
                                initializer=weight_init, regularizer=weight_regularizer)
            x = tf.nn.conv2d_transpose(x, filter=spectral_norm(w), output_shape=output_shape,
                                       strides=[1, stride, stride, 1], padding='SAME')

            if use_bias:
                bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0))
                x = tf.nn.bias_add(x, bias)

            x = tf.layers.conv2d_transpose(inputs=x, filters=channels,
                                           kernel_size=kernel, kernel_initializer=weight_init,
                                           strides=stride, padding='SAME', use_bias=use_bias)

        return x

# Residual-block

def resblock(x_init, channels, use_bias=True, scope='resblock_0'):
    with tf.variable_scope(scope):
        with tf.variable_scope('res1'):
            x = conv(x_init, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias)
            x = instance_norm(x)
            x = relu(x)

        with tf.variable_scope('res2'):
            x = conv(x, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias)
            x = instance_norm(x)

        return x + x_init

# Activation function

def lrelu(x, alpha=0.2):
    return tf.nn.leaky_relu(x, alpha)

def relu(x):
    return tf.nn.relu(x)

def tanh(x):
    return tf.tanh(x)

# Normalization function

def instance_norm(x, scope='instance_norm'):
    return tf_contrib.layers.instance_norm(x,
                                           center=True, scale=True,

def spectral_norm(w, iteration=1):
    w_shape = w.shape.as_list()
    w = tf.reshape(w, [-1, w_shape[-1]])

    u = tf.get_variable("u", [1, w_shape[-1]], initializer=tf.truncated_normal_initializer(), trainable=False)

    u_hat = u
    v_hat = None
    for i in range(iteration):
        power iteration
        Usually iteration = 1 will be enough
        v_ = tf.matmul(u_hat, tf.transpose(w))
        v_hat = l2_norm(v_)

        u_ = tf.matmul(v_hat, w)
        u_hat = l2_norm(u_)

    sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat))
    w_norm = w / sigma

    with tf.control_dependencies([u.assign(u_hat)]):
        w_norm = tf.reshape(w_norm, w_shape)

    return w_norm

def l2_norm(v, eps=1e-12):
    return v / (tf.reduce_sum(v ** 2) ** 0.5 + eps)

# Loss function

def L1_loss(x, y):
    return tf.reduce_mean(tf.abs(x - y))

def discriminator_loss(loss_func, real, fake, real_blur):
    real_loss = 0
    fake_loss = 0
    real_blur_loss = 0

    if loss_func == 'wgan-gp' or loss_func == 'wgan-lp':
        real_loss = -tf.reduce_mean(real)
        fake_loss = tf.reduce_mean(fake)
        real_blur_loss = tf.reduce_mean(real_blur)

    if loss_func == 'lsgan':
        real_loss = tf.reduce_mean(tf.square(real - 1.0))
        fake_loss = tf.reduce_mean(tf.square(fake))
        real_blur_loss = tf.reduce_mean(tf.square(real_blur))

    if loss_func == 'gan' or loss_func == 'dragan':
        real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(real), logits=real))
        fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(fake), logits=fake))
        real_blur_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(real_blur),

    if loss_func == 'hinge':
        real_loss = tf.reduce_mean(relu(1.0 - real))
        fake_loss = tf.reduce_mean(relu(1.0 + fake))
        real_blur_loss = tf.reduce_mean(relu(1.0 + real_blur))

    loss = real_loss + fake_loss + real_blur_loss

    return loss

def generator_loss(loss_func, fake):
    fake_loss = 0

    if loss_func == 'wgan-gp' or loss_func == 'wgan-lp':
        fake_loss = -tf.reduce_mean(fake)

    if loss_func == 'lsgan' :
        fake_loss = tf.reduce_mean(tf.square(fake - 1.0))

    if loss_func == 'gan' or loss_func == 'dragan':
        fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(fake), logits=fake))

    if loss_func == 'hinge':
        fake_loss = -tf.reduce_mean(fake)

    loss = fake_loss

    return loss

def vgg_loss(real, fake):
    vgg = Vgg19('vgg19.npy')

    real_feature_map = vgg.conv4_4_no_activation

    fake_feature_map = vgg.conv4_4_no_activation

    loss = L1_loss(real_feature_map, fake_feature_map)

    return loss

5. vgg19模型檔案vgg19.py


import tensorflow as tf
import numpy as np
import time

VGG_MEAN = [103.939, 116.779, 123.68]

class Vgg19:

    def __init__(self, vgg19_npy_path=None):
        self.data_dict = np.load(vgg19_npy_path, encoding='latin1').item()
        print("npy file loaded")

    def build(self, rgb):
        load variable from npy to build the VGG
        input format: bgr image with shape [batch_size, h, w, 3]
        scale: (-1, 1)

        start_time = time.time()
        rgb_scaled = ((rgb + 1) / 2) * 255.0        # [-1, 1] ~ [0, 255]

        red, green, blue = tf.split(axis=3, num_or_size_splits=3, value=rgb_scaled)
        bgr = tf.concat(axis=3, values=[blue - VGG_MEAN[0],
                                        green - VGG_MEAN[1],
                                        red - VGG_MEAN[2]])

        self.conv1_1 = self.conv_layer(bgr, "conv1_1")
        self.conv1_2 = self.conv_layer(self.conv1_1, "conv1_2")
        self.pool1 = self.max_pool(self.conv1_2, 'pool1')

        self.conv2_1 = self.conv_layer(self.pool1, "conv2_1")
        self.conv2_2 = self.conv_layer(self.conv2_1, "conv2_2")
        self.pool2 = self.max_pool(self.conv2_2, 'pool2')

        self.conv3_1 = self.conv_layer(self.pool2, "conv3_1")
        self.conv3_2 = self.conv_layer(self.conv3_1, "conv3_2")
        self.conv3_3 = self.conv_layer(self.conv3_2, "conv3_3")
        self.conv3_4 = self.conv_layer(self.conv3_3, "conv3_4")
        self.pool3 = self.max_pool(self.conv3_4, 'pool3')

        self.conv4_1 = self.conv_layer(self.pool3, "conv4_1")
        self.conv4_2 = self.conv_layer(self.conv4_1, "conv4_2")
        self.conv4_3 = self.conv_layer(self.conv4_2, "conv4_3")

        self.conv4_4_no_activation = self.no_activation_conv_layer(self.conv4_3, "conv4_4")

        self.conv4_4 = self.conv_layer(self.conv4_3, "conv4_4")
        self.pool4 = self.max_pool(self.conv4_4, 'pool4')

        self.conv5_1 = self.conv_layer(self.pool4, "conv5_1")
        self.conv5_2 = self.conv_layer(self.conv5_1, "conv5_2")
        self.conv5_3 = self.conv_layer(self.conv5_2, "conv5_3")
        self.conv5_4 = self.conv_layer(self.conv5_3, "conv5_4")
        self.pool5 = self.max_pool(self.conv5_4, 'pool5')

        print(("Finished building vgg19: %ds" % (time.time() - start_time)))

    def max_pool(self, bottom, name):
        return tf.nn.max_pool(bottom, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name=name)

    def conv_layer(self, bottom, name):
        with tf.variable_scope(name):
            filt = self.get_conv_filter(name)

            conv = tf.nn.conv2d(bottom, filt, [1, 1, 1, 1], padding='SAME')

            conv_biases = self.get_bias(name)
            bias = tf.nn.bias_add(conv, conv_biases)

            relu = tf.nn.relu(bias)
            return relu

    def no_activation_conv_layer(self, bottom, name):
        with tf.variable_scope(name):
            filt = self.get_conv_filter(name)

            conv = tf.nn.conv2d(bottom, filt, [1, 1, 1, 1], padding='SAME')

            conv_biases = self.get_bias(name)
            x = tf.nn.bias_add(conv, conv_biases)

            return x

    def get_conv_filter(self, name):
        return tf.constant(self.data_dict[name][0], name="filter")

    def get_bias(self, name):
        return tf.constant(self.data_dict[name][1], name="biases")

6. 高斯平滑檔案edge_smooth.py


from utils import check_folder
import numpy as np
import cv2, os, argparse
from glob import glob
from tqdm import tqdm

def parse_args():
    desc = "Edge smoothed"
    parser = argparse.ArgumentParser(description=desc)
    parser.add_argument('--dataset', type=str, default='img2anime', help='dataset_name')
    parser.add_argument('--img_size', type=int, default=256, help='The size of image')

    return parser.parse_args()

def make_edge_smooth(dataset_name, img_size) :

    file_list = glob('./dataset/trainB/*.*')
    save_dir = './dataset/{}/trainB_smooth'.format(dataset_name)

    kernel_size = 5
    kernel = np.ones((kernel_size, kernel_size), np.uint8)
    gauss = cv2.getGaussianKernel(kernel_size, 0)
    gauss = gauss * gauss.transpose(1, 0)

    for f in tqdm(file_list):
        file_name = os.path.basename(f)

        bgr_img = cv2.imread(f)
        gray_img = cv2.imread(f, 0)

        bgr_img = cv2.resize(bgr_img, (img_size, img_size))
        gray_img = cv2.resize(gray_img, (img_size, img_size))

        # 計算邊緣,並進行膨脹演算法處理
        edges = cv2.Canny(gray_img, 100, 200)
        dilation = cv2.dilate(edges, kernel)

        h, w = edges.shape

        # 進行高斯模糊
        gauss_img = np.copy(bgr_img)
        for i in range(kernel_size // 2, h - kernel_size // 2):
            for j in range(kernel_size // 2, w - kernel_size // 2):
                if dilation[i, j] != 0:  # gaussian blur to only edge
                    gauss_img[i, j, 0] = np.sum(np.multiply(bgr_img[i - kernel_size // 2:i + kernel_size // 2 + 1,
                                                            j - kernel_size // 2:j + kernel_size // 2 + 1, 0], gauss))
                    gauss_img[i, j, 1] = np.sum(np.multiply(bgr_img[i - kernel_size // 2:i + kernel_size // 2 + 1,
                                                            j - kernel_size // 2:j + kernel_size // 2 + 1, 1], gauss))
                    gauss_img[i, j, 2] = np.sum(np.multiply(bgr_img[i - kernel_size // 2:i + kernel_size // 2 + 1,
                                                            j - kernel_size // 2:j + kernel_size // 2 + 1, 2], gauss))

        cv2.imwrite(os.path.join(save_dir, file_name), gauss_img)

def main():
    # parse arguments
    args = parse_args()
    if args is None:

    make_edge_smooth(args.dataset, args.img_size)

if __name__ == '__main__':

7. 模型檔案cartoonGAN.py


from layer import *
from utils import *
from glob import glob
import time
from tensorflow.contrib.data import prefetch_to_device, shuffle_and_repeat, map_and_batch
import numpy as np

class CartoonGAN(object):
    def __init__(self, sess, args):
        self.model_name = 'CartoonGAN'
        self.sess = sess
        self.checkpoint_dir = args.checkpoint_dir
        self.result_dir = args.result_dir
        self.log_dir = args.log_dir
        self.dataset_name = args.dataset

        self.epoch = args.epoch
        self.init_epoch = args.init_epoch       # args.epoch // 20
        self.iteration = args.iteration
        self.decay_flag = args.decay_flag
        self.decay_epoch = args.decay_epoch

        self.gan_type = args.gan_type

        self.batch_size = args.batch_size
        self.print_freq = args.print_freq
        self.save_freq = args.save_freq

        self.init_lr = args.lr
        self.ch = args.ch

        """ Weight """
        self.adv_weight = args.adv_weight
        self.vgg_weight = args.vgg_weight
        self.ld = args.ld

        """ Generator """
        self.n_res = args.n_res

        """ Discriminator """
        self.n_dis = args.n_dis
        self.n_critic = args.n_critic
        self.sn = args.sn

        self.img_size = args.img_size
        self.img_ch = args.img_ch

        self.sample_dir = os.path.join(args.sample_dir, self.model_dir)

        self.trainA_dataset = glob('./dataset/trainA/*.*')
        self.trainB_dataset = glob('./dataset/trainB/*.*')
        self.trainB_smooth_dataset = glob('./dataset/trainB_smooth/*.*')

        self.dataset_num = max(len(self.trainA_dataset), len(self.trainB_dataset))


        print("##### Information #####")
        print("# gan type : ", self.gan_type)
        print("# dataset : ", self.dataset_name)
        print("# max dataset number : ", self.dataset_num)
        print("# batch_size : ", self.batch_size)
        print("# epoch : ", self.epoch)
        print("# init_epoch : ", self.init_epoch)
        print("# iteration per epoch : ", self.iteration)


        print("##### Generator #####")
        print("# residual blocks : ", self.n_res)


        print("##### Discriminator #####")
        print("# the number of discriminator layer : ", self.n_dis)
        print("# the number of critic : ", self.n_critic)
        print("# spectral normalization : ", self.sn)


    # Generator

    def generator(self, x_init, reuse=False, scope="generator"):
        channel = self.ch
        with tf.variable_scope(scope, reuse=reuse):
            x = conv(x_init, channel, kernel=7, stride=1, pad=3, pad_type='reflect', use_bias=False, scope='conv')
            x = instance_norm(x, scope='ins_norm')
            x = relu(x)

            # Down-Sampling
            for i in range(2):
                x = conv(x, channel*2, kernel=3, stride=2, pad=1, use_bias=True, scope='conv_s2_'+str(i))
                x = conv(x, channel*2, kernel=3, stride=1, pad=1, use_bias=False, scope='conv_s1_'+str(i))
                x = instance_norm(x, scope='ins_norm_'+str(i))
                x = relu(x)

                channel = channel * 2

            # Bottleneck
            for i in range(self.n_res):
                x = resblock(x, channel, use_bias=False, scope='resblock_' + str(i))

            # Up-Sampling
            for i in range(2):
                x = deconv(x, channel//2, kernel=3, stride=2, use_bias=True, scope='deconv_'+str(i))
                x = conv(x, channel//2, kernel=3, stride=1, pad=1, use_bias=False, scope='up_conv_'+str(i))
                x = instance_norm(x, scope='up_ins_norm_'+str(i))
                x = relu(x)

                channel = channel // 2

            x = conv(x, channels=self.img_ch, kernel=7, stride=1, pad=3, pad_type='reflect', 
                     use_bias=True, scope='G_logit')
            x = tanh(x)

            return x

    # Discriminator

    def discriminator(self, x_init, reuse=False, scope="discriminator"):
        channel = self.ch // 2
        with tf.variable_scope(scope, reuse=reuse):
            x = conv(x_init, channel, kernel=3, stride=1, pad=1, use_bias=True, sn=self.sn, scope='conv_0')
            x = lrelu(x, 0.2)

            for i in range(1, self.n_dis):
                x = conv(x, channel * 2, kernel=3, stride=2, pad=1, use_bias=True, sn=self.sn, scope='conv_s2_' + str(i))
                x = lrelu(x, 0.2)

                x = conv(x, channel * 4, kernel=3, stride=1, pad=1, use_bias=False, sn=self.sn, scope='conv_s1_' + str(i))
                x = instance_norm(x, scope='ins_norm_' + str(i))
                x = lrelu(x, 0.2)

                channel = channel * 2

            x = conv(x, channel * 2, kernel=3, stride=1, pad=1, use_bias=False, sn=self.sn, scope='last_conv')
            x = instance_norm(x, scope='last_ins_norm')
            x = lrelu(x, 0.2)

            x = conv(x, channels=1, kernel=3, stride=1, pad=1, use_bias=True, sn=self.sn, scope='D_logit')

            return x

    # Model
    def gradient_panalty(self, real, fake, scope="discriminator"):
        if self.gan_type.__contains__('dragan'):
            eps = tf.random_uniform(shape=tf.shape(real), minval=0., maxval=1.)
            _, x_var = tf.nn.moments(real, axes=[0, 1, 2, 3])
            x_std = tf.sqrt(x_var)  # magnitude of noise decides the size of local region

            fake = real + 0.5 * x_std * eps

        alpha = tf.random_uniform(shape=[self.batch_size, 1, 1, 1], minval=0., maxval=1.)
        interpolated = real + alpha * (fake - real)

        logit = self.discriminator(interpolated, reuse=True, scope=scope)

        grad = tf.gradients(logit, interpolated)[0]             # gradient of D(interpolated)
        grad_norm = tf.norm(tf.layers.flatten(grad), axis=1)    # l2 norm

        GP = 0
        # WGAN - LP
        if self.gan_type.__contains__('lp'):
            GP = self.ld * tf.reduce_mean(tf.square(tf.maximum(0.0, grad_norm - 1.)))

        elif self.gan_type.__contains__('gp') or self.gan_type == 'dragan' :
            GP = self.ld * tf.reduce_mean(tf.square(grad_norm - 1.))

        return GP

    def build_model(self):
        self.lr = tf.placeholder(tf.float32, name='learning_rate')

        """ Input Image"""
        Image_Data_Class = ImageData(self.img_size, self.img_ch)

        trainA = tf.data.Dataset.from_tensor_slices(self.trainA_dataset)
        trainB = tf.data.Dataset.from_tensor_slices(self.trainB_dataset)
        trainB_smooth = tf.data.Dataset.from_tensor_slices(self.trainB_smooth_dataset)

        gpu_device = '/gpu:0'
        trainA = trainA.apply(shuffle_and_repeat(self.dataset_num)).apply(
            map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16,
                          drop_remainder=True)).apply(prefetch_to_device(gpu_device, self.batch_size))
        trainB = trainB.apply(shuffle_and_repeat(self.dataset_num)).apply(
            map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16,
                          drop_remainder=True)).apply(prefetch_to_device(gpu_device, self.batch_size))
        trainB_smooth = trainB_smooth.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(
            Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16,
            drop_remainder=True)).apply(prefetch_to_device(gpu_device, self.batch_size))

        trainA_iterator = trainA.make_one_shot_iterator()
        trainB_iterator = trainB.make_one_shot_iterator()
        trainB_smooth_iterator = trainB_smooth.make_one_shot_iterator()

        self.real_A = trainA_iterator.get_next()
        self.real_B = trainB_iterator.get_next()
        self.real_B_smooth = trainB_smooth_iterator.get_next()

        self.test_real_A = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_real_A')

        """ Define Generator, Discriminator """
        self.fake_B = self.generator(self.real_A)

        real_B_logit = self.discriminator(self.real_B)
        fake_B_logit = self.discriminator(self.fake_B, reuse=True)
        real_B_smooth_logit = self.discriminator(self.real_B_smooth, reuse=True)

        """ Define Loss """
        if self.gan_type.__contains__('gp') or self.gan_type.__contains__('lp') or self.gan_type.__contains__('dragan') :
            GP = self.gradient_panalty(real=self.real_B, fake=self.fake_B) + self.gradient_panalty(self.real_B, fake=self.real_B_smooth)
        else :
            GP = 0.0

        v_loss = self.vgg_weight * vgg_loss(self.real_A, self.fake_B)
        g_loss = self.adv_weight * generator_loss(self.gan_type, fake_B_logit)
        d_loss = self.adv_weight * discriminator_loss(self.gan_type, real_B_logit, fake_B_logit, real_B_smooth_logit) + GP

        self.Vgg_loss = v_loss
        self.Generator_loss = g_loss + v_loss
        self.Discriminator_loss = d_loss

        """ Result Image """
        self.test_fake_B = self.generator(self.test_real_A, reuse=True)

        """ Training """
        t_vars = tf.trainable_variables()
        G_vars = [var for var in t_vars if 'generator' in var.name]
        D_vars = [var for var in t_vars if 'discriminator' in var.name]

        self.init_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Vgg_loss, var_list=G_vars)
        self.G_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Generator_loss, var_list=G_vars)
        self.D_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Discriminator_loss, var_list=D_vars)

        """" Summary """
        self.G_loss = tf.summary.scalar("Generator_loss", self.Generator_loss)
        self.D_loss = tf.summary.scalar("Discriminator_loss", self.Discriminator_loss)

        self.G_gan = tf.summary.scalar("G_gan", g_loss)
        self.G_vgg = tf.summary.scalar("G_vgg", v_loss)

        self.V_loss_merge = tf.summary.merge([self.G_vgg])
        self.G_loss_merge = tf.summary.merge([self.G_loss, self.G_gan, self.G_vgg])
        self.D_loss_merge = tf.summary.merge([self.D_loss])

    def train(self):
        # initialize all variables

        # saver to save model
        self.saver = tf.train.Saver()

        # summary writer
        self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_dir, self.sess.graph)

        # restore check-point if it exits
        could_load, checkpoint_counter = self.load(self.checkpoint_dir)
        if could_load:
            start_epoch = (int)(checkpoint_counter / self.iteration)
            start_batch_id = checkpoint_counter - start_epoch * self.iteration
            counter = checkpoint_counter
            print(" [*] Load SUCCESS")
            start_epoch = 0
            start_batch_id = 0
            counter = 1
            print(" [!] Load failed...")

        # loop for epoch
        start_time = time.time()
        past_g_loss = -1.
        lr = self.init_lr
        for epoch in range(start_epoch, self.epoch):
            # lr = self.init_lr if epoch < self.decay_epoch else self.init_lr * (self.epoch - epoch) / (self.epoch - self.decay_epoch)
            if self.decay_flag :
                lr = self.init_lr * pow(0.5, epoch // self.decay_epoch)

            for idx in range(start_batch_id, self.iteration):

                train_feed_dict = {self.lr: lr}

                if epoch < self.init_epoch:
                    # Init G
                    real_A_images, fake_B_images, _, v_loss, summary_str = self.sess.run([self.real_A, self.fake_B,
                                                                             self.Vgg_loss, self.V_loss_merge], feed_dict = train_feed_dict)
                    self.writer.add_summary(summary_str, counter)
                    print("Epoch: [%3d] [%5d/%5d] time: %4.4f v_loss: %.8f" % (epoch, idx, self.iteration, time.time() - start_time, v_loss))

                    # Update D
                    _, d_loss, summary_str = self.sess.run([self.D_optim, self.Discriminator_loss, self.D_loss_merge], feed_dict = train_feed_dict)
                    self.writer.add_summary(summary_str, counter)

                    # Update G
                    g_loss = None
                    if (counter - 1) % self.n_critic == 0 :
                        real_A_images, fake_B_images, _, g_loss, summary_str = self.sess.run([self.real_A, self.fake_B,
                                                                                              self.Generator_loss, self.G_loss_merge], feed_dict = train_feed_dict)
                        self.writer.add_summary(summary_str, counter)
                        past_g_loss = g_loss

                    if g_loss == None:
                        g_loss = past_g_loss
                    print("Epoch: [%3d] [%5d/%5d] time: %4.4f d_loss: %.8f, g_loss: %.8f" % (epoch, idx, self.iteration, time.time() - start_time, d_loss, g_loss))

                # display training status
                counter += 1

                if np.mod(idx+1, self.print_freq) == 0 :
                    save_images(real_A_images, [self.batch_size, 1],
                                './{}/real_A_{:03d}_{:05d}.png'.format(self.sample_dir, epoch, idx+1))
                    save_images(fake_B_images, [self.batch_size, 1],
                                './{}/fake_B_{:03d}_{:05d}.png'.format(self.sample_dir, epoch, idx+1))

                if np.mod(idx + 1, self.save_freq) == 0:
                    self.save(self.checkpoint_dir, counter)

            # After an epoch, start_batch_id is set to zero
            # non-zero value is only for the first epoch after loading pre-trained model
            start_batch_id = 0

            # save model for final step
            self.save(self.checkpoint_dir, counter)

    def model_dir(self):
        n_res = str(self.n_res) + 'resblock'
        n_dis = str(self.n_dis) + 'dis'
        return "{}_{}_{}_{}_{}_{}_{}_{}_{}".format(self.model_name, self.dataset_name,
                                                         self.gan_type, n_res, n_dis,
                                                         self.n_critic, self.sn,
                                                         int(self.adv_weight), int(self.vgg_weight))

    def save(self, checkpoint_dir, step):
        checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)

        if not os.path.exists(checkpoint_dir):

        self.saver.save(self.sess, os.path.join(checkpoint_dir, self.model_name + '.model'), global_step=step)

    def load(self, checkpoint_dir):
        print(" [*] Reading checkpoints...")
        checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)

        ckpt = tf.train.get_checkpoint_state(checkpoint_dir) # checkpoint file information

        if ckpt and ckpt.model_checkpoint_path:
            ckpt_name = os.path.basename(ckpt.model_checkpoint_path) # first line
            self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
            counter = int(ckpt_name.split('-')[-1])
            print(" [*] Success to read {}".format(ckpt_name))
            return True, counter
            print(" [*] Failed to find a checkpoint")
            return False, 0

    def test(self):
        test_A_files = glob('./dataset/testA/*.*')

        self.saver = tf.train.Saver()
        could_load, checkpoint_counter = self.load(self.checkpoint_dir)
        self.result_dir = os.path.join(self.result_dir, self.model_dir)

        if could_load:
            print(" [*] Load SUCCESS")
            print(" [!] Load failed...")

        # write html for visual comparison
        index_path = os.path.join(self.result_dir, 'index.html')
        index = open(index_path, 'w')

        for sample_file in test_A_files:                           # A -> B
            print('Processing A image: ' + sample_file)
            sample_image = np.asarray(load_test_data(sample_file))
            image_path = os.path.join(self.result_dir, '{0}'.format(os.path.basename(sample_file)))

            fake_img = self.sess.run(self.test_fake_B, feed_dict={self.test_real_A: sample_image})
            save_images(fake_img, [1, 1], image_path)

            index.write("<td>%s</td>" % os.path.basename(image_path))

            index.write("<td><img src='%s' width='%d' height='%d'></td>" % (sample_file if os.path.isabs(
                sample_file) else ('../..' + os.path.sep + sample_file), self.img_size, self.img_size))
            index.write("<td><img src='%s' width='%d' height='%d'></td>" % (image_path if os.path.isabs(
                image_path) else ('../..' + os.path.sep + image_path), self.img_size, self.img_size))


8. 主檔案main.py


from CartoonGAN import CartoonGAN
import argparse
from utils import *

"""parsing and configuration"""

def parse_args():
    desc = "Tensorflow implementation of CartoonGAN"
    parser = argparse.ArgumentParser(description=desc)
    parser.add_argument('--phase', type=str, default='train', help='train or test ?')
    parser.add_argument('--dataset', type=str, default='img2anime', help='dataset_name')

    parser.add_argument('--epoch', type=int, default=500, help='The number of epochs to run')
    parser.add_argument('--init_epoch', type=int, default=1, help='The number of epochs for weight initialization')
    parser.add_argument('--iteration', type=int, default=500, help='The number of training iterations')
    parser.add_argument('--batch_size', type=int, default=1, help='The size of batch size')
    parser.add_argument('--print_freq', type=int, default=100, help='The number of image_print_freq')
    parser.add_argument('--save_freq', type=int, default=100, help='The number of ckpt_save_freq')
    parser.add_argument('--decay_flag', type=str2bool, default=False, help='The decay_flag')
    parser.add_argument('--decay_epoch', type=int, default=10, help='decay epoch')

    parser.add_argument('--lr', type=float, default=0.0001, help='The learning rate')
    parser.add_argument('--ld', type=float, default=10.0, help='The gradient penalty lambda')
    parser.add_argument('--adv_weight', type=float, default=1.0, help='Weight about GAN')
    parser.add_argument('--vgg_weight', type=float, default=10.0, help='Weight about VGG19')
    parser.add_argument('--gan_type', type=str, default='gan', help='[gan / lsgan / wgan-gp / wgan-lp / dragan / hinge')

    parser.add_argument('--ch', type=int, default=64, help='base channel number per layer')
    parser.add_argument('--n_res', type=int, default=8, help='The number of resblock')

    parser.add_argument('--n_dis', type=int, default=3, help='The number of discriminator layer')
    parser.add_argument('--n_critic', type=int, default=1, help='The number of critic')
    parser.add_argument('--sn', type=str2bool, default=False, help='using spectral norm')

    parser.add_argument('--img_size', type=int, default=256, help='The size of image')
    parser.add_argument('--img_ch', type=int, default=3, help='The size of image channel')
    # parser.add_argument('--augment_flag', type=str2bool, default=False, help='Image augmentation use or not')

    parser.add_argument('--checkpoint_dir', type=str, default='checkpoint',
                        help='Directory name to save the checkpoints')
    parser.add_argument('--result_dir', type=str, default='results',
                        help='Directory name to save the generated images')
    parser.add_argument('--log_dir', type=str, default='logs',
                        help='Directory name to save training logs')
    parser.add_argument('--sample_dir', type=str, default='samples',
                        help='Directory name to save the samples on training')

    return check_args(parser.parse_args())

"""checking arguments"""
def check_args(args):
    # --checkpoint_dir

    # --result_dir

    # --result_dir

    # --sample_dir

    # --epoch
        assert args.epoch >= 1
        print('number of epochs must be larger than or equal to one')

    # --batch_size
        assert args.batch_size >= 1
        print('batch size must be larger than or equal to one')
    return args

def main():
    # parse arguments
    args = parse_args()


    # open session
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        gan = CartoonGAN(sess, args)

        # build graph

        # show network architecture

        if args.phase == 'train':
            print(" [*] Training finished!")

        if args.phase == 'test':
            print(" [*] Test finished!")

if __name__ == '__main__':



1. 檔案結構參見三

2. 對於漫畫風格的影像生成,cartoonGAN的生成效果要明顯好於cycleGAN及其相似的模型