1. 程式人生 > >輕量級分散式RPC框架實現(續)

輕量級分散式RPC框架實現(續)

1、背景

最近在搜尋Netty和Zookeeper方面的文章時,看到了這篇文章《輕量級分散式 RPC 框架》,作者用Zookeeper、Netty和Spring寫了一個輕量級的分散式RPC框架。花了一些時間看了下他的程式碼,寫的乾淨簡單,寫的RPC框架可以算是一個簡易版的dubbo。這個RPC框架雖小,但是麻雀雖小,五臟俱全,有興趣的可以學習一下。

本人在這個簡易版的RPC上添加了如下特性:

* 服務非同步呼叫的支援,回撥函式callback的支援

* 客戶端使用長連線(在多次呼叫共享連線)

* 服務端非同步多執行緒處理RPC請求

2、簡介

RPC,即 Remote Procedure Call(遠端過程呼叫),呼叫遠端計算機上的服務,就像呼叫本地服務一樣。RPC可以很好的解耦系統,如WebService就是一種基於Http協議的RPC。

這個RPC整體框架如下:

這個RPC框架使用的一些技術所解決的問題:

服務釋出與訂閱:服務端使用Zookeeper註冊服務地址,客戶端從Zookeeper獲取可用的服務地址。

通訊:使用Netty作為通訊框架。

Spring:使用Spring配置服務,載入Bean,掃描註解。

動態代理:客戶端使用代理模式透明化服務呼叫。

訊息編解碼:使用Protostuff序列化和反序列化訊息。

3、服務端釋出服務

使用註解標註要釋出的服務

服務註解

[java] view plain copy print?在CODE上檢視程式碼片派生到我的程式碼片
  1. @Target({ElementType.TYPE})  
  2. @Retention
    (RetentionPolicy.RUNTIME)  
  3. @Component
  4. public@interface RpcService {  
  5.     Class<?> value();  
  6. }  
@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Component
public @interface RpcService {
    Class<?> value();
}

一個服務介面:
[java] view plain copy print?在CODE上檢視程式碼片派生到我的程式碼片
  1. publicinterface HelloService {  
  2.     String hello(String name);  
  3.     String hello(Person person);  
  4. }  
public interface HelloService {

    String hello(String name);

    String hello(Person person);
}
一個服務實現:使用註解標註 [java] view plain copy print?在CODE上檢視程式碼片派生到我的程式碼片
  1. @RpcService(HelloService.class)  
  2. publicclass HelloServiceImpl implements HelloService {  
  3.     @Override
  4.     public String hello(String name) {  
  5.         return"Hello! " + name;  
  6.     }  
  7.     @Override
  8.     public String hello(Person person) {  
  9.         return"Hello! " + person.getFirstName() + " " + person.getLastName();  
  10.     }  
  11. }  
@RpcService(HelloService.class)
public class HelloServiceImpl implements HelloService {

    @Override
    public String hello(String name) {
        return "Hello! " + name;
    }

    @Override
    public String hello(Person person) {
        return "Hello! " + person.getFirstName() + " " + person.getLastName();
    }
}

服務在啟動的時候掃描得到所有的服務介面及其實現:
[java] view plain copy print?在CODE上檢視程式碼片派生到我的程式碼片
  1. @Override
  2.     publicvoid setApplicationContext(ApplicationContext ctx) throws BeansException {  
  3.         Map<String, Object> serviceBeanMap = ctx.getBeansWithAnnotation(RpcService.class);  
  4.         if (MapUtils.isNotEmpty(serviceBeanMap)) {  
  5.             for (Object serviceBean : serviceBeanMap.values()) {  
  6.                 String interfaceName = serviceBean.getClass().getAnnotation(RpcService.class).value().getName();  
  7.                 handlerMap.put(interfaceName, serviceBean);  
  8.             }  
  9.         }  
  10.     }  
@Override
    public void setApplicationContext(ApplicationContext ctx) throws BeansException {
        Map<String, Object> serviceBeanMap = ctx.getBeansWithAnnotation(RpcService.class);
        if (MapUtils.isNotEmpty(serviceBeanMap)) {
            for (Object serviceBean : serviceBeanMap.values()) {
                String interfaceName = serviceBean.getClass().getAnnotation(RpcService.class).value().getName();
                handlerMap.put(interfaceName, serviceBean);
            }
        }
    }

