netty 实现webSocket


public class WebSocketClientCache {

    private static ConcurrentHashMap<String, WebSocketClientNode> channelsMap = new ConcurrentHashMap<>();
    //心跳 1 已发送
    private static ConcurrentHashMap<String, Integer> pingPongChannelsMap = new ConcurrentHashMap<>();

    private static ConcurrentHashMap<String, Long> registTrackIdMap = new ConcurrentHashMap<String, Long>();

    //轮训时间 检测过期连接
    private final static int SCHEDULE_SECONDS = 180;
    private static ScheduledExecutorService scheduleService = Executors.newScheduledThreadPool(1);

    private static Lock pingPongLock = new ReentrantLock();


    /*标记状态*/
    private static volatile boolean isSent = true;

    static {
        scheduleService.scheduleAtFixedRate(new Runnable() {
            @Override
            public void run() {
                //3分鐘 send ping to all
                LogUtils.logDebug(WebSocketClientCache.class, "WebSocketClientCache scheduleWithFixedDelay starting  ...");

                if (isSent) {
                    isSent = false;
                    //定时发送心跳
                    sendPingMessageToAll();
                } else {
                    isSent = true;
                    clearNotPingPongMessage();
                }
            }
        }, 1L, SCHEDULE_SECONDS, TimeUnit.SECONDS);
    }


    public static void putChannelHandlerContext(String channelId, ChannelHandlerContext channelHandlerContext,
                                                WebSocketServerHandshaker handshaker) {
        if (channelsMap.containsKey(channelId)) {
            return;
        }
        try {
            WebSocketClientNode webSocketClientNode = new WebSocketClientNode();
            webSocketClientNode.handshaker = handshaker;
            webSocketClientNode.ctx = channelHandlerContext;
            channelsMap.put(channelId, webSocketClientNode);
            pingPongChannelsMap.remove(channelId);
        } catch (Exception e) {
            LogUtils.logError(WebSocketClientCache.class, e);
        }
    }

    public static ChannelHandlerContext getChannelHandlerContext(String channelId) {
        WebSocketClientNode webSocketClientNode = channelsMap.get(channelId);
        if (webSocketClientNode != null && webSocketClientNode.ctx != null) {
            return webSocketClientNode.ctx;
        }
        return null;
    }


    public static void removeChannelHandlerContext(String channelId) {
        WebSocketClientNode webSocketClientNode = channelsMap.remove(channelId);
        if (webSocketClientNode != null) {
            Channel channel = webSocketClientNode.ctx.channel();
            //关闭连接
            if (channel.isOpen()) {
                webSocketClientNode.handshaker.close(channel, new CloseWebSocketFrame());
                channel.close();
            }
        }
        registTrackIdMap.remove(channelId);
        pingPongChannelsMap.remove(channelId);
    }


    public static List<ChannelHandlerContext> getChannelHandlerContextList() {
        if (channelsMap.isEmpty()) {
            return null;
        }
        Set<String> keySet = channelsMap.keySet();
        List<ChannelHandlerContext> list = new LinkedList<>();
        for (String key : keySet) {
            WebSocketClientNode webSocketClientNode = channelsMap.get(key);
            list.add(webSocketClientNode.ctx);
        }
        if (list.size() == 0) {
            return null;
        }
        return list;
    }


    /**
     * 全部发送心跳
     */
    public static void sendPingMessageToAll() {
        if (channelsMap.isEmpty()) {
            return;
        }
        Set<String> keySet = channelsMap.keySet();
        for (String key : keySet) {
            WebSocketClientNode webSocketClientNode = channelsMap.get(key);
            //往客户端发ping 客户端会返回pong 可以用来判断客户端存活
            PingWebSocketFrame pingWebSocketFrame = new PingWebSocketFrame();
            Channel channel = webSocketClientNode.ctx.channel();
            if (channel.isOpen()) {
                channel.writeAndFlush(pingWebSocketFrame);
            }
            //标记为已发送
            pingPongChannelsMap.put(key, 1);
        }
    }


