1. 程式人生 > >DCGANs原始碼解析(二)

DCGANs原始碼解析(二)

model.py

DCGANs大部分都在一個叫做 DCGAN 的 Python 類(class)中(model.py)。像這樣把所有東西都放在一個類中非常有用,因為訓練後中間狀態可以被儲存起來,以便後面使用。

首先讓我們定義生成器和鑑別器(上一篇已經介紹過了)。
linear, conv2d_transpose, conv2d, 和 lrelu 函式都是在 ops.py 中定義的。

1.初始化DCGAN類

我們初始化DCGAN類時,就用generator和discriminator這些函式創造了模型。

我們需要兩種版本的鑑別器,他們共享同樣的引數。一個用於來自真實資料分佈的小批影象,另一個用於來自生成器的小批影象。下面self.D等是來自真實圖片資料的判別器,self.D_等是來自生成器圖片的判別器。

self.G = self.generator(self.z)
self.D, self.D_logits = self.discriminator(self.images)
self.D_, self.D_logits_ = self.discriminator(self.G, reuse=True)

2.定義損失函式

接著,我們將定義損失函式。在這裡不用求和(sums),我們用D的預測和我想讓它更好地工作而對它的期望之間的交叉熵( cross entropy (https://en.wikipedia.org/wiki/Cross_entropy))。

鑑別器想讓來自真實資料的預測都為1,而來自生成器的假造資料都為0。生成器想讓鑑別器的所有預測都為1.下面是根據這個預期定義的損失函式

#d_loss_real是真實圖片輸入到判別器中的結果和預期的為1的結果之間的交叉熵
self.d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.D_logits,tf.ones_like(self.D)))  
#d_loss_fake是生成器生成的圖片輸入到判別器中的結果和預期為0的結果之間的交叉熵
self.d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.D_logits_,tf.zeros_like(self.D_)))
#判別器的損失函式d_loss是d_loss_fake和d_loss_real之和
self.d_loss = self.d_loss_real + self.d_loss_fake

#生成器的損失函式d_loss是生成器生成的圖片輸入到判別器中的結果和預期為1的結果之間的交叉熵
self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.D_logits_,tf.ones_like(self.D_)))

3.收集變數

分別從每個模型中收集變數,讓它們可以被分開訓練。

t_vars = tf.trainable_variables()
self.d_vars = [var for var in t_vars if 'd_' in var.name]
self.g_vars = [var for var in t_vars if 'g_' in var.name]

4.定義優化器

現在我們準備好優化引數了,我們要用的是 ADAM (https://arxiv.org/abs/1412.6980),這是一種適應的非凸優化方法,通常用於現代深度學習中。ADAM 經常會與 SGD 競爭,而且通常不需要手動調節學習速率,動量,及其他超引數(hyper-parameter)。

d_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1) \
                  .minimize(self.d_loss, var_list=self.d_vars)
g_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1) \
                  .minimize(self.g_loss, var_list=self.g_vars)

這裡優化器選擇ADAM ,最終目標是要最小化d_loss和g_loss。

5.訓練

我們準備好遍歷資料了。在每一個時期,我們在一個小批圖片中取樣,執行優化器升級網路。有趣的是,如果 G 只更新了一次,鑑別器的損耗就不會為零。而且,我認為最後對 d_loss_fake 和 d_loss_real 函式的額外呼叫引發了一點不必要的計算,而且是多餘的,因為這些值已經作為 d_optim 和 g_optim 的一部分計算過了。作為 TensorFlow 中的一項練習,你可以試著用這個部分去優化,並給原始 repo 傳送一個 PR 。

for epoch in xrange(config.epoch):
    ...
    for idx in xrange(0, batch_idxs):
        batch_images = ...
        batch_z = np.random.uniform(-1, 1, [config.batch_size, self.z_dim]).astype(np.float32)

        # Update D network
        #更新一個 D 網路
        _, summary_str = self.sess.run([d_optim, self.d_sum],
            feed_dict={ self.images: batch_images, self.z: batch_z })

        # Update G network
        #更新一個 G 網路
        _, summary_str = self.sess.run([g_optim, self.g_sum],
            feed_dict={ self.z: batch_z })

        # Run g_optim twice to make sure that d_loss does not go to zero*
        # (different from paper)
       #執行兩次*g_optim 以確保 d_loss 不會變成0
       #(與論文裡不一樣)
        _, summary_str = self.sess.run([g_optim, self.g_sum],
            feed_dict={ self.z: batch_z })

        errD_fake = self.d_loss_fake.eval({self.z: batch_z})
        errD_real = self.d_loss_real.eval({self.images: batch_images})
        errG = self.g_loss.eval({self.z: batch_z})

這裡 self.sess.run()函式是執行一個會話,第一個引數是圖的輸出節點,第二個引數圖的輸入節點。如

