1. 程式人生 > >netty實戰-netty client連線池設計

netty實戰-netty client連線池設計

概述

最近有很多網友在諮詢netty client中,netty的channel連線池應該如何設計。這是個稍微有些複雜的主題,牽扯到蠻多技術點,要想在網上找到相關的又相對完整的參考文章,確實不太容易。

在本篇文章中,會給出其中一種解決方案,並且附帶完整的可執行的程式碼。如果網友有更好的方案,可以回覆本文,我們一起討論討論,一起開闊思路和眼界。

閱讀本文之前需要具備一些基礎知識

1、知道netty的一些基礎知識,比如ByteBuf類的相關api;
2、知道netty的執行流程;
3、 必須閱讀過我之前寫的netty實戰-自定義解碼器處理半包訊息,因為本文部分程式碼來自這篇文章。

現在微服務非常的熱門,也有很多公司在用。微服務框架中,如果是使用thrift、grpc來作為資料序列化框架的話,通常都會生成一個SDK給客戶端使用者使用。客戶端只要使用這個SDK,就可以方便的呼叫服務端的微服務介面。本文討論的就是使用SDK的netty客戶端,它的netty channel連線池的設計方案。至於netty http client的channel連線池設計,基於http的,是另外一個主題了,需要另外寫文章來討論的。

netty channel連線池設計

DB連線池中,當某個執行緒獲取到一個db connection後,在讀取資料或者寫資料時,如果執行緒沒有操作完,這個db connection一直被該執行緒獨佔著,直到執行緒執行完任務。如果netty client的channel連線池設計也是使用這種獨佔的方式的話,有幾個問題。

1、netty中channel的writeAndFlush方法,呼叫完後是不用等待返回結果的,writeAndFlush一被呼叫,馬上返回。對於這種情況,是完全沒必要讓執行緒獨佔一個channel的。
2、使用類似DB pool的方式,從池子裡拿連線,用完後返回,這裡的一進一出,需要考慮併發鎖的問題。另外,如果請求量很大的時候,連線會不夠用,其他執行緒也只能等待其他執行緒釋放連線。

因此不太建議使用上面的方式來設計netty channel連線池,channel獨佔的代價太大了。可以使用Channel陣列的形式, 複用netty的channel。當執行緒要需要Channel的時候,隨機從陣列選中一個Channel,如果Channel還未建立,則建立一個。如果執行緒選中的Channel已經建立了,則複用這個Channel。

這裡寫圖片描述

假設channel陣列的長度為4

private Channel[] channels = new Channel[4];

當外部系統請求client的時候,client從channels陣列中隨機挑選一個channel,如果該channel尚未建立,則觸發建立channel的邏輯。無論有多少請求,都是複用這4個channel。假設有10個執行緒,那麼部分執行緒可能會使用相同的channel來發送資料和接收資料。因為是隨機選擇一個channel的,多個執行緒命中同一個channel的機率還是很大的。如下圖

這裡寫圖片描述

10個執行緒中,可能有3個執行緒都是使用channel2來發送資料的。這個會引入另外一個問題。thread1通過channel2傳送一條訊息msg1到服務端,thread2也通過channel2傳送一條訊息msg2到服務端,當服務端處理完資料,通過channel2返回資料給客戶端的時候,如何區分哪條訊息是哪個執行緒的呢?如果不做區分,萬一thread1拿到的結果其實是thread2要的結果,怎麼辦?

那麼如何做到讓thread1和thread2拿到它們自己想要的結果呢?

之前我在netty實戰-自定義解碼器處理半包訊息一文中提到,自定義訊息的時候,通常會在訊息中加入一個序列號,用來唯一標識訊息的。當thread1傳送訊息時,往訊息中插入一個唯一的訊息序列號,同時為thread1建立一個callback回撥程式,當服務端返回訊息的時候,根據訊息中的序列號從對應的callback程式獲取結果。這樣就可以解決上面說到的問題。

