TFGAN實現Conditional GAN
關於TFGAN、GAN的原理以及Unconditional GAN都已經在之前的文章: 簡單易用的輕量級生成對抗網路工具庫:TFGAN 中說明了,本文內容主要是使用TFGAN實現Conditional GAN模型。
環境
- Python 3.6
- Tensorflow-gpu 1.8.0
Conditional GAN
經典的非條件GAN(Unconditional GAN)是從噪聲分佈中隨機生成我們需要的資料,但是我們無法控制生成的資料屬於哪一類,而條件GAN(Conditional GAN)就是用來解決這一個問題的。
CGAN中所謂條件就是指我們現在生成的網路不僅僅需要逼真,而且還要有一定的條件。如下圖所示,Generator和Discriminator的輸入不僅包含了隨機噪聲,還包含了指定類別的one-hot編碼,通過這樣的方式我們就可使讓生成器生成我們指定的類別資料。

CGAN
實現
CGAN和UGAN的網路結構基本一致,主要區別就在於輸入中增加了類別的one-hot編碼。
Generator
如下所示,模型的輸入 inputs
是一個元組 (noise, one_hot_labels)
,通過 tfgan.features.condition_tensor_from_onehot
函式將這兩個輸入連線後送入生成器。
def conditional_generator(inputs, weight_decay=2.5e-5, is_training=True): """Simple generator to produce MNIST images. Args: noise: A 2-tuple of Tensors (noise, one_hot_labels) and creates a conditional generator. weight_decay: The value of the l2 weight decay. is_training: If `True`, batch norm uses batch statistics. If `False`, batch norm uses the exponential moving average collected from population statistics. Returns: A generated image in the range [-1, 1]. """ noise, one_hot_labels = inputs noise = tfgan.features.condition_tensor_from_onehot(noise, one_hot_labels) with tf.contrib.framework.arg_scope( [layers.fully_connected, layers.conv2d_transpose], activation_fn=tf.nn.relu, normalizer_fn=layers.batch_norm, weights_regularizer=layers.l2_regularizer(weight_decay)): with tf.contrib.framework.arg_scope([layers.batch_norm], is_training=is_training, zero_debias_moving_mean=True): net = layers.fully_connected(noise, 1024) net = layers.fully_connected(net, 7 * 7 * 128) net = tf.reshape(net, [-1, 7, 7, 128]) net = layers.conv2d_transpose(net, 64, [4, 4], stride=2) net = layers.conv2d_transpose(net, 32, [4, 4], stride=2) # Make sure that generator output is in the same range as `inputs` # ie [-1, 1]. net = layers.conv2d(net, 1, [4, 4], normalizer_fn=None, activation_fn=tf.tanh) return net
Discriminator
如下所示,模型的輸入除了 img
還有其對應的 one_hot_labels
,通過 tfgan.features.condition_tensor_from_onehot
函式將影象的分類特徵與類別編碼連線起來,進行最後的判別。
def conditional_discriminator(img, conditioning, weight_decay=2.5e-5, is_training=True): """Discriminator network on MNIST digits. Args: img: Real or generated MNIST digits. Should be in the range [-1, 1]. conditioning: A 2-tuple of Tensors representing (noise, one_hot_labels). weight_decay: The L2 weight decay. weight_decay: The L2 weight decay. is_training: If `True`, batch norm uses batch statistics. If `False`, batch norm uses the exponential moving average collected from population statistics. Returns: Logits for the probability that the image is real. """ _, one_hot_labels = conditioning with tf.contrib.framework.arg_scope( [layers.conv2d, layers.fully_connected], activation_fn=tf.nn.relu, normalizer_fn=None, weights_regularizer=layers.l2_regularizer(weight_decay), biases_regularizer=layers.l2_regularizer(weight_decay)): net = layers.conv2d(img, 64, [4, 4], stride=2) net = layers.conv2d(net, 128, [4, 4], stride=2) net = layers.flatten(net) net = tfgan.features.condition_tensor_from_onehot(net, one_hot_labels) with tf.contrib.framework.arg_scope([layers.batch_norm], is_training=is_training): net = layers.fully_connected(net, 1024, normalizer_fn=layers.batch_norm) return layers.linear(net, 1)
實驗
CGAN的模型圖如下所示:

graph
生成結果如下所示,經過2w次訓練,模型基本能夠生成我們指定的資料。
Loss:

loss
生成效果:

Epochs:6000

Epochs:10000

