【手写框架】1. Mybatis持久层框架

本文深入剖析了自定义Mybatis持久层框架的实现过程,从数据库配置信息、SQL语句抽取,到参数结果集封装及数据库连接池管理,详细介绍了自定义框架的设计思路与代码实现,包括Configuration、MapperStatement等核心组件的定义与使用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

自定义Mybatis持久层框架

分析

直接使用jdbc进行数据库操作存在以下问题

  • 数据库配置信息硬编码
  • sql语句、参数硬编码、结果集硬编码
  • 频繁的创建、销毁数据库资源,造成资源浪费,影响系统性能

为了解决以上问题我们需要使用一定的架构对jdbc操作进行封装,下面是我的实现思路

实现思路

  • 数据库配置信息抽出到配置文件中
  • SQL语句抽出来到到配置文件中
  • 参数结果集使用反射(内省)进行封装
  • 使用数据库连接池来管理数据库资源

代码讲解

服务端需要提供的数据

  1. 数据库操作语句存放到mapper文件中,命名以实体类为前缀命名,例:User的映射文件我们起名userMapper.xml
<configuration>
    <dataSource>
        <property name="driverClass" value="com.mysql.cj.jdbc.Driver"/>
        <property name="jdbcUrl" value="jdbc:mysql://127.0.0.1:3306/test"/>
        <property name="username" value="root"/>
        <property name="password" value="root"/>
    </dataSource>
    <mappers>
        <mapper resource="userMapper.xml"/>
    </mappers>
</configuration>
  1. 数据库配置信息存放到sqlMapconfig.xml文件中,注入mapper文件配置信息
