基础unitTest生成

本文主要探讨了如何在Java项目中进行单元测试的自动化生成,包括使用相关工具和库进行unit test的创建和执行,旨在提升开发效率和代码质量。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

import org.springframework.util.ResourceUtils;
import org.springframework.util.StringUtils;

import java.io.*;
import java.lang.annotation.Annotation;
import java.lang.reflect.Constructor;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.math.BigDecimal;
import java.util.*;
import java.util.regex.Pattern;

public class GenerateUnitTestJava {

    public static void main(String[] args) throws IOException {
        Path.PROJECT_NAME = "";
        Delete.delete();
        List<ClassTemplate> l = Scan.scan();
        for (ClassTemplate t : l) {
            Write.write(t);
        }
    }

    private static class Path {
        static String PROJECT_NAME = "";
        static String find() throws FileNotFoundException {
            return String.format("%s/%s", ResourceUtils.getFile("").getAbsolutePath(), PROJECT_NAME);
        }
    }

    private static class Scan {
        static Pattern JAVA_END = Pattern.compile(".+\\.java$");
        static Pattern PACKAGE_PT = Pattern.compile(".*package .*;.*");
        static Pattern STATIC_CLASS_PT = Pattern.compile(".*[ ]+static[ ]+class[ ]+.*");
        static Pattern CLASS_PT = Pattern.compile(".*[ ]+class[ ]+.*");
        static Pattern ENUM_PT = Pattern.compile(".*[ ]+enum[ ]+.*");
        static String[] SP_ARR = {"{", "<", " implements ", " extends "};

        static List<ClassTemplate> scan() throws IOException {
            return scan(Path.find() + "/src/main/java/");
        }

        static List<ClassTemplate> scan(String filePath) throws IOException {
            return scan(new File(filePath));
        }

        static List<ClassTemplate> scan(File f) throws IOException {
            if (null == f) {
                return Collections.emptyList();
            }
            if (f.isDirectory()) {
                File[] files = f.listFiles();
                if (files.length < 1) {
                    return Collections.emptyList();
                }
                List<ClassTemplate> r = new ArrayList<>();
                for (File f0 : files) {
                    r.addAll(scan(f0));
                }
                return r;
            }
            String fn = f.getName();
            if (!JAVA_END.matcher(fn).matches()) {
                return Collections.emptyList();
            }
            List<String> cns = findClassName(f);
            List<ClassTemplate> r = new ArrayList<>();
            for (String cn : cns) {
                try {
                    Class<?> tClass = Class.forName(cn);
                    int mds = tClass.getModifiers();
                    if (tClass.isInterface() || Modifier.isPrivate(mds)
                            || (Modifier.isAbstract(mds) && !tClass.isEnum())) {
                        continue;
                    }
                    r.add(Builder.buildTemplate(tClass));
                } catch (Throwable throwable) {
                    // default
                }
            }
            return r;
        }

        static List<String> findClassName(File f) throws IOException {
            List<String> r = new ArrayList<>();
            BufferedReader br = new BufferedReader(new FileReader(f));
            String t;
            while (null != (t=br.readLine()) && !PACKAGE_PT.matcher(t).matches()) {
                // default;
            }
            if (null == t) {
                br.close();
                return Collections.emptyList();
            }
            int ps = t.indexOf("package ");
            String pn = t.substring(ps + 8, t.indexOf(";",ps));
            String fn = f.getName().substring(0, f.getName().length() - 5);
            r.add(String.format("%s.%s", pn, fn));
            while (null != (t=br.readLine())) {
                if (t.length() < 7) {
                    continue;
                }
                findNormalClass(r, t, pn, fn);
                findStaticClass(r, t, pn, fn);
                findNormalEnum(r, t, pn, fn);
            }
            br.close();
            return r;
        }

        static String getJName(String t, int f) {
            int i = t.length();
            for (String s : SP_ARR) {
                int si = t.indexOf(s, f);
                i = si > 0 && si < i ? si : i;
            }
            return t.substring(f, i).trim();
        }

