背景
在Flink CDC整库同步_flink cdc mysql 全量同步-优快云博客中实现数据同步的基本功能,单如果源端实时数量比较大,那么数据无法在秒级别实现实时入湖,改用copy方式写入数据优化sink算子实现逻辑。
实现
创建自定义数据写入器,使用List集合缓存数据在内存空间中。并根据Flink任务 checkpoint时,使用gp的copymanager将List缓存的数据写入gp中。
import util.JDBCURL;
import util.PropUtils;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.compress.utils.Lists;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.time.StopWatch;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.connector.jdbc.catalog.PostgresCatalog;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.runtime.state.FunctionSnapshotContext;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.util.JSONPObject;
import org.apache.flink.streaming.api.CheckpointingMode;
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
import org.apache.flink.streaming.api.functions.sink.RichSinkFunction;
import org.apache.flink.streaming.api.operators.StreamingRuntimeContext;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.Schema;
import org.apache.flink.table.catalog.CatalogBaseTable;
import org.apache.flink.table.catalog.ObjectPath;
import org.apache.flink.table.catalog.exceptions.TableNotExistException;
import org.apache.flink.table.data.RawValueData;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.data.StringData;
import org.apache.flink.table.data.TimestampData;
import org.apache.flink.table.data.util.DataFormatConverters;
import org.apache.flink.table.types.AbstractDataType;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.types.Row;
import org.postgresql.PGConnection;
import org.postgresql.copy.CopyManager;
import org.postgresql.jdbc.PgConnection;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.StringBufferInputStream;
import java.math.BigDecimal;
import java.nio.charset.StandardCharsets;
import java.sql.*;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Properties;
import java.util.regex.Matcher;
import java.util.stream.Collectors;
import static org.apache.flink.table.api.DataTypes.FIELD;
import static org.apache.flink.table.api.DataTypes.STRING;
@Slf4j
public class DwCopySink extends RichSinkFunction<RowData> implements CheckpointedFunction {
private static final String DEFAULT_NULL_VALUE = "\002";
private static final String DEFAULT_FIELD_DELIMITER = "\001";
// pg 字符串里含有\u0000 会报错 ERROR: invalid byte sequence for encoding "UTF8": 0x00
public static final String SPACE = "\u0000";
private static final String LINE_DELIMITER = "\n";
protected static final String COPY_SQL_TEMPL =
"copy %s(%s) from stdin DELIMITER '%s' NULL as '%s'";
private PgConnection connection;
private CopyManager copyManager;
private static PostgresCatalog postgresCatalog;
private static CatalogBaseTable table;
private List<RowData> batchList = new ArrayList<>();
private static DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss");
private String sql;
@Override
public void open(Configuration parameters) throws Exception {
super.open(parameters);
String url = PropUtils.getProperties().getProperty("dw_jdbc");
String username = PropUtils.getProperties().getProperty("dw_user");
String password = PropUtils.getProperties().getProperty("dw_password");
connection = (PgConnection) DriverManager.getConnection(url, username, password);
buildCopySql();
}
@Override
public void invoke(RowData value, Context context) throws Exception {
super.invoke(value, context);
batchList.add(value);
}
@Override
public void finish() throws Exception {
super.finish();
}
@Override
public void close() throws Exception {
flush();
super.close();
if (connection != null) {
connection.close();
}
}
@Override
public void snapshotState(FunctionSnapshotContext context) throws Exception {
//每次checkpoint时刷新数据,避免数据重复写入
flush();
}
@Override
public void initializeState(FunctionInitializationContext context) throws Exception {
//没有初始化动作,无需处理
}
public static PostgresCatalog getCatalog() {
if (postgresCatalog == null) {
String url = PropUtils.getProperties().getProperty("dw_jdbc");
String username = PropUtils.getProperties().getProperty("dw_user");
String password = PropUtils.getProperties().getProperty("dw_password");
final Matcher matcher1 = JDBCURL.getPattern("jdbc:postgresql://{host}[:{port}]/[{database}][\\?{params}]")
.matcher(url);
if (matcher1.matches()) {
String database = matcher1.group("database");
String host = matcher1.group("host");
String port = matcher1.group("port");
String source_url = String.format("jdbc:postgresql://%s:%d", host, Integer.parseInt(port));
postgresCatalog = new PostgresCatalog(DwCopySink.class.getClassLoader(), "postgresql",
database, username, password, source_url);
}
}
return postgresCatalog;
}
/**
* 根据要同步的表信息返回row类型
*
* @return
* @throws TableNotExistException
*/
public RowType getTableRowType() throws TableNotExistException {
Schema schema = getTable().getUnresolvedSchema();
List<LogicalType> logicalTypes = schema.getColumns().stream().map(col ->
{
Schema.UnresolvedPhysicalColumn column = (Schema.UnresolvedPhysicalColumn) col;
return ((DataType) column.getDataType()).getLogicalType();
}).collect(Collectors.toList());
List<String> nameList = schema.getColumns().stream().map(col -> col.getName()).collect(Collectors.toList());
return RowType.of(logicalTypes.toArray(new LogicalType[logicalTypes.size()]),
nameList.toArray(new String[nameList.size()]));
}
public static CatalogBaseTable getTable() throws TableNotExistException {
if (table == null) {
String schemaName = PropUtils.getProperties().getProperty("copy_dw_schema");
String tableName = PropUtils.getProperties().getProperty("copy_dw_table");
table = getCatalog().getTable(new ObjectPath(postgresCatalog.getDefaultDatabase(),
schemaName + "." + tableName));
}
return table;
}
private String buildCopySql() throws SQLException, TableNotExistException {
if(StringUtils.isNotEmpty(sql)){
return sql;
}
String schemaName = PropUtils.getProperties().getProperty("copy_dw_schema");
String tableName = PropUtils.getProperties().getProperty("copy_dw_table");
//获取表的基本信息
Schema schema = getTable().getUnresolvedSchema();
String fieldsExpression = schema.getColumns().stream().map(col -> this.quoteIdentifier(col.getName()))
.collect(Collectors.joining(", "));
String tableLocation;
if (schemaName != null && !"".equals(schemaName.trim())) {
tableLocation = quoteIdentifier(schemaName) + "." + quoteIdentifier(tableName);
} else {
tableLocation = quoteIdentifier(tableName);
}
sql = String.format(
COPY_SQL_TEMPL, tableLocation, fieldsExpression, DEFAULT_FIELD_DELIMITER, DEFAULT_NULL_VALUE);
return sql;
}
private String quoteIdentifier(String identifier) {
return "\"" + identifier + "\"";
}
/**
* 刷新数据到数据库中
*
* @throws SQLException
*/
private synchronized void flush() throws SQLException {
if (batchList.size() == 0) {
return;
}
long size = batchList.size();
copyManager = connection.getCopyAPI();
StopWatch stopWatch = new StopWatch();
StringBuilder rowsStrBuilder = new StringBuilder(128);
stopWatch.start();
for (RowData row : batchList) {
int lastIndex = row.getArity() - 1;
StringBuilder rowStr = new StringBuilder(128);
for (int index = 0; index < row.getArity(); index++) {
appendColumn(row, index, rowStr, index == lastIndex);
}
String tempData = rowStr.toString();
rowsStrBuilder.append(copyModeReplace(tempData)).append(LINE_DELIMITER);
}
StopWatch stopWatch2 = new StopWatch();
try (ByteArrayInputStream bi =
new ByteArrayInputStream(rowsStrBuilder.toString().getBytes(StandardCharsets.UTF_8))) {
stopWatch.stop();
stopWatch2.start();//开始记录写入数据时间
copyManager.copyIn(sql, bi);
} catch (Exception e) {
log.error(sql);
log.error("写入数据异常:",e);
throw new RuntimeException("写入数据异常");
}
stopWatch2.stop();
log.info(String.format("【%d】条数据,构建byte缓冲区花费时间为【%d】 ms",size,stopWatch.getTime()));
log.info("\n");
long watch2Time = stopWatch2.getTime();
//log.info(String.format("写入数据的时间%d",watch2Time));
//换算每秒写入数据条数
BigDecimal decimal = BigDecimal.valueOf(size).divide(BigDecimal.valueOf(watch2Time),3,BigDecimal.ROUND_HALF_EVEN).multiply(BigDecimal.valueOf(1000));
log.info(String.format("写入【%d】条数据,总花费时间为【%d】ms,平均速度%d条每秒",
size,watch2Time, decimal.intValue()));
batchList.clear();
}
private void appendColumn(
RowData rowData, int pos, StringBuilder rowStr, boolean isLast) {
Schema schema = null;
try {
schema = getTable().getUnresolvedSchema();
} catch (TableNotExistException e) {
throw new RuntimeException(e);
}
Schema.UnresolvedPhysicalColumn unresolvedColumn = (Schema.UnresolvedPhysicalColumn) schema.getColumns().get(pos);
LogicalType logicalType = ((DataType) unresolvedColumn.getDataType()).getLogicalType();
String col = DEFAULT_NULL_VALUE;
switch (logicalType.getTypeRoot()){
case CHAR:
case VARCHAR:
col = rowData.getString(pos).toString();
break;
case BOOLEAN:
col = String.valueOf(rowData.getBoolean(pos));
break;
case DECIMAL:
col = rowData.getDecimal(pos,32,10).toString();
break;
case TINYINT:
case SMALLINT:
case INTEGER:
col = String.valueOf(rowData.getInt(pos));
break;
case BIGINT:
col = String.valueOf(rowData.getLong(pos));
break;
case FLOAT:
case DOUBLE:
col = String.valueOf(rowData.getDouble(pos));
break;
case DATE:
LocalDateTime localDateTime = rowData.getTimestamp(pos, 6).toLocalDateTime();
col = localDateTime.format(formatter);
break;
case TIME_WITHOUT_TIME_ZONE:
case TIMESTAMP_WITHOUT_TIME_ZONE:
case TIMESTAMP_WITH_TIME_ZONE:
case TIMESTAMP_WITH_LOCAL_TIME_ZONE:
LocalDateTime dateTime = rowData.getTimestamp(pos, 6).toLocalDateTime();
col = dateTime.format(formatter);
break;
default:
throw new RuntimeException(String.format("未支持的数据类型 %s %s",unresolvedColumn.getName(),unresolvedColumn.getDataType().toString()));
}
if (col == null) {
rowStr.append(DEFAULT_NULL_VALUE);
} else {
rowStr.append(col);
}
if (!isLast) {
rowStr.append(DEFAULT_FIELD_DELIMITER);
}
}
private String copyModeReplace(String rowStr) {
if (rowStr.contains("\\")) {
rowStr = rowStr.replaceAll("\\\\", "\\\\\\\\");
}
if (rowStr.contains("\r")) {
rowStr = rowStr.replaceAll("\r", "\\\\r");
}
if (rowStr.contains("\n")) {
rowStr = rowStr.replaceAll("\n", "\\\\n");
}
// pg 字符串里含有\u0000 会报错 ERROR: invalid byte sequence for encoding "UTF8": 0x00
if (rowStr.contains(SPACE)) {
rowStr = rowStr.replaceAll(SPACE, "");
}
return rowStr;
}
public PostgresCatalog getPostgresCatalog() {
return postgresCatalog;
}
}