import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import javax.annotation.Resource;
import javax.sql.DataSource;
import java.sql.*;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
/**
* 数据库操作JDBC工具类
*/
@Slf4j
public class DBUtil {
private final static String TYPE_ORACLE = "oracle";
private final static String DB = "jdbc.db.";
private final static String DB_USERNAME = "username";
private final static String DB_PWD = "password";
private final static String DB_DRIVER = "driver";
private final static String DB_URL = "url";
private final static String DB_IP = "ip";
private final static String DB_PORT = "port";
private final static String DB_TYPE = "type";
private final static String DB_DBNAME = "dbName";
private final static String REGEX_TYPE = "jdbc:(.*?):";
private final static String REGEX_ORACLE = "jdbc:oracle:thin:@(.*?):(.*):(.*)";
private final static String CONN_ORACLE = "jdbc:oracle:thin:@";
private final static String SELECT_FROM = "select * from ";
private static Map<String, String> dbConfig = new HashMap<String, String>();
private static String dbType;
@Resource
private DataSource dataSource;
public static synchronized Map<String, String> getDBConfig() {
if (dbConfig == null || dbConfig.size() == 0) {
if (dbConfig == null)
dbConfig = new HashMap<String, String>();
Map<String, String> propConfig = PropertyUtil.getPropertyMap(Const.DB_CONFIG);
String propUrl = propConfig.get(DB + DB_URL);
Matcher m = Pattern.compile(REGEX_TYPE).matcher(propUrl);
while (m.find()) {
dbType = m.group(1);
}
dbConfig.put(DB_TYPE, dbType);
if (TYPE_ORACLE.equals(dbType)) {
dbConfig.put(DB_DRIVER, propConfig.get(DB + DB_DRIVER));
dbConfig.put(DB_USERNAME, propConfig.get(DB + DB_USERNAME));
dbConfig.put(DB_PWD, propConfig.get(DB + DB_PWD));
dbConfig.put(DB_PWD, propConfig.get(DB + DB_PWD));
Matcher m_oracle = Pattern.compile(REGEX_ORACLE).matcher(propUrl);
while (m_oracle.find()) {
dbConfig.put(DB_IP, m_oracle.group(1));
dbConfig.put(DB_PORT, m_oracle.group(2));
dbConfig.put(DB_DBNAME, m_oracle.group(3));
}
}
}
return dbConfig;
}
public static Connection getDBConn() throws Exception {
Map<String, String> config = getDBConfig();
Connection conn = getConn(
config.get(DB_DRIVER),
dbType,
config.get(DB_USERNAME),
config.get(DB_PWD),
config.get(DB_IP),
config.get(DB_PORT),
config.get(DB_DBNAME)
);
return conn;
}
public static Connection getConn(String driver, String dbType, String username, String password, String ip,
String port, String databaseName) throws Exception {
String oracleDataSourceUrl = CONN_ORACLE + ip + ":" + port + ":" + databaseName;
if (TYPE_ORACLE.equals(dbType)) {
Class.forName(driver);
return DriverManager.getConnection(
oracleDataSourceUrl,
username,
password);
} else {
return null;
}
}
public List<String> getTables() throws Exception {
return getTables(getDBConn());
}
public List<String> getTables(Connection conn) throws Exception {
if (TYPE_ORACLE.equals(dbType)) {
return getTablesOracle(conn);
} else {
return null;
}
}
public List<String> getTablesOracle(Connection conn) {
try {
List<String> tableList = new ArrayList<String>();
DatabaseMetaData meta = conn.getMetaData();
ResultSet rs = meta.getTables(null, null, null, new String[]{"TABLE"});
while (rs.next()) {
tableList.add(rs.getString(3));
}
return tableList;
} catch (Exception e) {
log.error("getTablesOracle方法错误:" + e.getMessage());
log.error("getTablesOracle方法错误:", e);
} finally {
closeConn(conn);
}
return null;
}
public List<TColumn> getTableColumnsByTableName(String tableName) throws Exception {
return getTableColumnsByTableName(getDBConn(), tableName);
}
public List<TColumn> getTableColumnsByTableName(Connection conn, String tableName) throws Exception {
return getTableColumns(conn, SELECT_FROM + tableName);
}
public List<TColumn> getTableColumns(String sqlStr) throws Exception {
return getTableColumns(getDBConn(), sqlStr);
}
public static List<TColumn> getTableColumns(Connection conn, String sqlStr) throws Exception {
String sql = SELECT_FROM + "(" + sqlStr + ") tcolumns where 0!=0";
PreparedStatement pstmt = (PreparedStatement) conn
.prepareStatement(sql);
pstmt.execute();
List<TColumn> columns = new ArrayList<TColumn>();
ResultSetMetaData rsmd = (ResultSetMetaData) pstmt.getMetaData();
for (int i = 1; i < rsmd.getColumnCount() + 1; i++) {
columns.add(new TColumn(rsmd.getColumnName(i), rsmd.getColumnTypeName(i), rsmd.getPrecision(i),
rsmd.getScale(i), rsmd.isNullable(i)));
}
return columns;
}
public List<List<Object>> queryByTableName(String tableName) throws Exception {
return queryByTableName(getDBConn(), tableName);
}
public List<List<Object>> queryByTableName(Connection conn, String tableName) throws Exception {
return query(conn, SELECT_FROM + tableName);
}
public List<List<Object>> query(String sqlStr) throws Exception {
return query(getDBConn(), sqlStr);
}
public List<List<Object>> query(Connection conn, String sqlStr) throws Exception {
List<TColumn> columns = new ArrayList<TColumn>();
List<List<Object>> dataList = new ArrayList<List<Object>>();
Statement stmt = null;
ResultSet rs = null;
try {
conn = getDBConn();
stmt = conn.createStatement();
rs = stmt.executeQuery(sqlStr);
columns = getTableColumns(conn, sqlStr);
List<Object> columnList = new ArrayList<Object>();
for (TColumn tc : columns) {
columnList.add(tc.getName());
}
dataList.add(columnList);
while (rs.next()) {
List<Object> oneDataList = new ArrayList<Object>();
for (int i = 1; i < columns.size() + 1; i++) {
oneDataList.add(rs.getObject(i));
}
dataList.add(oneDataList);
}
conn.close();
} catch (Exception e) {
log.error("query方法错误:" + e.getMessage());
log.error("query方法错误:", e);
} finally {
closeConn(conn);
}
return dataList;
}
public Page<List<Object>> queryByTableName(Connection conn, String tableName, Page<List<Object>> page)
throws Exception {
return query(conn, SELECT_FROM + tableName, page);
}
public Page<List<Object>> queryByTableName(String tableName, Page<List<Object>> page) throws Exception {
return query(getDBConn(), SELECT_FROM + tableName, page);
}
public Page<List<Object>> query(String sqlStr, Page<List<Object>> page) throws Exception {
return query(getDBConn(), sqlStr, page);
}
public Page<List<Object>> query(Connection conn, String sqlStr, Page<List<Object>> page) throws Exception {
// 存放字段名
List<TColumn> columns = new ArrayList<TColumn>();
// 存放数据(从数据库读出来的一条条的数据)
List<List<Object>> dataList = new ArrayList<List<Object>>();
Statement stmt = null;
ResultSet rs = null;
String sqlPage = null;
try {
conn = getDBConn();
stmt = conn.createStatement();
rs = stmt.executeQuery(getCountSql(sqlStr));
while (rs.next()) {
page.setTotalRecord(rs.getInt(1));
break;
}
if (TYPE_ORACLE.equals(dbType)) {
sqlPage = getOraclePageSql(page, new StringBuffer(sqlStr));
}
rs = stmt.executeQuery(sqlPage);
columns = getTableColumns(conn, sqlStr);
List<Object> columnList = new ArrayList<Object>();
for (TColumn tc : columns) {
columnList.add(tc.getName());
}
dataList.add(columnList);
while (rs.next()) {
List<Object> oneDataList = new ArrayList<Object>();
for (int i = 1; i < columns.size() + 1; i++) {
oneDataList.add(rs.getObject(i));
}
dataList.add(oneDataList);
}
page.setResults(dataList);
conn.close();
} catch (Exception e) {
log.error("query方法错误:", e);
} finally {
closeConn(conn);
}
return page;
}
public int operate(String sqlStr) throws Exception {
return operate(getDBConn(), sqlStr);
}
public int operate(Connection conn, String sqlStr) throws Exception {
int res = 0;
Statement stmt = null;
try {
stmt = conn.createStatement();
res = stmt.executeUpdate(sqlStr);
} catch (Exception e) {
log.error("operate方法错误:", e);
} finally {
closeConn(conn);
}
return res;
}
private static String getOraclePageSql(Page<?> page, StringBuffer sqlBuffer) {
// 计算第一条记录的位置,Oracle分页是通过rownum进行的,而rownum是从1开始的
int offset = (page.getPageNum() - 1) * page.getPageSize() + 1;
sqlBuffer.insert(0, "select u.*, rownum r from (").append(") u where rownum < ")
.append(offset + page.getPageSize());
sqlBuffer.insert(0, "select * from (").append(") where r >= ").append(offset);
return sqlBuffer.toString();
}
private static String getCountSql(String sql) {
return "select count(*) from (" + sql + ") countRecord";
}
public static void closeConn(Connection conn) {
try {
if (conn != null)
conn.close();
} catch (Exception e) {
log.error("closeConn方法错误:", e);
}
}
/**
* 执行Sql(分页)
*
* @param sqlStr sql语句
* @param page 分页
* @return
*/
public static Map<String, Object> executeSQL(String sqlStr, Page<List<Object>> page) {
Map<String, Object> result = new HashMap<String, Object>(10);
// 存放字段名
List<TColumn> columns = new ArrayList<TColumn>();
// 存放数据(从数据库读出来的一条条的数据)
List<List<Object>> dataList = new ArrayList<List<Object>>();
Statement stmt = null;
ResultSet rs = null;
String sqlPage = null;
Connection conn = null;
int lines = 0;
//请求起始时间_毫秒
long startTime = System.currentTimeMillis();
long rTime = 0;
try {
if (StringUtils.isNotBlank(sqlStr)) {
result.put("executeSQL", sqlStr);
conn = getDBConn();
stmt = conn.createStatement();
if (isQuerySql(sqlStr)) {
//判断是否为查询语句
rs = stmt.executeQuery(getCountSql(sqlStr));
while (rs.next()) {
page.setTotalRecord(rs.getInt(1));
break;
}
sqlPage = getOraclePageSql(page, new StringBuffer(sqlStr));
rs = stmt.executeQuery(sqlPage);
//请求结束时间_毫秒
long endTime = System.currentTimeMillis();
rTime = endTime - startTime;
result.put("rTime", rTime);
columns = getTableColumns(conn, sqlStr);
List<Object> columnList = new ArrayList<Object>();
for (TColumn tc : columns) {
columnList.add(tc.getName());
}
dataList.add(columnList);
while (rs.next()) {
List<Object> oneDataList = new ArrayList<Object>();
for (int i = 1; i < columns.size() + 1; i++) {
oneDataList.add(rs.getObject(i));
}
dataList.add(oneDataList);
}
page.setResults(dataList);
result.put("list", page);
result.put("type", "query");
} else {
lines = stmt.executeUpdate(sqlStr);
//请求结束时间_毫秒
long endTime = System.currentTimeMillis();
rTime = endTime - startTime;
result.put("type", "operate");
}
//成功返回 1
result.put("res", "1");
} else {
//失败返回 0
result.put("res", "0");
result.put("resMsg", "sql语句不能为空");
}
} catch (Exception e) {
//失败返回 0
result.put("res", "0");
result.put("resMsg", e.getMessage());
log.error("executeSQL方法错误:", e);
} finally {
closeStmt(stmt);
closeConn(conn);
}
result.put("rTime", rTime);
result.put("lines", lines);
return result;
}
/**
* 获取sql语句是否为查询
*/
public static boolean isQuerySql(String sqlStr) {
if (sqlStr.trim().toLowerCase().startsWith("select"))
return true;
return false;
}
public static void closeStmt(Statement stmt) {
try {
if (stmt != null)
stmt.close();
} catch (Exception e) {
log.error("closeStmt方法错误:" + e.getMessage());
log.error("closeStmt方法错误:", e);
}
}
/**
* 执行Sql(不分页)
*
* @param sqlStr sql语句
* @return
*/
public static Map<String, Object> executeSQL(String sqlStr) {
Map<String, Object> result = new HashMap<String, Object>();
// 存放字段名
List<TColumn> columns = new ArrayList<TColumn>();
// 存放数据(从数据库读出来的一条条的数据)
List<List<Object>> dataList = new ArrayList<List<Object>>();
Statement stmt = null;
ResultSet rs = null;
Connection conn = null;
int lines = 0;
//请求起始时间_毫秒
long startTime = System.currentTimeMillis();
long rTime = 0;
try {
if (StringUtils.isNotBlank(sqlStr)) {
result.put("executeSQL", sqlStr);
conn = getDBConn();
stmt = conn.createStatement();
//判断是否为查询语句
if (isQuerySql(sqlStr)) {
rs = stmt.executeQuery(getCountSql(sqlStr));
while (rs.next()) {
result.put("count", rs.getInt(1));
break;
}
rs = stmt.executeQuery(sqlStr);
//请求结束时间_毫秒
long endTime = System.currentTimeMillis();
rTime = endTime - startTime;
result.put("rTime", rTime);
columns = getTableColumns(conn, sqlStr);
List<Object> columnList = new ArrayList<Object>();
for (TColumn tc : columns) {
columnList.add(tc.getName());
}
dataList.add(columnList);
while (rs.next()) {
List<Object> oneDataList = new ArrayList<Object>();
for (int i = 1; i < columns.size() + 1; i++) {
oneDataList.add(rs.getObject(i));
}
dataList.add(oneDataList);
}
result.put("list", dataList);
result.put("type", "query");
} else {
lines = stmt.executeUpdate(sqlStr);
//请求结束时间_毫秒
long endTime = System.currentTimeMillis();
rTime = endTime - startTime;
result.put("type", "operate");
}
//成功返回 1
result.put("res", "1");
} else {
//失败返回 0
result.put("res", "0");
result.put("resMsg", "sql语句不能为空");
}
} catch (Exception e) {
//失败返回 0
result.put("res", "0");
result.put("resMsg", e.getMessage());
log.error("executeSQL方法错误:", e);
} finally {
closeStmt(stmt);
closeConn(conn);
}
result.put("rTime", rTime);
result.put("lines", lines);
return result;
}
public DataSource getDataSource() {
return dataSource;
}
public void setDataSource(DataSource dataSource) {
this.dataSource = dataSource;
}
/**
* 测试用例,不用关注
* @param args
*/
public static void main(String[] args) {
/**
* 1. 测试通过url获取数据库类型
*/
String propUrl = "jdbc:oracle:thin:@127.0.0.1:1521:helowin";
String REGEX_TYPE = "jdbc:(.*?):";
Matcher m = Pattern.compile(REGEX_TYPE).matcher(propUrl);
while (m.find()) {
dbType = m.group(1);
System.out.println(dbType);
}
/**
* 2. 测试通过sql查询
*/
Page<List<Object>> page = new Page<>();
page.setPageNum(1);
page.setPageSize(1);
Map<String, Object> map = executeSQL(
"select * from test11",
page);
System.out.println(map);
}
}
01-14