訊息格式

這裡寫圖片描述

訊息、訊息seq以及callback對應關係

這裡寫圖片描述

這裡寫圖片描述

OK,下面就基於上面的設計來進行編碼。

程式碼

先來實現netty客戶端,設定10個執行緒併發獲取channel,為了達到真正的併發,利用CountDownLatch來做開關,同時channel連線池設定4個channel。

package nettyinaction.nettyclient.channelpool.client;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.UnpooledByteBufAllocator;
import io.netty.channel.Channel;
import nettyinaction.nettyclient.channelpool.ChannelUtils;
import nettyinaction.nettyclient.channelpool.IntegerFactory;

import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.CountDownLatch;

public class SocketClient {
    public static void main(String[] args) throws InterruptedException {
        //當所有執行緒都準備後,開閘,讓所有執行緒併發的去獲取netty的channel
        final CountDownLatch countDownLatchBegin = new CountDownLatch(1);

        //當所有執行緒都執行完任務後,釋放主執行緒,讓主執行緒繼續執行下去
        final CountDownLatch countDownLatchEnd = new CountDownLatch(10);

        //netty channel池
        final NettyChannelPool nettyChannelPool = new NettyChannelPool();

        final Map<String, String> resultsMap = new HashMap<>();
        //使用10個執行緒,併發的去獲取netty channel
        for (int i = 0; i < 10; i++) {
            new Thread(new Runnable() {
                @Override
                public void run() {
                    try {
                        //先讓執行緒block住
                        countDownLatchBegin.await();

                        Channel channel = null;
                        try {
                            channel = nettyChannelPool.syncGetChannel();
                        } catch (InterruptedException e) {
                            e.printStackTrace();
                        }

                        //為每個執行緒建立一個callback,當訊息返回的時候,在callback中獲取結果
                        CallbackService callbackService = new CallbackService();
                        //給訊息分配一個唯一的訊息序列號
                        int seq = IntegerFactory.getInstance().incrementAndGet();
                        //利用Channel的attr方法,建立訊息與callback的對應關係
                        ChannelUtils.putCallback2DataMap(channel,seq,callbackService);

                        synchronized (callbackService) {
                            UnpooledByteBufAllocator allocator = new UnpooledByteBufAllocator(false);
                            ByteBuf buffer = allocator.buffer(20);
                            buffer.writeInt(ChannelUtils.MESSAGE_LENGTH);

                            buffer.writeInt(seq);
                            String threadName = Thread.currentThread().getName();
                            buffer.writeBytes(threadName.getBytes());
                            buffer.writeBytes("body".getBytes());

                            //給netty 服務端傳送訊息,非同步的,該方法會立刻返回
                            channel.writeAndFlush(buffer);

                            //等待返回結果
                            callbackService.wait();

                            //解析結果,這個result在callback中已經解析到了。
                            ByteBuf result = callbackService.result;
                            int length = result.readInt();
                            int seqFromServer = result.readInt();

                            byte[] head = new byte[8];
                            result.readBytes(head);
                            String headString = new String(head);

                            byte[] body = new byte[4];
                            result.readBytes(body);
                            String bodyString = new String(body);
                            resultsMap.put(threadName, seqFromServer + headString + bodyString);
                        }
                    } catch (Exception e) {
                        e.printStackTrace();
                    }
                    finally {
                        countDownLatchEnd.countDown();
                    }
                }
            }).start();
        }

        //開閘,讓10個執行緒併發獲取netty channel
        countDownLatchBegin.countDown();

        //等10個執行緒執行完後,列印最終結果
        countDownLatchEnd.await();
        System.out.println("resultMap="+resultsMap);
    }

    public static class CallbackService{
        public volatile ByteBuf result;
        public void receiveMessage(ByteBuf receiveBuf) throws Exception {
            synchronized (this) {
                result = receiveBuf;
                this.notify();
            }
        }
    }
}