在Zookeeper叢集上註冊服務地址:
[java] view plain copy print?在CODE上檢視程式碼片派生到我的程式碼片
  1. publicclass ServiceRegistry {  
  2.     privatestaticfinal Logger LOGGER = LoggerFactory.getLogger(ServiceRegistry.class);  
  3.     private CountDownLatch latch = new CountDownLatch(1);  
  4.     private String registryAddress;  
  5.     public ServiceRegistry(String registryAddress) {  
  6.         this.registryAddress = registryAddress;  
  7.     }  
  8.     publicvoid register(String data) {  
  9.         if (data != null) {  
  10.             ZooKeeper zk = connectServer();  
  11.             if (zk != null) {  
  12.                 AddRootNode(zk); // Add root node if not exist
  13.                 createNode(zk, data);  
  14.             }  
  15.         }  
  16.     }  
  17.     private ZooKeeper connectServer() {  
  18.         ZooKeeper zk = null;  
  19.         try {  
  20.             zk = new ZooKeeper(registryAddress, Constant.ZK_SESSION_TIMEOUT, new Watcher() {  
  21.                 @Override
  22.                 publicvoid process(WatchedEvent event) {  
  23.                     if (event.getState() == Event.KeeperState.SyncConnected) {  
  24.                         latch.countDown();  
  25.                     }  
  26.                 }  
  27.             });  
  28.             latch.await();  
  29.         } catch (IOException e) {  
  30.             LOGGER.error("", e);  
  31.         }  
  32.         catch (InterruptedException ex){  
  33.             LOGGER.error("", ex);  
  34.         }  
  35.         return zk;  
  36.     }  
  37.     privatevoid AddRootNode(ZooKeeper zk){  
  38.         try {  
  39.             Stat s = zk.exists(Constant.ZK_REGISTRY_PATH, false);  
  40.             if (s == null) {  
  41.                 zk.create(Constant.ZK_REGISTRY_PATH, newbyte[0], ZooDefs.Ids.OPEN_ACL_UNSAFE, CreateMode.PERSISTENT);  
  42.             }  
  43.         } catch (KeeperException e) {  
  44.             LOGGER.error(e.toString());  
  45.         } catch (InterruptedException e) {  
  46.             LOGGER.error(e.toString());  
  47.         }  
  48.     }  
  49.     privatevoid createNode(ZooKeeper zk, String data) {  
  50.         try {  
  51.             byte[] bytes = data.getBytes();  
  52.             String path = zk.create(Constant.ZK_DATA_PATH, bytes, ZooDefs.Ids.OPEN_ACL_UNSAFE, CreateMode.EPHEMERAL_SEQUENTIAL);  
  53.             LOGGER.debug("create zookeeper node ({} => {})", path, data);  
  54.         } catch (KeeperException e) {  
  55.             LOGGER.error("", e);  
  56.         }  
  57.         catch (InterruptedException ex){  
  58.             LOGGER.error("", ex);  
  59.         }  
  60.     }  
  61. }  
public class ServiceRegistry {

    private static final Logger LOGGER = LoggerFactory.getLogger(ServiceRegistry.class);

    private CountDownLatch latch = new CountDownLatch(1);

    private String registryAddress;

    public ServiceRegistry(String registryAddress) {
        this.registryAddress = registryAddress;
    }

    public void register(String data) {
        if (data != null) {
            ZooKeeper zk = connectServer();
            if (zk != null) {
                AddRootNode(zk); // Add root node if not exist
                createNode(zk, data);
            }
        }
    }

    private ZooKeeper connectServer() {
        ZooKeeper zk = null;
        try {
            zk = new ZooKeeper(registryAddress, Constant.ZK_SESSION_TIMEOUT, new Watcher() {
                @Override
                public void process(WatchedEvent event) {
                    if (event.getState() == Event.KeeperState.SyncConnected) {
                        latch.countDown();
                    }
                }
            });
            latch.await();
        } catch (IOException e) {
            LOGGER.error("", e);
        }
        catch (InterruptedException ex){
            LOGGER.error("", ex);
        }
        return zk;
    }