    public static void getPongMessage(String channelId) {
        if (channelId == null) {
            return;
        }
        pingPongChannelsMap.remove(channelId);
    }


    public static void clearNotPingPongMessage() {
        if (pingPongChannelsMap.isEmpty()) {
            return;
        }
        Set<String> keySet = pingPongChannelsMap.keySet();
        for (String key : keySet) {
            Integer status = pingPongChannelsMap.get(key);
            if (status != null && status.intValue() == 1) {
                WebSocketClientNode webSocketClientNode = channelsMap.get(key);
                //关闭websocket // 握手关闭连接
                Channel channel = webSocketClientNode.ctx.channel();
                if (channel.isOpen()) {
                    webSocketClientNode.handshaker.close(channel, new CloseWebSocketFrame());
                }
            }
            //删除连接
            removeChannelHandlerContext(key);
        }
    }


    private static class WebSocketClientNode {
        WebSocketServerHandshaker handshaker;
        ChannelHandlerContext ctx;
    }


    /**
     * 全部发送消息
     * 往客户端推送消息
     */
    public static void sendMessageToAll(String message) {
        if (channelsMap.isEmpty()) {
            return;
        }
        Set<String> keySet = channelsMap.keySet();
        for (String key : keySet) {
            WebSocketClientNode webSocketClientNode = channelsMap.get(key);
            TextWebSocketFrame textWebSocketFrame = new TextWebSocketFrame(message);
            if (webSocketClientNode.ctx.channel().isOpen() && webSocketClientNode.ctx.channel().isWritable()) {
                webSocketClientNode.ctx.channel().writeAndFlush(textWebSocketFrame);
            } else {
                channelsMap.remove(key);
                registTrackIdMap.remove(key);
            }
        }
    }


    public static void registTrackTagId(String key, Long id) {
        registTrackIdMap.put(key, id);
    }

    public static void removeTrackTagId(String key) {
        registTrackIdMap.remove(key);
    }

    public static boolean isRegistTrackTagId(Long tagId) {
        Collection<Long> ids = registTrackIdMap.values();
        return ids.contains(tagId);
    }


}

 

package com.hikvision.energy.websocket.server;

import com.hikvision.energy.core.util.log.LogUtils;
import com.hikvision.energy.websocket.handler.WebSocketServerHandler;
import com.hikvision.energy.websocket.handler.WebSocketServerHandlerFactory;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.*;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.handler.stream.ChunkedWriteHandler;
import org.springframework.beans.factory.annotation.Autowired;

import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

/**
 * netty websocket 服务器
 * <p>
 * Created by zhuangjiesen on 2017/8/8.
 */
public class WebSocketNettyServer {
    /**
     * 端口号
     **/
    private int port;
    @Autowired
    private WebSocketServerHandlerFactory webSocketServerHandlerFactory;

    private ExecutorService executorService = Executors.newSingleThreadExecutor();

    public WebSocketNettyServer() {
        super();
    }


    /**
     * 创建一个单例的线程池,
     */
    public void startNettyServer() {
        executorService.execute(new Runnable() {
            @Override
            public void run() {
                // TODO Auto-generated method stub
                EventLoopGroup boss = new NioEventLoopGroup();
                EventLoopGroup worker = new NioEventLoopGroup();

                try {

                    ServerBootstrap bootstrap = new ServerBootstrap();

                    bootstrap.group(boss, worker);
                    bootstrap.channel(NioServerSocketChannel.class);
                    bootstrap.option(ChannelOption.SO_BACKLOG, 1024); //连接数
                    bootstrap.option(ChannelOption.TCP_NODELAY, true);  //不延迟,消息立即发送
//                    bootstrap.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 2000);  //超时时间
                    bootstrap.childOption(ChannelOption.SO_KEEPALIVE, true); //长连接
                    bootstrap.childHandler(new ChannelInitializer<SocketChannel>() {
                        @Override
                        protected void initChannel(SocketChannel socketChannel)
                                throws Exception {
                            ChannelPipeline p = socketChannel.pipeline();
                            p.addLast("http-codec", new HttpServerCodec());
                            p.addLast("aggregator", new HttpObjectAggregator(65536));
                            p.addLast("http-chunked", new ChunkedWriteHandler());
                            p.addLast("handler", webSocketServerHandlerFactory.newWebSocketServerHandler());
                            p.addLast(" ", webSocketServerHandlerFactory.newWebSocketServerOutboundHandler());
                        }
                    });


                    ChannelFuture f = bootstrap.bind(port).sync();


                    if (f.isSuccess()) {
                        LogUtils.logDebug(WebSocketNettyServer.class, " WebSocketNettyServer start successfully ....");
                    }
                    f.channel().closeFuture().sync();
                    LogUtils.logDebug(WebSocketNettyServer.class, " WebSocketNettyServer start successfully ....");
                } catch (Exception e) {
                    LogUtils.logError(e.getClass(), e);
                } finally {
                    LogUtils.logError(WebSocketNettyServer.class, " WebSocketNettyServer shutdownGracefully ....");
                    boss.shutdownGracefully();
                    worker.shutdownGracefully();
                }
            }
        });


    }