其中IntegerFactory類用於生成訊息的唯一序列號

package nettyinaction.nettyclient.channelpool;


import java.util.concurrent.atomic.AtomicInteger;

public class IntegerFactory {
    private static class SingletonHolder {
        private static final AtomicInteger INSTANCE = new AtomicInteger();
    }

    private IntegerFactory(){}

    public static final AtomicInteger getInstance() {
        return SingletonHolder.INSTANCE;
    }
}

而ChannelUtils類則用於建立channel、訊息序列號和callback程式的對應關係。

package nettyinaction.nettyclient.channelpool;

import io.netty.channel.Channel;
import io.netty.util.AttributeKey;

import java.util.Map;

public class ChannelUtils {
    public static final int MESSAGE_LENGTH = 16;
    public static final AttributeKey<Map<Integer, Object>> DATA_MAP_ATTRIBUTEKEY = AttributeKey.valueOf("dataMap");
    public static <T> void putCallback2DataMap(Channel channel, int seq, T callback) {
        channel.attr(DATA_MAP_ATTRIBUTEKEY).get().put(seq, callback);
    }

    public static <T> T removeCallback(Channel channel, int seq) {
        return (T) channel.attr(DATA_MAP_ATTRIBUTEKEY).get().remove(seq);
    }
}

NettyChannelPool則負責建立netty的channel。

package nettyinaction.nettyclient.channelpool.client;


import io.netty.bootstrap.Bootstrap;
import io.netty.channel.*;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import io.netty.util.Attribute;
import nettyinaction.nettyclient.channelpool.ChannelUtils;
import nettyinaction.nettyclient.channelpool.SelfDefineEncodeHandler;

import java.util.Map;
import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;

public class NettyChannelPool {
    private Channel[] channels;
    private Object [] locks;
    private static final int MAX_CHANNEL_COUNT = 4;

    public NettyChannelPool() {
        this.channels = new Channel[MAX_CHANNEL_COUNT];
        this.locks = new Object[MAX_CHANNEL_COUNT];
        for (int i = 0; i < MAX_CHANNEL_COUNT; i++) {
            this.locks[i] = new Object();
        }
    }

    /**
     * 同步獲取netty channel
     */
    public Channel syncGetChannel() throws InterruptedException {
        //產生一個隨機數,隨機的從陣列中獲取channel
        int index = new Random().nextInt(MAX_CHANNEL_COUNT);
        Channel channel = channels[index];
        //如果能獲取到,直接返回
        if (channel != null && channel.isActive()) {
            return channel;
        }

        synchronized (locks[index]) {
            channel = channels[index];
            //這裡必須再次做判斷,當鎖被釋放後,之前等待的執行緒已經可以直接拿到結果了。
            if (channel != null && channel.isActive()) {
                return channel;
            }

            //開始跟服務端互動,獲取channel
            channel = connectToServer();

            channels[index] = channel;
        }

        return channel;
    }

    private Channel connectToServer() throws InterruptedException {
        EventLoopGroup eventLoopGroup = new NioEventLoopGroup();
        Bootstrap bootstrap = new Bootstrap();
        bootstrap.group(eventLoopGroup)
                 .channel(NioSocketChannel.class)
                 .option(ChannelOption.SO_KEEPALIVE, Boolean.TRUE)
                 .option(ChannelOption.TCP_NODELAY, Boolean.TRUE)
                 .handler(new LoggingHandler(LogLevel.INFO))
                 .handler(new ChannelInitializer<SocketChannel>() {
                     @Override
                     protected void initChannel(SocketChannel ch) throws Exception {
                         ChannelPipeline pipeline = ch.pipeline();
                         pipeline.addLast(new SelfDefineEncodeHandler());
                         pipeline.addLast(new SocketClientHandler());
                     }
                 });

        ChannelFuture channelFuture = bootstrap.connect("localhost", 8899);
        Channel channel = channelFuture.sync().channel();

        //為剛剛建立的channel,初始化channel屬性
        Attribute<Map<Integer,Object>> attribute = channel.attr(ChannelUtils.DATA_MAP_ATTRIBUTEKEY);
        ConcurrentHashMap<Integer, Object> dataMap = new ConcurrentHashMap<>();
        attribute.set(dataMap);
        return channel;
    }
}