    private void AddRootNode(ZooKeeper zk){
        try {
            Stat s = zk.exists(Constant.ZK_REGISTRY_PATH, false);
            if (s == null) {
                zk.create(Constant.ZK_REGISTRY_PATH, new byte[0], ZooDefs.Ids.OPEN_ACL_UNSAFE, CreateMode.PERSISTENT);
            }
        } catch (KeeperException e) {
            LOGGER.error(e.toString());
        } catch (InterruptedException e) {
            LOGGER.error(e.toString());
        }
    }

    private void createNode(ZooKeeper zk, String data) {
        try {
            byte[] bytes = data.getBytes();
            String path = zk.create(Constant.ZK_DATA_PATH, bytes, ZooDefs.Ids.OPEN_ACL_UNSAFE, CreateMode.EPHEMERAL_SEQUENTIAL);
            LOGGER.debug("create zookeeper node ({} => {})", path, data);
        } catch (KeeperException e) {
            LOGGER.error("", e);
        }
        catch (InterruptedException ex){
            LOGGER.error("", ex);
        }
    }
}

這裡在原文的基礎上加了AddRootNode()判斷服務父節點是否存在,如果不存在則新增一個PERSISTENT的服務父節點,這樣雖然啟動服務時多了點判斷,但是不需要手動命令新增服務父節點了。

關於Zookeeper的使用原理,可以看這裡《ZooKeeper基本原理》。

4、客戶端呼叫服務

使用代理模式呼叫服務:

[java] view plain copy print?在CODE上檢視程式碼片派生到我的程式碼片
  1. publicclass RpcProxy {  
  2.     private String serverAddress;  
  3.     private ServiceDiscovery serviceDiscovery;  
  4.     public RpcProxy(String serverAddress) {  
  5.         this.serverAddress = serverAddress;  
  6.     }  
  7.     public RpcProxy(ServiceDiscovery serviceDiscovery) {  
  8.         this.serviceDiscovery = serviceDiscovery;  
  9.     }  
  10.     @SuppressWarnings("unchecked")  
  11.     public <T> T create(Class<?> interfaceClass) {  
  12.         return (T) Proxy.newProxyInstance(  
  13.                 interfaceClass.getClassLoader(),  
  14.                 new Class<?>[]{interfaceClass},  
  15.                 new InvocationHandler() {  
  16.                     @Override
  17.                     public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {  
  18.                         RpcRequest request = new RpcRequest();  
  19.                         request.setRequestId(UUID.randomUUID().toString());  
  20.                         request.setClassName(method.getDeclaringClass().getName());  
  21.                         request.setMethodName(method.getName());  
  22.                         request.setParameterTypes(method.getParameterTypes());  
  23.                         request.setParameters(args);  
  24.                         if (serviceDiscovery != null) {  
  25.                             serverAddress = serviceDiscovery.discover();  
  26.                         }  
  27.                         if(serverAddress != null){  
  28.                             String[] array = serverAddress.split(":");  
  29.                             String host = array[0];  
  30.                             int port = Integer.parseInt(array[1]);  
  31.                             RpcClient client = new RpcClient(host, port);  
  32.                             RpcResponse response = client.send(request);  
  33.                             if (response.isError()) {  
  34.                                 thrownew RuntimeException("Response error.",new Throwable(response.getError()));  
  35.                             } else {  
  36.                                 return response.getResult();  
  37.                             }  
  38.                         }  
  39.                         else{  
  40.                             thrownew RuntimeException("No server address found!");  
  41.                         }  
  42.                     }  
  43.                 }  
  44.         );  
  45.     }  
  46. }  
public class RpcProxy {

    private String serverAddress;
    private ServiceDiscovery serviceDiscovery;

    public RpcProxy(String serverAddress) {
        this.serverAddress = serverAddress;
    }

    public RpcProxy(ServiceDiscovery serviceDiscovery) {
        this.serviceDiscovery = serviceDiscovery;
    }