    public int getPort() {
        return port;
    }

    public void setPort(int port) {
        this.port = port;
    }

    public WebSocketServerHandlerFactory getWebSocketServerHandlerFactory() {
        return webSocketServerHandlerFactory;
    }

    public void setWebSocketServerHandlerFactory(WebSocketServerHandlerFactory webSocketServerHandlerFactory) {
        this.webSocketServerHandlerFactory = webSocketServerHandlerFactory;
    }
}
public class WebSocketServerOutboundHandler extends ChannelOutboundHandlerAdapter {


    @Override
    public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
        String id = ctx.channel().id().asLongText();
        WebSocketClientCache.removeChannelHandlerContext(id);
        super.close(ctx, promise);
    }
}
public class WebSocketServerHandlerFactory {

    private WebSocketMessageHandler webSocketMessageHandler;

    public WebSocketServerHandler newWebSocketServerHandler() {
        WebSocketServerHandler webSocketServerHandler = new WebSocketServerHandler();
        webSocketServerHandler.setWebSocketMessageHandler(webSocketMessageHandler);
        return webSocketServerHandler;
    }

    public WebSocketServerOutboundHandler newWebSocketServerOutboundHandler() {
        WebSocketServerOutboundHandler webSocketServerOutboundHandler = new WebSocketServerOutboundHandler();
        return webSocketServerOutboundHandler;
    }

    public WebSocketMessageHandler getWebSocketMessageHandler() {
        return webSocketMessageHandler;
    }

    public void setWebSocketMessageHandler(WebSocketMessageHandler webSocketMessageHandler) {
        this.webSocketMessageHandler = webSocketMessageHandler;
    }
}
package com.hikvision.energy.websocket.handler;


import com.alibaba.fastjson.JSON;
import com.hikvision.energy.core.util.log.LogUtils;
import com.hikvision.energy.util.CfgMgr;
import com.hikvision.energy.websocket.cache.WebSocketClientCache;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.*;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.handler.codec.http.websocketx.*;
import io.netty.util.CharsetUtil;

import java.util.concurrent.ConcurrentHashMap;

public class WebSocketServerHandler extends ChannelInboundHandlerAdapter {

    public final static String WEBSOCKET = "websocket";

    public final static String Upgrade = "Upgrade";

    public final static String WEBSOCKET_ADDRESS_FORMAT = "ws://${host}/web/websocket";

    private WebSocketServerHandshaker webSocketServerHandshaker;

    private WebSocketMessageHandler webSocketMessageHandler;

    @Override
    public void channelActive(ChannelHandlerContext ctx) throws Exception {
        // TODO Auto-generated method stub
//        LogUtils.logError(WebSocketServerHandler.class , "WebSocketServerHandler channelActive ..." );
    }

    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
        // TODO Auto-generated method stub
        if (msg instanceof FullHttpRequest) {
            //处理http请求
            handleHttpRequest(ctx, ((FullHttpRequest) msg));
        } else if (msg instanceof WebSocketFrame) {
            //处理websocket请求
            handlerWebSocketFrame(ctx, (WebSocketFrame) msg);
        } else if (msg instanceof CloseWebSocketFrame) {
            //关闭链路
            onWebSocketFrameColsed(ctx, (CloseWebSocketFrame) msg);
        } else {
            throw new RuntimeException("无法处理的请求");
        }
    }

    @Override
    public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
        // TODO Auto-generated method stub
