分享一个jpa的where条件生成工具类

jpa中自己拼接where条件写一大堆Predicate和CriteriaBuilder进行拼接写一大堆if-else 会有大量代码冗余,于是出现了帮助类来解决where条件生成的代码

jpa的where条件生成帮助类为两个:
Query用来指定使用什么方式查询;QueryHelp用来生成条件语句。

Query 类:

package com.starcity.utils;

import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.ObjectUtil;
import com.starcity.annotation.Query;
import lombok.extern.slf4j.Slf4j;
import org.hibernate.query.criteria.internal.ValueHandlerFactory;

import javax.persistence.criteria.*;
import java.lang.reflect.Field;
import java.util.*;
import java.util.stream.Collectors;

import static java.util.stream.Collectors.joining;

/**
 * @author WeiMaoMao
 * @date
 */
@Slf4j
@SuppressWarnings({"unchecked", "all"})
public class QueryHelp {

    public static <R, Q> Predicate getPredicate(Root<R> root, Q query, CriteriaBuilder cb) {
        List<Predicate> list = new ArrayList<>();

        if (query == null) {
            return cb.and(list.toArray(new Predicate[0]));
        }
        try {
            List<Field> fields = getAllFields(query.getClass(), new ArrayList<>());
            for (Field field : fields) {
                boolean accessible = field.isAccessible();
                field.setAccessible(true);
                Query q = field.getAnnotation(Query.class);
                if (q != null) {
                    String propName = q.propName();
                    String joinName = q.joinName();
                    String blurry = q.blurry();
                    String attributeName = isBlank(propName) ? field.getName() : propName;
                    Class<?> fieldType = field.getType();
                    Object val = field.get(query);
                    if (ObjectUtil.isNull(val) || "".equals(val)) {
                        continue;
                    }
                    Join join = null;
                    // 模糊多字段
                    if (ObjectUtil.isNotEmpty(blurry)) {
                        String[] blurrys = blurry.split(",");
                        List<Predicate> orPredicate = new ArrayList<>();
                        for (String s : blurrys) {
                            orPredicate.add(cb.like(root.get(s)
                                    .as(String.class), "%" + val.toString() + "%"));
                        }
                        Predicate[] p = new Predicate[orPredicate.size()];
                        list.add(cb.or(orPredicate.toArray(p)));
                        continue;
                    }
                    if (ObjectUtil.isNotEmpty(joinName)) {
                        String[] joinNames = joinName.split(">");
                        for (String name : joinNames) {
                            switch (q.join()) {
                                case LEFT:
                                    if (ObjectUtil.isNotNull(join)) {
                                        join = join.join(name, JoinType.LEFT);
                                    } else {
                                        join = root.join(name, JoinType.LEFT);
                                    }
                                    break;
                                case RIGHT:
                                    if (ObjectUtil.isNotNull(join)) {
                                        join = join.join(name, JoinType.RIGHT);
                                    } else {
                                        join = root.join(name, JoinType.RIGHT);
                                    }
                                    break;
                                default:
                                    break;
                            }
                        }
                    }
                    switch (q.type()) {
                        case EQUAL:
                            list.add(cb.equal(getExpression(attributeName, join, root)
                                    .as((Class<? extends Comparable>) fieldType), val));
                            break;
                        case GREATER_THAN:
                            list.add(cb.greaterThanOrEqualTo(getExpression(attributeName, join, root)
                                    .as((Class<? extends Comparable>) fieldType), (Comparable) val));
                            break;
                        case LESS_THAN:
                            list.add(cb.lessThanOrEqualTo(getExpression(attributeName, join, root)
                                    .as((Class<? extends Comparable>) fieldType), (Comparable) val));
                            break;
                        case LESS_THAN_NQ:
                            list.add(cb.lessThan(getExpression(attributeName, join, root)
                                    .as((Class<? extends Comparable>) fieldType), (Comparable) val));
                            break;
                        case INNER_LIKE:
                            list.add(cb.like(getExpression(attributeName, join, root)
                                    .as(String.class), "%" + val.toString() + "%"));
                            break;
                        case LEFT_LIKE:
                            list.add(cb.like(getExpression(attributeName, join, root)
                                    .as(String.class), "%" + val.toString()));
                            break;
                        case RIGHT_LIKE:
                            list.add(cb.like(getExpression(attributeName, join, root)
                                    .as(String.class), val.toString() + "%"));
                            break;
                        case IN:
                            if (CollUtil.isNotEmpty((Collection<Long>) val)) {
                                list.add(getExpression(attributeName, join, root).in((Collection<Long>) val));
                            }
                            break;
                        case NOT_EQUAL:
                            list.add(cb.notEqual(getExpression(attributeName, join, root), val));
                            break;
                        case NOT_NULL:
                            list.add(cb.isNotNull(getExpression(attributeName, join, root)));
                            break;
                        case NULL:
                            if((boolean)val){
                                list.add(cb.isNull(getExpression(attributeName, join, root)));
                            }
                            break;
                        case BETWEEN:
                            List<Object> between = new ArrayList<>((List<Object>) val);
                            list.add(cb.between(getExpression(attributeName, join, root).as((Class<? extends Comparable>) between.get(0).getClass()),
                                    (Comparable) between.get(0), (Comparable) between.get(1)));
                            break;
                        default:
                            break;
                    }
                }
                field.setAccessible(accessible);
            }
        } catch (Exception e) {
            log.error(e.getMessage(), e);
        }
        int size = list.size();
        return cb.and(list.toArray(new Predicate[size]));
    }