    @SuppressWarnings("unchecked")
    public <T> T create(Class<?> interfaceClass) {
        return (T) Proxy.newProxyInstance(
                interfaceClass.getClassLoader(),
                new Class<?>[]{interfaceClass},
                new InvocationHandler() {
                    @Override
                    public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
                        RpcRequest request = new RpcRequest();
                        request.setRequestId(UUID.randomUUID().toString());
                        request.setClassName(method.getDeclaringClass().getName());
                        request.setMethodName(method.getName());
                        request.setParameterTypes(method.getParameterTypes());
                        request.setParameters(args);

                        if (serviceDiscovery != null) {
                            serverAddress = serviceDiscovery.discover();
                        }
                        if(serverAddress != null){
                            String[] array = serverAddress.split(":");
                            String host = array[0];
                            int port = Integer.parseInt(array[1]);

                            RpcClient client = new RpcClient(host, port);
                            RpcResponse response = client.send(request);

                            if (response.isError()) {
                                throw new RuntimeException("Response error.",new Throwable(response.getError()));
                            } else {
                                return response.getResult();
                            }
                        }
                        else{
                            throw new RuntimeException("No server address found!");
                        }
                    }
                }
        );
    }
}

這裡每次使用代理遠端呼叫服務,從Zookeeper上獲取可用的服務地址,通過RpcClient send一個Request,等待該Request的Response返回。這裡原文有個比較嚴重的bug,在原文給出的簡單的Test中是很難測出來的,原文使用了obj的wait和notifyAll來等待Response返回,會出現“假死等待”的情況:一個Request傳送出去後,在obj.wait()呼叫之前可能Response就返回了,這時候在channelRead0裡已經拿到了Response並且obj.notifyAll()已經在obj.wait()之前呼叫了,這時候send後再obj.wait()就出現了假死等待,客戶端就一直等待在這裡。使用CountDownLatch可以解決這個問題。

注意:這裡每次呼叫的send時候才去和服務端建立連線,使用的是短連線,這種短連線在高併發時會有連線數問題,也會影響效能。

從Zookeeper上獲取服務地址:


[java] view plain copy print?在CODE上檢視程式碼片派生到我的程式碼片
  1. publicclass ServiceDiscovery {  
  2.     privatestaticfinal Logger LOGGER = LoggerFactory.getLogger(ServiceDiscovery.class);  
  3.     private CountDownLatch latch = new CountDownLatch(1);  
  4.     privatevolatile List<String> dataList = new ArrayList<>();  
  5.     private String registryAddress;  
  6.     public ServiceDiscovery(String registryAddress) {  
  7.         this.registryAddress = registryAddress;  
  8.         ZooKeeper zk = connectServer();  
  9.         if (zk != null) {  
  10.             watchNode(zk);  
  11.         }  
  12.     }  
  13.     public String discover() {  
  14.         String data = null;  
  15.         int size = dataList.size();  
  16.         if (size > 0) {  
  17.             if (size == 1) {  
  18.                 data = dataList.get(0);  
  19.                 LOGGER.debug("using only data: {}", data);  
  20.             } else {  
  21.                 data = dataList.get(ThreadLocalRandom.current().nextInt(size));  
  22.                 LOGGER.debug("using random data: {}", data);  
  23.             }  
  24.         }  
  25.         return data;  
  26.     }  
  27.     private ZooKeeper connectServer() {  
  28.         ZooKeeper zk = null;  
  29.         try {  
  30.             zk = new ZooKeeper(registryAddress, Constant.ZK_SESSION_TIMEOUT, new Watcher() {  
  31.                 @Override
  32.                 publicvoid process(WatchedEvent event) {  
  33.                     if (event.getState() == Event.KeeperState.SyncConnected) {  
  34.                         latch.countDown();  
  35.                     }  
  36.                 }  
  37.             });  
  38.             latch.await();  
  39.         } catch (IOException | InterruptedException e) {  
  40.             LOGGER.error("", e);  
  41.         }  
  42.         return zk;  
  43.     }  
  44.     privatevoid watchNode(final ZooKeeper zk) {  
  45.         try {  
  46.             List<String> nodeList = zk.getChildren(Constant.ZK_REGISTRY_PATH, new Watcher() {  
  47.                 @Override
  48.                 publicvoid process(WatchedEvent event) {  
  49.                     if (event.getType() == Event.EventType.NodeChildrenChanged) {  
  50.                         watchNode(zk);  
  51.                     }  
  52.                 }  
  53.             });  
  54.             List<String> dataList = new ArrayList<>();  
  55.             for (String node : nodeList) {  
  56.                 byte[] bytes = zk.getData(Constant.ZK_REGISTRY_PATH + "/" + node, falsenull);  
  57.                 dataList.add(new String(bytes));  
  58.             }  
  59.             LOGGER.debug("node data: {}", dataList);  
  60.             this.dataList = dataList;  
  61.         } catch (KeeperException | InterruptedException e) {  
  62.             LOGGER.error("", e);  
  63.         }  
  64.     }  
  65. }  
