1. 程式人生 > >手把手教你用GAN實現半監督學習

手把手教你用GAN實現半監督學習

引言

本文主要介紹如何在tensorflow上僅使用200個帶標籤的mnist影象,實現在一萬張測試圖片上99%的測試精度,原理在於使用GAN做半監督學習。前文主要介紹一些原理部分,後文詳細介紹程式碼及其實現原理。前文介紹比較簡單,有基礎的同學請掠過直接看第二部分,文章末尾給出了程式碼GitHub連結。對GAN不瞭解的同學可以檢視微信公眾號:機器學習演算法全棧工程師 的GAN入門文章。

本部落格中的程式碼最終以GitHub中的程式碼為準,GitHub連結在文章底部,另外,本文已投稿至微信公眾號:機器學習演算法全棧工程師,歡迎關注此公眾號

1.監督,無監督,半監督學習介紹

在正式介紹實現半監督學習之前,我在這裡首先介紹一下監督學習(supervised learning),半監督學習(semi-supervised learning)和無監督學習(unsupervised learning)的區別。監督學習是指在訓練集中包含訓練資料的標籤(label),比如類別標籤,位置標籤等等。最普遍使用標籤學習的是分類任務,對於分類任務,輸入給網路訓練樣本(samples)的一些特徵(feature)以及此樣本對應的標籤(label),通過神經網路擬合的方法,神經網路可以在特徵和標籤之間找到一個合適的對映關係(mapping),這樣當訓練完成後,輸入給網路沒有label的樣本,神經網路可以通過這一個對映關係猜出它屬於哪一類。典型機器學習的監督學習的例子是KNN和SVM。目前機器視覺領域的急速發展離不開監督學習。
而無監督學習的訓練事先沒有訓練標籤,直接輸入給演算法一些資料,演算法會努力學習資料的共同點,尋找樣本之間的規律性。無監督學習是很典型的學習,人的學習有時候就是基於無監督的,比如我並不懂音樂,但是我聽了上百首歌曲後,我可以根據我聽的結果將音樂分為搖滾樂(記為0類)、民謠(記為1類)、純音樂(記為2類)等等,事實上,我並不知道具體是哪一類,所以將它們記為0,1,2三類。典型的無監督學習方法是聚類演算法,比如k-means。
東方快車電影裡面大偵探有過一個臺詞,人們的話只有對與錯,沒有中間地帶,最後經過一系列事件後他找到了對與錯之間的betweeness。在監督學習和無監督學習之間,同樣存在著中間地帶-半監督學習。半監督學習簡單來說就是將無監督學習和監督學習相結合,一部分包含了監督學習一部分包含了無監督學習,比如給一個分類任務,此分類任務的訓練集中有精確標籤的資料非常少,但是包含了大量的沒有標註的資料,如果直接用監督學習的方法去做的話,效果不一定很好,有標註的訓練資料太少很容易導致過擬合,而且大量的無標註的資料都沒有充分的利用,最常見的例子是在醫學影象的分析檢測任務中,醫學影象本身就不容易獲得,要獲得精標註的影象就需要有經驗的醫生去一個一個標註,顯然他們並沒有那麼多的時間。這時候就是半監督學習的用武之地了,半監督學習很適合用在標籤資料少,訓練資料又比較多的情況。
常見的半監督學習方法主要有:
1.Self training
2.Generative model
3.S3VMs
4.Graph-Based AIgorithems
5.Multiview AIgorithems
接下來我會結合Improved Techniques for Training GANs這篇論文詳細介紹如何使用目前最火的生成模型GAN去實現半監督學習,也即是半監督學習的第二種方法,並給出詳細的程式碼解釋,對理論不是很熟悉的同學可以直接看程式碼。另外註明:我只復現了論文半監督學習的部分,之前也有人復現了此部分,但是我感覺他對原文有很大的曲解,他使用了所有的標籤去幫助生成,並不在分類上,不太符合半監督學習的本質,而且程式碼很複雜,感興趣的可以去GitHub上搜ssgan,希望能幫助你。