        private static void findNormalEnum(List<String> r, String t, String pn, String fn) {
            if (!ENUM_PT.matcher(t).matches()) {
                return;
            }
            String cn = getJName(t, t.indexOf(" enum ") + 6);
            if (cn.equals(fn)) {
                return;
            }
            r.add(String.format("%s.%s", pn, cn));
            r.add(String.format("%s.%s$%s", pn, fn, cn));
        }

        private static void findStaticClass(List<String> r, String t, String pn, String fn) {
            if(!STATIC_CLASS_PT.matcher(t).matches()) {
                return;
            }
            int is = t.indexOf(" static ") + 8;
            char[] arr = t.toCharArray();
            int ff = 0;
            for (int i=is; i<(arr.length - 7); i++) {
                if (arr[i] == " ".charAt(0)) {
                    continue;
                }
                if (arr[i] == "c".charAt(0)
                        && arr[i+1] == "l".charAt(0)
                        && arr[i+2] == "a".charAt(0)
                        && arr[i+3] == "s".charAt(0)
                        && arr[i+4] == "s".charAt(0)
                        && arr[i+5] == " ".charAt(0)) {
                    ff = i+6;
                    break;
                }
            }
            if (ff != 0) {
                r.add(String.format("%s.%s$%s", pn, fn, getJName(t, ff)));
            }
        }

        private static void findNormalClass(List<String> r, String t, String pn, String fn) {
            if (!CLASS_PT.matcher(t).matches() || STATIC_CLASS_PT.matcher(t).matches()) {
                return;
            }
            String cn = getJName(t, t.indexOf(" class ") + 7);
            if (cn.equals(fn)) {
                return;
            }
            r.add(String.format("%s.%s", pn, cn));
        }
    }

    private static class Delete {
        static boolean delete() throws IOException {
            return delete(new File(Path.find() + "/src/test/"));
        }

        static boolean delete(File file) throws IOException {
            if (null == file) {
                return false;
            }
            String sbp = file.getAbsolutePath();
            File f = new File(sbp.replace("/src/main/", "/src/test"));
            if (!f.exists()) {
                return false;
            }
            if (f.isDirectory()) {
                File[] files = f.listFiles();
                if (null == file || files.length < 1) {
                    return false;
                }
                for (File f0 : files) {
                    delete(f0);
                }
                return true;
            }
            if (!Builder.END_PATTERN.matcher(f.getName()).matches()) {
                return false;
            }
            BufferedReader br = new BufferedReader(new FileReader(f));
            String t = br.readLine();
            if (!Builder.START_PATTERN.matcher(t).matches()) {
                br.close();
                return false;
            }
            String hashCode = t.substring(Builder.START.length());
            StringBuilder sb = new StringBuilder();
            while ((t=br.readLine()) != null) {
                sb.append(t).append("\r\n");
            }
            br.close();
            if (sb.length() > 2) {
                sb.delete(sb.length()-2, sb.length());
            }
            if (!hashCode.equals(String.valueOf(sb.toString().hashCode()))) {
                return false;
            }
            System.out.println(String.format("delete %s", f.getName()));
            return f.delete();
        }
    }

    private static class Write {
        static boolean write(ClassTemplate ct) throws IOException {
            if (null == ct || (ct.testMethods.isEmpty() && ct.testConstructors.isEmpty())) {
                return false;
            }
            String fdp = String.format("%s/src/test/java/%s/", Path.find(),
                    ct.packageName.replace(".", "/"));
            File fd = new File(fdp);
            if (!fd.exists()) {
                fd.mkdirs();
            }
            String fp = String.format("%s%s.java", fdp, ct.testClassName);
            File f = new File(fp);
            if (f.exists()) {
                return false;
            }
            String s = Builder.build(ct);
            BufferedWriter bw = new BufferedWriter(new FileWriter(f));
            bw.write(String.format("%s%s", Builder.START, s.hashCode()));
            bw.newLine(); bw.flush();
            bw.write(s);bw.flush();
            bw.close();
            System.out.println(String.format("write for %s", ct.tClass.getName()));
            return true;
        }
    }