public class ServiceDiscovery {

    private static final Logger LOGGER = LoggerFactory.getLogger(ServiceDiscovery.class);

    private CountDownLatch latch = new CountDownLatch(1);

    private volatile List<String> dataList = new ArrayList<>();

    private String registryAddress;

    public ServiceDiscovery(String registryAddress) {
        this.registryAddress = registryAddress;
        ZooKeeper zk = connectServer();
        if (zk != null) {
            watchNode(zk);
        }
    }

    public String discover() {
        String data = null;
        int size = dataList.size();
        if (size > 0) {
            if (size == 1) {
                data = dataList.get(0);
                LOGGER.debug("using only data: {}", data);
            } else {
                data = dataList.get(ThreadLocalRandom.current().nextInt(size));
                LOGGER.debug("using random data: {}", data);
            }
        }
        return data;
    }

    private ZooKeeper connectServer() {
        ZooKeeper zk = null;
        try {
            zk = new ZooKeeper(registryAddress, Constant.ZK_SESSION_TIMEOUT, new Watcher() {
                @Override
                public void process(WatchedEvent event) {
                    if (event.getState() == Event.KeeperState.SyncConnected) {
                        latch.countDown();
                    }
                }
            });
            latch.await();
        } catch (IOException | InterruptedException e) {
            LOGGER.error("", e);
        }
        return zk;
    }

    private void watchNode(final ZooKeeper zk) {
        try {
            List<String> nodeList = zk.getChildren(Constant.ZK_REGISTRY_PATH, new Watcher() {
                @Override
                public void process(WatchedEvent event) {
                    if (event.getType() == Event.EventType.NodeChildrenChanged) {
                        watchNode(zk);
                    }
                }
            });
            List<String> dataList = new ArrayList<>();
            for (String node : nodeList) {
                byte[] bytes = zk.getData(Constant.ZK_REGISTRY_PATH + "/" + node, false, null);
                dataList.add(new String(bytes));
            }
            LOGGER.debug("node data: {}", dataList);
            this.dataList = dataList;
        } catch (KeeperException | InterruptedException e) {
            LOGGER.error("", e);
        }
    }
}

每次服務地址節點發生變化,都需要再次watchNode,獲取新的服務地址列表。

5、訊息編碼

請求訊息:

[java] view plain copy print?在CODE上檢視程式碼片派生到我的程式碼片
  1. publicclass RpcRequest {  
  2.     private String requestId;  
  3.     private String className;  
  4.     private String methodName;  
  5.     private Class<?>[] parameterTypes;  
  6.     private Object[] parameters;  
  7.     public String getRequestId() {  
  8.         return requestId;  
  9.     }  
  10.     publicvoid setRequestId(String requestId) {  
  11.         this.requestId = requestId;  
  12.     }  
  13.     public String getClassName() {  
  14.         return className;  
  15.     }  
  16.     publicvoid setClassName(String className) {  
  17.         this.className = className;  
  18.     }  
  19.     public String getMethodName() {  
  20.         return methodName;  
  21.     }  
  22.     publicvoid setMethodName(String methodName) {  
  23.         this.methodName = methodName;  
  24.     }  
  25.     public Class<?>[] getParameterTypes() {  
  26.         return parameterTypes;  
  27.     }  
  28.     publicvoid setParameterTypes(Class<?>[] parameterTypes) {  
  29.         this.parameterTypes = parameterTypes;  
  30.     }  
  31.     public Object[] getParameters() {  
  32.         return parameters;  
  33.     }  
  34.     publicvoid setParameters(Object[] parameters) {  
  35.         this.parameters = parameters;  
  36.     }  
  37. }  
