1. 程式人生 > >『TensorFlow』以GAN為例的神經網絡類範式

『TensorFlow』以GAN為例的神經網絡類範式

default 方法 paper ear 類屬性 lin 簡單 貪婪 base

1、導入包:

import os
import time
import math
from glob import glob
from PIL import Image
import tensorflow as tf
import numpy as np

import ops                    # 層函數封裝包
import utils                  # 其他輔助函數

2、簡單的臨時輔助函數:

def conv_out_size_same(size, stride):
    # 對浮點數向上取整(大於f的最小整數)
    return int(math.ceil(float(size) / float(stride)))

3、聲明類&初始化類:

示例沒有使用到,實際上一般類屬性也會用到

類屬性&__init__初始化:用於接收參數生成低層次的屬性值,數據讀取或者數據名列表一般也會放在__init__中

class DCGAN():

    def __init__(self, sess,
                 input_height=108, input_width=108,
                 crop=True, batch_size=64, sample_num=64,
                 output_height=64, output_width=64,
                 z_dim
=100, gf_dim=64, df_dim=64, gfc_dim=1024, dfc_dim=1024, c_dim=3, dataset_name=default, input_fname_pattern=*.jpg, checkpoint_dir=None, sample_dir=None): """ Args: sess: TensorFlow session batch_size: The size of batch. Should be specified before training. z_dim: (optional) Dimension of dim for Z. [100] gf_dim: (optional) Dimension of gen filters in first conv layer. [64] df_dim: (optional) Dimension of discrim filters in first conv layer. [64] gfc_dim: (optional) Dimension of gen units for for fully connected layer. [1024] dfc_dim: (optional) Dimension of discrim units for fully connected layer. [1024] c_dim: (optional) Dimension of image color. For grayscale input, set to 1. [3]
""" self.sess = sess self.batch_size = batch_size self.sample_num = sample_num # crop輸入輸出尺寸 # crop為True則output尺寸為網絡輸入尺寸 # crop為False則input直接進入網絡輸入層 self.crop = crop self.input_height = input_height self.input_width = input_width self.output_height = output_height self.output_width = output_width self.z_dim = z_dim self.gf_dim = gf_dim self.df_dim = df_dim self.dfc_dim = dfc_dim self.gfc_dim = gfc_dim self.g_bn0 = ops.batch_norm(name=g_bn0) self.g_bn1 = ops.batch_norm(name=g_bn1) self.g_bn2 = ops.batch_norm(name=g_bn2) self.g_bn3 = ops.batch_norm(name=g_bn3) self.d_bn1 = ops.batch_norm(name=d_bn1) self.d_bn2 = ops.batch_norm(name=d_bn2) self.d_bn3 = ops.batch_norm(name=d_bn3) ‘‘‘讀取數據‘‘‘ self.dataset_name = dataset_name self.input_fname_pattern = input_fname_pattern self.checkpoint_dir = checkpoint_dir self.data = glob(os.path.join(./data, self.dataset_name, self.input_fname_pattern)) # 載入所有圖片 ‘‘‘讀取一張圖片判斷通道數目‘‘‘ imreadImg = np.asarray(Image.open(self.data[0])) if len(imreadImg.shape) >= 3: self.c_dim = imreadImg.shape[-1] else: self.c_dim = 1 self.grayscale = (self.c_dim == 1)

4、網絡結構生成:

由於GAN的特殊性,被拆分了build_model(self)作為主幹,discriminator(self,image,reuse=False)和generator(self,z)作為模組,這一過程包含了由數據進入網絡到loss函數計算的整個流程

    def build_model(self):

        if self.crop:
            image_dims = [self.output_height, self.output_width, self.c_dim]
        else:
            image_dims = [self.input_height, self.input_width, self.c_dim]

        ‘‘‘數據輸入層‘‘‘
        self.input_layer = tf.placeholder(tf.float32, [self.batch_size].extend(image_dims), name=input_layer)
        inputs = self.input_layer

        self.z = tf.placeholder(tf.float32, [None, self.z_dim], name=z)
        self.z_sum = tf.summary.histogram(z, self.z)

        ‘‘‘主要計算節點‘‘‘
        # 生成
        self.G                  = self.generator(self.z)
        self.D, self.D_logits   = self.discriminator(inputs, reuse=False)
        self.sampler            = self.sampler(self.z)
        self.D_, self.D_logits_ = self.discriminator(self.G, reuse=True)

        # 記錄
        self.G_sum = tf.summary.image(G, self.G)
        self.D_sum = tf.summary.histogram(D, self.D)
        self.D__sum = tf.summary.histogram(D_, self.D_)

        ‘‘‘損失函數‘‘‘
        # 構建
        self.d_loss_real = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(self.D_logits,tf.ones_like(self.D)))
        self.d_loss_fake = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(self.D_logits_,tf.zeros_like(self.D_)))
        self.g_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(self.D_logits_,tf.ones_like(self.D_)))
        self.d_loss = self.d_loss_real + self.d_loss_fake

        # 記錄
        self.d_loss_real_sum = tf.Summary.scalar("d_loss_real",self.d_loss_real)
        self.d_loss_fake_sum = tf.Summary.scalar("d_loss_fake",self.d_loss_fake)
        self.g_loss_sum = tf.Summary.scalar("g_loss",self.g_loss)
        self.d_loss_sum = tf.Summary.scalar("d_loss",self.d_loss)

        # 訓練參數分離
        t_vars = tf.trainable_variables()
        self.d_vars = [var for var in t_vars if d_ in var.name]
        self.g_vars = [var for var in t_vars if g_ in var.name]

        # 保存器類
        self.saver = tf.train.Saver()

    def discriminator(self,image,reuse=False):
        with tf.variable_scope(discriminator, reuse=reuse):
            h0 = ops.lrelu(ops.conv2d(image,self.df_dim,name=d_h0_conv))
            h1 = ops.lrelu(self.d_bn1(ops.conv2d(h0,self.df_dim * 2,name=d_h1_conv)))
            h2 = ops.lrelu(self.d_bn2(ops.conv2d(h1,self.df_dim * 4,name=d_h2_conv)))
            h3 = ops.lrelu(self.d_bn3(ops.conv2d(h2,self.df_dim * 8,name=d_h3_conv)))
            h4 = ops.linear(tf.reshape(h3,[self.batch_size,-1]),1,d_h4_lin)

        return tf.nn.sigmoid(h4),h4

    def generator(self,z):
        with tf.variable_scope(generator):
            s_h, s_w = self.output_height, self.output_width                        # 生成圖片大小
            s_h2,s_w2 = conv_out_size_same(s_h,2),conv_out_size_same(s_w,2)
            s_h4,s_w4 = conv_out_size_same(s_h2,2),conv_out_size_same(s_w2,2)
            s_h8,s_w8 = conv_out_size_same(s_h4,2),conv_out_size_same(s_w4,2)
            s_h16,s_w16 = conv_out_size_same(s_h8,2),conv_out_size_same(s_w8,2)

            # batch_size不變,h、w每層擴大一倍,c每層縮小一半

            # 線性層
            self.z_,self.h0_w,self.h0_b = ops.linear(z,self.gf_dim * 8 * s_h16 * s_w16,g_h0_lin,with_w=True)
            self.h0 = tf.reshape(self.z_,[-1,s_h16,s_w16,self.gf_dim * 8])
            h0 = tf.nn.relu(self.g_bn0(self.h0))

            # 轉置卷積層
            self.h1,self.h1_w,self.h1_b = ops.deconv2d(h0,[self.batch_size,s_h8,s_w8,self.gf_dim * 4],name=g_h1,with_w=True)
            h1 = tf.nn.relu(self.g_bn1(self.h1))

            h2,self.h2_w,self.h2_b = ops.deconv2d(h1,[self.batch_size,s_h4,s_w4,self.gf_dim * 2],name=g_h2,with_w=True)
            h2 = tf.nn.relu(self.g_bn2(h2))

            h3,self.h3_w,self.h3_b = ops.deconv2d(h2,[self.batch_size,s_h2,s_w2,self.gf_dim * 1],name=g_h3,with_w=True)
            h3 = tf.nn.relu(self.g_bn3(h3))

            h4,self.h4_w,self.h4_b = ops.deconv2d(h3,[self.batch_size,s_h,s_w,self.c_dim],name=g_h4,with_w=True)

        return tf.nn.tanh(h4)