    private static class Builder {
        static Pattern END_PATTERN = Pattern.compile(".+AutoTest\\.java$");
        static String END = "AutoTest";
        static Pattern START_PATTERN = Pattern.compile("^// auto generate @sometwo\\.fun hashCode=[-]?\\d+$");
        static String START = "// auto generate @sometwo.fun hashCode=";
        static List<String> AUTO_SET_ANN = Arrays.asList(
                "org.springframework.beans.factory.annotation.Autowired",
                "com.sun.jersey.api.core.InjectParam",
                "javax.annotation.Resource"
        );
        static String RUN_WITH_STR = "@RunWith(MockitoJUnitRunner.class)";

        static String build(ClassTemplate ct) {
            StringBuilder sb = new StringBuilder();
            sb.append(String.format("package %s;\r\n\r\n", ct.packageName));
            for (String s : ct.imports) {
                sb.append(String.format("import %s;\r\n", s));
            }
            sb.append("\r\n");
            sb.append(ct.runWithStr).append("\r\n");
            sb.append(String.format("public class %s {\r\n", ct.testClassName));
            if (!ct.testMethods.isEmpty()) {
                if (ct.tClass.isEnum()) {
                    sb.append(String.format("private %s %s = %s;\r\n", ct.typeName, ct.injectPropertyName, DefaultValue.value(ct.tClass)));
                } else {
                    sb.append("\t@InjectMocks\r\n");
                    sb.append(String.format("\tprivate %s %s;\r\n", ct.typeName, ct.injectPropertyName));
                }
            }
            for (Map.Entry<String, Class<?>> p : ct.mockProperties.entrySet()) {
                sb.append("\t@Mock\r\n");
                sb.append(String.format("\tprivate %s %s;\r\n", p.getValue().getSimpleName(), p.getKey()));
            }
            sb.append("\r\n\t@Before\r\n");
            sb.append(String.format("\tpublic void before%s () {\r\n", System.currentTimeMillis()));
            for (String s : ct.beforeMethodLines) {
                sb.append(String.format("\t\t%s\r\n", s));
            }
            sb.append("\t}");
            for (Map.Entry<String, List<String>> m: ct.testConstructors.entrySet()) {
                sb.append("\r\n\r\n\t@Test\r\n");
                sb.append(String.format("\tpublic void %sAutoTest () {\r\n", m.getKey()));
                sb.append("\t\ttry {\r\n");
                for (String s : m.getValue()) {
                    sb.append("\t\t\t").append(s).append("\r\n");
                }
                sb.append("\r\n\t\t} catch (Throwable throwable) {\r\n\t\t\t// default\r\n\t\t}\r\n");
                sb.append("\t}");
            }
            for (Map.Entry<String, List<String>> m : ct.testMethods.entrySet()) {
                sb.append("\r\n\r\n\t@Test\r\n");
                sb.append(String.format("\tpublic void %sAutoTest () {\r\n", m.getKey()));
                sb.append("\t\ttry {\r\n");
                for (String s : m.getValue()) {
                    sb.append("\t\t\t").append(s).append("\r\n");
                }
                sb.append("\r\n\t\t} catch (Throwable throwable) {\r\n\t\t\t// default\r\n\t\t}\r\n");
                sb.append("\t}");
            }
            sb.append("\r\n}");
            return sb.toString();
        }

