手写简单的RPC(一)(Socket版)
1. 创建RPC注解
引入Hessian依赖
<dependency>
<groupId>com.caucho</groupId>
<artifactId>hessian</artifactId>
<version>4.0.63</version>
</dependency>
@Target(value = ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
public @interface Rpc {
}
2.创建请求参数对象
public class RpcRequest implements Serializable {
// 全限定类名
private String className;
// 方法名称
private String methodName;
// 参数类型
private Class<?>[] paramsType;
// 参数具体的值
private Object[] params;
//Getter and Setter
}
3.创建响应参数对象
public class RpcResponse implements Serializable {
// 错误信息
private Throwable error;
// 结果集
private Object result;
//Getter and Setter
}
4.创建服务提供方
public class RpcServer {
private List<String> classNames=new ArrayList<String>();
private Boolean flag=true;
/**
*@Description: 扫包目录
*@Params [packageName]
*@Return void
*@Author Mr.Li
*@Date 2020/4/26 16:28
*/
private void scanPackage(String packageName){
URL url = this.getClass().getClassLoader().getResource(replaceTo(packageName));
// 获取文件夹
String urlFile = url.getFile();
File file = new File(urlFile);
String[] list = file.list();
for (String path : list) {
File eachFile = new File(urlFile + File.separator + path);
// 如果是文件夹
if (eachFile.isDirectory())
scanPackage(packageName+"."+eachFile.getName());
// 如果是字节码文件
if (eachFile.getName().endsWith(".class"))
classNames.add(packageName+"."+eachFile.getName().replace(".class",""));
}
}
private String replaceTo(String packageName) {
return packageName.replaceAll("\\.","/");
}
public Map<String,Object> getService(String packageNames){
try{
// 封装所有服务提供者的service key:接口名的全限定类名 value:服务的具体实现类
Map<String,Object> map=new ConcurrentHashMap<String, Object>();
// 扫描指定的包 class的全限定类名封装在了classNames集合中
String[] packageName = packageNames.split(",");
for (String path : packageName) {
scanPackage(path);
}
// 实例化服务的提供者
if (classNames!=null)
for (String className : classNames) {
Class clazz = null;
clazz = Class.forName(className);
if (clazz.isAnnotationPresent(Rpc.class)){
// 获取接口
Class[] interfaces = clazz.getInterfaces();
// 实例化
Object o = clazz.newInstance();
// 将接口和对象放入map集合中
map.put(interfaces[0].getName(),o);
}
}
return map;
} catch (Exception e) {
e.printStackTrace();
throw new RuntimeException("服务器初始化失败");
}
}
public void start(int port,String packageName){
ServerSocket serverSocket=null;
ThreadPoolExecutor executor=null;
try{
Map<String, Object> service = getService(packageName);
// 创建ServerSocket链接
serverSocket = new ServerSocket(port);
// 创建线程池
executor=new ThreadPoolExecutor(5,10,60, TimeUnit.SECONDS,new ArrayBlockingQueue<Runnable>(10));
while (flag){
// 获取客户端链接
Socket accept = serverSocket.accept();
RpcServiceRunnable server = new RpcServiceRunnable(accept, service);
// 使用线程池执行任务
executor.execute(server);
}
} catch (IOException e) {
e.printStackTrace();
}finally {
CloseUtils.close(serverSocket,executor,null,null, null,null);
}
}
private void stop(){
this.flag=false;
}
}
5. RpcServiceRunnable(具体业务逻辑)
public class RpcServiceRunnable implements Runnable{
private Socket client;
private Map<String, Object> service;
public RpcServiceRunnable(Socket client, Map<String, Object> service) {
this.client=client;
this.service=service;
}
public void run() {
InputStream in=null;
OutputStream out=null;
HessianInput hessianInput=null;
HessianOutput hessianOutput=null;
RpcResponse response=new RpcResponse();
try {
// 获取请求的流信息
in=client.getInputStream();
hessianInput=new HessianInput(in);
// 准备响应相关的流和序列化技术
out=client.getOutputStream();
hessianOutput=new HessianOutput(out);
// 获取请求的具体信息,转换成RPCRequest对象
Object object = hessianInput.readObject();
if (!(object instanceof RpcRequest)){
setError(response,hessianOutput,"非法参数");
return;
}
// 请求信息对象
RpcRequest request= (RpcRequest) object;
// 找到接口的实现类
Object impl = service.get(request.getClassName());
if (impl==null){
setError(response,hessianOutput,"没有对应的实体类");
return;
}
Method method = impl.getClass().getMethod(request.getMethodName(), request.getParamsType());
if (method==null){
setError(response,hessianOutput,"没有找到对应的方法");
return;
}
// 返回结果
Object result = method.invoke(impl, request.getParams());
response.setResult(result);
hessianOutput.writeObject(response);
} catch (Exception e) {
e.printStackTrace();
}finally {
CloseUtils.close(null,null,in,out,hessianInput,hessianOutput);
}
}
private void setError(RpcResponse response,HessianOutput out,String error){
response.setError(new Exception(error));
try {
out.writeObject(response);
} catch (IOException e) {
e.printStackTrace();
}
}
}
6.CloseUtils
public class CloseUtils {
public static void close(ServerSocket serverSocket, ThreadPoolExecutor executor, InputStream in, OutputStream out, HessianInput input, HessianOutput output) {
if (serverSocket!=null) {
try {
serverSocket.close();
} catch (IOException e) {
e.printStackTrace();
}
}
if (executor!=null)
executor.shutdown();
if (in!=null) {
try {
in.close();
} catch (IOException e) {
e.printStackTrace();
}
}
if (out!=null) {
try {
out.close();
} catch (IOException e) {
e.printStackTrace();
}
}
if (input!=null){
try {
in.close();
} catch (IOException e) {
e.printStackTrace();
}
}
if (output!=null){
try {
output.close();
} catch (IOException e) {
e.printStackTrace();
}
}
}
}
7.客户端(远程调用方)
public class RpcClient {
public Object start(RpcRequest request,String host,int port) throws Exception {
InputStream in=null;
OutputStream out=null;
HessianInput hessianInput=null;
HessianOutput hessianOutput=null;
try{
Socket socket = new Socket(host, port);
// 封装请求信息
out=socket.getOutputStream();
hessianOutput=new HessianOutput(out);
// 发起请求
hessianOutput.writeObject(request);
// 获取响应信息
in=socket.getInputStream();
hessianInput=new HessianInput(in);
System.out.println("获取响应信息");
Object object = hessianInput.readObject();
System.out.println("object = " + object);
if (!(object instanceof RpcResponse)){
throw new RuntimeException("服务器参数格式不正确");
}
RpcResponse response= (RpcResponse) object;
// 判断服务器是否存在异常
if (response.getError()!=null)
throw new RuntimeException("服务器繁忙");
return response.getResult();
} catch (Exception e) {
e.printStackTrace();
throw new RuntimeException("客户端发生异常");
}finally {
CloseUtils.close(null,null,in,out,hessianInput,hessianOutput);
}
}
}
8.创建客户端代理对象
public class RpcClientProxy implements InvocationHandler {
private String host;
private Integer port;
public RpcClientProxy(String host, Integer port) {
this.host = host;
this.port = port;
}
/**
*@Description: 创建接口的代理对象
*@Params [clazz]
*@Return T
*@Author Mr.Li
*@Date 2020/4/26 21:20
*/
public <T> T getProxyObject(Class<T> clazz){
return (T) Proxy.newProxyInstance(clazz.getClassLoader(),new Class<?>[]{clazz},this);
}
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
RpcRequest request=new RpcRequest();
request.setClassName(method.getDeclaringClass().getName());
request.setMethodName(method.getName());
request.setParamsType(method.getParameterTypes());
request.setParams(args);
// 使用客户端发起请求
RpcClient rpc=new RpcClient();
return rpc.start(request,host,port);
}
}
9.测试实体类
public class Item implements Serializable {
private Long id;
private String name;
// Getter and Setter
}
10.测试类接口
public interface TestServer {
public Item hello(Long id);
}
11.测试类实现
@Rpc
public class TestServiceImpl implements TestServer {
public Item hello(Long id) {
System.out.println("id = " + id);
return new Item(1L,"小米手机");
}
}
12.服务提供方
public class ServiceClient {
public static void main(String[] args) {
RpcServer server=new RpcServer();
server.start(8888,"com.MyRpc.service");
System.out.println();
}
}
13.服务调用方
public class ClientTest {
public static void main(String[] args) {
RpcClientProxy proxy=new RpcClientProxy("127.0.0.1",8888);
TestServer proxyObject = proxy.getProxyObject(TestServer.class);
System.out.println(proxyObject.hello(3L));
}
}