2. Improved Techniques for Training GANs

GAN是無監督學習的代表,它可以不斷學習模擬資料的分佈進而生成和訓練資料相似分佈的樣本,在訓練過程不需要標籤,GAN在無監督學習領域,生成領域,半監督學習領域以及強化學習領域都有廣泛的應用。但是GAN存在很多的訓練不穩定等等的問題,作者good fellow在2016年放出了Improved Techniques for Training GANs,對GAN訓練不穩定的問題做了一些解釋和經驗上的解決方案,並給出了和半監督學習結合的方法。
從平衡點角度解釋GAN的不穩定性來說,GAN的納什均衡點是一個鞍點,並不是一個區域性最小值點,基於梯度的方法主要是尋找高維空間中的極小值點,因此使用梯度訓練的方法很難使GAN收斂到平衡點。為此,為了一部分緩解這個問題,goodfellow聯合提出了一些改進方案,
主要有:
Feature matching,
Minibatch discrimination
weight Historical averaging (相當於一個正則化的方式)
One-sided label smoothing
Virtual batch normalization
後來發現Feature matching在半監督學習上表現良好,mini-batch discrimination表現很差。

3. semi-supervised GAN

對於一個普通的分類器來說,假設對MNIST分類,一共有10類資料,分別是0-9,分類器模型以資料x作為輸入,輸出一個K=10維的向量,經過soft max後計算出分類概率最大的那個類別。在監督學習領域,往往是通過最小化類別標籤 和預測分佈 的交叉熵來實現最好的結果。
但是將GAN用在半監督學習領域的時候需要做一些改變,生成器不做改變,仍然負責從輸入噪聲資料中生成影象,判別器D不在是一個簡單的真假分類(二分類)器,假設輸入資料有K類,D就是K+1的分類器,多出的那一類是判別輸入是否是生成器G生成的影象。網路的流程圖見下圖:
GAN半監督學習流程圖
網路結構確定了之後就是損失函式的設計部分,藉助GAN我們就可以從無標籤資料中學習,只要知道輸入資料是真實資料,那就可以通過最大化 這裡寫圖片描述

來實現,上述式子可解釋為不管輸入的是哪一類真的圖片(不是生成器G生成的假圖片),只要最大化輸出它是真影象的概率就可以了,不需要具體分出是哪一類。由於GAN的生成器的參與,訓練資料中有一半都是生成的假資料。
下面給出判別器D的損失函式設計,D損失函式包括兩個部分,一個是監督學習損失,一個是半監督學習損失,具體公式如下:
這裡寫圖片描述
其中
這裡寫圖片描述

對於無監督學習來說,只需要輸出真假就可以了,不需要確定是哪一類,因此我們令
這裡寫圖片描述

其中 Pmodel 表示判別是假影象的概率,那麼D(x)就代表了輸出是真影象的概率,那麼無監督學習的損失函式就可以表示為
這裡寫圖片描述

這不就是GAN的損失函式嘛!好了,到這裡得出結論,在半監督學習中,判別器的分類要多分一類,多出的這一類表示的是生成器生成的假影象這一類,另外判別器的損失函式不僅包括了監督損失而且還有無監督的損失函式,在訓練過程中同時最小化這兩者。損失函式介紹完畢,接下來介紹程式碼實現部分。

4.程式碼實現及解讀

注:完整程式碼的GitHub連線在文章底部。這裡只擷取關鍵部分做介紹
在程式碼中,我使用feature matching,one side label smoothing方式,並沒有使用論文中介紹的Historical averaging,而是隻對判別器D使用了簡單的l2正則化,防止過擬合,另外論文中介紹的Minibatch discrimination, Virtual batch normalization等等都沒有使用,主要是這兩者在半監督學習中表現不是很好,但是如果想獲得好的生成結果還是很有用的。
首先介紹網路結構部分,因為是在mnist資料集比較簡單,所以隨便搭了一個判別器和生成器,具體如下:

def discriminator(self, name, inputs, reuse):
        l = tf.shape(inputs)[0]
        inputs = tf.reshape(inputs, (l,self.img_size,self.img_size,self.dim))
        with tf.variable_scope(name,reuse=reuse):
            out = []
            output = conv2d('d_con1',inputs,5, 64, stride=2, padding='SAME') #14*14
            output1 = lrelu(self.bn('d_bn1',output))
            out.append(output1)
            # output1 = tf.contrib.keras.layers.GaussianNoise
            output = conv2d('d_con2', output1, 3, 64*2, stride=2, padding='SAME')#7*7
            output2 = lrelu(self.bn('d_bn2', output))
            out.append(output2)
            output = conv2d('d_con3', output2, 3, 64*4, stride=1, padding='VALID')#5*5
            output3 = lrelu(self.bn('d_bn3', output))
            out.append(output3)
            output = conv2d('d_con4', output3, 3, 64*4, stride=2, padding='VALID')#2*2
            output4 = lrelu(self.bn('d_bn4', output))
            out.append(output4)
            output = tf.reshape(output4, [l, 2*2*64*4])# 2*2*64*4
            output = fc('d_fc', output, self.num_class)
            # output = tf.nn.softmax(output)
            return output, out

其中conv2d()是卷積操作,引數依次是,層的名字,輸入tensor,卷積核大小,輸出通道數,步長,padding。判別器中每一層都加了歸一化層,這裡使用最簡單的歸一化,函式如下所示,另外每一層的啟用函式使用leakrelu。判別器D最終返回兩個值,第一個是計算的logits,另外一個是一個列表,列表的每一個元素代表判別器每一層的輸出,為接下來實現feature matching做準備。

生成器結構如下所示:其最後一層啟用函式使用tanh

    def generator(self,name, noise, reuse):
        with tf.variable_scope(name,reuse=reuse):
            l = self.batch_size
            output = fc('g_dc', noise, 2*2*64)
            output = tf.reshape(output, [-1, 2, 2, 64])
            output = tf.nn.relu(self.bn('g_bn1',output))
            output = deconv2d('g_dcon1',output,5,outshape=[l, 4, 4, 64*4])
            output = tf.nn.relu(self.bn('g_bn2',output))

            output = deconv2d('g_dcon2', output, 5, outshape=[l, 8, 8, 64 * 2])
            output = tf.nn.relu(self.bn('g_bn3', output))

            output = deconv2d('g_dcon3', output, 5, outshape=[l, 16, 16,64 * 1])
            output = tf.nn.relu(self.bn('g_bn4', output))

            output = deconv2d('g_dcon4', output, 5, outshape=[l, 32, 32, self.dim])
            output = tf.image.resize_images(output, (28, 28))
            # output = tf.nn.relu(self.bn('g_bn4', output))
            return tf.nn.tanh(output)

網路結構是根據DCGAN的結構改的,所以網路簡要介紹到這裡。

接下來介紹網路初始化方面:
首先在train.py裡建立一個Train的類,並做一些初始化