        static ClassTemplate buildTemplate(Class<?> tClass) {
            ClassTemplate ct = new ClassTemplate();
            ct.tClass = tClass;
            ct.injectPropertyName = StringUtils.uncapitalize(tClass.getSimpleName());
            ct.packageName = tClass.getPackage().getName();
            String typeName = tClass.getName().contains("$") ? tClass.getName().replace("$", ".") : tClass.getSimpleName();
            ct.typeName = typeName.indexOf(ct.packageName) == 0 ? typeName.substring(ct.packageName.length() + 1) : typeName;
            ct.runWithStr = RUN_WITH_STR;
            ct.testClassName = (tClass.getSimpleName() + Builder.END).replace(".", "");
            ct.imports.add("org.junit.*");
            ct.imports.add("org.junit.runner.RunWith");
            ct.imports.add("org.mockito.runners.MockitoJUnitRunner");
            ct.imports.add("org.mockito.*");
            ct.beforeMethodLines.add("MockitoAnnotations.initMocks(this);");
            buildMockProperty(ct);
            buildConstructMethod(ct);
            buildMethod(ct);
            buildSuperMethod(ct);
            filterImport(ct);
            return ct;
        }

        private static void filterImport(ClassTemplate ct) {
            Set<String> ts = new HashSet<>();
            ct.imports.remove(ct.tClass.getName());
            Iterator<String> it = ct.imports.iterator();
            while (it.hasNext()) {
                String s = it.next();
                if (!StringUtils.hasText(s) || "boolean,int,double,short,long,bit,char,float".contains(s)
                        || (s.startsWith("java.lang") && s.split("\\.").length == 3)) {
                    it.remove();
                } else if (s.contains("$")) {
                    ts.add(s.replace("$", "."));
                    it.remove();
                }
            }
            ct.imports.addAll(ts);
        }

        static void buildCommonMethod(ClassTemplate ct, Method dm, String p) {
            StringBuilder s = new StringBuilder(p).append(".").append(dm.getName()).append("(");
            Class<?>[] pts = dm.getParameterTypes();
            for (Class<?> pt : pts) {
                if (!pt.isArray()) {
                    ct.imports.add(pt.getName());
                }
                s.append(DefaultValue.value(pt)).append(",");
                if (!pt.isArray()) {
                    ct.imports.add(pt.getName());
                }
            }
            if (",".charAt(0) == s.charAt(s.length()-1)) {
                s.deleteCharAt(s.length()-1);
            }
            s.append(");");
            String n = dm.getName();
            while (ct.testMethods.containsKey(n)) {
                n += "0";
            }
            ct.testMethods.put(n, Collections.singletonList(s.toString()));
        }

        static void buildSuperMethod(ClassTemplate ct) {
            if (ct.tClass.isEnum()) {
                return;
            }
            Class<?> st = ct.tClass;
            while (!Object.class.equals(st = st.getSuperclass())) {
                int stm = st.getModifiers();
                if (!Modifier.isAbstract(stm)) {
                    continue;
                }
                try {
                    String jp = st.getPackage().getName().replace(".", "/");
                    String fp = String.format("%s/src/main/java/%s/%s.java", Path.find(), jp, st.getSimpleName());
                    if (!new File(fp).exists()) {
                        break;
                    }
                } catch (FileNotFoundException e) {
                    // default
                }
                Method[] dms = st.getDeclaredMethods();
                for (Method dm : dms) {
                    int m = dm.getModifiers();
                    if (Modifier.isPrivate(m) || dm.isBridge() || dm.getName().contains("$") || Modifier.isAbstract(m)
                            || (!Modifier.isPublic(m) && !st.getPackage().equals(ct.tClass.getPackage()))) {
                        continue;
                    }
                    if (!st.getName().contains("$")) {
                        ct.imports.add(st.getName());
                    }
                    buildCommonMethod(ct, dm, String.format("((%s) %s)", ct.typeName, ct.injectPropertyName));
                }
            }
        }