5、預測部分:

一般網絡用於predict標簽的部分,對應到GAN就是生成仿真圖片的位置,這裏是不參與訓練的

    def sampler(self,z):
        # 和生成器完全相同的結構且共享了變量,知識在正則化處is_training為False,這影響了滑動平均使用的兩個部分
        with tf.variable_scope("generator") as scope:
            scope.reuse_variables()

            s_h,s_w = self.output_height,self.output_width
            s_h2,s_w2 = conv_out_size_same(s_h,2),conv_out_size_same(s_w,2)
            s_h4,s_w4 = conv_out_size_same(s_h2,2),conv_out_size_same(s_w2,2)
            s_h8,s_w8 = conv_out_size_same(s_h4,2),conv_out_size_same(s_w4,2)
            s_h16,s_w16 = conv_out_size_same(s_h8,2),conv_out_size_same(s_w8,2)

            h0 = tf.reshape(ops.linear(z,self.gf_dim * 8 * s_h16 * s_w16,g_h0_lin), [-1,s_h16,s_w16,self.gf_dim * 8])
            h0 = tf.nn.relu(self.g_bn0(h0,train=False))

            h1 = ops.deconv2d(h0,[self.batch_size,s_h8,s_w8,self.gf_dim * 4],name=g_h1)
            h1 = tf.nn.relu(self.g_bn1(h1,train=False))

            h2 = ops.deconv2d(h1,[self.batch_size,s_h4,s_w4,self.gf_dim * 2],name=g_h2)
            h2 = tf.nn.relu(self.g_bn2(h2,train=False))

            h3 = ops.deconv2d(h2,[self.batch_size,s_h2,s_w2,self.gf_dim * 1],name=g_h3)
            h3 = tf.nn.relu(self.g_bn3(h3,train=False))

            h4 = ops.deconv2d(h3,[self.batch_size,s_h,s_w,self.c_dim],name=g_h4)

6、訓練部分:

超級麻煩的部分,

  • 構建優化器
  • 載入上次訓練的結果
  • 叠代訓練
    • 讀取batch_size數據
    • feed進網絡訓練
    • 輸出中間參量輔助查看
    • 保存模型
    def train(self,config):
        # 辨別器優化(總)
        d_optim = tf.train.AdamOptimizer(config.learning_rate,beta1=config.beta1)             .minimize(self.d_loss,var_list=self.d_vars)
        # 生成器優化
        g_optim = tf.train.AdamOptimizer(config.learning_rate,beta1=config.beta1)             .minimize(self.g_loss,var_list=self.g_vars)

        tf.global_variables_initializer().run()

        # 記錄各個值叠代的變化
        self.g_sum = tf.Summary.merge([self.z_sum,self.D__sum, self.G_sum,self.d_loss_fake_sum,self.g_loss_sum])
        self.d_sum = tf.summary.merge([self.z_sum,self.d_sum,self.d_loss_real_sum,self.d_loss_sum])

        self.writer = tf.Summary.Writer("./logs",self.sess.graph)

        # 讀取sample_num張圖片
        sample_files = self.data[0:self.sample_num]
        sample = [utils.get_image(sample_file,
                      input_height=self.input_height,
                      input_width=self.input_width,
                      resize_height=self.output_height,
                      resize_width=self.output_width,
                      crop=self.crop) for sample_file in sample_files]
        sample_inputs = np.array(sample).astype(np.float32)
        sample_z = np.random.uniform(-1,1,size=(self.sample_num,self.z_dim))

        counter = 1
        start_time = time.time()
        could_load,checkpoint_counter = self.load(self.checkpoint_dir)

        # 載入model繼續訓練
        if could_load:
            counter = checkpoint_counter
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")

        for epoch in range(config.epoch):
            self.data = glob(os.path.join(
                "./data",config.dataset,self.input_fname_pattern))
            batch_idxs = min(len(self.data),config.train_size) // config.batch_size
            for idx in range(0,batch_idxs):

                # 讀取batch圖片x
                batch_files = self.data[idx * config.batch_size:(idx + 1) * config.batch_size]
                batch = [
                    utils.get_image(batch_file,
                              input_height=self.input_height,
                              input_width=self.input_width,
                              resize_height=self.output_height,
                              resize_width=self.output_width,
                              crop=self.crop) for batch_file in batch_files]
                batch_images = np.array(batch).astype(np.float32)

                # 生成噪聲z
                batch_z = np.random.uniform(-1,1,[config.batch_size,self.z_dim])                     .astype(np.float32)

                # Update D network
                _,summary_str = self.sess.run([d_optim,self.d_sum],
                                              feed_dict={self.input_layer: batch_images,self.z: batch_z})
                self.writer.add_summary(summary_str,counter)

                # Update G network
                _,summary_str = self.sess.run([g_optim,self.g_sum],
                                              feed_dict={self.z: batch_z})
                self.writer.add_summary(summary_str,counter)                # 書寫器書寫的並不是一般意義上的記錄而是普通的標量值

                # Update G network
                # Run g_optim twice to make sure that d_loss does not go to zero (different from paper)
                _,summary_str = self.sess.run([g_optim,self.g_sum],
                                              feed_dict={self.z: batch_z})
                self.writer.add_summary(summary_str,counter)

                # run損失值
                errD_fake = self.d_loss_fake.eval({self.z: batch_z})
                errD_real = self.d_loss_real.eval({self.input_layer: batch_images})
                errG = self.g_loss.eval({self.z: batch_z})

                counter += 1
                print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f"                       % (epoch,idx,batch_idxs,
                         time.time() - start_time,errD_fake + errD_real,errG))
                if np.mod(counter,100) == 1:
                    try:
                        samples,d_loss,g_loss = self.sess.run(
                            [self.sampler,self.d_loss,self.g_loss],
                            feed_dict={
                                self.z: sample_z,
                                self.input_layer: sample_inputs,
                            },
                        )
                        utils.save_images(samples,utils.image_manifold_size(samples.shape[0]),
                                    ./{}/train_{:02d}_{:04d}.png.format(config.sample_dir,epoch,idx))
                        print("[Sample] d_loss: %.8f, g_loss: %.8f" % (d_loss,g_loss))
                    except:
                        print("one pic error!...")
                if np.mod(counter,500) == 2:
                    self.save(config.checkpoint_dir,counter)

保存&載入模型的一個demo

個人感覺功能有點臃腫,不過還是很值得借鑒的,

比如使用裝飾器把函數隱藏成屬性這個我就感覺很沒必要,畢竟都是自家內部調用... ...

檢查文件夾時的固定搭配這個就很不錯:

if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)

作者為了跑不同的數據集在文件名歸類上下了一番功夫,所以load模塊比較復雜,所以適當的多給了一些註釋

    ‘‘‘模型保存&載入‘‘‘

    # checkpoint_dir/datasetname_batchsize_outputheight_outputwidth/模型
    @property
    def model_dir(self):
        return "{}_{}_{}_{}".format(
            self.dataset_name,self.batch_size,
            self.output_height,self.output_width)

    def save(self,checkpoint_dir,step):
        model_name = "DCGAN.model"
        checkpoint_dir = os.path.join(checkpoint_dir,self.model_dir)

        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)

        self.saver.save(self.sess,
                        os.path.join(checkpoint_dir,model_name),
                        global_step=step)

    def load(self,checkpoint_dir):
        import re
        print(" [*] Reading checkpoints...")
        checkpoint_dir = os.path.join(checkpoint_dir,self.model_dir)                  # 合並模型根路徑和數據集路徑
        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)                          # 模型保存文件夾->最新模型文件名
        if ckpt and ckpt.model_checkpoint_path:
            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)                  # 提取無路徑模型文件名,感覺沒有必要,checkpoint保存的名字本身就是不帶路徑的
            self.saver.restore(self.sess,os.path.join(checkpoint_dir,ckpt_name))      # 載入參數
            counter = int(next(re.finditer("(\d+)",ckpt_name)).group(0))              # 提取訓練輪數
            print(" [*] Success to read {}".format(ckpt_name))
            return True,counter
        else:
            print(" [*] Failed to find a checkpoint")
        return False,0