先使用構造方法,初始化channels陣列,長度為4。NettyChannelPool類有兩個關鍵的地方。
第一個是獲取channel的時候必須加上鎖。另外一個是當獲取到channel後,利用channel的屬性,建立一個Map,後面需要利用這個Map建立訊息序列號和callback程式的對應關係。

//初始化channel屬性
        Attribute<Map<Integer,Object>> attribute = channel.attr(ChannelUtils.DATA_MAP_ATTRIBUTEKEY);
        ConcurrentHashMap<Integer, Object> dataMap = new ConcurrentHashMap<>();
        attribute.set(dataMap);

這個map就是我們上面看到的
這裡寫圖片描述

Map的put的動作,就是在SocketClient類中的

ChannelUtils.putCallback2DataMap(channel,seq,callbackService);

執行的。客戶端處理訊息還需要兩個hanlder輔助,一個是處理半包問題,一個是接收服務端的返回的訊息。

SelfDefineEncodeHandler類用於處理半包訊息

package nettyinaction.nettyclient.channelpool;

import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageDecoder;

import java.util.List;

public class SelfDefineEncodeHandler extends ByteToMessageDecoder {
    @Override
    protected void decode(ChannelHandlerContext ctx, ByteBuf bufferIn, List<Object> out) throws Exception {
        if (bufferIn.readableBytes() < 4) {
            return;
        }

        int beginIndex = bufferIn.readerIndex();
        int length = bufferIn.readInt();

        if (bufferIn.readableBytes() < length) {
            bufferIn.readerIndex(beginIndex);
            return;
        }

        bufferIn.readerIndex(beginIndex + 4 + length);

        ByteBuf otherByteBufRef = bufferIn.slice(beginIndex, 4 + length);

        otherByteBufRef.retain();

        out.add(otherByteBufRef);
    }
}

SocketClientHandler類用於接收服務端返回的訊息,並且根據訊息序列號獲取對應的callback程式

package nettyinaction.nettyclient.channelpool.client;

import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import nettyinaction.nettyclient.channelpool.ChannelUtils;

public class SocketClientHandler extends ChannelInboundHandlerAdapter {
    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
        Channel channel = ctx.channel();

        ByteBuf responseBuf = (ByteBuf)msg;
        responseBuf.markReaderIndex();

        int length = responseBuf.readInt();
        int seq = responseBuf.readInt();

        responseBuf.resetReaderIndex();

        //獲取訊息對應的callback
        SocketClient.CallbackService callbackService = ChannelUtils.<SocketClient.CallbackService>removeCallback(channel, seq);
        callbackService.receiveMessage(responseBuf);
    }
}

到此客戶端程式編寫完畢。至於服務端的程式碼,則比較簡單,這裡直接貼上程式碼。

package nettyinaction.nettyclient.channelpool.server;

import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import nettyinaction.nettyclient.channelpool.SelfDefineEncodeHandler;

public class SocketServer {
    public static void main(String[] args) throws InterruptedException {
        EventLoopGroup parentGroup = new NioEventLoopGroup();
        EventLoopGroup childGroup = new NioEventLoopGroup();

        try {
            ServerBootstrap serverBootstrap = new ServerBootstrap();
            serverBootstrap.group(parentGroup, childGroup)
                           .channel(NioServerSocketChannel.class)
                           .handler(new LoggingHandler(LogLevel.INFO))
                           .childHandler(new ChannelInitializer<SocketChannel>() {
                                @Override
                                protected void initChannel(SocketChannel ch) throws Exception {
                                    ChannelPipeline pipeline = ch.pipeline();
                                    pipeline.addLast(new SelfDefineEncodeHandler());
                                    pipeline.addLast(new BusinessServerHandler());
                                }
                           });

            ChannelFuture channelFuture = serverBootstrap.bind(8899).sync();
            channelFuture.channel().closeFuture().sync();
        }
        finally {
            parentGroup.shutdownGracefully();
            childGroup.shutdownGracefully();
        }
    }
}