        static void buildMethod(ClassTemplate ct) {
            if (ct.tClass.isEnum()) {
                ct.testMethods.put("valueOfTrueAutoTest", Collections.singletonList(
                        String.format("%s.valueOf(%s.name());", ct.typeName, DefaultValue.value(ct.tClass))
                ));
            }
            Method[] dms = ct.tClass.getDeclaredMethods();
            for (Method dm : dms) {
                int m = dm.getModifiers();
                if (Modifier.isPrivate(m) || dm.isBridge() || dm.getName().contains("$")) {
                    continue;
                }
                if (ct.tClass.isEnum() && Modifier.isAbstract(m)) {
                    Field[] dfs = ct.tClass.getDeclaredFields();
                    for (Field df : dfs) {
                        if (!df.isEnumConstant()) {
                            continue;
                        }
                        buildCommonMethod(ct, dm, String.format("%s.%s", ct.typeName, df.getName()));
                    }
                    continue;
                }
                buildCommonMethod(ct, dm, ct.injectPropertyName);
            }
            try {
                ct.tClass.getDeclaredMethod("equals", Object.class);
                ct.testMethods.put("equals" + System.currentTimeMillis(), MethodEquals.build(ct));
            } catch (NoSuchMethodException e) {
                // default
            }
        }

        static void buildConstructMethod(ClassTemplate ct) {
            if (ct.tClass.isEnum()) {
                return;
            }
            Constructor<?>[] dcs = ct.tClass.getDeclaredConstructors();
            for (Constructor<?> dc : dcs) {
                int m = dc.getModifiers();
                if (Modifier.isPrivate(m)) {
                    continue;
                }
                StringBuilder s = new StringBuilder("new ").append(ct.typeName).append("(");
                Class<?>[] pts = dc.getParameterTypes();
                for (Class<?> pt : pts) {
                    if (!pt.isArray()) {
                        ct.imports.add(pt.getName());
                    }
                    s.append(DefaultValue.value(pt)).append(",");
                }
                if (",".charAt(0) == s.charAt(s.length()-1)) {
                    s.deleteCharAt(s.length()-1);
                }
                s.append(");");
                String n = "constructor" + System.currentTimeMillis();
                while (ct.testConstructors.containsKey(n)) {
                    n += "1";
                }
                ct.testConstructors.put(n, Collections.singletonList(s.toString()));
            }
        }

        static void buildMockProperty(ClassTemplate ct) {
            Field[] dfs = ct.tClass.getDeclaredFields();
            for (Field fd : dfs) {
                Annotation[] as = fd.getAnnotations();
                if (as.length < 1) {
                    continue;
                }
                for (Annotation a : as) {
                    for (String s : AUTO_SET_ANN) {
                        if (a.toString().contains(s)) {
                            ct.mockProperties.put(fd.getName(), fd.getType());
                            ct.imports.add(fd.getType().getName());
                            break;
                        }
                    }
                }
            }
        }
    }

    static class MethodEquals {
        static List<String> build(ClassTemplate ct) {
            ct.imports.add("java.beans.PropertyDescriptor");
            ct.imports.add("org.springframework.beans.BeanUtils");
            ct.imports.add("java.lang.reflect.Method");
            ct.imports.add("java.lang.reflect.Modifier");
            ct.imports.add(BigDecimal.class.getName());
            ct.imports.add(Date.class.getName());
            ct.imports.add(Map.class.getName());
            ct.imports.add(HashMap.class.getName());
            List<String> l = new ArrayList<>();
            l.add("Map<Class<?>, Object> vm = new HashMap<>();");
            l.add("vm.put(String.class, \"1\");");
            l.add("vm.put(byte.class, (byte) 1);");
            l.add("vm.put(Byte.class, (byte) 1);");
            l.add("vm.put(short.class, (short) 2);");
            l.add("vm.put(Short.class, (short) 2);");
            l.add("vm.put(float.class, (float) 3);");
            l.add("vm.put(Float.class, (float) 3);");
            l.add("vm.put(int.class, 4);");
            l.add("vm.put(Integer.class, 4);");
            l.add("vm.put(double.class, 5D);");
            l.add("vm.put(Double.class, 5D);");
            l.add("vm.put(long.class, 6L);");
            l.add("vm.put(Long.class, 6L);");
            l.add("vm.put(boolean.class, true);");
            l.add("vm.put(Boolean.class, false);");
            l.add("vm.put(BigDecimal.class, new BigDecimal(7));");
            l.add("vm.put(Date.class, new Date());");
            l.add(String.format("%s bean0 = BeanUtils.instantiateClass(%s.class);", ct.typeName, ct.typeName));
            l.add(String.format("%s bean1 = BeanUtils.instantiateClass(%s.class);", ct.typeName, ct.typeName));
            l.add(String.format("PropertyDescriptor[] pds = BeanUtils.getPropertyDescriptors(%s.class);", ct.typeName));
            l.add("for (PropertyDescriptor pd : pds) {");
            l.add("\tMethod wm = pd.getWriteMethod();");
            l.add("\tif(null == wm || !Modifier.isPublic(wm.getDeclaringClass().getModifiers())) {");
            l.add("\t\tcontinue;");
            l.add("\t}");
            l.add("\tClass<?>[] pts = wm.getParameterTypes();");
            l.add("\tif (pts.length != 1) {");
            l.add("\t\tcontinue;");
            l.add("\t}");
            l.add("\tObject v = vm.get(pts[0]);");
            l.add("\tif (null == v) {");
            l.add("\t\tcontinue;");
            l.add("\t}");
            l.add("\twm.invoke(bean0, v);");
            l.add("}");
            l.add("BeanUtils.copyProperties(bean0, bean1);");
            l.add("bean0.equals(bean1);");
            return l;
        }
    }