<mapper namespace="com.wtf.mapper.IUserMapper">
    <select id="getUserList" resultType="com.wtf.pojo.User">
        select id,username from users
    </select>

    <select id="getUserById" parameterType="java.lang.Integer" resultType="com.wtf.pojo.User">
        select id,username from users where id = #{id}
    </select>

    <select id="getUser" parameterType="com.wtf.pojo.User" resultType="com.wtf.pojo.User">
        select id,username from users where id = #{id} or username = #{username}
    </select>

    <insert id="addUser" parameterType="com.wtf.pojo.User">
        insert into users(id, username) values(#{id}, #{username})
    </insert>

    <update id="update" parameterType="com.wtf.pojo.User">
        update users(id, username) set username=#{username} where id = #{id}
    </update>

    <delete id="deleteUser" parameterType="java.lang.Integer">
        delete from users where id = #{id}
    </delete>
</mapper>

自定义框架的实现

  1. 定义Resources类,接收配置文件,并生成InputStream
public class Resources {

    public static InputStream getResourceAsStream(String path) {
        log.info("loading file {}", path);
        return Resources.class.getClassLoader().getResourceAsStream(path);
    }
}
  1. 定义SqlSessionFactoryBuilder类,用于创建SqlSessionFactory。其中用到了XmlConfigBuilder类,用于解析数据库配置文件生成Configuration实体类。
public class SqlSessionFactoryBuilder {
    public SqlSessionFactory build(InputStream inputStream) {
        //解析配置文件,生成Configuration
        Configuration configuration = new XmlConfigBuilder().parseConfig(inputStream);
        //返回默认SqlSessionFactory
        return new DefaultSqlSessionFactory(configuration);
    }
}

Configuration实体包含两大块:数据源和一个HashMap,HashMap用于存储Mapper配置文件中的select/update/insert/delete节点的信息

@Data
public class Configuration {
    /**
     * 数据源信息
     */
    private DataSource dataSource;

    /**
     * key:   Statement id
     * value: MapperStatement
     */
    private Map<String, MapperStatement> statementMap = new HashMap<String, MapperStatement>();
}

MapperStatement的结构如下

@Data
public class MapperStatement {
    /**
     * namespace + id
     */
    private String id;

    /**
     * 请求对象
     */
    private Class<?> parameterType;

    /**
     * 返回对象
     */
    private Class<?> resultType;

    /**
     * sql语句
     */
    private String sql;
}
  1. 上面提到SqlSessionFactoryBuilder用于创建SqlSessionFactory,下面我们说下SqlSessionFactory,看名字我们就能看出来它是用来创建SqlSession的。
public interface SqlSessionFactory {
    /**
     * 创建一个sqlSession对象
     *
     * @return SqlSession
     */
    SqlSession openSession();
}

下面我们来看一下它的一个实现类DefaultSqlSessionFactory。

public class DefaultSqlSessionFactory implements SqlSessionFactory {
    private Configuration configuration;

    public DefaultSqlSessionFactory(Configuration configuration) {
        this.configuration = configuration;
    }
    public SqlSession openSession() {
        SimpleExecutor simpleExecutor = new SimpleExecutor();
        return new DefaultSqlSession(configuration, simpleExecutor);
    }
}

DefaultSqlSession的openSession()方法中会去创建一个SqlSession,把Configuration配置类和SimpleExecutor一个执行器传递给DefaultSqlSession。

  1. 下面我们来看下SqlSession接口中有哪些方法
public interface SqlSession {

    /**
     * 获取mapper接口的代理对象
     *
     * @param t mapper接口
     * @return 代理对象
     */
    <T> T getMapper(Class<T> t);

    /**
     * 执行查询list方法
     *
     * @param statementId statementId
     * @param params      请求参数
     * @return 查询结果
     */
    <T> List<T> selectList(String statementId, Object params);

    /**
     * 执行查询单条数据方法
     *
     * @param statementId statementId
     * @param params      请求参数
     * @return 查询结果
     */
    <T> T selectOne(String statementId, Object params);

    /**
     * 插入数据
     *
     * @param statementId statementId
     * @param params      请求参数
     */
    <T> void insert(String statementId, Object params);

    /**
     * 修改数据
     *
     * @param statementId statementId
     * @param params      请求参数
     */
    <T> void update(String statementId, Object params);

    /**
     * 删除数据
     *
     * @param statementId statementId
     * @param params      请求参数
     * @param <T>
     */
    <T> void delete(String statementId, Object params);
}

可以大概看出来,SqlSession主要是进行数据库操作的,另外SqlSession还提供了一个getMapper(Class<T> t)方法,这个方法接收了一个Mapper接口,主要用户对Mapper接口进行代理,实现Mapper接口中定义的方法。我们来看下SqlSession的一个默认实现类DefaultSqlSession

public class DefaultSqlSession implements SqlSession {

    private final Configuration configuration;

    private final Executor executor;

    public DefaultSqlSession(Configuration configuration, Executor executor) {
        this.configuration = configuration;
        this.executor = executor;
    }

    public <T> List<T> selectList(String statementId, Object params) {
        List<Object> query = executor.query(configuration, getMapperStatement(statementId), params);
        return (List<T>) query;
    }

    public <T> T selectOne(String statementId, Object params) {
        List<Object> objects = this.selectList(statementId, params);
        if (objects != null && objects.size() > 1) {
            throw new RuntimeException("查询结果不唯一");
        }
        if (objects == null || objects.size() <= 0) {
            return (T) null;
        }
        return (T) objects.get(0);
    }

    public <T> void insert(String statementId, Object params) {
        executor.update(configuration, getMapperStatement(statementId), params);
    }

    public <T> void update(String statementId, Object params) {
        executor.update(configuration, getMapperStatement(statementId), params);
    }

    public <T> void delete(String statementId, Object params) {
        executor.update(configuration, getMapperStatement(statementId), params);
    }

    public <T> T getMapper(Class<T> t) {
        Object proxyInstance = Proxy.newProxyInstance(this.getClass().getClassLoader(), new Class[]{t}, new InvocationHandler() {
            public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
                //代理对象的方法名
                String name = method.getName();
                //代理对象的类名
                String className = method.getDeclaringClass().getName();
                //statement id
                String statementId = String.format("%s.%s", className, name);

                MapperStatement mappedStatement = configuration.getStatementMap().get(statementId);
                if (mappedStatement == null) {
                    log.error("mappedStatement {} is not found", statementId);
                    throw new RuntimeException("找不到需要执行的方法");
                }
                String sql = mappedStatement.getSql();
                //根据sql语句包含的操作,来指向对象的方法
                if (sql.toLowerCase().contains(SELECT)) {
                    Type returnType = method.getGenericReturnType();
                    if (returnType instanceof ParameterizedType) {
                        return selectList(statementId, args);
                    }
                    return selectOne(statementId, args);
                }
                if (sql.toLowerCase().contains(INSERT)) {
                    insert(statementId, args);
                }
                if (sql.toLowerCase().contains(UPDATE)) {
                    update(statementId, args);
                }
                if (sql.toLowerCase().contains(DELETE)) {
                    delete(statementId, args);
                }
                return null;
            }
        });
        return (T) proxyInstance;
    }
}