附:腳本調用

import os
import pprint
import numpy as np
import tensorflow as tf

from model import DCGAN


# 接收命令行參數分三步

flags = tf.app.flags

flags.DEFINE_integer("epoch", 25, "Epoch to train [25]")
flags.DEFINE_float("learning_rate", 0.0002, "Learning rate of for adam [0.0002]")
flags.DEFINE_float("beta1", 0.5, "Momentum term of adam [0.5]")
flags.DEFINE_integer("train_size", np.inf, "The size of train images [np.inf]")
flags.DEFINE_integer("batch_size", 64, "The size of batch images [64]")
flags.DEFINE_integer("input_height", 108, "The size of image to use (will be center cropped). [108]")
flags.DEFINE_integer("input_width", None, "The size of image to use (will be center cropped). If None, same value as input_height [None]")
flags.DEFINE_integer("output_height", 64, "The size of the output images to produce [64]")
flags.DEFINE_integer("output_width", None, "The size of the output images to produce. If None, same value as output_height [None]")
flags.DEFINE_string("dataset", "celebA", "The name of dataset [celebA, mnist, lsun]")
flags.DEFINE_string("input_fname_pattern", "*.jpg", "Glob pattern of filename of input images [*]")
flags.DEFINE_string("checkpoint_dir", "checkpoint", "Directory name to save the checkpoints [checkpoint]")
flags.DEFINE_string("sample_dir", "samples", "Directory name to save the image samples [samples]")
flags.DEFINE_boolean("train", False, "True for training, False for testing [False]")
flags.DEFINE_boolean("crop", False, "True for training, False for testing [False]")
flags.DEFINE_boolean("visualize", False, "True for visualizing, False for nothing [False]")

FLAGS = flags.FLAGS


# 必須帶參數,否則:‘TypeError: main() takes no arguments (1 given)‘;
# main的參數名隨意定義,無要求
def main(_):
    # pprint模塊,更美觀的顯示數據結構
    pp = pprint.PrettyPrinter()
    pp.pprint(flags.FLAGS.__flags)

    if FLAGS.input_width is None:
        FLAGS.input_width = FLAGS.input_height
    if FLAGS.output_width is None:
        FLAGS.output_width = FLAGS.output_height

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)


    run_config = tf.ConfigProto()
    # TensorFlow占用gpu資源的默認方式異常貪婪,這裏修改為按需求申請
    run_config.gpu_options.allow_growth = True
    # 下面的是按比例申請
    # run_config.gpu_options.per_process_gpu_memory_fraction=0.333

    with tf.Session(config=run_config) as sess:
        dcgan = DCGAN(
            sess,
            input_width=FLAGS.input_width,
            input_height=FLAGS.input_height,
            output_width=FLAGS.output_width,
            output_height=FLAGS.output_height,
            batch_size=FLAGS.batch_size,
            sample_num=FLAGS.batch_size,
            dataset_name=FLAGS.dataset,
            input_fname_pattern=FLAGS.input_fname_pattern,
            crop=FLAGS.crop,
            checkpoint_dir=FLAGS.checkpoint_dir,
            sample_dir=FLAGS.sample_dir)

    if FLAGS.train:
        dcgan.train(FLAGS)
    else:
        if not dcgan.load(FLAGS.checkpoint_dir)[0]:
            raise Exception("[!] Train a model first, then run test mode")

if __name__==__main__:
    tf.app.run()

預測部分沒寫好,所以沒加上來,但是這不妨礙理解思路

值得一提的是dcgan.train(FLAGS),這裏直接傳入了FLAGS,對應內部train函數接收參數config,{config.參數名}這樣的調用方法十分方便,這也有助於理解腳本化TF程序的便利之處『TensorFlow』腳本化使用方法。

『TensorFlow』以GAN為例的神經網絡類範式