1. 程式人生 > >spring boot 環境下websocket 開發簡單示例

spring boot 環境下websocket 開發簡單示例

示例如下,需要注意的是,

  1. 類上需要打上@Scope("prototype")註解,否則socket就是單例的
  2. socket的configurator屬性的配置類需要實現ApplicationContextAware,和程式碼裡面一樣的配置之後,server裡面才能成功注入spring管理的物件

 

package com.xiaogang.websocketdemo.web.socket;

import com.alibaba.fastjson.JSON;
import com.xiaogang.websocketdemo.config.ClickServerEndpointConf;
import com.xiaogang.websocketdemo.dto.ApiResult;
import com.xiaogang.websocketdemo.dto.UserActionCount;
import com.xiaogang.websocketdemo.runnable.SendAllUserClickMsg;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Scope;
import org.springframework.core.annotation.Order;
import org.springframework.stereotype.Component;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

import javax.websocket.*;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ThreadPoolExecutor;

/**
 * @author xiaogang
 * @date 2018/11/30 11:12
 */
@Slf4j
@Component
@Scope("prototype")
@ServerEndpoint(value = "/socket/order/msg/{userName}",configurator = ClickServerEndpointConf.class)
@Order(11111)
public class ClickServer {

    @Autowired
    private volatile ThreadPoolExecutor threadPoolExecutor;

    public ClickServer() {
        System.out.println("ClickServer.ClickServer");
    }

    /**
     * 和某個客戶端會話的唯一關聯
     */
    private volatile Session session;

    /**
     * 整個ClickServer的全域性物件
     */
//    private ServerEndpointConfig serverEndpointConfig;

    public static Set<Session> sessions = new HashSet<>();

    private String userName;

    @OnOpen
    public void onOpen(Session session,@PathParam("userName")String userName) throws IOException {
        ApiResult result = new ApiResult();
        this.session = session;
        this.userName = userName;
        log.info("使用者:{}連結成功",userName);
        try {
            validUserUnique(userName);
        } catch (Exception e) {
            e.printStackTrace();
            result.setCode(0);
            result.setMsg(e.getMessage());
            String jsonString = JSON.toJSONString(result);
            this.session.getBasicRemote().sendText(jsonString);
            this.session.close();
            return;
        }
//        this.serverEndpointConfig = (ServerEndpointConfig) endpointConfig;
        Map<String, Object> userProperties = this.session.getUserProperties();
        UserActionCount userActionCount = new UserActionCount();
        userActionCount.setUserName(userName);
        userActionCount.setActionCount(0);
        userProperties.put("UserActionCount", userActionCount);
        sessions.add(session);
    }

    private void validUserUnique(String userName) {
        Assert.isTrue(!StringUtils.isEmpty(userName),"使用者名稱不能為空");
        Assert.isTrue(!(userName.length() > 6),"使用者名稱長度不能超過6個字元");
        for (Session openSession : sessions) {
            UserActionCount userActionCount = (UserActionCount) openSession.getUserProperties().get("UserActionCount");
            String existuserName = userActionCount.getUserName();
            Assert.isTrue(!existuserName.equals(userName),"使用者已經報名,無法重複報名");
        }
    }

    @OnMessage
    public void onMessage(String msg) throws IOException {
        log.info("接收到訊息:{}",msg);
        if (msg.equals("restart")) {
            for (Session session : sessions) {
                Map<String, Object> userProperties = session.getUserProperties();
                UserActionCount userActionCount = (UserActionCount) userProperties.get("UserActionCount");
                userActionCount.setActionCount(0);
                userProperties.put("UserActionCount",userActionCount);
                ApiResult result = new ApiResult();
                result.setCode(1);
                result.setMsg("ok");
                result.setData(new ArrayList());
                String jsonString = JSON.toJSONString(result);
                session.getBasicRemote().sendText(jsonString);
            }
            for (Session session : sessions) {
                session.close();
            }
        }else{
            addCount();
            sendAllUserClickData();
        }
    }

    @OnClose
    public void onClose() throws IOException {
        if (this.session == null) {
            sendAllUserClickData();
            return;
        }
        boolean open = this.session.isOpen();
        if (open) {
            this.session.close();
        }
        sessions.remove(this.session);
        sendAllUserClickData();
    }

    @OnError
    public void onError(Throwable throwable) throws IOException {
        throwable.printStackTrace();
        boolean open = this.session.isOpen();
        if (open) {
            this.session.close();
        }
    }

    private void sendAllUserClickData() throws IOException {
        log.info("給所有使用者傳送所有使用者的點選次數");
        Set<UserActionCount> userActionCounts = obtainAllUserClickData();
        SendAllUserClickMsg sendAllUserClickMsg = new SendAllUserClickMsg(sessions, userActionCounts);
        threadPoolExecutor.execute(sendAllUserClickMsg);
    }


    private void addCount() {
        log.info("為使用者:{}新增一次點選次數",userName);
        UserActionCount userActionCount = (UserActionCount) this.session.getUserProperties().get("UserActionCount");
        int actionCount = userActionCount.getActionCount();
        userActionCount.setActionCount(actionCount + 1);
        this.session.getUserProperties().put("UserActionCount",userActionCount);
        log.info("點選次數新增完成");
    }

    private Set<UserActionCount> obtainAllUserClickData() throws IOException {
        log.info("獲取所有使用者的點選資料");
        Set<UserActionCount> allUserClickData = new HashSet();

        for (Session openSession : sessions) {
            if (!openSession.isOpen()) {
                openSession.close();
                sessions.remove(openSession);
            }
            UserActionCount userActionCount = (UserActionCount) openSession.getUserProperties().get("UserActionCount");
            allUserClickData.add(userActionCount);
        }
        log.info("所有使用者的點選次數獲取完成:{}",allUserClickData);
        return allUserClickData;
    }

}


 

 

 

package com.xiaogang.websocketdemo.config;

import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.context.annotation.Configuration;

import javax.websocket.server.ServerEndpointConfig.Configurator;

/**
 * @author xiaogang
 * @date 2018/11/30 16:03
 */
@Configuration
public class ClickServerEndpointConf extends Configurator implements ApplicationContextAware{

    private static volatile ApplicationContext applicationContext;

    @Override
    public <T> T getEndpointInstance(Class<T> clazz) throws InstantiationException {
        T bean = applicationContext.getBean(clazz);
        return bean;
    }

    @Override
    public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
        ClickServerEndpointConf.applicationContext = applicationContext;
    }
}