通过代码我们可用看到在DefaultSqlSession中,真正执行调用的是Executor接口中的方法。

  1. Executor接口我定义的一个执行器,用来执行与数据的一些交互
public class SimpleExecutor implements Executor {

    public <E> List<E> query(Configuration configuration, MapperStatement mapperStatement, Object params) {
        Connection connection = null;
        try {
            //处理参数,重组SQL
            connection = getConnection(configuration);
            String sql = mapperStatement.getSql();
            //重组sql
            BoundSql boundSql = getBoundSql(sql);
            sql = sql.replaceAll("[\r\n]", "");
            log.info("==>  Preparing: {}", sql.trim());
            PreparedStatement preparedStatement = connection.prepareStatement(boundSql.getSql());

            //如果参数类型是Integer 或者String直接赋值
            setStatementParameter(mapperStatement, boundSql, preparedStatement, params);

            //重组返回对象
            Class<?> resultClass = mapperStatement.getResultType();
            ResultSet resultSet = preparedStatement.executeQuery();
            List<Object> list = new ArrayList<Object>();
            //获取元数据,包含参数的名称
            ResultSetMetaData metaData = resultSet.getMetaData();
            int columnCount = metaData.getColumnCount();
            while (resultSet.next()) {
                Object o = resultClass.newInstance();
                for (int i = 1; i <= columnCount; i++) {
                    //获取参数name
                    String columnName = metaData.getColumnName(i);
                    //获取参数value
                    Object fieldValue = resultSet.getObject(columnName);
                    //创建属性描述器,为属性生成读写方法
                    PropertyDescriptor descriptor = new PropertyDescriptor(columnName, resultClass);
                    Method writeMethod = descriptor.getWriteMethod();
                    writeMethod.invoke(o, fieldValue);
                }
                list.add(o);
            }
            return (List<E>) list;
        } catch (SQLException thenables) {
            thenables.printStackTrace();
        } catch (NoSuchFieldException e) {
            e.printStackTrace();
        } catch (IllegalAccessException e) {
            e.printStackTrace();
        } catch (InstantiationException e) {
            e.printStackTrace();
        } catch (IntrospectionException e) {
            e.printStackTrace();
        } catch (InvocationTargetException e) {
            e.printStackTrace();
        } finally {
            closeSession(connection);
        }
        return null;
    }

    public void update(Configuration configuration, MapperStatement mapperStatement, Object params) {
        Connection connection = null;
        try {
            //处理参数,重组SQL
            connection = getConnection(configuration);
            String sql = mapperStatement.getSql();
            //重组sql
            BoundSql boundSql = getBoundSql(sql);
            sql = sql.replaceAll("[\r\n]", "");
            log.info("==>  Preparing: {}", sql.trim());
            PreparedStatement preparedStatement = connection.prepareStatement(boundSql.getSql());
            //如果参数类型是Integer 或者String直接赋值
            setStatementParameter(mapperStatement, boundSql, preparedStatement, params);
            preparedStatement.executeUpdate();

        } catch (SQLException thenables) {
            thenables.printStackTrace();
        } catch (NoSuchFieldException e) {
            e.printStackTrace();
        } catch (IllegalAccessException e) {
            e.printStackTrace();
        } catch (InstantiationException e) {
            e.printStackTrace();
        } finally {
            closeSession(connection);
        }
    }