self.sess.run([d_optim, self.d_sum], feed_dict={ self.images: batch_images, self.z: batch_z }),

上面的會話會根據輸出節點d_optim, self.d_sum在圖中找到最初的輸入節點。

d_optim———>d_loss——->D_logits, D_logits_。

其中D_logits的輸入是self.images, D_logits_的輸入是self.z。因此這裡run的第二個引數應該為{self.images,self.z}。

但是self.images,self.z只是個用placeholder定義的佔位符,因此需要指定實際的輸入。所以,這裡用feed_dict指定了個字典,key值為self.images的佔位符對應的值為batch_images,即載入的真實圖片資料。key值為self.z的佔位符對應的值為batch_z,即噪音資料。

這裡看一下self.images,self.z的定義,均是用placeholder生成的佔位符。

self.images = tf.placeholder(tf.float32, [self.batch_size] + [self.output_size, self.output_size, self.c_dim],
                                name='real_images')

self.z = tf.placeholder(tf.float32, [None, self.z_dim],name='z') 

介紹tensorflow

張量(Tensor)

名字就是TensorFlow,直觀來看,就是張量的流動。張量(tensor),即任意維度的資料,一維、二維、三維、四維等資料統稱為張量。而張量的流動則是指保持計算節點不變,讓資料進行流動。

這樣的設計是針對連線式的機器學習演算法,比如邏輯斯底迴歸,神經網路等。連線式的機器學習演算法可以把演算法表達成一張圖,張量在圖中從前到後走一遍就完成了前向運算;而殘差從後往前走一遍,就完成了後向傳播。

運算元(operation)

在TF的實現中,機器學習演算法被表達成圖,圖中的節點是運算元(operation),節點會有0到多個輸出,下圖是TF實現的一些運算元。

每個運算元都會有屬性,所有的屬性都在建立圖的時候被確定下來,比如,最常用的屬性是為了支援多型,比如加法運算元既能支援float32,又能支援int32計算。

邊(edge)

TF的圖中的邊分為兩種:

正常邊,正常邊上可以流動資料,即正常邊就是tensor

特殊邊,又稱作控制依賴,(control dependencies)

  1. 沒有資料從特殊邊上流動,但是特殊邊卻可以控制節點之間的依賴關係,在特殊邊的起始節點完成運算之前,特殊邊的結束節點不會被執行。
  2. 也不僅僅非得有依賴關係才可以用特殊邊,還可以有其他用法,比如為了控制記憶體的時候,可以讓兩個實際上並沒有前後依賴關係的運算分開執行。
  3. 特殊邊可以在client端被直接使用

會話(Session)

客戶端使用會話來和TF系統互動,一般的模式是,建立會話,此時會生成一張空圖;在會話中新增節點和邊,形成一張圖,然後執行。

下圖有一個TF的會話樣例和所對應的圖示。

這裡寫圖片描述

這裡寫圖片描述
變數(Variables)

機器學習演算法都會有引數,而引數的狀態是需要儲存的。而引數是在圖中有其固定的位置的,不能像普通資料那樣正常流動。因而,TF中將Variables實現為一個特殊的運算元,該運算元會返回它所儲存的可變tensor的控制代碼。

相關推薦

DCGANs原始碼解析

model.py DCGANs大部分都在一個叫做 DCGAN 的 Python 類(class)中(model.py)。像這樣把所有東西都放在一個類中非常有用,因為訓練後中間狀態可以被儲存起來,以便後面使用。 首先讓我們定義生成器和鑑別器(上一篇已經介紹過了

Spring原始碼解析——元件註冊2

    import com.ken.service.BookService; import org.springframework.context.annotation.ComponentScan; import org.springframework.context.

認真的 Netty 原始碼解析

Channel 的 register 操作 經過前面的鋪墊,我們已經具備一定的基礎了,我們開始來把前面學到的內容揉在一起。這節,我們會介紹 register 操作,這一步其實是非常關鍵的,對於我們原始碼分析非常重要。 register 我們從 EchoClient 中的 connect() 方法出發,或者 E

jquery 1.7.2原始碼解析構造jquery物件

構造jquery物件 jQuery物件是一個類陣列物件。 一)建構函式jQuery() 建構函式的7種用法:   1.jQuery(selector [, context ]) 傳入字串引數:檢查該字串是選擇器表示式還是HTML程式碼。如果是選擇器表示式,則遍歷文件查詢匹配的DOM元

java集合原始碼解析--AbstractCollection

今天帶來的是java單列頂層介面的第一個輕量級實現:AbstractCollection 我們直接進入正題,先來看看它的宣告: package java.util; //可以從名字上同樣看到 AbstractCollection 是一個抽象類,所以並不能例項化, //這個類只是作

EventBus原始碼解析—釋出事件和登出流程