//        LogUtils.logError(WebSocketServerHandler.class , "WebSocketServerHandler channelReadComplete ..." );
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
        // TODO Auto-generated method stub
//        LogUtils.logError(WebSocketServerHandler.class , "WebSocketServerHandler exceptionCaught ..."  , " exception : " , cause );
    }


    /**
     * 处理http请求
     */
    public void handleHttpRequest(ChannelHandlerContext ctx,
                                  FullHttpRequest req) {

        String headVal = "";
        if (req.headers() != null) {
            headVal = req.headers().get(this.Upgrade);
        }
        //不是websocket请求
        if (!req.decoderResult().isSuccess()
                || (headVal != null && !this.WEBSOCKET.equals(headVal.toLowerCase()))) {

            DefaultFullHttpResponse defaultFullHttpResponse = new DefaultFullHttpResponse(
                    HttpVersion.HTTP_1_1, HttpResponseStatus.BAD_REQUEST);

            // 返回应答给客户端
            if (defaultFullHttpResponse.status().code() != 200) {
                ByteBuf buf = Unpooled.copiedBuffer(defaultFullHttpResponse.status().toString(),
                        CharsetUtil.UTF_8);
                defaultFullHttpResponse.content().writeBytes(buf);
                buf.release();
            }
            // 如果是非Keep-Alive,关闭连接
            ChannelFuture f = ctx.channel().writeAndFlush(defaultFullHttpResponse);

            //boolean isKeepAlive = false;

            //if ((!isKeepAlive) || defaultFullHttpResponse.status().code() != 200) {
            if (defaultFullHttpResponse.status().code() != 200) {
                f.addListener(ChannelFutureListener.CLOSE);
            }
            return;
        }


        String host = "";
        if (req.headers().get("Host") != null) {
            host = req.headers().get("Host").toString();
        } else {
            host = CfgMgr.getInstallIp();
        }

        String webAddress = WEBSOCKET_ADDRESS_FORMAT.replace("${host}", host);
        WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(
                webAddress, null, false);
        //wbSocket Factory 通過請求可以獲取 握手。
        WebSocketServerHandshaker handshaker = wsFactory.newHandshaker(req);
        if (handshaker == null) {
            WebSocketServerHandshakerFactory
                    .sendUnsupportedVersionResponse(ctx.channel());
        } else {
            handshaker.handshake(ctx.channel(), req);
            webSocketServerHandshaker = handshaker;
            WebSocketClientCache.putChannelHandlerContext(ctx.channel().id().asLongText(), ctx, handshaker);
        }
    }

    /**
     * 处理websocket 请求
     */
    public void handlerWebSocketFrame(ChannelHandlerContext ctx,
                                      WebSocketFrame frame) {
        if (frame == null) {
            return;
        }
        // 判断是否关闭链路的指令
        if ((frame instanceof CloseWebSocketFrame)) {
            LogUtils.logDebug(WebSocketServerHandler.class, "关闭连接....");
            onWebSocketFrameColsed(ctx, frame);
            return;
        }

        // 判断是否ping消息
        if (frame instanceof PingWebSocketFrame) {
            LogUtils.logDebug(WebSocketServerHandler.class, "i am ping message ....");
            ctx.channel().write(
                    new PongWebSocketFrame(frame.content().retain()));
            return;
        }
        if (frame instanceof PongWebSocketFrame) {
            LogUtils.logDebug(WebSocketServerHandler.class, "i am pong message ....");
            ctx.channel().write(
                    new PingWebSocketFrame(frame.content().retain()));

            //获取pong
            WebSocketClientCache.getPongMessage(ctx.channel().id().asLongText());
            return;
        }
        LogUtils.logDebug(WebSocketServerHandler.class, "frame class : " + frame.getClass().getName());


        // 本例程仅支持文本消息,不支持二进制消息
        if (!(frame instanceof TextWebSocketFrame)) {
            LogUtils.logDebug(WebSocketServerHandler.class, "not support other frame data : ", JSON.toJSONString(frame));
            throw new UnsupportedOperationException(String.format(
                    "%s frame types not supported", frame.getClass().getName()));
        }

        String channelId = ctx.channel().id().asLongText();
        //每次插入,保证可靠
        WebSocketClientCache.putChannelHandlerContext(channelId, ctx, webSocketServerHandshaker);
        //获取pong
        WebSocketClientCache.getPongMessage(ctx.channel().id().asLongText());
        if (frame != null) {
            TextWebSocketFrame textWebSocketFrame = (TextWebSocketFrame) frame;
//            String frameContent = textWebSocketFrame.text();
//            Channel channel = ctx.channel();
//            String channelId = channel.id().asLongText();
//            //获取pong
//            WebSocketClientCache.getPongMessage(ctx.channel().id().asLongText());
//            TextWebSocketFrame tws = new TextWebSocketFrame("我回应了。。。。");
//            channel.writeAndFlush(tws);
//            LogUtils.logDebug(WebSocketServerHandler.class , "frameContent : " + frameContent);
            webSocketMessageHandler.onMessage(textWebSocketFrame, ctx);
        }


    }


    /*
    * 处理websocket 关闭请求
    * */
    public void onWebSocketFrameColsed(ChannelHandlerContext ctx,
                                       WebSocketFrame frame) {
        if (webSocketServerHandshaker != null) {
            WebSocketClientCache.removeChannelHandlerContext(ctx.channel().id().asLongText());
            webSocketServerHandshaker.close(ctx.channel(), (CloseWebSocketFrame) frame
                    .retain());
        }
    }


    public WebSocketMessageHandler getWebSocketMessageHandler() {
        return webSocketMessageHandler;
    }

    public void setWebSocketMessageHandler(WebSocketMessageHandler webSocketMessageHandler) {
        this.webSocketMessageHandler = webSocketMessageHandler;
    }
}
package com.hikvision.energy.websocket.handler.impl;