    static class ClassTemplate {
        Class<?> tClass;
        /** eg:fun.sometwo.go */
        String packageName;
        /** fun.sometwo.service.Test */
        Set<String> imports = new HashSet<>();
        /** eg:XXXTest*/
        String testClassName;
        String typeName;
        String runWithStr;
        String injectPropertyName;
        Map<String, Class<?>> mockProperties = new HashMap<>();
        List<String> beforeMethodLines = new ArrayList<>();
        Map<String, List<String>> testConstructors = new HashMap<>();
        /** key: testMethodName, value: testMethodLines */
        Map<String, List<String>> testMethods = new HashMap<>();
    }

    private static class DefaultValue {
        static Map<Class<?>, String> valMap;
        static {
            valMap = new HashMap<>();
            valMap.put(String.class, "\"1\"");
            valMap.put(null, "null");
            valMap.put(Integer.class, "1");
            valMap.put(int.class, "11");
            valMap.put(Long.class, "2L");
            valMap.put(long.class, "22");
            valMap.put(Short.class, "(short) 3");
            valMap.put(short.class, "(short) 33");
            valMap.put(Double.class, "4D");
            valMap.put(double.class, "44D");
            valMap.put(Float.class, "5f");
            valMap.put(float.class, "55f");
            valMap.put(BigDecimal.class, "new BigDecimal(6)");
            valMap.put(Boolean.class, "true");
            valMap.put(boolean.class, "true");
            valMap.put(Byte.class, "(byte) 7");
            valMap.put(byte.class, "(byte) 77");
            valMap.put(Class.class, "Object.class");
            valMap.put(Date.class, "new Date()");
            valMap.put(Object.class, "new Object()");
            valMap.put(Object[].class, "new Object[]{}");
            valMap.put(Set.class, "new java.util.HashSet()");
            valMap.put(List.class, "new java.util.ArrayList()");
            valMap.put(Map.class, "new java.util.HashMap()");
        }

        static String value(Class<?> clzss) {
            if (valMap.containsKey(clzss)) {
                return valMap.get(clzss);
            }
            String n = clzss.getName().contains("$") ? clzss.getName().replace("$", ".")
                    : clzss.getSimpleName();
            if (clzss.isArray()) {
                return String.format("new %s{}", n);
            }
            if (clzss.isEnum()) {
                Field[] df = clzss.getDeclaredFields();
                for (Field f : df) {
                    if (f.isEnumConstant()) {
                        return String.format("%s.%s", n, f.getName());
                    }
                }
            }
            if (Modifier.isFinal(clzss.getModifiers())) {
                return String.format("org.springframework.beans.BeanUtils.instantiateClass(%s.class)", n);
            }
            return String.format("Mockito.mock(%s.class)", n);
        }
    }
}
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值