RocketMQ Consumer 負載均衡演算法原始碼學習 -- AllocateMessageQueueConsistentHash
RocketMQ 提供了一致性hash 演算法來做Consumer 和 MessageQueue的負載均衡。 原始碼中一致性hash 環的實現是很優秀的,我們一步一步分析。
-
一個Hash環包含多個節點, 我們用 MyNode 去封裝節點, 方法 getKey() 封裝獲取節點的key。我們可以實現MyNode 去描述一個物理節點或虛擬節點。MyVirtualNode 實現 MyNode, 表示一個虛擬節點。這裡注意:一個虛擬節點是依賴於一個物理節點,所以MyVirtualNode 中封裝了 一個 泛型 T physicalNode。物理節點MyClientNode也是實現了這個MyNode介面,很好的設計。程式碼加註釋如下:
/** * 表示hash環的一個節點 */ public interface MyNode { /** * @return 節點的key */ String getKey(); } /** * 虛擬節點 */ public class MyVirtualNode<T extends MyNode> implements MyNode { final T physicalNode; // 主節點 final int replicaIndex; // 虛節點下標 public MyVirtualNode(T physicalNode, int replicaIndex) { this.physicalNode = physicalNode; this.replicaIndex = replicaIndex; } @Override public String getKey() { return physicalNode.getKey() + "-" + replicaIndex; } /** * thisMyVirtualNode 是否是pNode 的 虛節點 */ public boolean isVirtualNodeOf(T pNode) { return physicalNode.getKey().equals(pNode.getKey()); } public T getPhysicalNode() { return physicalNode; } } private static class MyClientNode implements MyNode { private final String clientID; public MyClientNode(String clientID) { this.clientID = clientID; } @Override public String getKey() { return clientID; } }
-
上面實現了節點, 一致性hash 下一個問題是怎麼封裝hash演算法呢?RocketMQ 使用 MyHashFunction 介面定義hash演算法。使用MD5 + bit 位hash的方式實現hash演算法。我們完全可以自己實現hash演算法,具體見我的“常見的一些hash函式”文章。MyMD5Hash 演算法程式碼的如下:
// MD5 hash 演算法, 這裡hash演算法可以用常用的 hash 演算法替換。 private static class MyMD5Hash implements MyHashFunction { MessageDigest instance; public MyMD5Hash() { try { instance = MessageDigest.getInstance("MD5"); } catch (NoSuchAlgorithmException e) { } } @Override public long hash(String key) { instance.reset(); instance.update(key.getBytes()); byte[] digest = instance.digest(); long h = 0; for (int i = 0; i < 4; i++) { h <<= 8; h |= ((int)digest[i]) & 0xFF; } return h; } }
-
現在,hash環的節點有了, hash演算法也有了,最重要的是描述一個一致性hash 環。 想一想,這個環可以由N 個物理節點, 每個物理節點對應m個虛擬節點,節點位置用hash演算法值描述。每個物理節點就是每個Consumer, 每個Consumer 的 id 就是 物理節點的key。 每個MessageQueue 的toString() 值 hash 後,用來找環上對應的最近的下一個物理節點。原始碼如下,這裡展示主要的程式碼,其中最巧妙地是routeNode 方法, addNode 方法 注意我的註釋:
public class MyConsistentHashRouter<T extends MyNode> { private final SortedMap<Long, MyVirtualNode<T>> ring = new TreeMap<>(); // key是虛節點key的雜湊值, value 是虛節點 private final MyHashFunction myHashFunction; /** * @param pNodes 物理節點集合 * @param vNodeCount 每個物理節點對應的虛節點數量 * @param hashFunction hash 函式 用於 hash 各個節點 */ public MyConsistentHashRouter(Collection<T> pNodes, int vNodeCount, MyHashFunction hashFunction) { if (hashFunction == null) { throw new NullPointerException("Hash Function is null"); } this.myHashFunction = hashFunction; if (pNodes != null) { for (T pNode : pNodes) { this.addNode(pNode, vNodeCount); } } } /** * 新增物理節點和它的虛節點到hash環。 * @param pNode 物理節點 * @param vNodeCount 虛節點數量。 */ public void addNode(T pNode, int vNodeCount) { if (vNodeCount < 0) { throw new IllegalArgumentException("ill virtual node counts :" + vNodeCount); } int existingReplicas = this.getExistingReplicas(pNode); for (int i = 0; i < vNodeCount; i++) { MyVirtualNode<T> vNode = new MyVirtualNode<T>(pNode, i + existingReplicas); // 建立一個新的虛節點,位置是 i+existingReplicas ring.put(this.myHashFunction.hash(vNode.getKey()), vNode); // 將新的虛節點放到hash環中 } } /** * 根據一個給定的key 在 hash環中 找到離這個key最近的下一個物理節點 * @param key 一個key, 用於找這個key 在環上最近的節點 */ public T routeNode(String key) { if (ring.isEmpty()) { return null; } Long hashVal = this.myHashFunction.hash(key); SortedMap<Long, MyVirtualNode<T>> tailMap = ring.tailMap(hashVal); Long nodeHashVal = !tailMap.isEmpty() ? tailMap.firstKey() : ring.firstKey(); return ring.get(nodeHashVal).getPhysicalNode(); } /** * @param pNode 物理節點 * @return 當前這個物理節點對應的虛節點的個數 */ public int getExistingReplicas(T pNode) { int replicas = 0; for (MyVirtualNode<T> vNode : ring.values()) { if (vNode.isVirtualNodeOf(pNode)) { replicas++; } } return replicas; }
-
現在一致性hash 環有了, 剩下的就是 和rocketmq 的 consumer, mq 構成負載均衡策略了。比較簡單, 程式碼如下:
/** * 基於一致性性hash環的Consumer負載均衡. */ public class MyAllocateMessageQueueConsistentHash implements AllocateMessageQueueStrategy { // 每個物理節點對應的虛節點的個數 private final int virtualNodeCnt; private final MyHashFunction customHashFunction; public MyAllocateMessageQueueConsistentHash() { this(10); // 預設10個虛擬節點 } public MyAllocateMessageQueueConsistentHash(int virtualNodeCnt) { this(virtualNodeCnt, null); } public MyAllocateMessageQueueConsistentHash(int virtualNodeCnt, MyHashFunction customHashFunction) { if (virtualNodeCnt < 0) { throw new IllegalArgumentException("illegal virtualNodeCnt : " + virtualNodeCnt); } this.virtualNodeCnt = virtualNodeCnt; this.customHashFunction = customHashFunction; } @Override public List<MessageQueue> allocate(String consumerGroup, String currentCID, List<MessageQueue> mqAll, List<String> cidAll) { // 省去一系列非空校驗 Collection<MyClientNode> cidNodes = new ArrayList<>(); for (String cid : cidAll) { cidNodes.add(new MyClientNode(cid)); } final MyConsistentHashRouter<MyClientNode> router; if (this.customHashFunction != null) { router = new MyConsistentHashRouter<MyClientNode>(cidNodes, virtualNodeCnt, customHashFunction); }else { router = new MyConsistentHashRouter<MyClientNode>(cidNodes, virtualNodeCnt); } List<MessageQueue> results = new ArrayList<MessageQueue>(); // 當前 currentCID 對應的 mq // 將每個mq 根據一致性hash 演算法找到對應的物理節點(Consumer) for (MessageQueue mq : mqAll) { MyClientNode clientNode = router.routeNode(mq.toString()); // 根據 mq toString() 方法做hash 和環上節點比較 if (clientNode != null && currentCID.equals(clientNode.getKey())) { results.add(mq); } } return results; } @Override public String getName() { return "CONSISTENT_HASH"; } private static class MyClientNode implements MyNode { private final String clientID; public MyClientNode(String clientID) { this.clientID = clientID; } @Override public String getKey() { return clientID; } } }