import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.hikvision.energy.core.util.log.LogUtils;
import com.hikvision.energy.websocket.cache.WebSocketClientCache;
import com.hikvision.energy.websocket.handler.WebSocketMessageHandler;

import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketFrame;

import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.math.NumberUtils;

/**
 * Created by zhuangjiesen on 2017/8/9.
 */

public class TextFrameWebSocketMessageHandler implements WebSocketMessageHandler {
    
    @Override
    public void onMessage(WebSocketFrame webSocketFrame, ChannelHandlerContext ctx) {
        TextWebSocketFrame textWebSocketFrame = (TextWebSocketFrame) webSocketFrame;
        String message = textWebSocketFrame.text();
        if (StringUtils.isNotBlank(message)) {
            try {
                
                Channel channel = ctx.channel();
                String channelId = channel.id().asLongText();
                
                String[] splitArr = message.split("=");
                if (splitArr != null && splitArr.length == 2) {
                    
                    String trackTagIdStr = splitArr[1];
                    if (NumberUtils.isDigits(trackTagIdStr)) {
                        //在cache 的 map 中存 通道id 和标签的id
                        WebSocketClientCache.registTrackTagId(channelId, NumberUtils.toLong(trackTagIdStr));
                    } 
                } else {
                    //删除 tag Id
                    WebSocketClientCache.removeTrackTagId(channelId);
                }
            } catch (Exception e) {
                LogUtils.logError(this, e, "analyze message error", message);
            }
        }
        LogUtils.logDebug(TextFrameWebSocketMessageHandler.class , "textWebSocketFrame : " + ((TextWebSocketFrame) webSocketFrame).text());

        

    }
}

 

posted @ 2018-05-22 20:29  灬Silence灬  阅读(271)  评论(0编辑  收藏  举报