    private void setStatementParameter(MapperStatement mapperStatement, BoundSql boundSql, PreparedStatement preparedStatement, Object params) throws SQLException, NoSuchFieldException, IllegalAccessException, InstantiationException {
        Class<?> parameterClass = mapperStatement.getParameterType();
        if (parameterClass == null) {
            return;
        }

        Object[] param = (Object[]) params;
        if (isBaseType(parameterClass)) {
            preparedStatement.setObject(1, param[0]);
            return;
        }
        //设置请求参数
        List<ParameterMapping> parameterList = boundSql.getParameterMappingList();
        StringBuffer sb = new StringBuffer();
        sb.append("==> Parameters:  ");
        for (int i = 0; i < parameterList.size(); i++) {
            ParameterMapping parameterMapping = parameterList.get(i);
            String content = parameterMapping.getContent();
            Field field = parameterClass.getDeclaredField(content);
            field.setAccessible(true);
            preparedStatement.setObject(i + 1, field.get(param[0]));
            sb.append(field.get(param[0])).append(", ");
        }
        log.info(sb.substring(0, sb.length() - 2));
    }

    public static boolean isBaseType(Class<?> clazz) {
        if (clazz.equals(java.lang.Integer.class) || clazz.equals(java.lang.Byte.class) ||
                clazz.equals(java.lang.Long.class) || clazz.equals(java.lang.Double.class) ||
                clazz.equals(java.lang.Float.class) || clazz.equals(java.lang.Character.class) ||
                clazz.equals(java.lang.Short.class) || clazz.equals(java.lang.Boolean.class)) {
            return true;
        }
        return false;
    }

    private Connection getConnection(Configuration configuration) throws SQLException {
        DataSource dataSource = configuration.getDataSource();
        return dataSource.getConnection();
    }

    private void closeSession(Connection connection) {
        try {
            if (connection != null) {
                connection.close();
            }
        } catch (SQLException e) {
            e.printStackTrace();
        }
    }

    /**
     * 完成sql替换以及解析工作
     */
    private BoundSql getBoundSql(String sql) {
        //表里处理类,配置标记解析器完成对占位符的解析工作
        ParameterMappingTokenHandler tokenHandler = new ParameterMappingTokenHandler();
        GenericTokenParser genericTokenParser = new GenericTokenParser("#{", "}", tokenHandler);
        String parse = genericTokenParser.parse(sql);
        List<ParameterMapping> parameterMappings = tokenHandler.getParameterMappings();
        return new BoundSql(parse, parameterMappings);

    }
}

我们可用看到,SimpleExecutor中的query() 方法和update()方法主要做了三件事

  • 封装参数
  • 执行数据库操作
  • 封装结果集

测试

数据准备
create database test;

create table users(id int primary key, username varchar(20), password varchar(10));
创建User类
@Data
public class User {
    private Integer id;

    private String username;
}
创建UserMapper 接口
public interface IUserMapper {

    List<User> getUserList();

    User getUserById(Integer id);

    User getUser(User user);

    void addUser(User user);

    void update(User user);

    void deleteUser(Integer id);
}
测试类
public class UserMapperTest {

    private IUserMapper userMapper;

    @Before
    public void getUserMapper() {
        InputStream inputStream = Resources.getResourceAsStream("sqlMapConfig.xml");
        SqlSessionFactory sessionFactory = new SqlSessionFactoryBuilder().build(inputStream);
        SqlSession sqlSession = sessionFactory.openSession();
        userMapper = sqlSession.getMapper(IUserMapper.class);
    }

    @Test
    public void getUserList() {
        List<User> userList = userMapper.getUserList();
        for (User user : userList) {
            System.out.println(user);
        }
    }

    @After
    public void  getUserById() {
        User user = userMapper.getUserById(10);
        System.out.println(user);
    }

    @Test
    public void  getUser() {
        User user1 = new User();
        user1.setId(1);
        User user = userMapper.getUser(user1);
        System.out.println(user);
    }

    @Test
    public void addUser() {
        User user = new User();
        user.setId(10);
        user.setUsername("wudi");
        userMapper.addUser(user);
    }

    @Test
    public void update() {
        User user = new User();
        user.setId(10);
        user.setUsername("wuaaaadi");
        userMapper.update(user);
    }

    @Test
    public void deleteUser() {
        userMapper.deleteUser(10);
    }
}

至此,我们的自定义框架已经介绍完了。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值