Netty(4) 使用Netty做心跳检测

230 阅读3分钟

[TOC]

背景

使用socket通信时,一般经常会处理多个服务器之间的心跳检测,一般来讲我们去维护服务器集群,肯定要一台或者(多台)服务器主机(Master),然后还应该有N台(Slave),那么我们的主机肯定要时时刻刻知道自己下面的服务器各方面的情况,然后进行实时监控的功能。

说明

  1. Server 接收 Client 的连接,并对Client进行认证
  2. Client 定时向Server发送认证信息进行认证(心跳),这样因为有请求,TCP连接就会断开

1. Server.java

import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;

public class Server {
	public static void main(String[] args) throws Exception{		
		EventLoopGroup pGroup = new NioEventLoopGroup();
		EventLoopGroup cGroup = new NioEventLoopGroup();
		
		ServerBootstrap b = new ServerBootstrap();
		b.group(pGroup, cGroup)
		 .channel(NioServerSocketChannel.class)
		 .option(ChannelOption.SO_BACKLOG, 1024)
		 //设置日志
		 .handler(new LoggingHandler(LogLevel.INFO))
		 .childHandler(new ChannelInitializer<SocketChannel>() {
			protected void initChannel(SocketChannel sc) throws Exception {
				sc.pipeline().addLast(MarshallingCodeCFactory.buildMarshallingDecoder());
				sc.pipeline().addLast(MarshallingCodeCFactory.buildMarshallingEncoder());
				sc.pipeline().addLast(new ServerHeartBeatHandler());
			}
		});
		
		ChannelFuture cf = b.bind(8765).sync();
		
		cf.channel().closeFuture().sync();
		pGroup.shutdownGracefully();
		cGroup.shutdownGracefully();
		
	}
}

2. ServerHeartBeatHandler.java

import java.util.HashMap;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;

public class ServerHeartBeatHandler extends ChannelInboundHandlerAdapter {
   /** key:ip value:auth */
   private static HashMap<String, String> AUTH_IP_MAP = new HashMap<String, String>();
   private static final String SUCCESS_KEY = "auth_success_key";
   
   static {
      AUTH_IP_MAP.put("192.168.192.1", "1234");
   }
   
   private boolean auth(ChannelHandlerContext ctx, Object msg){
         //System.out.println(msg);
         String [] ret = ((String) msg).split(",");
         String auth = AUTH_IP_MAP.get(ret[0]);
         if(auth != null && auth.equals(ret[1])){
            ctx.writeAndFlush(SUCCESS_KEY);
            return true;
         } else {
            ctx.writeAndFlush("auth failure !").addListener(ChannelFutureListener.CLOSE);
            return false;
         }
   }
   
   @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
      if(msg instanceof String){
         auth(ctx, msg);
      } else if (msg instanceof RequestInfo) {
         
         RequestInfo info = (RequestInfo) msg;
         System.out.println("--------------------------------------------");
         System.out.println("当前主机ip为: " + info.getIp());
         System.out.println("当前主机cpu情况: ");
         HashMap<String, Object> cpu = info.getCpuPercMap();
         System.out.println("总使用率: " + cpu.get("combined"));
         System.out.println("用户使用率: " + cpu.get("user"));
         System.out.println("系统使用率: " + cpu.get("sys"));
         System.out.println("等待率: " + cpu.get("wait"));
         System.out.println("空闲率: " + cpu.get("idle"));
         
         System.out.println("当前主机memory情况: ");
         HashMap<String, Object> memory = info.getMemoryMap();
         System.out.println("内存总量: " + memory.get("total"));
         System.out.println("当前内存使用量: " + memory.get("used"));
         System.out.println("当前内存剩余量: " + memory.get("free"));
         System.out.println("--------------------------------------------");
         
         ctx.writeAndFlush("info received!");
      } else {
         ctx.writeAndFlush("connect failure!").addListener(ChannelFutureListener.CLOSE);
      }
    }
}

3. Client.java

import io.netty.bootstrap.Bootstrap;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;

public class Client {
	public static void main(String[] args) throws Exception{
		EventLoopGroup group = new NioEventLoopGroup();
		Bootstrap b = new Bootstrap();
		b.group(group)
		 .channel(NioSocketChannel.class)
		 .handler(new ChannelInitializer<SocketChannel>() {
			@Override
			protected void initChannel(SocketChannel sc) throws Exception {
				sc.pipeline().addLast(MarshallingCodeCFactory.buildMarshallingDecoder());
				sc.pipeline().addLast(MarshallingCodeCFactory.buildMarshallingEncoder());
				sc.pipeline().addLast(new ClienHeartBeatHandler());
			}
		});
		
		ChannelFuture cf = b.connect("127.0.0.1", 8765).sync();

		cf.channel().closeFuture().sync();
		group.shutdownGracefully();
	}
}

4. ClienHeartBeatHandler.java

import java.net.InetAddress;
import java.util.HashMap;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import org.hyperic.sigar.CpuPerc;
import org.hyperic.sigar.Mem;
import org.hyperic.sigar.Sigar;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.util.ReferenceCountUtil;

public class ClienHeartBeatHandler extends ChannelInboundHandlerAdapter {
    private ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(1);
    
    private ScheduledFuture<?> heartBeat;
	//	主动向服务器发送认证信息
    private InetAddress addr ;
    private static final String SUCCESS_KEY = "auth_success_key";

