public final class TcpClient {
private static final Logger LOG = LoggerFactory.getLogger(TcpClient.class);
private static final int DEFAULT_TIMEOUT_SECONDS = 5;
// 默认响应超时时间(秒)
private static final String DEFAULT_HOST = "127.0.0.1";
// 默认服务器地址
private static final int DEFAULT_PORT = 8080;
// 默认服务器端口
private final String host;
// 目标服务器地址
private final int port;
// 目标服务器端口
private Channel channel;
// 客户端TCP通道
private EventLoopGroup group;
// 客户端事件循环组
/**
* 使用默认地址(127.0.0.1:8080)初始化客户端
*/
public TcpClient() {
this(DEFAULT_HOST, DEFAULT_PORT);
}
/**
* 指定服务器地址和端口初始化客户端
*
* @param host 服务器IP地址(非空)
* @param port 服务器端口(1-65535)
*/
public TcpClient(String host, int port) {
this.host = host;
this.port = port;
}
/**
* 连接服务器(阻塞直到连接成功或失败)
*
* @throws InterruptedException 线程中断异常(连接过程被中断)
*/
public void connect() throws InterruptedException {
group = new NioEventLoopGroup();
Bootstrap b = new Bootstrap();
b.group(group)
.channel(NioSocketChannel.class)
.handler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel ch) {
// 客户端编解码器需与服务器严格对称(处理接收和发送的消息)
ch.pipeline().addLast(
// 接收方向:解析长度头 → 解码JSON → GreetResponse
new LengthFieldFrameDecoder(),
// 解析4字节长度头(服务器响应)
new JsonResponseDecoder(),
// JSON字节 → GreetResponse对象
new LengthFieldFrameEncoder(),
// 为JSON字节添加4字节长度头
new JsonRequestEncoder()
// GreetRequest对象 → JSON字节
);
}
});
ChannelFuture connectFuture = b.connect(host, port).sync();
if (connectFuture.isSuccess()) {
this.channel = connectFuture.channel();
LOG.info("客户端连接成功,目标地址:{}:{}", host, port);
} else {
LOG.error("客户端连接失败,目标地址:{}:{}", host, port);
throw new IllegalStateException("连接服务器失败");
}
}
/**
* 发送greet请求并同步等待响应(带超时机制)
*
* @param request 要发送的GreetRequest对象(非空)
* @return 服务器返回的GreetResponse对象(超时或失败返回null)
* @throws InterruptedException 线程中断异常(等待过程被中断)
*/
public GreetResponse sendRequest(GreetRequest request) throws InterruptedException {
if (channel == null || !channel.isActive()) {
LOG.error("无法发送请求:通道未激活或未连接");
return null;
}
CountDownLatch responseLatch = new CountDownLatch(1);
TcpClientHandler responseHandler = new TcpClientHandler(responseLatch);
// 动态添加临时处理器(避免多个请求响应交叉)
channel.pipeline().addLast(responseHandler);
try {
// 发送请求(自动触发编码流程:GreetRequest → JSON → 长度头)
channel.writeAndFlush(request).sync();
LOG.debug("已发送greet请求,name={}", request.getName());
// 等待响应(默认超时5秒)
if (responseLatch.await(DEFAULT_TIMEOUT_SECONDS, TimeUnit.SECONDS)) {
return responseHandler.getResponse();
} else {
LOG.warn("等待响应超时,name={}", request.getName());
return null;
}
} finally {
// 移除临时处理器
channel.pipeline().remove(responseHandler);
}
}
/**
* 关闭客户端资源(释放线程组和通道)
*/
public void shutdown() {
if (channel != null) {
channel.close().syncUninterruptibly();
}
if (group != null) {
group.shutdownGracefully();
}
LOG.info("客户端已关闭");
}
/**
* 主方法:客户端使用流程
*
* @param args 命令行参数(未使用)
* @throws InterruptedException 线程中断异常
*/
public static void main(String[] args) throws InterruptedException {
TcpClient client = new TcpClient();
try {
client.connect();
GreetRequest request = new GreetRequest();
request.setName("Netty User");
GreetResponse response = client.sendRequest(request);
if (response != null) {
LOG.info("收到服务器响应:{}", response.getMessage());
} else {
LOG.error("未收到有效响应");
}
} finally {
client.shutdown();
}
}
}public final class TcpClientHandler extends SimpleChannelInboundHandler<GreetResponse> {
private static final Logger LOG = LoggerFactory.getLogger(TcpClientHandler.class);
private final CountDownLatch responseLatch;
// 同步锁(等待响应)
private GreetResponse response;
// 存储服务器响应
public TcpClientHandler(CountDownLatch responseLatch) {
this.responseLatch = responseLatch;
}
@Override
protected void channelRead0(ChannelHandlerContext ctx, GreetResponse response) {
this.response = response;
LOG.debug("收到服务器响应: message={}", response.getMessage());
responseLatch.countDown(); // 响应到达后释放锁
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
LOG.error("客户端通道异常: {}", cause.getMessage());
ctx.close();
}
/**
* 获取服务器返回的响应对象
* @return GreetResponse对象(可能为null)
*/
public GreetResponse getResponse() {
return response;
}
}public class JsonRequestDecoder extends ByteToMessageDecoder {
private static final Logger LOG = LoggerFactory.getLogger(JsonRequestDecoder.class);
private final ObjectMapper objectMapper = new ObjectMapper();
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) {
try {
// 将ByteBuf转换为字节数组
byte[] bytes = new byte[in.readableBytes()];
in.readBytes(bytes);
// 反序列化为GreetRequest对象
GreetRequest request = objectMapper.readValue(bytes, GreetRequest.class);
out.add(request);
LOG.debug("解码请求:{}", request.getName());
} catch (Exception e) {
LOG.error("JSON解码失败", e);
ctx.close(); // 解码失败关闭连接
}
}
}public class JsonRequestEncoder extends MessageToByteEncoder<GreetRequest> {
private final ObjectMapper objectMapper = new ObjectMapper();
@Override
protected void encode(ChannelHandlerContext ctx, GreetRequest request, ByteBuf out) throws Exception {
// 将GreetRequest对象序列化为JSON字节数组
byte[] jsonBytes = objectMapper.writeValueAsBytes(request);
// 将字节数组写入Netty的ByteBuf(必须使用out.writeBytes())
out.writeBytes(jsonBytes);
}
}public class JsonResponseDecoder extends ByteToMessageDecoder {
private final ObjectMapper objectMapper = new ObjectMapper();
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
// 将ByteBuf中的字节转换为JSON字符串
byte[] jsonBytes = new byte[in.readableBytes()];
in.readBytes(jsonBytes);
// 将JSON字符串反序列化为GreetResponse对象
GreetResponse response = objectMapper.readValue(jsonBytes, GreetResponse.class);
out.add(response);
}
}public class JsonResponseEncoder extends MessageToByteEncoder<GreetResponse> {
private static final Logger LOG = LoggerFactory.getLogger(JsonResponseEncoder.class);
private final ObjectMapper objectMapper = new ObjectMapper();
@Override
protected void encode(ChannelHandlerContext ctx, GreetResponse response, ByteBuf out) {
try {
// 序列化为JSON字节数组
byte[] bytes = objectMapper.writeValueAsBytes(response);
out.writeBytes(bytes);
LOG.debug("编码响应:{}", response.getMessage());
} catch (Exception e) {
LOG.error("JSON编码失败", e);
ctx.close(); // 编码失败关闭连接
}
}
}public class LengthFieldFrameDecoder extends LengthFieldBasedFrameDecoder {
private static final int MAX_FRAME_LENGTH = 1024 * 1024;
// 1MB
private static final int LENGTH_FIELD_OFFSET = 0;
// 长度字段偏移量
private static final int LENGTH_FIELD_LENGTH = 4;
// 长度字段占4字节
private static final int LENGTH_ADJUSTMENT = 0;
// 长度字段不包含自身
private static final int INITIAL_BYTES_TO_STRIP = 4;
// 剥离前4字节长度头
public LengthFieldFrameDecoder() {
super(MAX_FRAME_LENGTH, LENGTH_FIELD_OFFSET, LENGTH_FIELD_LENGTH,
LENGTH_ADJUSTMENT, INITIAL_BYTES_TO_STRIP);
}
@Override
protected Object decode(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
// 调用父类方法完成帧解码
return super.decode(ctx, in);
}
}public class LengthFieldFrameEncoder extends LengthFieldPrepender {
public LengthFieldFrameEncoder() {
super(4);
// 长度字段占4字节
}
}public class GreetRequest {
private String name;
/**
* 获取请求中的名称
* @return 名称字符串
*/
public String getName() {
return name;
}
/**
* 设置请求中的名称
* @param name 名称字符串
*/
public void setName(String name) {
this.name = name;
}
}public class GreetResponse {
private String message;
/**
* 获取响应消息内容
* @return 消息字符串
*/
public String getMessage() {
return message;
}
/**
* 设置响应消息内容
* @param message 消息字符串
*/
public void setMessage(String message) {
this.message = message;
}
}public class GreetServerHandler extends SimpleChannelInboundHandler<GreetRequest> {
private static final Logger LOG = LoggerFactory.getLogger(GreetServerHandler.class);
@Override
protected void channelRead0(ChannelHandlerContext ctx, GreetRequest request) {
LOG.info("收到greet请求,name={}", request.getName());
// 生成响应
GreetResponse response = new GreetResponse();
response.setMessage("Hello, " + request.getName() + "!");
ctx.writeAndFlush(response);
// 自动触发编码器添加长度头
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
LOG.error("服务器通道异常: {}", cause.getMessage());
ctx.close();
}
}public class TcpServer {
private static final Logger LOG = LoggerFactory.getLogger(TcpServer.class);
private static final int SERVER_PORT = 8080;
// 服务器端口(常量大写)
private final EventLoopGroup bossGroup;
// Boss线程组(接收连接)
private final EventLoopGroup workerGroup;
// Worker线程组(处理IO)
public TcpServer() {
this.bossGroup = new NioEventLoopGroup(1);
// 单线程Boss组
this.workerGroup = new NioEventLoopGroup();
// 默认线程数(CPU核心数×2)
}
/**
* 启动TCP服务器
* @throws InterruptedException 线程中断异常
*/
public void start() throws InterruptedException {
ServerBootstrap b = new ServerBootstrap();
b.group(bossGroup, workerGroup)
.channel(NioServerSocketChannel.class)
.option(ChannelOption.SO_BACKLOG, 128)
// 连接队列大小
.childOption(ChannelOption.SO_KEEPALIVE, true)
// 保持长连接
.childHandler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel ch) {
// 按顺序添加编解码器和业务处理器
ch.pipeline().addLast(
new LengthFieldFrameDecoder(),
// 解析长度头+拆包
new JsonRequestDecoder(),
// JSON → GreetRequest
new LengthFieldFrameEncoder(),
// 添加长度头
new JsonResponseEncoder(),
// GreetResponse → JSON
new GreetServerHandler()
// 业务处理
);
}
});
ChannelFuture f = b.bind(SERVER_PORT).sync();
LOG.info("TCP服务器启动,监听端口:{}", SERVER_PORT);
f.channel().closeFuture().sync(); // 阻塞等待服务器关闭
}
/**
* 关闭服务器资源
*/
public void shutdown() {
bossGroup.shutdownGracefully();
workerGroup.shutdownGracefully();
LOG.info("服务器已关闭");
}
public static void main(String[] args) throws InterruptedException {
TcpServer server = new TcpServer();
try {
server.start();
} finally {
server.shutdown();
}
}
}请针对这个简易的tcp服务程序写出单元测试,并在pom引入合适版本的依赖
最新发布