#coding:utf-8
from glob import glob
from PIL import Image
import matplotlib.pyplot as plt
import scipy.misc as scm
from vlib.layers import *
import tensorflow as tf
import numpy as np
from vlib.load_data import *
import os
import vlib.plot as plot
import vlib.my_extract as dataload
import vlib.save_images as save_img
import time
from tensorflow.examples.tutorials.mnist import input_data #as mnist_data
mnist = input_data.read_data_sets('data/', one_hot=True)
# temp = 0.89
class Train(object):
    def __init__(self, sess, args):
        #sess=tf.Session()
        self.sess = sess
        self.img_size = 28   # the size of image
        self.trainable = True
        self.batch_size = 50  # must be even number
        self.lr = 0.0002
        self.mm = 0.5      # momentum term for adam
        self.z_dim = 128   # the dimension of noise z
        self.EPOCH = 50    # the number of max epoch
        self.LAMBDA = 0.1  # parameter of WGAN-GP
        self.model = args.model  # 'DCGAN' or 'WGAN'
        self.dim = 1       # RGB is different with gray pic
        self.num_class = 11
        self.load_model = args.load_model
        self.build_model()  # initializer

    def build_model(self):
        # build  placeholders
        self.x=tf.placeholder(tf.float32,shape=[self.batch_size,self.img_size*self.img_size*self.dim],name='real_img')
        self.z = tf.placeholder(tf.float32, shape=[self.batch_size, self.z_dim], name='noise')
        self.label = tf.placeholder(tf.float32, shape=[self.batch_size, self.num_class - 1], name='label')
        self.flag = tf.placeholder(tf.float32, shape=[], name='flag')
        self.flag2 = tf.placeholder(tf.float32, shape=[], name='flag2')

        # define the network
        self.G_img = self.generator('gen', self.z, reuse=False)
        d_logits_r, layer_out_r = self.discriminator('dis', self.x, reuse=False)
        d_logits_f, layer_out_f = self.discriminator('dis', self.G_img, reuse=True)

        d_regular = tf.add_n(tf.get_collection('regularizer', 'dis'), 'loss')  # D regular loss
        # caculate the unsupervised loss
        un_label_r = tf.concat([tf.ones_like(self.label), tf.zeros(shape=(self.batch_size, 1))], axis=1)
        un_label_f = tf.concat([tf.zeros_like(self.label), tf.ones(shape=(self.batch_size, 1))], axis=1)
        logits_r, logits_f = tf.nn.softmax(d_logits_r), tf.nn.softmax(d_logits_f)
        d_loss_r = -tf.log(tf.reduce_sum(logits_r[:, :-1])/tf.reduce_sum(logits_r[:,:]))
        d_loss_f = -tf.log(tf.reduce_sum(logits_f[:, -1])/tf.reduce_sum(logits_f[:,:]))
        # d_loss_r = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=un_label_r*0.9, logits=d_logits_r))
        # d_loss_f = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=un_label_f*0.9, logits=d_logits_f))
        # feature match
        f_match = tf.constant(0., dtype=tf.float32)
        for i in range(4):
            f_match += tf.reduce_mean(tf.multiply(layer_out_f[i]-layer_out_r[i], layer_out_f[i]-layer_out_r[i]))

        # caculate the supervised loss
        s_label = tf.concat([self.label, tf.zeros(shape=(self.batch_size,1))], axis=1)
        s_l_r = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=s_label*0.9, logits=d_logits_r))
        s_l_f = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=un_label_f*0.9, logits=d_logits_f))  # same as d_loss_f
        self.d_l_1, self.d_l_2 = d_loss_r + d_loss_f, s_l_r
        self.d_loss = d_loss_r + d_loss_f + s_l_r*self.flag*10 + d_regular
        self.g_loss = d_loss_f + 0.01*f_match

        all_vars = tf.global_variables()
        g_vars = [v for v in all_vars if 'gen' in v.name]
        d_vars = [v for v in all_vars if 'dis' in v.name]
        for v in all_vars:
            print v
        if self.model == 'DCGAN':
            self.opt_d = tf.train.AdamOptimizer(self.lr, beta1=self.mm).minimize(self.d_loss, var_list=d_vars)
            self.opt_g = tf.train.AdamOptimizer(self.lr, beta1=self.mm).minimize(self.g_loss, var_list=g_vars)
        elif self.model == 'WGAN_GP':
            self.opt_d = tf.train.AdamOptimizer(1e-5, beta1=0.5, beta2=0.9).minimize(self.d_loss, var_list=d_vars)
            self.opt_g = tf.train.AdamOptimizer(1e-5, beta1=0.5, beta2=0.9).minimize(self.g_loss, var_list=g_vars)
        else:
            print ('model can only be "DCGAN","WGAN_GP" !')
            return
        # test
        test_logits, _ = self.discriminator('dis', self.x, reuse=True)
        test_logits = tf.nn.softmax(test_logits)
        temp = tf.reshape(test_logits[:, -1],shape=[self.batch_size, 1])
        for i in range(10):
            temp = tf.concat([temp, tf.reshape(test_logits[:, -1],shape=[self.batch_size, 1])], axis=1)
        test_logits -= temp
        self.prediction = tf.nn.in_top_k(test_logits, tf.argmax(s_label, axis=1), 1)

        self.saver = tf.train.Saver()
        if not self.load_model:
            init = tf.global_variables_initializer()
            self.sess.run(init)
        elif self.load_model:
            self.saver.restore(self.sess, os.getcwd()+'/model_saved/model.ckpt')
            print 'model load done'
        self.sess.graph.finalize()

    def train(self):
        if not os.path.exists('model_saved'):
            os.mkdir('model_saved')
        if not os.path.exists('gen_picture'):
            os.mkdir('gen_picture')
        noise = np.random.normal(-1, 1, [self.batch_size, 128])
        temp = 0.80
        print 'training'
        for epoch in range(self.EPOCH):
            # iters = int(156191//self.batch_size)
            iters = 50000//self.batch_size
            flag2 = 1  # if epoch>10 else 0
            for idx in range(iters):
                start_t = time.time()
                flag = 1 if idx < 4 else 0 # set we use 2*batch_size=200 train data labeled.
                batchx, batchl = mnist.train.next_batch(self.batch_size)
                # batchx, batchl = self.sess.run([batchx, batchl])
                g_opt = [self.opt_g, self.g_loss]
                d_opt = [self.opt_d, self.d_loss, self.d_l_1, self.d_l_2]
                feed = {self.x:batchx, self.z:noise, self.label:batchl, self.flag:flag, self.flag2:flag2}
                # update the Discrimater k times
                _, loss_d, d1,d2 = self.sess.run(d_opt, feed_dict=feed)
                # update the Generator one time
                _, loss_g = self.sess.run(g_opt, feed_dict=feed)
                print ("[%3f][epoch:%2d/%2d][iter:%4d/%4d],loss_d:%5f,loss_g:%4f, d1:%4f, d2:%4f"%
                       (time.time()-start_t, epoch, self.EPOCH,idx,iters, loss_d, loss_g,d1,d2)), 'flag:',flag
                plot.plot('d_loss', loss_d)
                plot.plot('g_loss', loss_g)
                if ((idx+1) % 100) == 0:  # flush plot picture per 1000 iters
                    plot.flush()
                plot.tick()
                if (idx+1)%500==0:
                    print ('images saving............')
                    img = self.sess.run(self.G_img, feed_dict=feed)
                    save_img.save_images(img, os.getcwd()+'/gen_picture/'+'sample{}_{}.jpg'\
                                         .format(epoch, (idx+1)/500))
                    print 'images save done'
            test_acc = self.test()
            plot.plot('test acc', test_acc)
            plot.flush()
            plot.tick()
            print 'test acc:{}'.format(test_acc), 'temp:%3f'%(temp)
            if test_acc > temp:
                print ('model saving..............')
                path = os.getcwd() + '/model_saved'
                save_path = os.path.join(path, "model.ckpt")
                self.saver.save(self.sess, save_path=save_path)
                print ('model saved...............')
                temp = test_acc

