1. 程式人生 > >強化學習Sarsa演算法走迷宮小例子

強化學習Sarsa演算法走迷宮小例子

Sarsa演算法:

Sarsa演算法與Q-learing演算法的不同之處是什麼?

一個簡單的解釋,引用莫凡大神的話:

  • 他在當前 state 已經想好了 state 對應的 action, 而且想好了 下一個 state_ 和下一個 action_ (Qlearning 還沒有想好下一個 action_)
  • 更新 Q(s,a) 的時候基於的是下一個 Q(s_, a_) (Qlearning 是基於 maxQ(s_))

對於第二句話,可以從走迷宮的程式碼中只管體現出來:(程式碼來自於莫凡大神編寫地址:https://github.com/MorvanZhou/Reinforcement-learning-with-tensorflow/blob/master/contents/3_Sarsa_maze/RL_brain.py

# off-policy
class QLearningTable(RL):
    def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
        super(QLearningTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)

    def learn(self, s, a, r, s_):
        self.check_state_exist(s_)
        q_predict = self.q_table.loc[s, a]
        if s_ != 'terminal':
            q_target = r + self.gamma * self.q_table.loc[s_, :].max()  # next state is not terminal
        else:
            q_target = r  # next state is terminal
        self.q_table.loc[s, a] += self.lr * (q_target - q_predict)  # update


# on-policy
class SarsaTable(RL):

    def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
        super(SarsaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)

    def learn(self, s, a, r, s_, a_):
        self.check_state_exist(s_)
        q_predict = self.q_table.loc[s, a]
        if s_ != 'terminal':
            q_target = r + self.gamma * self.q_table.loc[s_, a_]  # next state is not terminal
        else:
            q_target = r  # next state is terminal
        self.q_table.loc[s, a] += self.lr * (q_target - q_predict)  # update

可以看出二者的q_target不同,Q-learing取得是最大值,但是實際不一定會選,而Sarsa則是直接取到下一個a_,也就是下一個狀態的動作,這個動作是下一次一定要做的。