1. 程式人生 > >Deeplearning4j 實戰 (9):強化學習 -- Cartpole任務的訓練和效果測試

Deeplearning4j 實戰 (9):強化學習 -- Cartpole任務的訓練和效果測試

在之前的部落格中,我用Deeplearning4j構建深度神經網路來解決監督、無監督的機器學習問題。但除了這兩類問題外,強化學習也是機器學習中一個重要的分支,並且Deeplearning4j的子專案--Rl4j提供了對部分強化學習演算法的支援。這裡,就以強化學習中的經典任務--Cartpole問題作為學習Rl4j的入門例子,講解從環境搭建、模型訓練再到最後的效果評估的結果。

Cartpole描述的問題可以認為是:在一輛小車上豎立一根杆子,然後給小車一個推或者拉的力,使得杆子儘量保持平衡不滑倒。更詳細的描述可參見openai官網上關於Cartpole問題的解釋:https://gym.openai.com/envs/CartPole-v0

接著給出強化學習的一些概念:environment,action,reward

environment:描述強化學習問題中的外部環境,比如:Cartpole問題中杆子的角度,小車的位置、速度等。

action:在不同外部環境條件下采取的動作,比如:Cartpole問題中對於小車施加推或者拉的力。action可以是離散的集合,也可以是連續的。

reward:對於agent/network作出的action後獲取的回報/評價。比如:Cartpole問題中如果施加的力可以繼續讓杆子保持平衡,那reward就+1。

在描述reward這個概念時,提到了agent這個概念,在實際應用中,agent可以用神經網路來實現。

對於強化學習訓練後的agent來說,學習到的是如何在變化中的environment和reward選擇action的能力。通常有兩種學習策略可以選擇:Policy-Based和Value-Based。 Policy-Based直接學習action,通過Policy Gradient來更新模型引數,而相對的,Value-Based是最優化action所帶來的reward(action-value function,Q-function)來間接選取action。一般認為如果action是離散的,那麼Value-Based會優於Policy-Based,而連續的action則相反。在這裡主要討論Value-Based的學習策略,或者更具體的說Q-learning的問題。對於Policy-Based還有Model-Based不做討論。

Q-learning的概念早在20多年前就已經提出,再與近年來流行的深度神經網路結合產生了DQN的概念。Q-learning的目標是最大化Q值從而學習到選取action的策略。Q-leaning學習的策略公式:

Q(st,at)Q(st,at)+α[rt+1+λmaxaQ(st+1,a)Q(st,at)]

對於這裡主要討論的Catpole問題,我們也採用Q-learning來實現。

可以看到,與監督學習相比,強化學習多了action,environment等概念。雖然可以將reward類比成監督學習中的label(或者反過來,label也可以認為是強化學習中最終的reward),但通過action與environment不斷的互動甚至改變environment這一特點,是監督學習中所沒有的。在構建應用的時候,監督學習的學習的目標:label,灌入的資料都是一個定值。比如,影象的分類的問題,在用CNN訓練的時候,圖片本身不發生變化,label也不會發生變化,唯一變化的是神經網路中的權重值。但強化學習在訓練的時候,除了神經網路中的權重會發生變化(如果用NN建模的話),environment、reward等都會發生動態的變化。這樣構建合適正確的訓練資料會比較麻煩,容易出錯,所以對於CartPole問題,我們可以採用openAI提供的強化學習開發環境gym來訓練/測試agent。

gym的官方地址:https://gym.openai.com/