# output = conv2d('Z_cona{}'.format(i), output, 3, 64, stride=1, padding='SAME')

    def generator(self,name, noise, reuse):
        with tf.variable_scope(name,reuse=reuse):
            l = self.batch_size
            output = fc('g_dc', noise, 2*2*64)
            output = tf.reshape(output, [-1, 2, 2, 64])
            output = tf.nn.relu(self.bn('g_bn1',output))
            output = deconv2d('g_dcon1',output,5,outshape=[l, 4, 4, 64*4])
            output = tf.nn.relu(self.bn('g_bn2',output))

            output = deconv2d('g_dcon2', output, 5, outshape=[l, 8, 8, 64 * 2])
            output = tf.nn.relu(self.bn('g_bn3', output))

            output = deconv2d('g_dcon3', output, 5, outshape=[l, 16, 16,64 * 1])
            output = tf.nn.relu(self.bn('g_bn4', output))

            output = deconv2d('g_dcon4', output, 5, outshape=[l, 32, 32, self.dim])
            output = tf.image.resize_images(output, (28, 28))
            # output = tf.nn.relu(self.bn('g_bn4', output))
            return tf.nn.tanh(output)

    def discriminator(self, name, inputs, reuse):
        l = tf.shape(inputs)[0]
        inputs = tf.reshape(inputs, (l,self.img_size,self.img_size,self.dim))
        with tf.variable_scope(name,reuse=reuse):
            out = []
            output = conv2d('d_con1',inputs,5, 64, stride=2, padding='SAME') #14*14
            output1 = lrelu(self.bn('d_bn1',output))
            out.append(output1)
            # output1 = tf.contrib.keras.layers.GaussianNoise
            output = conv2d('d_con2', output1, 3, 64*2, stride=2, padding='SAME')#7*7
            output2 = lrelu(self.bn('d_bn2', output))
            out.append(output2)
            output = conv2d('d_con3', output2, 3, 64*4, stride=1, padding='VALID')#5*5
            output3 = lrelu(self.bn('d_bn3', output))
            out.append(output3)
            output = conv2d('d_con4', output3, 3, 64*4, stride=2, padding='VALID')#2*2
            output4 = lrelu(self.bn('d_bn4', output))
            out.append(output4)
            output = tf.reshape(output4, [l, 2*2*64*4])# 2*2*64*4
            output = fc('d_fc', output, self.num_class)
            # output = tf.nn.softmax(output)
            return output, out

    def bn(self, name, input):
        val = tf.contrib.layers.batch_norm(input, decay=0.9,
                                           updates_collections=None,
                                           epsilon=1e-5,
                                           scale=True,
                                           is_training=True,
                                           scope=name)
        return val

    # def get_loss(self, logits, layer_out):
    def test(self):
        count = 0.
        print 'testing................'
        for i in range(10000//self.batch_size):
            testx, textl = mnist.test.next_batch(self.batch_size)
            prediction = self.sess.run(self.prediction, feed_dict={self.x:testx, self.label:textl})
            count += np.sum(prediction)
        return count/10000.

args是傳進來的引數,主要包括三個,一個是args.model,選擇DCGAN模式還是WGAN-GP模式,二者的不同主要在於損失函式不同和優化器的學習率不同,其他都一樣。第二個引數是args.trainable,訓練還是測試,訓練時為True,測試是False。Loadmodel表示是否選擇載入訓練好的權重。
Build_model函式裡面主要包括了網路訓練前的準備工作,主要包括損失函式的設計和優化器的設計。下文將詳細做出介紹,尤其是損失函式部分。
首先,建立了五個placeholder,flag表示兩個標誌位,只有0-1兩種情況,注意到我num_class是11,也就是做11分類,但是lable的placeholder中shape是(batchsize,10)。為了方便,我將生成器的生成結果和真實資料X級聯在一起作為判別器的輸入,輸出再把他它們結果split分開。
d_regular 表示正則化,這裡我將判別器中所有的weights做了l2正則。
監督學習的損失函式使用常見的交叉熵損失函式,對生成器生成的影象的label的one_hot型為:
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]
將原始的label擴充套件到(batchsize,11)後再和生成器生成的假資料的label再第一維度concat到一起得到batchl,另外乘以0.9,做單邊標籤平滑(one side smoothing),由此計算得到監督學習的損失函式值s_l,。

