1. 程式人生 > >RocketMQ Consumer 負載均衡演算法原始碼學習 -- AllocateMessageQueueConsistentHash

RocketMQ Consumer 負載均衡演算法原始碼學習 -- AllocateMessageQueueConsistentHash

RocketMQ 提供了一致性hash 演算法來做Consumer 和 MessageQueue的負載均衡。 原始碼中一致性hash 環的實現是很優秀的,我們一步一步分析。

  1. 一個Hash環包含多個節點, 我們用 MyNode 去封裝節點, 方法 getKey() 封裝獲取節點的key。我們可以實現MyNode 去描述一個物理節點或虛擬節點。MyVirtualNode 實現 MyNode, 表示一個虛擬節點。這裡注意:一個虛擬節點是依賴於一個物理節點,所以MyVirtualNode 中封裝了 一個 泛型 T physicalNode。物理節點MyClientNode也是實現了這個MyNode介面,很好的設計。程式碼加註釋如下:

     /**
      * 表示hash環的一個節點
      */
     public interface MyNode {
     
         /**
          * @return 節點的key
          */
         String getKey();
     }
    
     		/**
      * 虛擬節點
      */
     public class MyVirtualNode<T extends MyNode> implements MyNode {
     
     final T physicalNode;  // 主節點
     final int replicaIndex;  // 虛節點下標
     
         public MyVirtualNode(T physicalNode, int replicaIndex) {
             this.physicalNode = physicalNode;
             this.replicaIndex = replicaIndex;
         }
     
         @Override
         public String getKey() {
             return physicalNode.getKey() + "-" + replicaIndex;
         }
     
         /**
          * thisMyVirtualNode 是否是pNode 的 虛節點
          */
         public boolean isVirtualNodeOf(T pNode) {
             return physicalNode.getKey().equals(pNode.getKey());
         }
     
         public T getPhysicalNode() {
             return physicalNode;
         }
     }
         private static class MyClientNode implements MyNode {
             private final String clientID;
             public MyClientNode(String clientID) {
                 this.clientID = clientID;
             }
             @Override
             public String getKey() {
                 return clientID;
             }
         }
    
  2. 上面實現了節點, 一致性hash 下一個問題是怎麼封裝hash演算法呢?RocketMQ 使用 MyHashFunction 介面定義hash演算法。使用MD5 + bit 位hash的方式實現hash演算法。我們完全可以自己實現hash演算法,具體見我的“常見的一些hash函式”文章。MyMD5Hash 演算法程式碼的如下:

     // MD5 hash 演算法, 這裡hash演算法可以用常用的 hash 演算法替換。
         private static class MyMD5Hash implements MyHashFunction {
             MessageDigest instance;
             public MyMD5Hash() {
                 try {
                     instance = MessageDigest.getInstance("MD5");
                 } catch (NoSuchAlgorithmException e) {
                 }
             }
     
             @Override
             public long hash(String key) {
                 instance.reset();
                 instance.update(key.getBytes());
                 byte[] digest = instance.digest();
     
                 long h = 0;
                 for (int i = 0; i < 4; i++) {
                     h <<= 8;
                     h |= ((int)digest[i]) & 0xFF;
                 }
                 return h;
             }
         }
    
  3. 現在,hash環的節點有了, hash演算法也有了,最重要的是描述一個一致性hash 環。 想一想,這個環可以由N 個物理節點, 每個物理節點對應m個虛擬節點,節點位置用hash演算法值描述。每個物理節點就是每個Consumer, 每個Consumer 的 id 就是 物理節點的key。 每個MessageQueue 的toString() 值 hash 後,用來找環上對應的最近的下一個物理節點。原始碼如下,這裡展示主要的程式碼,其中最巧妙地是routeNode 方法, addNode 方法 注意我的註釋:

    public class MyConsistentHashRouter<T extends MyNode> {
    
    private final SortedMap<Long, MyVirtualNode<T>> ring = new TreeMap<>(); // key是虛節點key的雜湊值, value 是虛節點
    private final MyHashFunction myHashFunction;
    /**
     * @param pNodes 物理節點集合
     * @param vNodeCount 每個物理節點對應的虛節點數量
     * @param hashFunction hash 函式 用於 hash 各個節點
     */
    public MyConsistentHashRouter(Collection<T> pNodes, int vNodeCount, MyHashFunction hashFunction) {
        if (hashFunction == null) {
            throw new NullPointerException("Hash Function is null");
        }
        this.myHashFunction = hashFunction;
        if (pNodes != null) {
            for (T pNode : pNodes) {
                this.addNode(pNode, vNodeCount);
            }
        }
    }
    /**
     * 新增物理節點和它的虛節點到hash環。
     * @param pNode 物理節點
     * @param vNodeCount 虛節點數量。
     */
    public void addNode(T pNode, int vNodeCount) {
        if (vNodeCount < 0) {
            throw new IllegalArgumentException("ill virtual node counts :" + vNodeCount);
        }
        int existingReplicas = this.getExistingReplicas(pNode);
        for (int i = 0; i < vNodeCount; i++) {
            MyVirtualNode<T> vNode = new MyVirtualNode<T>(pNode, i + existingReplicas); // 建立一個新的虛節點,位置是 i+existingReplicas
            ring.put(this.myHashFunction.hash(vNode.getKey()), vNode); // 將新的虛節點放到hash環中
        }
    }
    /**
     * 根據一個給定的key 在 hash環中 找到離這個key最近的下一個物理節點
     * @param key 一個key, 用於找這個key 在環上最近的節點
     */
    public T routeNode(String key) {
        if (ring.isEmpty()) {
            return null;
        }
        Long hashVal = this.myHashFunction.hash(key);
        SortedMap<Long, MyVirtualNode<T>> tailMap = ring.tailMap(hashVal);
        Long nodeHashVal = !tailMap.isEmpty() ? tailMap.firstKey() : ring.firstKey();
        return ring.get(nodeHashVal).getPhysicalNode();
    }
    
    /**
     * @param pNode 物理節點
     * @return 當前這個物理節點對應的虛節點的個數
     */
    public int getExistingReplicas(T pNode) {
        int replicas = 0;
        for (MyVirtualNode<T> vNode : ring.values()) {
            if (vNode.isVirtualNodeOf(pNode)) {
                replicas++;
            }
        }
        return replicas;
    }
    
  4. 現在一致性hash 環有了, 剩下的就是 和rocketmq 的 consumer, mq 構成負載均衡策略了。比較簡單, 程式碼如下:

     			/**
     	 * 基於一致性性hash環的Consumer負載均衡.
     	*/	 
    
     public class MyAllocateMessageQueueConsistentHash implements AllocateMessageQueueStrategy {
     
         // 每個物理節點對應的虛節點的個數
         private final int virtualNodeCnt;
         private final MyHashFunction customHashFunction;
     
         public MyAllocateMessageQueueConsistentHash() {
             this(10);   // 預設10個虛擬節點
         }
     
         public MyAllocateMessageQueueConsistentHash(int virtualNodeCnt) {
             this(virtualNodeCnt, null);
     
         }
         public MyAllocateMessageQueueConsistentHash(int virtualNodeCnt, MyHashFunction customHashFunction) {
             if (virtualNodeCnt < 0) {
                 throw new IllegalArgumentException("illegal virtualNodeCnt : " + virtualNodeCnt);
             }
             this.virtualNodeCnt = virtualNodeCnt;
             this.customHashFunction = customHashFunction;
         }
     
         @Override
         public List<MessageQueue> allocate(String consumerGroup, String currentCID, List<MessageQueue> mqAll, List<String> cidAll) {
             // 省去一系列非空校驗
             Collection<MyClientNode> cidNodes = new ArrayList<>();
             for (String cid : cidAll) {
                 cidNodes.add(new MyClientNode(cid));
             }
             final MyConsistentHashRouter<MyClientNode> router;
             if (this.customHashFunction != null) {
                 router = new MyConsistentHashRouter<MyClientNode>(cidNodes, virtualNodeCnt, customHashFunction);
             }else {
                 router = new MyConsistentHashRouter<MyClientNode>(cidNodes, virtualNodeCnt);
             }
             List<MessageQueue> results = new ArrayList<MessageQueue>();  // 當前 currentCID 對應的 mq
             // 將每個mq 根據一致性hash 演算法找到對應的物理節點(Consumer)
             for (MessageQueue mq : mqAll) {
                 MyClientNode clientNode = router.routeNode(mq.toString());   // 根據 mq toString() 方法做hash 和環上節點比較
                 if (clientNode != null && currentCID.equals(clientNode.getKey())) {
                     results.add(mq);
                 }
             }
             return results;
         }
     
         @Override
         public String getName() {
             return "CONSISTENT_HASH";
         }
     
         private static class MyClientNode implements MyNode {
             private final String clientID;
             public MyClientNode(String clientID) {
                 this.clientID = clientID;
             }
             @Override
             public String getKey() {
                 return clientID;
             }
         }
     
     }