package nettyinaction.nettyclient.channelpool.server;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.UnpooledByteBufAllocator;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import nettyinaction.nettyclient.channelpool.ChannelUtils;

public class BusinessServerHandler extends ChannelInboundHandlerAdapter {
    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
        Channel channel = ctx.channel();
        ByteBuf buf = (ByteBuf)msg;
        //1、讀取訊息長度
        int length = buf.readInt();

        //2、讀取訊息序列號
        int seq = buf.readInt();

        //3、讀取訊息頭部
        byte[] head = new byte[8];
        buf.readBytes(head);
        String headString = new String(head);

        //4、讀取訊息體
        byte[] body = new byte[4];
        buf.readBytes(body);
        String bodyString = new String(body);

        //5、新建立一個快取區,寫入內容,返回給客戶端
        UnpooledByteBufAllocator allocator = new UnpooledByteBufAllocator(false);
        ByteBuf responseBuf = allocator.buffer(20);
        responseBuf.writeInt(ChannelUtils.MESSAGE_LENGTH);
        responseBuf.writeInt(seq);
        responseBuf.writeBytes(headString.getBytes());
        responseBuf.writeBytes(bodyString.getBytes());

        //6、將資料寫回到客戶端
        channel.writeAndFlush(responseBuf);
    }
}

執行服務端程式碼和客戶端程式碼,期望的結果是

10個執行緒傳送訊息後,能從服務端獲取到正確的對應的返回資訊,這些資訊不會發生錯亂,各個執行緒都能拿到自己想要的結果,不會發生錯讀的情況。

執行後的結果如下

Thread-3=9Thread-3body,
Thread-4=8Thread-4body,
Thread-5=5Thread-5body,
Thread-6=1Thread-6body,
Thread-7=3Thread-7body,
Thread-8=10Thread-8body,
Thread-9=4Thread-9body,
Thread-0=7Thread-0body,
Thread-1=6Thread-1body,
Thread-2=2Thread-2body

通過觀察結果,可以知道10個執行緒併發獲取channel後,部分執行緒共享一個channel,但是10個執行緒能仍然能正確獲取到結果。

程式碼細節解析

1、等待服務端的返回

由於 channel.writeAndFlush是非同步的,必須有一種機制來讓執行緒等待服務端返回結果。這裡採用最原始的wait和notify方法。當writeAndFlush呼叫後,立刻讓當前執行緒wait住,放置在callbackservice物件的等待列表中,當伺服器端返回訊息時,客戶端的SocketClientHandler類中的channelRead方法會被執行,解析完資料後,從channel的attr屬性中獲取DATA_MAP_ATTRIBUTEKEY 這個key對應的map。並根據解析出來的seq從map中獲取事先放置好的callbackservice物件,執行它的receiveMessage方法。將receiveBuf這個存放結果的快取區物件賦值到callbackservice的result屬性中。並呼叫callbackservice物件的notify方法,喚醒wait在callbackservice物件的執行緒,讓其繼續往下執行。

2、產生訊息序列號

                        int seq = IntegerFactory.getInstance().incrementAndGet();

為了演示的方便,這裡是產生單伺服器全域性唯一的序列號。如果請求量大的話,就算是AtomicInteger是CAS操作,也會產生很多的競爭。建議產生channel級別的唯一序列號,降低競爭。只要保證在一個channel內的訊息的序列號是不重複的即可。

至於其他的一些程式碼細節,讀者可以自己再細看。