生成器G的損失函式

生成器G的損失函式包括兩部分,一個是來自GAN訓練的部分,另外一個是feature matching , 論文中提到的feature matching意思是特徵匹配,主要思想是希望生成器生成的假資料輸入到判別器,經過判別器每一層計算的結果和將真實資料X輸入到判別器,判別器每一層的結果儘可能的相似,公式如下:
這裡寫圖片描述
其中f(x)是D的每一層的輸出。Feature matching 是指導G進行訓練,所以我將他放在了G的損失函式裡。

分類器D的損失函式:

相比較G的損失函式,D的損失函式就比較麻煩了
接下來介紹無監督學習的損失函式實現:
在前面介紹的無監督學習的損失函式中,有一部分和GAN的損失函式很相似,所以再程式碼中我們使用了
這裡寫圖片描述
無監督學習的時候沒有標籤的指導,此時判別器或者稱為分類器D無法正確對輸入進行分類,此時只要求D能夠區分真假就可以了,由此我們得到了無監督學習的損失un_s,直觀上也很好理解,假設輸入給判別器D真影象,它結果經過soft max後輸出類似下面表格的形式,其中前十個黃色區域表示對0-9的分類概率,最後一個灰色的表示對假影象的分類概率,由於無監督學習中判別器D並不知道具體是哪一類資料,所以乾脆D的損失函式最小化輸出假影象的概率就可以了,當輸入為生成器生成的假影象時,只要最小化D輸出為真影象的概率,由此我們得到了un_s.。但是此時有一個問題,即是有監督學習的時候不就沒有用了嗎,因為這時候應該使用s_l.為了解決這個問題,我使用了一個標誌位flag作為控制他們之間的使用,具體程式碼:

flag*s_l + ( 1 – flag)*un_s

有標籤的時候flag是1,表示使用s_l,無監督的時候flag是0,表示使用無監督損失函式。此時已經完成了判別器D損失函式的一部分設計,剩下的一部分和GAN中的D的損失一樣,在程式碼中我給出了兩種損失函式,一個是原始GAN的交叉熵損失函式,和DCGAN使用的一樣,另外一個是improved wgan論文中使用的損失函式,但是在做了對比之後,我強烈建議使用DCGAN來做,improved wgan的損失函式雖然在生成結果的優化上有很大幫助,但是並不適合半監督學習中。

訓練

接下來就是訓練部分:
此時可能有一個疑問,我們是如何實現只使用200帶標籤的資料訓練的,答案就在flag這個標誌位裡,在訓練部分程式碼中,當迭代次數小於2的時候,flag=1, 此時表示使用s_l作為損失函式的一部分,當flag=0的時候,un_s起作用而s_l並沒有起作用,這時,即使我們feed了正確的標籤資料,但是s_l不起作用,就相當於沒有使用標籤。
flag2的作用本來是使用他控制feature matching是否工作的,因為這部分損失相當的大,後來發現影響不大,暫時就放在這裡了。

測試

 def test(self):
        count = 0.
        print 'testing................'
        for i in range(10000//self.batch_size):
            testx, textl = mnist.test.next_batch(self.batch_size)
            prediction = self.sess.run(self.prediction, feed_dict={self.x:testx, self.label:textl})
            count += np.sum(prediction)
        return count/10000.
            test_acc = self.test()
            plot.plot('test acc', test_acc)
            plot.flush()
            plot.tick()
            print 'test acc:{}'.format(test_acc), 'temp:%3f'%(temp)
            if test_acc > temp:
                print ('model saving..............')
                path = os.getcwd() + '/model_saved'
                save_path = os.path.join(path, "model.ckpt")
                self.saver.save(self.sess, save_path=save_path)
                print ('model saved...............')
                temp = test_acc

測試精度結果變化圖

這裡寫圖片描述

本文實驗程式碼

備註

詳細程式碼請以github中為準,另關於結果不理想的問題,可能和之前做的遷移學習有關,下面是最近跑出來的結果,最好的精度是0.95,這個問題有時間會慢慢解決。另:連結中的模型精度是很高的,可以直接呼叫
這裡寫圖片描述

對機器學習和人工智慧感興趣,請掃碼關注微信公眾號!
這裡寫圖片描述