    @SuppressWarnings("unchecked")
    private static <T, R> Expression<T> getExpression(String attributeName, Join join, Root<R> root) {
        if (ObjectUtil.isNotEmpty(join)) {
            return join.get(attributeName);
        } else {
            return root.get(attributeName);
        }
    }

    private static boolean isBlank(final CharSequence cs) {
        int strLen;
        if (cs == null || (strLen = cs.length()) == 0) {
            return true;
        }
        for (int i = 0; i < strLen; i++) {
            if (!Character.isWhitespace(cs.charAt(i))) {
                return false;
            }
        }
        return true;
    }

    private static List<Field> getAllFields(Class clazz, List<Field> fields) {
        if (clazz != null) {
            fields.addAll(Arrays.asList(clazz.getDeclaredFields()));
            getAllFields(clazz.getSuperclass(), fields);
        }
        return fields;
    }


    public static <Q> StringBuffer buildSqlWhere(Q query, StringBuffer sql) {

        if (query == null) {
            return sql;
        }
        try {
            List<Field> fields = getAllFields(query.getClass(), new ArrayList<>());
            for (Field field : fields) {
                boolean accessible = field.isAccessible();
                field.setAccessible(true);
                Query q = field.getAnnotation(Query.class);
                if (q != null) {
                    String propName = q.propName();
                    String joinName = q.joinName();
                    String blurry = q.blurry();
                    String attributeName = isBlank(propName) ? field.getName() : propName;
                    Class<?> fieldType = field.getType();
                    Object val = field.get(query);
                    if (ObjectUtil.isNull(val) || "".equals(val)) {
                        continue;
                    }
                    // 模糊多字段
                    if (ObjectUtil.isNotEmpty(blurry)) {
                        String[] blurrys = blurry.split(",");
                        sql.append(" and (");
                        for (int i = 0; i < blurrys.length; i++) {
                            if (i != 0 && i < blurrys.length) {
                                sql.append(" or ");
                            }
                            sql.append(" " + blurrys[i] + " like '%" + val.toString() + "%'");
                        }
                        sql.append(")");
                        continue;
                    }
                    if (ObjectUtil.isNotEmpty(joinName)) {
                        String[] joinNames = joinName.split(">");
                        for (String name : joinNames) {
                            switch (q.join()) {
                                case LEFT:
                                    sql.append(" left join " + name);
                                    break;
                                case RIGHT:
                                    sql.append(" right join " + name);
                                    break;
                                default:
                                    break;
                            }
                        }
                    }
                    switch (q.type()) {
                        case EQUAL:
                            checkStringAndNum(fieldType, sql, "=", attributeName, val);
                            break;
                        case GREATER_THAN:
                            checkStringAndNum(fieldType, sql, ">=", attributeName, val);
                            break;
                        case LESS_THAN:
                            checkStringAndNum(fieldType, sql, "<=", attributeName, val);
                            break;
                        case LESS_THAN_NQ:
                            checkStringAndNum(fieldType, sql, "<", attributeName, val);
                            break;
                        case INNER_LIKE:
                            sql.append(" and " + attributeName + " like '%" + val.toString() + "%' ");
                            break;
                        case LEFT_LIKE:
                            sql.append(" and " + attributeName + " like '%" + val.toString() + "' ");
                            break;
                        case RIGHT_LIKE:
                            sql.append(" and " + attributeName + " like '" + val.toString() + "%' ");
                            break;
                        case IN:
                            if (CollUtil.isNotEmpty((Collection<Object>) val)) {
                                Collection<Object> list = (Collection<Object>) val;
                                Object o = list.stream().findFirst().orElse(null);
                                if (ValueHandlerFactory.isNumeric(o.getClass())) {
                                    Collection<Number> numberList = (Collection<Number>) val;
                                    sql.append(" and " + attributeName + " in (" + list.stream().map(item->item+"").collect(joining(",")) + ")");
                                } else {
                                    Collection<String> strList = (Collection<String>) val;
                                    sql.append(" and " + attributeName + " in ('" + strList.stream().collect(Collectors.joining("','")) + "')");
                                }
                            }
                            break;
                        case NOT_EQUAL:
                            checkStringAndNum(fieldType, sql, "!=", attributeName, val);
                            break;
                        case NOT_NULL:
                            sql.append(" and " + attributeName + " IS NOT NULL ");
                            break;
                        case NULL:
                            if((boolean)val){
                                sql.append(" and " + attributeName + " IS NULL ");
                            }
                            break;
                        case BETWEEN:
                            List<Object> between = new ArrayList<>((List<Object>) val);
                            if (ValueHandlerFactory.isNumeric(fieldType)) {
                                sql.append(" and " + attributeName + " BETWEEN " + between.get(0) + " AND " + between.get(1));
                                break;
                            }
                            sql.append(" and " + attributeName + " BETWEEN '" + between.get(0) + "' AND '" + between.get(1) + "' ");
                            break;
                        default:
                            break;
                    }
                }
                field.setAccessible(accessible);
            }
        } catch (Exception e) {
            log.error(e.getMessage(), e);
        }
        return sql;
    }


