这个只是生成UT代码的一部份, 本应该用模板的,但为了快先用着
格式不好看ide 格式化
使用方法视频
bilibili的演示视频 https://www.bilibili.com/video/BV1Aa411g7v7
package com.areyoo.lok.controller;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.lang.reflect.AnnotatedType;
import java.lang.reflect.Constructor;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.lang.reflect.Parameter;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Date;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mockito;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.util.ObjectUtils;
/**
* WwGenTest
*
* @author xusong
*/
public class WwGenTest {
/**
* java文件夹, 用于生成测试的 when代码 可以自定义 比如 D:/java/lok/target/test-classes/../../src/
*/
private static String filePath = "";
/**
* 生成 单元测试 InjectMocks 的变量名称
*/
private String serviceName = "";
// 感觉 answer = Answers.RETURNS_DEEP_STUBS 用处不大,先关上
private Boolean useAnswers = false;
private Boolean isSuperclass = false;
// 是否生成私有方法的单元测试
private Boolean genPrivateMethod = false;
// 是否使用json 初始化对象
private Boolean useJson = false;
private String jsonFn = "";
private String author = "";
// 是否使用 junit5
private Boolean junit5 = false;
/**
* 常用的 Exception
*/
private Class importException = Exception.class;
private String fileContent = "";
private String importAny = "static org.mockito.ArgumentMatchers";
// 有输出文件路径比如 "F:/test.txt";
private String outputFile = "";
// 用于不确定的泛型 比如 'com.areyoo.lok.service.api.WwService|T': String.class
private Map<String, Class> genericMap = new HashMap<>(15);
// 用于生成测试类继承的类
private Class baseTest = null;
@Test
public void genTest() throws Exception {
if ("".equals(filePath)) {
// 反推java文件夹
filePath = this.getClass().getResource("/").toString().substring(6) + "../../src/";
}
// genericMap.put("com.areyoo.lok.repository.AuthorRepository|S", Author.class);
// genericMap.put("com.areyoo.lok.repository.AuthorRepository|T", Author.class);
// genericMap.put("com.areyoo.lok.service.api.WwService|T", List.class);
// 生成当前类的单元测试
genCode(WwController.class, false);
// 生成父类的单元测试
// genCode(WwController.class.getSuperclass(), true);
}
private static String getAbsolutePath(Class myClass) {
// 取得要生成单元测试的类的绝对地址 用于生成 when thenReturn
String fileClassPath = myClass.getTypeName().replace(".", File.separator) + ".java";
return getAbsolutePath(new File(filePath), fileClassPath);
}
private static String getAbsolutePath(File file, String filePath) {
File[] fs = file.listFiles();
String result = "";
for (File f : fs) {
if (f.isDirectory()) {
result = getAbsolutePath(f, filePath);
if (!"".equals(result)) {
return result;
}
} else if (f.isFile() && f.getAbsolutePath().indexOf(filePath) > 0) {
return f.getAbsolutePath();
}
}
return "";
}
private Boolean isInit = true;
private void genCode(Class myClass, Boolean isSuperclass) throws Exception {
genCode(myClass, isSuperclass, true);
genCode(myClass, isSuperclass, false);
if (!"".equals(outputFile)) {
writeFileWithBufferedWriter();
}
}
private void writeFileWithBufferedWriter() throws IOException {
BufferedWriter writer = new BufferedWriter(new FileWriter(outputFile));
writer.write(stringBuffer.toString());
writer.close();
}
private void genCode(Class myClass, Boolean isSuperclass, Boolean init) throws Exception {
// 生成测试代码
if ("java.lang.Object".equals(myClass.getTypeName())) {
return;
}
if (init) {
importSet = new HashSet<>(16);
defaultMap = new HashMap<>(16);
}
isInit = init;
this.isSuperclass = isSuperclass;
String name = getType(myClass.getName());
if ("".equals(serviceName)) {
serviceName = name.substring(0, 1).toLowerCase() + name.substring(1);
}
println("package " + myClass.getName().substring(0, myClass.getName().length() - myClass.getSimpleName().length() - 1) + ";");
println("");
setImport("org.mockito.InjectMocks");
setImport("org.mockito.Mock");
setImport("org.junit.Assert");
if (junit5) {
setImport("org.junit.jupiter.api.Test");
} else {
setImport("org.junit.Test");
}
setImport("static org.mockito.Mockito.when");
List<String> importList = new ArrayList<>(importSet);
Collections.sort(importList);
for (String importStr : importList) {
println(importStr);
}
println("");
Set<Field> fields = getDeclaredFields(myClass);
List<String> lineList = readFileContent(myClass);
fileContent = String.join("\n", lineList);
Map<String, List<String>> map = new HashMap<>(16);
if (!"".equals(author)) {
println("/**\n" +
" * " + name + " UT\n" +
" *" + "\n" +
" * @author " + author + "\n" +
" * @date " + new Date() + "\n" +
" */");
}
if (baseTest == null) {
println("public class " + name + "Test {");
} else {
setImport(baseTest.getName());
println("public class " + name + "Test extends " + baseTest.getSimpleName() + " {");
}
println("@InjectMocks");
println("private " + myClass.getSimpleName() + " " + serviceName + ";");
println("");
int number = 0;
List<String> valueList = new ArrayList<>();
for (Field service : fields) {
if (service.getAnnotations().length > 0 && !service.getType().getName().contains("java.") && service.getType().getName().contains(".")) {
setImport("static org.mockito.Mockito.mock");
if (useAnswers) {
setImport("org.mockito.Answers");
println("@Mock(answer = Answers.RETURNS_DEEP_STUBS)");
} else {
println("@Mock");
}
setImport(service.getType().getName());
println("private " + service.getType().getSimpleName() + " " + service.getName() + ";");
println("");
for (Method serviceMethod : getDeclaredMethods(service.getType(), true)) {
String methodStr = service.getName() + "." + serviceMethod.getName() + "(";
Type t = serviceMethod.getAnnotatedReturnType().getType();
if (!"void".equals(t.getTypeName()) && ("".equals(fileContent) || fileContent.indexOf(methodStr) > 0)) {
if (!map.containsKey(methodStr)) {
map.put(methodStr, new ArrayList<>(10));
}
map.get(methodStr).add(getWhen(serviceMethod, number, service));
number++;
} else if ("void".equals(t.getTypeName()) && ("".equals(fileContent) || fileContent.indexOf(methodStr) > 0)) {
if (!map.containsKey(methodStr)) {
map.put(methodStr, new ArrayList<>(10));
}
map.get(methodStr).add(getVoidWhen(serviceMethod, service));
number++;
}
}
} else if (service.getAnnotations().length > 0 && (service.getType().getName().contains("java.") || !service.getType().getName().contains("."))) {
// 如果有注解及类型是标量
String setFieldStr = "ReflectionTestUtils.setField(" + serviceName + ", \"" + service.getName() + "\", " + getDefaultVal(service.getType()) + ");";
valueList.add(setFieldStr);
}
}
if (valueList.size() > 0) {
// 生成反射给成员变量赋值的代码
if (junit5) {
setImport("org.junit.jupiter.api.BeforeAll");
println("@BeforeAll");
} else {
setImport("org.junit.BeforeClass");
println("@BeforeClass");
}
println("public void beforeInit() {");
valueList.forEach((value) -> {
println(value);
});
println("}");
println("");
}
Map<String, Set<List<String>>> whenMap = new HashMap<>(16);
// 函数之间的关系
Map<String, Set<String>> whenMethod = new HashMap<>(16);
Map<String, Set<String>> putString = new HashMap<>(16);
Set<Method> methods = getDeclaredMethods(myClass, true);
for (Method method : methods) {
whenMap.put(method.getName(), new HashSet<>(15));
whenMethod.put(method.getName(), new HashSet<>(15));
putString.put(method.getName(), new HashSet<>(15));
}
String methodName = "";
if (!"".equals(fileContent)) {
for (String line : lineList) {
if (line.trim().length() <= 1) {
continue;
}
boolean maybeFunction = (line.indexOf("(") != -1);
if (maybeFunction && (line.indexOf("private") > 0 || line.ind