	@Override
	public void channelActive(ChannelHandlerContext ctx) throws Exception {
		addr = InetAddress.getLocalHost();
        String ip = addr.getHostAddress();
        System.err.println(ip);
		String key = "1234";
		String auth = ip + "," + key;
		ctx.writeAndFlush(auth);
	}
	
    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
    	try {
        	if(msg instanceof String){
        		String ret = (String)msg;
        		if(SUCCESS_KEY.equals(ret)){
        	    	// 握手成功,主动发送心跳消息
        	    	this.heartBeat = this.scheduler.scheduleWithFixedDelay(new HeartBeatTask(ctx), 0, 2, TimeUnit.SECONDS);
        		    System.out.println(msg);    			
        		}
        		else {
        			System.out.println(msg);
        		}
        	}
		} finally {
			ReferenceCountUtil.release(msg);
		}
    }

    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
    	cause.printStackTrace();
		if (heartBeat != null) {
		    heartBeat.cancel(true);
		    heartBeat = null;
		}
		ctx.fireExceptionCaught(cause);
    }
    
    private class HeartBeatTask implements Runnable {
    	private final ChannelHandlerContext ctx;

		public HeartBeatTask(final ChannelHandlerContext ctx) {
		    this.ctx = ctx;
		}
	
		@Override
		public void run() {
			try {
			    RequestInfo info = new RequestInfo();
			    //ip
			    info.setIp(addr.getHostAddress());
		        Sigar sigar = new Sigar();
		        //cpu prec
		        CpuPerc cpuPerc = sigar.getCpuPerc();
		        HashMap<String, Object> cpuPercMap = new HashMap<String, Object>();
		        cpuPercMap.put("combined", cpuPerc.getCombined());
		        cpuPercMap.put("user", cpuPerc.getUser());
		        cpuPercMap.put("sys", cpuPerc.getSys());
		        cpuPercMap.put("wait", cpuPerc.getWait());
		        cpuPercMap.put("idle", cpuPerc.getIdle());
		        // memory
		        Mem mem = sigar.getMem();
				HashMap<String, Object> memoryMap = new HashMap<String, Object>();
				memoryMap.put("total", mem.getTotal() / 1024L);
				memoryMap.put("used", mem.getUsed() / 1024L);
				memoryMap.put("free", mem.getFree() / 1024L);
				info.setCpuPercMap(cpuPercMap);
			    info.setMemoryMap(memoryMap);
			    ctx.writeAndFlush(info);
			    
			} catch (Exception e) {
				e.printStackTrace();
			}
		}
	}
}

5. RequestInfo.java

import java.io.Serializable;
import java.util.HashMap;

public class RequestInfo implements Serializable {
   private String ip ;
   private HashMap<String, Object> cpuPercMap ;
   private HashMap<String, Object> memoryMap;
   //.. other field
   
   public String getIp() {
      return ip;
   }
   public void setIp(String ip) {
      this.ip = ip;
   }
   public HashMap<String, Object> getCpuPercMap() {
      return cpuPercMap;
   }
   public void setCpuPercMap(HashMap<String, Object> cpuPercMap) {
      this.cpuPercMap = cpuPercMap;
   }
   public HashMap<String, Object> getMemoryMap() {
      return memoryMap;
   }
   public void setMemoryMap(HashMap<String, Object> memoryMap) {
      this.memoryMap = memoryMap;
   }
}

6. MarshallingCodeCFactory

import io.netty.handler.codec.marshalling.DefaultMarshallerProvider;
import io.netty.handler.codec.marshalling.DefaultUnmarshallerProvider;
import io.netty.handler.codec.marshalling.MarshallerProvider;
import io.netty.handler.codec.marshalling.MarshallingDecoder;
import io.netty.handler.codec.marshalling.MarshallingEncoder;
import io.netty.handler.codec.marshalling.UnmarshallerProvider;
import org.jboss.marshalling.MarshallerFactory;
import org.jboss.marshalling.Marshalling;
import org.jboss.marshalling.MarshallingConfiguration;

/**
 * Marshalling工厂
 */
public final class MarshallingCodeCFactory {
   /**
    * 创建Jboss Marshalling解码器MarshallingDecoder
    * @return MarshallingDecoder
      */
      public static MarshallingDecoder buildMarshallingDecoder() {
      //首先通过Marshalling工具类的精通方法获取Marshalling实例对象 参数serial标识创建的是java序列化工厂对象。
        final MarshallerFactory marshallerFactory = Marshalling.getProvidedMarshallerFactory("serial");
        //创建了MarshallingConfiguration对象,配置了版本号为5 
        final MarshallingConfiguration configuration = new MarshallingConfiguration();
        configuration.setVersion(5);
        //根据marshallerFactory和configuration创建provider
        UnmarshallerProvider provider = new DefaultUnmarshallerProvider(marshallerFactory, configuration);
        //构建Netty的MarshallingDecoder对象,俩个参数分别为provider和单个消息序列化后的最大长度
        MarshallingDecoder decoder = new MarshallingDecoder(provider, 1024 * 1024 * 1);
        return decoder;
      }

   /**

    * 创建Jboss Marshalling编码器MarshallingEncoder
    * @return MarshallingEncoder
      */
      public static MarshallingEncoder buildMarshallingEncoder() {
        final MarshallerFactory marshallerFactory = Marshalling.getProvidedMarshallerFactory("serial");
        final MarshallingConfiguration configuration = new MarshallingConfiguration();
        configuration.setVersion(5);
        MarshallerProvider provider = new DefaultMarshallerProvider(marshallerFactory, configuration);
        //构建Netty的MarshallingEncoder对象,MarshallingEncoder用于实现序列化接口的POJO对象序列化为二进制数组
        MarshallingEncoder encoder = new MarshallingEncoder(provider);
        return encoder;
      }
 }