    /**
     * 仅仅构建where的部分
     * @param query
     * @param sql
     * @param <Q>
     * @return
     */
    public static <Q> StringBuffer buildSqlOnlyWhere(Q query) {
        StringBuffer sql=new StringBuffer();

        if (query == null) {
            return sql;
        }
        try {
            List<Field> fields = getAllFields(query.getClass(), new ArrayList<>());
            for (Field field : fields) {
                boolean accessible = field.isAccessible();
                field.setAccessible(true);
                Query q = field.getAnnotation(Query.class);
                if (q != null) {
                    String propName = q.propName();
                    String joinName = q.joinName();
                    String blurry = q.blurry();
                    String attributeName = isBlank(propName) ? field.getName() : propName;
                    Class<?> fieldType = field.getType();
                    Object val = field.get(query);
                    if (ObjectUtil.isNull(val) || "".equals(val)) {
                        continue;
                    }
                    // 模糊多字段
                    if (ObjectUtil.isNotEmpty(blurry)) {
                        String[] blurrys = blurry.split(",");
                        sql.append(" and (");
                        for (int i = 0; i < blurrys.length; i++) {
                            if (i != 0 && i < blurrys.length) {
                                sql.append(" or ");
                            }
                            sql.append(" " + blurrys[i] + " like '%" + val.toString() + "%'");
                        }
                        sql.append(")");
                        continue;
                    }
                    if (ObjectUtil.isNotEmpty(joinName)) {
                        String[] joinNames = joinName.split(">");
                        for (String name : joinNames) {
                            switch (q.join()) {
                                case LEFT:
                                    sql.append(" left join " + name);
                                    break;
                                case RIGHT:
                                    sql.append(" right join " + name);
                                    break;
                                default:
                                    break;
                            }
                        }
                    }
                    switch (q.type()) {
                        case EQUAL:
                            checkStringAndNum(fieldType, sql, "=", attributeName, val);
                            break;
                        case GREATER_THAN:
                            checkStringAndNum(fieldType, sql, ">=", attributeName, val);
                            break;
                        case LESS_THAN:
                            checkStringAndNum(fieldType, sql, "<=", attributeName, val);
                            break;
                        case LESS_THAN_NQ:
                            checkStringAndNum(fieldType, sql, "<", attributeName, val);
                            break;
                        case INNER_LIKE:
                            sql.append(" and " + attributeName + " like '%" + val.toString() + "%' ");
                            break;
                        case LEFT_LIKE:
                            sql.append(" and " + attributeName + " like '%" + val.toString() + "' ");
                            break;
                        case RIGHT_LIKE:
                            sql.append(" and " + attributeName + " like '" + val.toString() + "%' ");
                            break;
                        case IN:
                            if (CollUtil.isNotEmpty((Collection<String>) val)) {
                                if (ValueHandlerFactory.isNumeric(fieldType)) {
                                    sql.append(" and " + attributeName + " in (" + ((Collection<String>) val).stream().collect(joining(",")) + ")");
                                } else {
                                    sql.append(" and " + attributeName + " in ('" + ((Collection<String>) val).stream().collect(joining("','")) + "')");
                                }
                            }
                            break;
                        case NOT_EQUAL:
                            checkStringAndNum(fieldType, sql, "!=", attributeName, val);
                            break;
                        case NOT_NULL:
                            sql.append(" and " + attributeName + " IS NOT NULL ");
                            break;
                        case NULL:
                            if((boolean)val){
                                sql.append(" and " + attributeName + " IS NULL ");
                            }
                            break;
                        case BETWEEN:
                            List<Object> between = new ArrayList<>((List<Object>) val);
                            if (ValueHandlerFactory.isNumeric(fieldType)) {
                                sql.append(" and " + attributeName + " BETWEEN " + between.get(0) + " AND " + between.get(1));
                                break;
                            }
                            sql.append(" and " + attributeName + " BETWEEN '" + between.get(0) + "' AND '" + between.get(1) + "' ");
                            break;
                        default:
                            break;
                    }
                }
                field.setAccessible(accessible);
            }
        } catch (Exception e) {
            log.error(e.getMessage(), e);
        }
        return sql;
    }
    /**
     * 判断字段类型并加入引号
     *
     * @param fieldType     字段类型
     * @param sql           总体sql
     * @param type          = >= <=
     * @param attributeName 字段名称
     * @param val           字段值
     */
    private static void checkStringAndNum(Class<?> fieldType, StringBuffer sql, String type, String attributeName, Object val) {
        if (attributeName.equals("enabled")) {
            System.out.println(ValueHandlerFactory.isBoolean(val));
        }
        //如果是数字类型
        if (ValueHandlerFactory.isNumeric(fieldType)) {
            sql.append(" and " + attributeName + " " + type + " " + val);
        }//如果是boolean
        else if (ValueHandlerFactory.isBoolean(val)) {
            sql.append(" and " + attributeName + " " + type + " " + ((Boolean) val ? 1 : 0) + "");
        }//字符
        else {
            sql.append(" and " + attributeName + " " + type + " '" + val + "' ");
        }
    }

}

