之前看了《Netty权威指南》一书,第14章用整个章节介绍了如何设计和实现一个简单的私有协议,内容很好,但是作者提供的代码片段有很多错误,根本不可能正确编译。
比如MarshallingEncoder这个类是Netty提供了JBoss Marshalling的一个适配类,它的encode方法是protected,不是public,并且其中用到的ChannelBufferByteOutput类是包类可见,外部无法引用的。Netty只所以这么设计,是因为这个工具类不能直接被外部使用,只是给它内部的ChannelHandler使用的。外部要用的话必须继承它。但是书中的代码直接使用了。。。不知道如何编译通过的。
另外NettyMessageEncoder里面计算Message长度的代码也有不问题,NettyMessageDecoder没有设置lengthAdjustment, 不可能运行成功的,不知道书中的结果截图如何获得的。
只好把书中的代码修改了一下,并且在本地运行成功过了,实现了消息结构的定义,消息的编解码类,JBoss Marshalling Encoder/Decoder的扩展,LoginReqHandler / LoginResHandler, NettyServer, NettyClient。现在把代码贴出来,有需要的同学可以拿去在本地跑一跑。
定义消息NettyMessage和消息头Header
package com.netty.test.netty4;
public class NettyMessage {
private Header header;
private Object body;
public Header getHeader() {
return header;
}
public void setHeader(Header header) {
this.header = header;
}
public Object getBody() {
return body;
}
public void setBody(Object body) {
this.body = body;
}
public String toString(){
return "NettyMessage [header=" + header + "]";
}
}
package com.netty.test.netty4;
import java.util.HashMap;
import java.util.Map;
public class Header {
private int crcCode = 0xabef0101;
private int length;
private long sessionID;
private byte type;
private byte priority;
private Map<String, Object> attachment = new HashMap<String, Object>();
public int getCrcCode() {
return crcCode;
}
public void setCrcCode(int crcCode) {
this.crcCode = crcCode;
}
public int getLength() {
return length;
}
public void setLength(int length) {
this.length = length;
}
public long getSessionID() {
return sessionID;
}
public void setSessionID(long sessionID) {
this.sessionID = sessionID;
}
public byte getType() {
return type;
}
public void setType(byte type) {
this.type = type;
}
public byte getPriority() {
return priority;
}
public void setPriority(byte priority) {
this.priority = priority;
}
public Map<String, Object> getAttachment() {
return attachment;
}
public void setAttachment(Map<String, Object> attachment) {
this.attachment = attachment;
}
public String toString(){
return "Header [crcCode=" + crcCode + ", length=" + length + ", sessionID=" + sessionID
+ ", type=" + type + ", priority=" + priority + ", attachment=" + attachment + "]";
}
}
扩展MarshallingEncoder 和 MarshallingDecoder,将protected方法编程public可以调用
package com.netty.test.netty4;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.marshalling.MarshallerProvider;
import io.netty.handler.codec.marshalling.MarshallingEncoder;
public class NettyMarshallingEncoder extends MarshallingEncoder{
public NettyMarshallingEncoder(MarshallerProvider provider) {
super(provider);
}
public void encode(ChannelHandlerContext ctx, Object msg, ByteBuf out) throws Exception{
super.encode(ctx, msg, out);
}
}
package com.netty.test.netty4;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.marshalling.MarshallingDecoder;
import io.netty.handler.codec.marshalling.UnmarshallerProvider;
public class NettyMarshallingDecoder extends MarshallingDecoder{
public NettyMarshallingDecoder(UnmarshallerProvider provider) {
super(provider);
}
public NettyMarshallingDecoder(UnmarshallerProvider provider, int maxObjectSize){
super(provider, maxObjectSize);
}
public Object decode(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
return super.decode(ctx, in);
}
}
定义MarshallingCodeCFactory工厂类来获取JBoss Marshalling 类
package com.netty.test.netty4;
import io.netty.handler.codec.marshalling.DefaultMarshallerProvider;
import io.netty.handler.codec.marshalling.DefaultUnmarshallerProvider;
import io.netty.handler.codec.marshalling.MarshallerProvider;
import io.netty.handler.codec.marshalling.MarshallingDecoder;
import io.netty.handler.codec.marshalling.UnmarshallerProvider;
import org.jboss.marshalling.MarshallerFactory;
import org.jboss.marshalling.Marshalling;
import org.jboss.marshalling.MarshallingConfiguration;
public class MarshallingCodeCFactory {
public static NettyMarshallingDecoder buildMarshallingDecoder(){
MarshallerFactory marshallerFactory = Marshalling.getProvidedMarshallerFactory("serial");
MarshallingConfiguration configuration = new MarshallingConfiguration();
configuration.setVersion(5);
UnmarshallerProvider provider = new DefaultUnmarshallerProvider(marshallerFactory, configuration);
NettyMarshallingDecoder decoder = new NettyMarshallingDecoder(provider, 1024);
return decoder;
}
public static NettyMarshallingEncoder buildMarshallingEncoder(){
MarshallerFactory marshallerFactory = Marshalling.getProvidedMarshallerFactory("serial");
MarshallingConfiguration configuration = new MarshallingConfiguration();
configuration.setVersion(5);
MarshallerProvider provider = new DefaultMarshallerProvider(marshallerFactory, configuration);
NettyMarshallingEncoder encoder = new NettyMarshallingEncoder(provider);
return encoder;
}
}
定义NettyMessage的Encoder,注意消息长度的计算方法,以及最后把Message传递出去
package com.netty.test.netty4;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.MessageToMessageEncoder;
import java.util.List;
import java.util.Map;
public class NettyMessageEncoder extends MessageToMessageEncoder<NettyMessage>{
private NettyMarshallingEncoder marshallingEncoder;
public NettyMessageEncoder(){
marshallingEncoder = MarshallingCodeCFactory.buildMarshallingEncoder();
}
@Override
protected void encode(ChannelHandlerContext ctx, NettyMessage msg,
List<Object> out) throws Exception {
if(msg == null || msg.getHeader() == null){
throw new Exception("The encode message is null");
}
ByteBuf sendBuf = Unpooled.buffer();
sendBuf.writeInt(msg.getHeader().getCrcCode());
sendBuf.writeInt(msg.getHeader().getLength());
sendBuf.writeLong(msg.getHeader().getSessionID());
sendBuf.writeByte(msg.getHeader().getType());
sendBuf.writeByte(msg.getHeader().getPriority());
sendBuf.writeInt(msg.getHeader().getAttachment().size());
String key = null;
byte[] keyArray = null;
Object value = null;
for(Map.Entry<String, Object> param: msg.getHeader().getAttachment().entrySet()){
key = param.getKey();
keyArray = key.getBytes("UTF-8");
sendBuf.writeInt(keyArray.length);
sendBuf.writeBytes(keyArray);
value = param.getValue();
marshallingEncoder.encode(ctx, value, sendBuf);
}
key = null;
keyArray = null;
value = null;
if(msg.getBody() != null){
marshallingEncoder.encode(ctx, msg.getBody(), sendBuf);
}
// sendBuf.writeInt(0);
// 在第4个字节出写入Buffer的长度
int readableBytes = sendBuf.readableBytes();
sendBuf.setInt(4, readableBytes);
// 把Message添加到List传递到下一个Handler
out.add(sendBuf);
}
}
定义NettyMessageDecoder类,注意设置LengthFieldBasedFrameDecoder的几个重要参数,直接影响到解码的结果
package com.netty.test.netty4;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import java.util.HashMap;
import java.util.Map;
public class NettyMessageDecoder extends LengthFieldBasedFrameDecoder{
private NettyMarshallingDecoder marshallingDecoder;
public NettyMessageDecoder(int maxFrameLength, int lengthFieldOffset,
int lengthFieldLength,int lengthAdjustment, int initialBytesToStrip) {
super(maxFrameLength, lengthFieldOffset, lengthFieldLength, lengthAdjustment, initialBytesToStrip);
marshallingDecoder = MarshallingCodeCFactory.buildMarshallingDecoder();
}
public Object decode(ChannelHandlerContext ctx, ByteBuf in) throws Exception{
ByteBuf frame = (ByteBuf)super.decode(ctx, in);
if(frame == null){
return null;
}
NettyMessage message = new NettyMessage();
Header header = new Header();
header.setCrcCode(frame.readInt());
header.setLength(frame.readInt());
header.setSessionID(frame.readLong());
header.setType(frame.readByte());
header.setPriority(frame.readByte());
int size = frame.readInt();
if(size > 0){
Map<String, Object> attach = new HashMap<String, Object>(size);
int keySize = 0;
byte[] keyArray = null;
String key = null;
for(int i=0; i<size; i++){
keySize = frame.readInt();
keyArray = new byte[keySize];
in.readBytes(keyArray);
key = new String(keyArray, "UTF-8");
attach.put(key, marshallingDecoder.decode(ctx, frame));
}
key = null;
keyArray = null;
header.setAttachment(attach);
}
if(frame.readableBytes() > 0){
message.setBody(marshallingDecoder.decode(ctx, frame));
}
message.setHeader(header);
return message;
}
}
定义LoginAuthReqHandler, 客户端发送请求的业务ChannelHandler
package com.netty.test.netty4;
import io.netty.channel.ChannelHandlerAdapter;
import io.netty.channel.ChannelHandlerContext;
public class LoginAuthReqHandler extends ChannelHandlerAdapter {
public void channelActive(ChannelHandlerContext ctx) throws Exception {
ctx.writeAndFlush(buildLoginReq());
}
public void channelRead(ChannelHandlerContext ctx, Object msg)
throws Exception {
NettyMessage message = (NettyMessage)msg;
if(message.getHeader() != null && message.getHeader().getType() == (byte)2){
System.out.println("Received from server response");
}
ctx.fireChannelRead(msg);
}
private NettyMessage buildLoginReq() {
NettyMessage message = new NettyMessage();
Header header = new Header();
header.setType((byte)1);
message.setHeader(header);
message.setBody("It is request");
return message;
}
public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
ctx.flush();
}
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
ctx.close();
}
}
定义LoginAuthRespHandler类,服务器端响应Login的业务ChannelHandler
package com.netty.test.netty4;
import io.netty.channel.ChannelHandlerAdapter;
import io.netty.channel.ChannelHandlerContext;
public class LoginAuthRespHandler extends ChannelHandlerAdapter {
public void channelRead(ChannelHandlerContext ctx, Object msg)
throws Exception {
NettyMessage message = (NettyMessage)msg;
if(message.getHeader() != null && message.getHeader().getType() == (byte)1){
System.out.println("Login is OK");
String body = (String)message.getBody();
System.out.println("Recevied message body from client is " + body);
}
ctx.writeAndFlush(buildLoginResponse((byte)3));
}
private NettyMessage buildLoginResponse(byte result) {
NettyMessage message = new NettyMessage();
Header header = new Header();
header.setType((byte)2);
message.setHeader(header);
message.setBody(result);
return message;
}
public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
ctx.flush();
}
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
ctx.close();
}
}
定义NettyClient
package com.netty.test.netty4;
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.LineBasedFrameDecoder;
import io.netty.handler.codec.string.StringDecoder;
public class NettyClient {
public void connect(String remoteServer, int port) throws Exception {
EventLoopGroup workerGroup = new NioEventLoopGroup();
try {
Bootstrap b = new Bootstrap();
b.group(workerGroup)
.channel(NioSocketChannel.class)
.handler(new ChildChannelHandler());
ChannelFuture f = b.connect(remoteServer,port).sync();
System.out.println("Netty time Client connected at port " + port);
f.channel().closeFuture().sync();
} finally {
workerGroup.shutdownGracefully();
}
}
public static class ChildChannelHandler extends
ChannelInitializer<SocketChannel> {
@Override
protected void initChannel(SocketChannel ch) throws Exception {
// -8表示lengthAdjustment,让解码器从0开始截取字节,并且包含消息头
ch.pipeline().addLast(new NettyMessageDecoder(1024 * 1024, 4, 4, -8, 0))
.addLast(new NettyMessageEncoder())
.addLast(new LoginAuthReqHandler());
}
}
public static void main(String[] args){
try {
new NettyClient().connect("127.0.0.1", 9080);
} catch (Exception e) {
e.printStackTrace();
}
}
}
定义NettyServer
package com.netty.test.netty4;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.codec.LineBasedFrameDecoder;
import io.netty.handler.codec.string.StringDecoder;
public class NettyServer {
public void bind(int port) throws Exception {
EventLoopGroup bossGroup = new NioEventLoopGroup();
EventLoopGroup workerGroup = new NioEventLoopGroup();
try {
ServerBootstrap b = new ServerBootstrap();
b.group(bossGroup, workerGroup)
.channel(NioServerSocketChannel.class)
.option(ChannelOption.SO_BACKLOG, 1024)
.childHandler(new ChildChannelHandler());
ChannelFuture f = b.bind(port).sync();
System.out.println("Netty time Server started at port " + port);
f.channel().closeFuture().sync();
} finally {
bossGroup.shutdownGracefully();
workerGroup.shutdownGracefully();
}
}
public static class ChildChannelHandler extends
ChannelInitializer<SocketChannel> {
@Override
protected void initChannel(SocketChannel ch) throws Exception {
ch.pipeline().addLast(new NettyMessageDecoder(1024 * 1024, 4, 4, -8, 0))
.addLast(new NettyMessageEncoder())
.addLast(new LoginAuthRespHandler());
}
}
public static void main(String[] args){
try {
new NettyServer().bind(9080);
} catch (Exception e) {
e.printStackTrace();
}
}
}
运行结果:
服务器端:
客户端:
运行依赖的jar包
示例代码下载
http://pan.baidu.com/s/1kT1PwO3