gym提供了棋類、視訊遊戲等強化學習問題的學習/測試/演算法效果比較的環境。這裡要處理的Cartpole問題,gym也提供了環境的支援。但是,除了python,gym對其他語言的支援不是很友好,所以為了可以獲取gym中的資料,RL4j提供了對gym-http-api(https://github.com/openai/gym-http-api)呼叫的包裝類。gym-http-api是為了方便除python外的其他語言也可以使用gym環境資料的一個REST介面。簡單來說,對於像RL4j這樣以Java實現的強化學習演算法庫可以通過gym-http-api獲取gym提供的資料。

gym的REST介面的安裝可以參見之前給出的github地址,裡面有詳細的描述。下面先給出gym-http-api的安裝和啟動過程的截圖:



下面就結合上面說的內容,給出RL4j的Catpole實現邏輯

1. 定義Q-learning的引數以及神經網路結構,兩者共同決定DQN的屬性

2. 定義讀取gym資料的包裝類物件

3. 訓練DQN並儲存模型

4. 載入儲存的模型並測試

這裡先貼下需要的Maven依賴以及程式碼版本

  <properties>
	  <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> 
	  <nd4j.version>0.8.0</nd4j.version>  
	  <dl4j.version>0.8.0</dl4j.version>  
	  <datavec.version>0.8.0</datavec.version>  
	  <rl4j.version>0.8.0</rl4j.version>
	  <scala.binary.version>2.10</scala.binary.version>  
  </properties>
  <dependencies>
	<dependency>  
		<groupId>org.nd4j</groupId>  
		<artifactId>nd4j-native</artifactId>   
		<version>${nd4j.version}</version>  
	</dependency>  
        <dependency>  
		<groupId>org.deeplearning4j</groupId>  
		<artifactId>deeplearning4j-core</artifactId>  
		<version>${dl4j.version}</version>  
	</dependency>  
        <dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>rl4j-core</artifactId>
            <version>${rl4j.version}</version>
        </dependency>
        <dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>rl4j-gym</artifactId>
            <version>${rl4j.version}</version>
        </dependency>

  </dependencies>


第一部分的程式碼如下:

    public static QLearning.QLConfiguration CARTPOLE_QL =
            new QLearning.QLConfiguration(
                    123,    //Random seed
                    200,    //Max step By epoch
                    150000, //Max step
                    150000, //Max size of experience replay
                    32,     //size of batches
                    500,    //target update (hard)
                    10,     //num step noop warmup
                    0.01,   //reward scaling
                    0.99,   //gamma
                    1.0,    //td-error clipping
                    0.1f,   //min epsilon
                    1000,   //num step for eps greedy anneal
                    true    //double DQN
            );

    public static DQNFactoryStdDense.Configuration CARTPOLE_NET = DQNFactoryStdDense.Configuration.builder()            												.l2(0.001)            																        .learningRate(0.0005)
       								.numHiddenNodes(16)
           							.numLayer(3)
            							.build();

第一部分中定義Q-learning的引數,包括每一輪的訓練的可採取的action的步數,最大步數以及儲存過往action的最大步數等。除此以外,DQNFactoryStdDense用來定義基於MLP的DQN網路結構,包括網路深度等常見引數。這裡的程式碼定義的是一個三層(只有一層隱藏層)的全連線神經網路。

接下來,定義兩個方法分別用於訓練和測試。catpole方法用於訓練DQN,而loadCartpole則用於測試。

訓練:

    public static void cartPole() {

        //record the training data in rl4j-data in a new folder (save)
        DataManager manager = new DataManager(true);

        //define the mdp from gym (name, render)
        GymEnv<Box, Integer, DiscreteSpace> mdp = null;
        try {
            mdp = new GymEnv<Box, Integer, DiscreteSpace>("CartPole-v0", false, false);
        } catch (RuntimeException e){
            System.out.print("To run this example, download and start the gym-http-api repo found at https://github.com/openai/gym-http-api.");
        }
        //define the training
        QLearningDiscreteDense<Box> dql = new QLearningDiscreteDense<Box>(mdp, CARTPOLE_NET, CARTPOLE_QL, manager);

        //train
        dql.train();

        //get the final policy
        DQNPolicy<Box> pol = dql.getPolicy();

        //serialize and save (serialization showcase, but not required)
        pol.save("/tmp/pol1");

        //close the mdp (close http)
        mdp.close();

    }

測試:

    public static void loadCartpole(){

        //showcase serialization by using the trained agent on a new similar mdp (but render it this time)

        //define the mdp from gym (name, render)
        GymEnv<Box, Integer, DiscreteSpace> mdp2 = new GymEnv<Box, Integer, DiscreteSpace>("CartPole-v0", true, false);

        //load the previous agent
        DQNPolicy<Box> pol2 = DQNPolicy.load("/tmp/pol1");

        //evaluate the agent
        double rewards = 0;
        for (int i = 0; i < 1000; i++) {
            mdp2.reset();
            double reward = pol2.play(mdp2);
            rewards += reward;
            Logger.getAnonymousLogger().info("Reward: " + reward);
        }

        Logger.getAnonymousLogger().info("average: " + rewards/1000);
        
        mdp2.close();

    }

在訓練模型的方法中,包含了第二、三步的內容。首先需要定義gym資料讀取物件,即程式碼中的GymEnv<Box, Integer, DiscreteSpace> mdp。它會通過gym-http-api的介面讀取訓練資料。接著,將第一步中定義的Q-learning的相關引數,神經網路結構作為引數傳入DQN訓練的包裝類中。其中DataManager的作用是用來管理訓練資料。

測試部分的程式碼實現了之前說的第四步,即載入策略模型並進行測試的過程。在測試的過程中,將每次action的reward打上log,並最後求取平均的reward。

訓練的過程截圖如下:


最後我們其實最關心的還是這個模型的效果。純粹通過平均reward的數值大小可能並不是非常的直觀,因此這裡給出一張gif的效果圖:


總結一下Cartpole問題的整個解決過程。首先我們明確,這是一個強化學習的問題,而不是傳統的監督學習,因為涉及到與環境的互動等因素。然後,利用openAI提供的強化學習開發環境gym來構建訓練平臺,而RL4j則可以定義並訓練DQN。最後的效果就是上面這張gif圖片。需要注意的是,這張gif效果圖並非是RL4j直接生成的,而是通過xvfb命令擷取虛擬monitor的在每個action後的效果拼接起來的圖。具體可先查閱xvfb的相關內容。