QueryHelp上面有三个方法:
第一个方法是jpa自身调用时引用
第二个方法是使用entityManager.createNativeQuery调用
第三个方法是只生成where条件,用于自己拼接

举例说明:
第一种:
在这里插入图片描述在这里插入图片描述
第二种:
在这里插入图片描述
在这里插入图片描述
第三种不出示例喽。

### Java 代码生成工具概述 在Java开发领域,代码生成工具能够显著提升开发效率并确保代码质量的一致性。这些工具通过自动生成常见的代码结构和逻辑,减少了手动编写重复代码的需求。 #### JHipster JHipster 是一种流行的全栈式开发平台,它不仅支持前端技术的选择(如 Angular, React 或 Vue.js),还提供了丰富的后端服务生成功能。对于 Spring Boot 和微服务架构的支持尤为突出[^1]。 #### MyBatis Generator 专门针对数据库操作层的代码生成解决方案之一就是 MyBatis Generator (MBG),它可以依据现有的数据库表结构来自动生成相应的 DAO 接口及其实现类、Model 类等基础组件,极大地方便了开发者快速搭建数据访问层。 #### OpenAPI Generator 当涉及到 RESTful API 的设计与实现时,OpenAPI Generator 成为了理想之选。该工具可以从定义好的 OpenAPI 规范文件中提取信息,并据此创建客户端 SDKs 及服务器 stubs,在前后端分离项目里特别有用处。 #### Spring Roo 作为早期出现的一个命令行驱动型 IDE 插件形式存在的 Rapid Application Development 工具——Spring Roo,则更侧重于简化企业级应用的整体构建过程。尽管其活跃度有所下降但仍保留了一定的应用场景价值。 #### Lombok 虽然严格意义上不属于传统意义上的“代码生成”,但 Project Lombok 提供了一些编译期注解处理器的功能,可以在不改变源码的情况下向程序注入额外的方法体或字段声明等内容,从而达到减少样板代码的目的。 #### 自动化全能型生成器 存在一些综合性的代码生成工具,这类工具可以根据不同的输入格式(JSON、SQL语句或是已有的实体类)自动推断目标语言特性,并批量生产出带有基本 CRUD 功能的完整业务对象模型及相关持久化方法[^2]。 #### ORM框架兼容的高级定制化生成器 某些特定场合下可能需要更加灵活可控的方式来进行代码合成工作,这时可以选择那些基于模板引擎打造而成的产品。它们允许用户按照个人喜好调整最终产出物的形式样式,同时兼顾主流ORM映射库之间的协作需求,比如 Hibernate/JPA、MyBatis/MyBatis Plus 等[^4]。 ```java // 示例:使用Template Engine生成DAO接口 public interface UserMapper { @Select("SELECT * FROM users WHERE id=#{id}") User selectById(Integer id); } ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值