1.EventBus原始碼解析(一)—訂閱過程 2.EventBus原始碼解析(二)—釋出事件和登出流程 前言 上一篇部落格已經比較詳細的講解了EventBus的註冊過程,有了上一篇部落格的基礎,其實關於EventBus的原始碼中的其他流程就非常好理解了,尤其是我

Spring原始碼解析:obtainFreshBeanFactory

spring的ApplicationContext容器的初始化流程主要由AbstractApplicationContext類中的refresh方法實現。 而refresh()方法中獲取新工廠的主要是由obtainFreshBeanFactory()實現的,後續的操作均是beanFactoty的進一步處理。

Redis5.0原始碼解析----------連結串列

基於Redis5.0 連結串列提供了高效的節點重排能力, 以及順序性的節點訪問方式, 並且可以通過增刪節點來靈活地調整連結串列的長度 每個連結串列節點使用一個 adlist.h/listNode 結構來表示: //adlist.h - A generic do

ThreadPoolExecutor原始碼解析

1.ThreadPoolExcuter執行例項 首先我們先看如何新建一個ThreadPoolExecutor去執行執行緒。然後深入到原始碼中去看ThreadPoolExecutor裡面使如何運作的。 public class Test { public

redis原始碼解析動態字串sds基本功能函式

1. 簡介   本文繼上文基礎上,分析動態字串的功能函式,位於sds.c。由於函式較多,本篇介紹實現動態變化的基本增刪新建釋放函式。 2. 原始碼分析   sdsHdrSize()函式用於返回sdshdr的大小,主要使用sizeof()函式實現。 /*返回sds

OKHttp 3.10原始碼解析:攔截器鏈

本篇文章我們主要來講解OKhttp的攔截器鏈,攔截器是OKhttp的主要特色之一,通過攔截器鏈,我們可以對request或response資料進行相關處理,我們也可以自定義攔截器interceptor。 上一篇文章中我們講到,不管是OKhttp的同步請求還是非同步請求,都會呼叫RealCal

OkHttp原始碼解析

上一篇講到OkHttp的整體流程,但是裡面有一個很重要的方法getResponseWithInterceptorChain還沒有講到,這個方法裡面包括了整個okHttp最重要的攔截器鏈,所以我們今天來講解一下。 Response getResponseWithI

Java容器——HashMapJava8原始碼解析

在前文中介紹了HashMap中的重要元素,現在萬事俱備,需要刨根問底看看實現了。HashMap的增刪改查,都離不開元素查詢。查詢分兩部分,一是確定元素在table中的下標,二是確定以指定下標元素為首的具體位置。可以抽象理解為二維陣列,第一個通過雜湊函式得來,第二個下標則是連結串列或紅黑樹來得到,下面

Yolov2原始碼解析

一、資料集製作 首先是從官網上下載VOC2012資料集,這裡我個人得到是訓練集檔案:VOCtrainval_11-May-2012,為了減輕訓練開銷,我將驗證集作為測試集,通過將Main資料夾下的val.txt改名為test.txt檔案,將資料集製作成hdf5檔案的形式。 import os

RxJava2 原始碼解析

概述 知道源頭(Observable)是如何將資料傳送出去的。 知道終點(Observer)是如何接收到資料的。 何時將源頭和終點關聯起來的 知道執行緒排程是怎麼實現的 知道操作符是怎麼實現的 本篇計劃講解一下4,5. RxJava最強大的莫過

Spark2.3.2原始碼解析: 8. RDD 原始碼解析 textFile 返回的RDD例項是什麼

  本文主要目標是分析RDD的例項物件,到底放了什麼。 從程式碼val textFile = sc.textFile(args(0)) 開始: 直接看textFile 原始碼: 你會發現呼叫的是hadoop的api,通過 hadoopFile 讀取資料,返回一個hadoop

mybatis通用mapper原始碼解析

1.javabean的屬性值生成sql /** * 獲取所有查詢列,如id,name,code... * * @param entityClass * @return */ public static String getAllColumns(C

python原始碼解析

 一:PyObject  首先,先來看PyObject在object.h中的定義。typedef struct _object { _PyObject_HEAD_EXTRA Py_ssize_t ob_refcnt; struct _typeobjec

antd原始碼解析button控制元件的解析

第一節我們看了antd的button原始碼,現在我們用class的常用寫法改造下: import React,{ Component } from "React"; var _classnames2 = require('classnames');

【Servicemesh系列】【Envoy原始碼解析】一個Http請求到響應的全鏈路

目錄 1. http連線建立 當有新連線過來的時候,會呼叫上一章節所提及的被註冊到libevent裡面的回撥函式。我們回顧一下,上一章節提及了,會有多個worker註冊所有的listener,當有一個連線過來的時候,系統核心會排程一個執行緒出來交付