Epochs:20000
嘗試過使用1e-4的生成器學習率,能夠在1000步左右就產生正確的類別結果,但是生成資料清晰度不夠,繼續訓練會發生model collapse。考慮訓練前1000步使用1e-4,後面改用1e-5收斂效果會更好。
完整的CGAN程式碼如下所示:
import tensorflow as tf import tensorflow.contrib.gan as tfgan import tensorflow.contrib.layers as layers from tensorflow.examples.tutorials.mnist import input_data def float_image_to_uint8(image): """Convert float image in [-1, 1) to [0, 255] uint8. Note that `1` gets mapped to `0`, but `1 - epsilon` gets mapped to 255. Args: image: An image tensor. Values should be in [-1, 1). Returns: Input image cast to uint8 and with integer values in [0, 255]. """ image = (image * 128.0) + 128.0 return tf.cast(image, tf.uint8) def provide_data(batch_size, num_threads=1): file = "MNIST" # range 0~1 mnist = input_data.read_data_sets(file, one_hot=True) train_data = mnist.train.images.reshape(-1, 28, 28, 1) * 255 train_labels = mnist.train.labels # transfer to -1~1 train_data = (tf.to_float(train_data) - 128.0) / 128.0 # Creates a QueueRunner for the pre-fetching operation. input_queue = tf.train.slice_input_producer([train_data, train_labels], shuffle=True) images, labels = tf.train.batch( input_queue, batch_size=batch_size, num_threads=num_threads, capacity=5 * batch_size) return images, labels def conditional_generator(inputs, weight_decay=2.5e-5, is_training=True): """Simple generator to produce MNIST images. Args: noise: A 2-tuple of Tensors (noise, one_hot_labels) and creates a conditional generator. weight_decay: The value of the l2 weight decay. is_training: If `True`, batch norm uses batch statistics. If `False`, batch norm uses the exponential moving average collected from population statistics. Returns: A generated image in the range [-1, 1]. """ noise, one_hot_labels = inputs noise = tfgan.features.condition_tensor_from_onehot(noise, one_hot_labels) with tf.contrib.framework.arg_scope( [layers.fully_connected, layers.conv2d_transpose], activation_fn=tf.nn.relu, normalizer_fn=layers.batch_norm, weights_regularizer=layers.l2_regularizer(weight_decay)): with tf.contrib.framework.arg_scope([layers.batch_norm], is_training=is_training, zero_debias_moving_mean=True): net = layers.fully_connected(noise, 1024) net = layers.fully_connected(net, 7 * 7 * 128) net = tf.reshape(net, [-1, 7, 7, 128]) net = layers.conv2d_transpose(net, 64, [4, 4], stride=2) net = layers.conv2d_transpose(net, 32, [4, 4], stride=2) # Make sure that generator output is in the same range as `inputs` # ie [-1, 1]. net = layers.conv2d(net, 1, [4, 4], normalizer_fn=None, activation_fn=tf.tanh) return net def conditional_discriminator(img, conditioning, weight_decay=2.5e-5, is_training=True): """Discriminator network on MNIST digits. Args: img: Real or generated MNIST digits. Should be in the range [-1, 1]. conditioning: A 2-tuple of Tensors representing (noise, one_hot_labels). weight_decay: The L2 weight decay. weight_decay: The L2 weight decay. is_training: If `True`, batch norm uses batch statistics. If `False`, batch norm uses the exponential moving average collected from population statistics. Returns: Logits for the probability that the image is real. """ _, one_hot_labels = conditioning with tf.contrib.framework.arg_scope( [layers.conv2d, layers.fully_connected], activation_fn=tf.nn.relu, normalizer_fn=None, weights_regularizer=layers.l2_regularizer(weight_decay), biases_regularizer=layers.l2_regularizer(weight_decay)): net = layers.conv2d(img, 64, [4, 4], stride=2) net = layers.conv2d(net, 128, [4, 4], stride=2) net = layers.flatten(net) net = tfgan.features.condition_tensor_from_onehot(net, one_hot_labels) with tf.contrib.framework.arg_scope([layers.batch_norm], is_training=is_training): net = layers.fully_connected(net, 1024, normalizer_fn=layers.batch_norm) return layers.linear(net, 1) def train(batch_size, max_steps, gen_lr, dis_lr, train_log_dir): tf.reset_default_graph() if not tf.gfile.Exists(train_log_dir): tf.gfile.MakeDirs(train_log_dir) # Set up the input. images, one_hot_labels = provide_data(batch_size) noise = tf.random_normal([batch_size, 64]) with tf.name_scope('model'): # Build the generator and discriminator. gan_model = tfgan.gan_model( generator_fn=conditional_generator,# you define discriminator_fn=conditional_discriminator,# you define real_data=images, generator_inputs=(noise, one_hot_labels)) with tf.name_scope('loss'): # Build the GAN loss. gan_loss = tfgan.gan_loss( gan_model, generator_loss_fn=tfgan.losses.wasserstein_generator_loss, discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss, gradient_penalty_weight=1.0, add_summaries=True) with tf.name_scope('train'): # Create the train ops, which calculate gradients and apply updates to weights. train_ops = tfgan.gan_train_ops( gan_model, gan_loss, generator_optimizer=tf.train.AdamOptimizer(gen_lr, 0.5), discriminator_optimizer=tf.train.AdamOptimizer(dis_lr, 0.5), check_for_unused_update_ops=False, summarize_gradients=True, aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N) # Run the train ops in the alternating training scheme. tfgan.gan_train( train_ops, hooks=[tf.train.StopAtStepHook(num_steps=max_steps)], logdir=train_log_dir, save_summaries_steps=10) def test(eval_dir, checkpoint_dir): tf.reset_default_graph() if not tf.gfile.Exists(eval_dir): tf.gfile.MakeDirs(eval_dir) noises = tf.random_normal([100, 64]) c = [i for i in range(10) for j in range(10)] conditions = tf.one_hot(c, 10) random_inputs = (noises, conditions) with tf.variable_scope('Generator'): images = conditional_generator(random_inputs, is_training=False) reshaped_images = tfgan.eval.image_reshaper(images[:100, ...], num_cols=10) uint8_images = float_image_to_uint8(reshaped_images) image_write_ops = tf.write_file( '%s/%s' % (eval_dir, 'conditional_gan.png'), tf.image.encode_png(uint8_images[0])) tf.contrib.training.evaluate_repeatedly( checkpoint_dir, eval_ops=image_write_ops, hooks=[tf.contrib.training.StopAfterNEvalsHook(1)], max_number_of_evaluations=1) if __name__ == '__main__': train(14, 10000, 1e-5, 1e-4, 'cg_logs/') test('cg_eval/', 'cg_logs/')