public class RpcRequest {

    private String requestId;
    private String className;
    private String methodName;
    private Class<?>[] parameterTypes;
    private Object[] parameters;

    public String getRequestId() {
        return requestId;
    }

    public void setRequestId(String requestId) {
        this.requestId = requestId;
    }

    public String getClassName() {
        return className;
    }

    public void setClassName(String className) {
        this.className = className;
    }

    public String getMethodName() {
        return methodName;
    }

    public void setMethodName(String methodName) {
        this.methodName = methodName;
    }

    public Class<?>[] getParameterTypes() {
        return parameterTypes;
    }

    public void setParameterTypes(Class<?>[] parameterTypes) {
        this.parameterTypes = parameterTypes;
    }

    public Object[] getParameters() {
        return parameters;
    }

    public void setParameters(Object[] parameters) {
        this.parameters = parameters;
    }
}
響應訊息: [java] view plain copy print?在CODE上檢視程式碼片派生到我的程式碼片
  1. publicclass RpcResponse {  
  2.     private String requestId;  
  3.     private String error;  
  4.     private Object result;  
  5.     publicboolean isError() {  
  6.         return error != null;  
  7.     }  
  8.     public String getRequestId() {  
  9.         return requestId;  
  10.     }  
  11.     publicvoid setRequestId(String requestId) {  
  12.         this.requestId = requestId;  
  13.     }  
  14.     public String getError() {  
  15.         return error;  
  16.     }  
  17.     publicvoid setError(String error) {  
  18.         this.error = error;  
  19.     }  
  20.     public Object getResult() {  
  21.         return result;  
  22.     }  
  23.     publicvoid setResult(Object result) {  
  24.         this.result = result;  
  25.     }  
  26. }  
public class RpcResponse {

    private String requestId;
    private String error;
    private Object result;

    public boolean isError() {
        return error != null;
    }

    public String getRequestId() {
        return requestId;
    }

    public void setRequestId(String requestId) {
        this.requestId = requestId;
    }

    public String getError() {
        return error;
    }

    public void setError(String error) {
        this.error = error;
    }

    public Object getResult() {
        return result;
    }

    public void setResult(Object result) {
        this.result = result;
    }
}

訊息序列化和反序列化工具:(基於 Protostuff 實現) [java] view plain copy print?在CODE上檢視程式碼片派生到我的程式碼片
  1. publicclass SerializationUtil {  
  2.     privatestatic Map<Class<?>, Schema<?>> cachedSchema = new ConcurrentHashMap<>();  
  3.     privatestatic Objenesis objenesis = new ObjenesisStd(true);  
  4.     private SerializationUtil() {  
  5.     }  
  6.     @SuppressWarnings("unchecked")  
  7.     privatestatic <T> Schema<T> getSchema(Class<T> cls) {  
  8.         Schema<T> schema = (Schema<T>) cachedSchema.get(cls);  
  9.         if (schema == null) {  
  10.             schema = RuntimeSchema.createFrom(cls);  
  11.             if (schema != null) {  
  12.                 cachedSchema.put(cls, schema);  
  13.             }  
  14.         }  
  15.         return schema;  
  16.     }  
  17.     /** 
  18.      * 序列化(物件 -> 位元組陣列) 
  19.      */
  20.     @SuppressWarnings("unchecked")  
  21.     publicstatic <T> byte[] serialize(T obj) {  
  22.         Class<T> cls = (Class<T>) obj.getClass();  
  23.         LinkedBuffer buffer = LinkedBuffer.allocate(LinkedBuffer.DEFAULT_BUFFER_SIZE);  
  24.         try {  
  25.             Sc