基础同步工具

SourceUtil


import org.apache.commons.lang3.StringUtils
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.elasticsearch.spark.sql.EsSparkSQL

object SourceUtil {

  def sourceFromEsByDt(spark: SparkSession, es_index: String, dt: String): DataFrame = {

    val esQuery =
      s"""
         |{
         |  "query": {
         |    "bool": {
         |      "filter": {
         |        "range": {
         |          "dt": {
         |            "gte": "$dt",
         |            "lte": "$dt"
         |          }
         |        }
         |      }
         |    }
         |  }
         |}
         |""".stripMargin

    EsSparkSQL.esDF(spark, s"/${es_index}/_doc", esQuery)
  }

  def sourceFromEsByDtRange(spark: SparkSession, es_index: String, dt: String, dt1: String): DataFrame = {

    val esQuery =
      s"""
         |{
         |  "query": {
         |    "bool": {
         |      "filter": {
         |        "range": {
         |          "dt": {
         |            "gte": "$dt",
         |            "lte": "$dt1"
         |          }
         |        }
         |      }
         |    }
         |  }
         |}
         |""".stripMargin

    EsSparkSQL.esDF(spark, s"/${es_index}/_doc", esQuery)
  }


  def sourceAllFromMysql(spark: SparkSession, host: String, db: String, t_source: String, user_name: String, pwd: String): DataFrame = {

    sourceAllFromMysql(spark, host, "3306", db, t_source, user_name, pwd)
  }

  /**
   * 单并行度读mysql
   */
  def sourceAllFromMysql(spark: SparkSession, host: String, port: String, db: String, t_source: String, user_name: String, pwd: String): DataFrame = {


    sourceAllFromMysql(spark, host, port, db, t_source, user_name, pwd, null)

  }


  /**
   * 带过滤条件的"全量导入"
   * 比如优点码全表600w,但是民主测评就几千条
   */
  def sourceAllFromMysql(spark: SparkSession, host: String, port: String, db: String, t_source: String, user_name: String, pwd: String, condition: String): DataFrame = {

    sourceAllFromMysql(spark, host, port, db, t_source, user_name, pwd, null, "")

  }


  /**
   * 基于被同步源表的元数据以及裁剪列重新构造同步条件,以避免带有图片信息的数据同步
   * 如crm的t_customer_report的remark字段是图片的base64编码
   */
  def sourceAllFromMysql(spark: SparkSession, host: String, port: String, db: String, t_source: String, user_name: String, pwd: String, condition: String, dropCols: String): DataFrame = {

    var condition1 = "1=1"

    if (StringUtils.isNotBlank(condition)) {
      condition1 = condition
    }

    val url = s"jdbc:mysql://${host}:${port}/${db}?tinyInt1isBit=false&connectionRequestTimout=300000&connectionTimeout=300000&socketTimeout=300000"
    val querySql = TableUtils.createQuerySql(url, user_name, pwd, t_source, dropCols) + s" where $condition1"


    //    val querySql = s"SELECT * FROM ${t_source} where ${condition1}"

    println("querySql:" + querySql)

    val jdbcDF = spark.read
      .format("jdbc")
      .option("url", url)
      .option("user", user_name)
      .option("password", pwd)
      .option("dbtable", s"($querySql) t")
      .load()

    jdbcDF.repartition(3)

  }


  /**
   * without condition
   */
  def sourceNewFromMysql(spark: SparkSession, host: String, port: String, db: String, t_source: String, user_name: String, pwd: String, createTime: String, dt: String): DataFrame = {

    sourceNewFromMysql(spark, host, port, db, t_source, user_name, pwd, createTime, dt, null)

  }

  def sourceNewFromMysql(spark: SparkSession, host: String, port: String, db: String, t_source: String, user_name: String, pwd: String, createTime: String, dt: String, condition: String): DataFrame = {


    var condition1 = "1=1"

    if (StringUtils.isNotBlank(condition)) {
      condition1 = condition
    }


    val expression = s"date_format($createTime,'%Y-%m-%d') = '$dt' and $condition1"
    val querySql = s"SELECT * FROM ${t_source} where ${expression}"

    println("querySql:" + querySql)


    val url = s"jdbc:mysql://${host}:${port}/${db}?tinyInt1isBit=false&user=${user_name}&password=${pwd}&connectionRequestTimout=300000&connectionTimeout=300000&socketTimeout=300000"

    //    val expression = s"date_format($createTime,'%Y-%m-%d') = '$dt' and $condition1"

    val jdbcDF: DataFrame = spark.read.format("jdbc")
      //      .options(Map("url" -> url, "dbtable" -> s"(SELECT * FROM ${t_source} where $expression) t"))
      .options(Map("url" -> url, "dbtable" -> s"($querySql) t"))
      .load()

    jdbcDF.repartition(3)

  }

  def sourceNewAndChangeFromMysql(spark: SparkSession, host: String, port: String, db: String, t_source: String, user_name: String, pwd: String, createTime: String, updateTime: String, dt: String): DataFrame = {

    val url = s"jdbc:mysql://${host}:${port}/${db}?tinyInt1isBit=false&user=${user_name}&password=${pwd}&connectionRequestTimout=300000&connectionTimeout=300000&socketTimeout=300000"

    val expression = s"date_format($createTime,'%Y-%m-%d') = '$dt' or date_format($updateTime,'%Y-%m-%d') = '$dt'"

    println("expression:" + expression)

    val jdbcDF: DataFrame = spark.read.format("jdbc")
      .options(Map("url" -> url, "dbtable" -> s"(SELECT * FROM ${t_source} where $expression) t"))
      .load()

    jdbcDF.repartition(3)

  }


  def sourceAllFromEs(spark: SparkSession, index: String, _doc: String): DataFrame = {
    val esQuery =
      s"""
         |{"query":{"match_all":{}}}
       """.stripMargin
    val resDF: DataFrame = EsSparkSQL.esDF(spark, s"/${index}/${_doc}", esQuery)
    resDF.repartition(3)
  }


  def sourceNewFromEs(spark: SparkSession, index: String, _doc: String, createTime: String, dt: String): DataFrame = {

    val esQuery =
      s"""
         |{
         |  "query": {
         |    "bool": {
         |      "filter": {
         |        "range": {
         |          "${createTime}": {
         |            "gte": "${dt} 00:00:00",
         |            "lte": "${dt} 23:59:59"
         |          }
         |        }
         |      }
         |    }
         |  }
         |}
         |""".stripMargin

    EsSparkSQL.esDF(spark, s"/${index}/${_doc}", esQuery)
//      .repartition(3)
  }

  def sourceNewAndChangeFromEs(spark: SparkSession, index: String, _doc: String, createTime: String, updateTime: String, dt: String): DataFrame = {

    val esQuery =
      s"""
         |{
         |  "query":
         |  {
         |    "bool":
         |    {
         |      "should":
         |      [
         |        {
         |        "range": {
         |          "$createTime":
         |          {
         |            "gte": "$dt 00:00:00",
         |            "lte": "$dt 23:59:59"
         |          }
         |        }
         |      },
         |      {
         |        "range": {
         |          "$updateTime":
         |          {
         |            "gte": "$dt 00:00:00",
         |            "lte": "$dt 23:59:59"
         |          }
         |        }
         |      }
         |    ]
         |    }
         |  }
         |}
         |""".stripMargin

    EsSparkSQL.esDF(spark, s"/${index}/${_doc}", esQuery)
      .repartition(3)
  }
}

SinkUtil

import org.apache.commons.lang3.StringUtils
import org.apache.spark.SparkException
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession}
import org.slf4j.LoggerFactory

import java.sql.{Connection, DriverManager, PreparedStatement, Statement}
import java.util
import java.util.Random
import java.util.function.BiConsumer

object SinkUtil {
  private val logger = LoggerFactory.getLogger(SinkUtil.getClass)

  def sink_to_mysql(host0: String, db: String, target_table: String, userName: String, pwd: String, frame_result: DataFrame, partitions: Int, mySaveMode: MySaveMode, dt: String
                    , isRollback: Boolean): Unit = {

    /**
     * 拦截空dataset的导入
     */
    if (frame_result.isEmpty) {
      println(s"the dataset of dt ${dt} sink to mysql is empty,return...")
      return
    }

    var host = host0
    var port = "3306"

    if (host0.split(":").length >= 2) {
      port = host0.split(":")(1)
      host = host0.split(":")(0)
    }

    val url = s"jdbc:mysql://${host}:${port}/${db}?rewriteBatchedStatements=true&useUnicode=true&characterEncoding=UTF-8&connectionRequestTimout=300000&connectionTimeout=300000&socketTimeout=300000&useSSL=false"

    var conn: Connection = null
    var prepareStatement: PreparedStatement = null

    try {
      conn = DriverManager.getConnection(url, userName, pwd)
      val statement = conn.createStatement()

      /**
       * 如果Mysql表不存在,自动建表
       */
      val ddl = TableUtils.createMysqlTableDDL(target_table, frame_result)
      println(ddl)
      try {
        statement.execute(ddl)
      } catch {
        case e: Exception => {
          /**
           * com.mysql.jdbc.exceptions.jdbc4.MySQLSyntaxErrorException: Column length too big for column 'reworknum_0_ids' (max = 21845); use BLOB or TEXT instead
           */
          logger.error("MySQL ddl exception {}:", e)
          e.printStackTrace()
        }
      }

      println(s"go on to sink $db.$target_table")

      DruidConnection.release(null, statement, null)

      conn.setAutoCommit(isRollback)
      if (mySaveMode == MySaveMode.OverWriteAllTable) {

        /**
         * in this time, the arg of dt can be set to null
         */
        prepareStatement = conn.prepareStatement(s"DELETE FROM $target_table WHERE 1=1")

      } else if (mySaveMode == MySaveMode.OverWriteByDt) {

        prepareStatement = conn.prepareStatement(s"DELETE FROM $target_table WHERE dt='$dt'")

      } else if (mySaveMode == MySaveMode.OverWriteByQuarter) {

        val quarter1 = DateUtil.getQuarterStart(dt)
        val quarter2 = DateUtil.getQuarterEnd(dt)

        prepareStatement = conn.prepareStatement(s"DELETE FROM $target_table WHERE dt between '$quarter1' and '$quarter2'")

      } else {
        throw new IllegalArgumentException("未知的SaveMode参数:" + mySaveMode)
      }
      prepareStatement.execute()

    } catch {
      case e: Exception => logger.error("MySQL connection exception {}:", e)
        if (conn != null) {
          if (isRollback) {
            conn.rollback()
          }
        }
        //        DbUtils.release(conn, prepareStatement)
        DruidConnection.release(null, prepareStatement, conn)
        System.exit(-1)
    } finally {
      if (null != conn && !conn.isClosed) {
        /**
         * User class threw exception: java.sql.SQLException: Can't call commit when autocommit=true
         */
        val commit: Boolean = conn.getAutoCommit
        if (!commit) {
          conn.commit()
        }
      }
      //      DbUtils.release(conn, prepareStatement)
      DruidConnection.release(null, prepareStatement, conn)
    }

    try {
      frame_result
        .repartition(partitions)
        .write
        .format("jdbc")
        .option("url", url)
        .option("user", userName)
        .option("password", pwd)

        /**
         * 同时注意表字段和你sparksql要一致,否则报错
         * Exception in thread "main" org.apache.spark.sql.AnalysisException:
         * Column "xxx" not found in schema Some(StructType(StructField(id,StringType,true),
         *
         * 实际上:DataFrame字段是目标表字段的子集即可
         */
        .option("dbtable", target_table)
        .mode(SaveMode.Append)
        .save()
    } catch {
      case e: SparkException => logger.error("Exception in writing data to MySQL database by spark task {}:", e)
        System.exit(-1)
    }


    val count = frame_result.count()

    mySaveMode match {

      case MySaveMode.OverWriteByQuarter =>

        /**
         * dt is 2023-09-18 and cnts is 8808
         * dt is 2023-09-17 and cnts is 1052
         * dt is 2023-09-03 and cnts is 367
         * dt is 2023-09-02 and cnts is 270
         * dt is 2023-09-01 and cnts is 858
         *
         */
        val map: util.Map[String, Integer] = new MysqlDao().getResultQuarter(host, port, db, userName, pwd, target_table, dt)

        var sum = 0

        map.forEach(new BiConsumer[String, Integer] {
          override def accept(t: String, u: Integer): Unit = {

            sum += u
          }
        })


        if (count != sum) {

          throw new DataSyncException(s"同步数据到mysql ${db}.${target_table}遗漏${count - sum}条...")

        } else {
          println(s"同步数据到mysql ${db}.${target_table} end")
        }


      case _ =>

        val map: util.Map[String, Integer] = new MysqlDao().getResultMap(host, port, db, userName, pwd, target_table, dt, dt)

        val count2 = map.getOrDefault(if (dt != null) dt else "1970-01-01", 0) - 0

        if (count != count2) {

          throw new DataSyncException(s"同步数据到mysql ${db}.${target_table}遗漏${count - count2}条...")

        } else {
          println(s"同步数据到mysql ${db}.${target_table} end")
        }


    }


  }

  def sink_to_mysql(host: String,
                    port: String,
                    db: String,
                    target_table: String,
                    userName: String,
                    pwd: String,
                    frame_result: DataFrame,
                    partitions: Int,
                    mySaveMode: MySaveMode,
                    dateStrs: String
                   ): Unit = {
    if (frame_result.isEmpty) {
      println(s"the dataset of dt ${dateStrs} sink to mysql is empty,return...")
      return
    }

    val url = s"jdbc:mysql://${host}:${port}/${db}?rewriteBatchedStatements=true&useUnicode=true&characterEncoding=UTF-8&connectionRequestTimout=300000&connectionTimeout=300000&socketTimeout=300000&useSSL=false"
    val conn: Connection = DriverManager.getConnection(url, userName, pwd)
    val statement: Statement = conn.createStatement()
    var prepareStatement: PreparedStatement = null

    try {
      if (mySaveMode == MySaveMode.OverWriteAllTable) {
        prepareStatement = conn.prepareStatement(s"DELETE FROM $target_table WHERE 1=1")
      } else if (mySaveMode == MySaveMode.OverWriteByDt) {
        prepareStatement = conn.prepareStatement(s"DELETE FROM $target_table WHERE dt in ($dateStrs)")
      } else throw new IllegalArgumentException("未知的SaveMode参数:" + mySaveMode)
      prepareStatement.execute()

    } catch {
      case e: Exception => logger.error("MySQL connection exception {}:", e)
        System.exit(-1)
    } finally {
      //      DbUtils.release(conn, prepareStatement)
      DruidConnection.release(null, prepareStatement, conn)
    }

    try {
      frame_result
        .filter(x => {

         var flag = false

          if (mySaveMode == MySaveMode.OverWriteAllTable) {

            flag = true

          } else if (mySaveMode == MySaveMode.OverWriteByDt) {

            flag = dateStrs.contains(x.getAs[String]("dt"))
          }

          flag

        })
        .repartition(partitions)
        .write
        .format("jdbc")
        .option("url", url)
        .option("user", userName)
        .option("password", pwd)
        .option("dbtable", target_table)
        .mode(SaveMode.Append)
        .save()
    } catch {
      case e: SparkException => logger.error("Exception in writing data to MySQL database by spark task {}:", e)
        System.exit(-1)
    }
  }

  def sink_to_mysql(host0: String, db: String, target_table: String, userName: String, pwd: String, frame_result: DataFrame, partitions: Int, mySaveMode: MySaveMode
                    , dt: String): Unit = {

    sink_to_mysql(host0, db, target_table, userName, pwd, frame_result, partitions, mySaveMode, dt, isRollback = true)


  }


  def sink_to_hive(dt: String, spark: SparkSession, frame_result: DataFrame, hive_db: String, hive_table: String): Unit = {
    sink_to_hive(dt, spark, frame_result, hive_db, hive_table, "parquet")
  }

  def sink_to_hive(dt: String, spark: SparkSession, frame_result: DataFrame, hive_db: String, hive_table: String, format: String): Unit = {
    sink_to_hive(dt, spark, frame_result, hive_db, hive_table, format, MySaveMode.OverWriteByDt)
  }

  def sink_to_hive(dt: String, spark: SparkSession, frame_result: DataFrame, hive_db: String, hive_table: String, format: String, mySaveMode: MySaveMode): Unit = {
    sink_to_hive(dt, spark, frame_result, hive_db, hive_table, format, mySaveMode, 1)
  }

  def sink_to_hive(dt: String, spark: SparkSession, frame_result: DataFrame, hive_db: String, hive_table: String, format: String, mySaveMode: MySaveMode, partitions: Int): Unit = {
    sink_to_hive(dt, spark, frame_result, hive_db, hive_table, format, mySaveMode, partitions, null)
  }


  /**
   * 只导入dt该天对应的数据
   * dt can be null,in this time,data will not be filter by dt
   *
   */
  def sink_to_hive(dt: String, spark: SparkSession, frame_result: DataFrame, hive_db: String, hive_table: String
                   , format: String, mySaveMode: MySaveMode, partitions: Int, comment: java.util.Map[String, String]): Unit = {


    val f1 = frame_result


    if (StringUtils.isBlank(dt) || mySaveMode == MySaveMode.OverWriteAllTable) {

      var f2 = f1

      /**
       * 处理表字段的注释信息
       */
      if (null != comment && !comment.isEmpty) {
        f2 = TableUtils.addComment(spark, f1, comment)
      }

      /**
       * overWriteAllByDt
       */

      sink_to_hive_OverWriteAllTable(f2, hive_db, hive_table, format, partitions)


      return
    }


    try {
      var f3 = TableUtils.formattedData(spark, f1.filter(x => {

        dt.equalsIgnoreCase(x.getAs[String]("dt"))

      }))


      var diff: Set[String] = Set()
      var diff1: Set[String] = Set()

      if (TableUtils.tableExists(spark, hive_db, hive_table)) {

        TableUtils.delPartitions(spark, dt, dt, hive_db, hive_table)

        println(s"${hive_db}.${hive_table}表存在,删除${dt}分区的数据")

        diff = f3.columns.toSet &~ spark.sql(
          s"""
            select * from $hive_db.$hive_table where dt='9999-99-99'
            """.stripMargin).columns.toSet

        diff1 = spark.sql(
          s"""
            select * from $hive_db.$hive_table where dt='9999-99-99'
            """.stripMargin).columns.toSet diff f3.columns.toSet


      } else {
        println(s"${hive_table}表不存在,write directly")
      }


      /**
       * 业务表新增的字段,补充
       */
      diff.foreach(x => {


        println("add col in hive table:" + x)
        spark.sql(
          s"""
             |ALTER TABLE ${hive_db}.${hive_table} ADD COLUMNS(${x} STRING)
             |""".stripMargin)

      }
      )

      /**
       * 业务表减少的字段,用空串占位,但不能保证hive表Schema的一致
       */
      for (elem <- diff1) {
        println("add col in dataFrame:" + elem)
        f3 = f3.withColumn(elem, lit(""))
      }


      /**
       * 处理表字段的注释信息
       */
      if (null != comment && !comment.isEmpty) {
        f3 = TableUtils.addComment(spark, f3, comment)
      }

      /**
       * 分区处理
       * partition指定为0的时候,按照数据量分区,3w/分区
       * 大于0的时候,使用指定数值进行分区
       */
      var parti = 3
      if (partitions == 0) {
        parti = (f3.count() / 30000).toInt
      } else {
        parti = partitions
      }

      println("frame.rdd.getNumPartitions:" + f3.rdd.getNumPartitions)

      if (parti > f3.rdd.getNumPartitions) {

        println(s"repartition ${f3.rdd.getNumPartitions} to ${parti}")

        f3
          .repartition(parti)
          .write
          .format(format)
          .mode(if (mySaveMode == MySaveMode.OverWriteByDt) SaveMode.Append else SaveMode.Overwrite)
          .partitionBy("dt")
          .saveAsTable(s"$hive_db.$hive_table")

      } else if (parti == f3.rdd.getNumPartitions) {

        println(s"not coalesce ${f3.rdd.getNumPartitions} to ${parti}")

        f3
          .write
          .format(format)
          .mode(if (mySaveMode == MySaveMode.OverWriteByDt) SaveMode.Append else SaveMode.Overwrite)
          .partitionBy("dt")
          .saveAsTable(s"$hive_db.$hive_table")


      } else {

        println(s"coalesce ${f3.rdd.getNumPartitions} to ${parti}")

        f3
          .coalesce(parti)
          .write
          .format(format)
          .mode(if (mySaveMode == MySaveMode.OverWriteByDt) SaveMode.Append else SaveMode.Overwrite)
          .partitionBy("dt")
          .saveAsTable(s"$hive_db.$hive_table")

      }
    } catch {
      case e: SparkException => logger.error("The spark task failed to write data to hive database  : {}", e)
      case e2: java.lang.Exception => logger.error(s"The spark task failed to write data to hive database of error2  : ${e2}", e2)
      case _: java.lang.Exception => logger.error("The spark task failed to write data to hive database,未知异常")
        System.exit(-1)
    }
  }

  /**
   * rg.apache.spark.sql.AnalysisException:
   * The format of the existing table paascloud.dwd_order_info_abi is `HiveFileFormat`.
   * It doesn't match the specified format `ParquetFileFormat`
   * 兼容原生hive表(格式为HiveFileFormat)
   * 使用该api要保证DataFrame和hive的字段数量和字段顺序的一致性
   *
   *
   */
  def sink_to_hive_HiveFileFormat(spark: SparkSession, f: DataFrame, hive_db: String, hive_table: String, conf: util.HashMap[String, Object]): Unit = {

    val view = "v_" + new Random().nextInt(100) + 1

    f.createOrReplaceTempView(view)

    if (null != conf) {

      val keys: Array[AnyRef] = conf.keySet().toArray()

      keys.foreach(x => {

        val k = x
        val v = conf.getOrDefault(k, "").toString

        /**
         * 有意义的值才set
         */
        if (StringUtils.isNotBlank(v)) {

          println(s"set ${k}=${v}")
          spark.sql(s"set ${k}=${v}")
        }

      })
    }

    spark.sql(
      s"""
         |
         |insert overwrite table ${hive_db}.${hive_table} partition(dt)
         |select * from ${view}
         |
         |""".stripMargin)
  }


  /**
   * 全量抽取,全量覆盖
   * 针对es没有时间戳无法增量同步的索引
   */
  def sink_to_hive_OverWriteAllTable(frame_result: DataFrame, hive_db: String, hive_table: String, format: String, partitions: Int): Unit = {


    frame_result
      .repartition(partitions)
      .write
      .format(format)
      .mode(SaveMode.Overwrite)
      .partitionBy("dt")
      .saveAsTable(s"$hive_db.$hive_table")


  }

}

common_mysql_to_hive

import com.mingzhi.ConfigUtils
import com.mingzhi.common.interf.{IDate, LoadStrategy, MySaveMode}
import com.mingzhi.common.utils._
import org.apache.commons.lang3.StringUtils
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.storage.StorageLevel


/**
 * parquet不支持date日期格式的问题:
 * UnsupportedOperationException Parquet does not support date. See HIVE-6384
 * 因此新增该类,增加hive格式参数parquet or orc json
 * 增加mysql的port参数支持
 * 不支持自定义hive表名,也就是说hive表名和源mysql表名原则上一致,但是支持前缀以区分不同业务库的同名表
 * 并统一导入策略
 * 支持全量导数策略下的多表批量导入
 */
object common_mysql_to_hive {

  //===========================params define=====================================
  /**
   * ip or "ip:port"
   */
  private var mysql_host: String = ""
  private var mysql_db = ""
  private var mysql_tables = ""
  private var mysql_username = ""
  private var mysql_pwd = ""


  private var hive_db = ""
  private var hive_table_prefix = "" //usually is empty
  private var format = "parquet"

  private var dt: String = ""
  private var dt1: String = ""

  /**
   * "ct_or_ut_time" or ""
   */
  private var create_time = ""
  private var update_time = ""

  /**
   * 增加该参数的目的:
   * 避免业务库频繁的更改表名直接对数仓表名的影响
   */
  private var hive_tables = "";

  /**
   * 增加该参数的目的,避免百分之99的无效数据的导入(全量表情况下可用)
   */
  private var condition = "1=1"


  /**
   * 全局列裁剪,对所有的表生效
   * 该参数的目的,去掉一些无用的列,例如图片
   */
  private var global_dropCols = ""


  /**
   * 分区合并
   */
  private var partitions = 1

  /**
   * 是否需要格式化
   */
  var formattedData = false //默认不需要格式化数据


  //================================================================

  private var mysql_port = "3306"


  /**
   * 默认抽数策略为抽新增
   * 实际根据程序传参参数确定策略
   */
  var load_strategy: LoadStrategy.Value = LoadStrategy.NEW

  var saveMode: MySaveMode.Value = MySaveMode.Ignore

  private def parseSaveMode(sMode: String): Unit = {

    saveMode = sMode match {
      case "OverWriteByDt" => MySaveMode.OverWriteByDt
      case "OverWriteByMonth" => MySaveMode.OverWriteByMonth
      case "OverWriteByQuarter" => MySaveMode.OverWriteByQuarter
      case "OverWriteAllTable" => MySaveMode.OverWriteAllTable

      case _ => MySaveMode.Ignore
    }
  }


  def main(args: Array[String]): Unit = {

    System.setProperty("HADOOP_USER_NAME", "root")

    val builder = SparkUtils.getBuilder

    //本地或者集群
    if (ConfigUtils.isWindowEvn) {

      builder.master("local[*]")


      mysql_host = s"${ConfigUtils.SEGI_MYSQL_HOST}:3308"
      //      mysql_host = s"${ConfigUtils.UAC_MYSQL_HOST}:3306"
      //      mysql_db = "vwork"
      mysql_db = "contract"
      //      mysql_db = "paascloud_uac_203"
      //      mysql_tables = "si_main,hr_staff_base,hr_staff_base_ext,hr_staff_certificate_management,hr_staff_edu,hr_staff_contract" +
      //        ",pt_post_genre_sp_mingzhe,pt_code_type_value,pt_post_level" +
      //        ",pf_cost_seal,si_cost_seal" +
      //        ",pt_staff_stru,hr_staff_leaveoffice"

      //      mysql_tables = "t_record_detailed,administrative_organization,t_uac_user,staff,organization_position"

      //      mysql_tables = "km_asset_card,km_asset_address,km_asset_category,km_provider_main,sys_org_element"

      mysql_tables = "ctt_contract,ctt_contract_attr_inst,ctt_target,ctt_contract_user_info,segi_organ,segi_tb_uhome_region,ctt_template_attr_value,segi_organ_com_rel,segi_tb_uhome_community"
      //      mysql_table = "t_uac_organization,t_uac_building,t_uac_building_room,t_uac_organization_auth,t_uac_organization_expand,t_uac_subsystem,t_uac_user"
      mysql_username = ConfigUtils.SEGI_MYSQL_USERNAME
      mysql_pwd = ConfigUtils.SEGI_MYSQL_PWD


      //      hive_db = "hr"
      hive_db = "contract"
      hive_table_prefix = ""
      format = "orc"

      /**
       * 全量:“”,""
       * 增量之新增:create_time,""
       * 增量之新增和变化:crate_time,update_time
       */
      create_time = ""
      update_time = ""

      dt = "2022-01-01"
      dt1 = "2022-01-01"

      //      hive_tables = "t_wms_warehouse_inventory_details,t_wms_material,t_wms_material_classify,t_uac_organization_expand"
      //      hive_tables = "km_asset_card,km_asset_address,km_asset_category,km_provider_main,sys_org_element"
      hive_tables = "";

      /**
       * 优点码民主测评数据
       */
      //      condition = "record_id IN ('7011560341845074035170451296007','70115366179389031651655254075925')"


    } else {
      mysql_host = args(0)
      mysql_db = args(1)
      mysql_tables = args(2)
      mysql_username = args(3)
      mysql_pwd = args(4)
      hive_db = args(5)
      hive_table_prefix = args(6)
      format = if (StringUtil.isBlank(args(7))) "parquet" else args(7)

      create_time = args(8)
      update_time = args(9)

      dt = args(10)
      dt1 = args(11)

      if (args.length == 13) {
        hive_tables = args(12)
      }
      else if (args.length == 14) {
        hive_tables = args(12)
        condition = args(13)
      }
      else if (args.length == 15) {
        hive_tables = args(12)
        condition = args(13)
        global_dropCols = args(14)
      } else if (args.length == 16) {
        hive_tables = args(12)
        condition = args(13)
        global_dropCols = args(14)
        partitions = if (StringUtils.isBlank(args(15))) 1 else Integer.parseInt(args(15))
      } else if (args.length == 17) {
        hive_tables = args(12)
        condition = args(13)
        global_dropCols = args(14)
        partitions = if (StringUtil.isBlank(args(15))) 1 else Integer.parseInt(args(15)) //默认分区数量1
        formattedData = if (StringUtil.isBlank(args(16))) false else args(16).toBoolean //是否需要格式化,默认是false不需要的,需要转换的时候需要设置true
      } else if (args.length == 18) {
        hive_tables = args(12)
        condition = args(13)
        global_dropCols = args(14)
        partitions = if (StringUtil.isBlank(args(15))) 1 else Integer.parseInt(args(15))
        formattedData = if (StringUtil.isBlank(args(16))) false else args(16).toBoolean
        val saveMode = args(17)

        parseSaveMode(saveMode)

      }
    }

    StringUtil.assertNotBlank(mysql_host, "mysql_host can not be null")
    StringUtil.assertNotBlank(mysql_db, "mysql_db can not be null")
    StringUtil.assertNotBlank(mysql_tables, "mysql_tables can not be null")
    StringUtil.assertNotBlank(mysql_username, "mysql_username can not be null")
    StringUtil.assertNotBlank(mysql_pwd, "mysql_pwd can not be null")
    StringUtil.assertNotBlank(hive_db, "hive_db can not be null")

    if (dt.length != "yyyy-MM-dd".length || dt1.length != "yyyy-MM-dd".length) {
      throw new IllegalArgumentException("dt格式必须为yyyy-MM-dd")
    }

    /**
     * 参数解析,确定抽数策略
     */
    if (mysql_host.split(":").length >= 2) {
      mysql_port = mysql_host.split(":")(1)
      mysql_host = mysql_host.split(":")(0)
    }

    if (StringUtil.isBlank(create_time) && StringUtil.isBlank(update_time)) {

      load_strategy = LoadStrategy.ALL
      println("1 load_strategy:" + load_strategy)

    } else if (StringUtil.isNotBlank(create_time) && StringUtil.isBlank(update_time)) {

      load_strategy = LoadStrategy.NEW
      println("2 load_strategy:" + load_strategy)

    } else if (StringUtil.isNotBlank(create_time) && StringUtil.isNotBlank(update_time)) {

      load_strategy = LoadStrategy.NEW_AND_CHANGE
      println("3 load_strategy:" + load_strategy)
    } else {
      throw new IllegalArgumentException("create_time 和 update_time参数组合异常")
    }

    println("load_strategy:" + load_strategy)

    var tables: Array[(String, String)] = mysql_tables.split(",").zip(mysql_tables.split(","))

    if (StringUtil.isNotBlank(hive_tables)) {

      if (mysql_tables.split(",").length != hive_tables.split(",").length) {
        throw new IllegalArgumentException("源表和目标hive表数量不一致...")
      }

      tables = mysql_tables.split(",").zip(hive_tables.split(""","""))

    }


    if (tables.length > 1 && load_strategy != LoadStrategy.ALL) {
      throw new IllegalArgumentException("非全量导入不支持多表批量操作...")
    }

    if (tables.length > 1 && StringUtil.isNotBlank(condition) && (!"1=1".equalsIgnoreCase(condition))) {
      throw new IllegalArgumentException("带有条件的导入不支持多表批量操作...")
    }


    val spark: SparkSession = builder
      .appName("mysql_to_hive")
      .getOrCreate()

    new IDate {
      override def onDate(dt: String): Unit = {

        tables
          .foreach(t => {


            var mysql_table = t._1 //like t_customer_report:remark or t_customer_report
            var dropCols = global_dropCols

            var hive_table = t._2 // maybe t_a:x-y-z or t_a

            if (mysql_table.split(":").length >= 2) {
              dropCols = mysql_table.split(":")(1)
              mysql_table = mysql_table.split(":")(0)
            }


            if (hive_table.split(":").length >= 2) {
              hive_table = hive_table.split(":")(0)
            }
            var frame_result = processSource(spark, dt, mysql_table, load_strategy, condition, dropCols)
              .persist(StorageLevel.MEMORY_ONLY_SER_2)

            if (formattedData) {
              frame_result = TableUtils.formattedData(spark, frame_result)
            }

            println(s"frame_result of ${t._1}===>")
            TableUtils.show(frame_result)

            /**
             * 全量表默认不进行分区,全表覆盖,只有一份
             */
            var mySaveMode = if (load_strategy == LoadStrategy.ALL) MySaveMode.OverWriteAllTable else MySaveMode.OverWriteByDt


            mySaveMode = saveMode match {

              case MySaveMode.OverWriteByDt => MySaveMode.OverWriteByDt
              case MySaveMode.OverWriteByMonth => MySaveMode.OverWriteByMonth
              case MySaveMode.OverWriteByQuarter => MySaveMode.OverWriteByQuarter
              case MySaveMode.OverWriteAllTable => MySaveMode.OverWriteAllTable
              case _ => mySaveMode

            }


            SinkUtil.sink_to_hive(dt, spark, frame_result, hive_db, s"$hive_table_prefix$hive_table", format, mySaveMode, partitions)

            frame_result.unpersist()
          })

      }
    }.invoke(dt, dt1)

    spark.stop()

  }

  def processSource(spark: SparkSession, dt: String, t_source: String, load_strategy: LoadStrategy.strategy, condition: String, dropCols: String): DataFrame = {

    var frame_mysql: DataFrame = null
    var frame_result: DataFrame = null

    println("process load_strategy:" + load_strategy)

    if (load_strategy == LoadStrategy.NEW) {

      frame_result = SourceUtil.sourceNewFromMysql(spark, mysql_host, mysql_port, mysql_db, t_source, mysql_username, mysql_pwd, create_time, dt, condition)
        .withColumn("dt", lit(dt))

    } else if (load_strategy == LoadStrategy.ALL) {

      frame_mysql = SourceUtil.sourceAllFromMysql(spark, mysql_host, mysql_port, mysql_db, t_source, mysql_username, mysql_pwd, condition, dropCols)

      frame_result = frame_mysql.withColumn("dt", lit(dt))

    } else if (load_strategy == LoadStrategy.NEW_AND_CHANGE) {

      frame_result = SourceUtil.sourceNewAndChangeFromMysql(spark, mysql_host, mysql_port, mysql_db, t_source, mysql_username, mysql_pwd, create_time, update_time, dt)
        .withColumn("dt", lit(dt))
    }

    frame_result
  }
}


hive_to_mysql

import org.apache.commons.lang3.StringUtils
import org.apache.commons.lang3.time.DateFormatUtils
import org.apache.spark.sql.SparkSession
import org.apache.spark.storage.StorageLevel
import org.slf4j.LoggerFactory

import java.util.Date

/**
 * 通用的从数仓同步数据到mysql的类
 * 支持自动构建mysql表ddl
 * 支持按dt覆盖或者全表覆盖
 * 支持列裁剪
 */
object hive_to_mysql {
  private val logger = LoggerFactory.getLogger(hive_to_mysql.getClass)

  def main(args: Array[String]): Unit = {


    var mysql_host = "192.168.xx.xx:3306"
    var from_db = "xx"
    var from_tables = "ads_user_analyse_month:a-b-c"
    var to_db = "mz_olap"
    var to_tables = "wfs_ads_user_analyse_month"
    var userName = ""
    var passWd = ""

    var dt = "2023-09-01"
    var dt1 = "2023-09-01"

    var saveMode = MySaveMode.OverWriteByDt
    var global_dropCols = ""

    System.setProperty("HADOOP_USER_NAME", "root")
    val builder = SparkUtils.getBuilder


    if (System.getProperties.getProperty("os.name").contains("Windows")) {

      builder.master("local[*]")
    } else {


      mysql_host = args(0)
      from_db = args(1)
      from_tables = args(2)
      to_db = args(3)
      to_tables = args(4)
      userName = args(5)
      passWd = args(6)
      dt = args(7)
      dt1 = args(8)

      if (args.length == 10) {

        saveMode = defineSaveMode(args(9))

      }
      else if (args.length == 11) {

        saveMode = defineSaveMode(args(9))
        global_dropCols = args(10)
      }

      StringUtil.assertNotBlank(mysql_host, "mysql_host can not be null")
      StringUtil.assertNotBlank(from_db, "from_db can not be null")
      StringUtil.assertNotBlank(from_tables, "from_tables can not be null")
      StringUtil.assertNotBlank(to_db, "to_db can not be null")
      StringUtil.assertNotBlank(to_tables, "to_tables can not be null")
      StringUtil.assertNotBlank(userName, "userName can not be null")
      StringUtil.assertNotBlank(passWd, "passWd can not be null")

      if (from_tables.split(",").length != to_tables.split(",").length) {
        throw new IllegalArgumentException("源表和目标表数量不一致...")
      }

    }


    /**
     * 同步异常抛出
     */
    val spark = builder
      .appName("hive_to_mysql")
      .getOrCreate()

    from_tables.split(",").zip(to_tables.split(",")).foreach(r => {

      var hive_table = r._1 //like t_a:x-y-z or t_a
      var dropCols = global_dropCols

      if (hive_table.split(":").length >= 2) {
        dropCols = hive_table.split(":")(1)
        hive_table = hive_table.split(":")(0)
      }


      process(spark, mysql_host, from_db, hive_table, to_db, r._2, userName, passWd, dt, dt1, saveMode, dropCols)

    })

    spark.stop()
    logger.info("The task execution completed ......")
  }

  def process(spark: SparkSession, mysql_host: String, from_db: String, from_table: String
              , to_db: String, to_table: String, userName: String, passWd: String, dt: String, dt1: String, saveMode: MySaveMode, dropCols: String): Unit = {

    if (MySaveMode.OverWriteByDt == saveMode || MySaveMode.OverWriteAllTable == saveMode) {

      processByDt(spark, mysql_host, from_db, from_table, to_db, to_table, userName, passWd, dt, dt1, saveMode, dropCols)

    } else if (MySaveMode.OverWriteByQuarter == saveMode) {

      processByQuarter(spark, mysql_host, from_db, from_table, to_db, to_table, userName, passWd, dt, dt1, saveMode, dropCols)

    } else if (MySaveMode.OverWriteByMonth == saveMode) {

      new IMonth {
        override def onMonth(dMonth: String): Unit = {

          /**
           * dMonth指向每月份的第一天 yyyy-MM-01
           * 因此实际按照dt处理
           */
          processByDt(spark, mysql_host, from_db, from_table, to_db, to_table, userName, passWd, dMonth, dMonth, MySaveMode.OverWriteByDt, dropCols)

        }
      }.invoke(dt, dt1)

    } else {
      throw new IllegalArgumentException("未知的SaveMode参数:" + saveMode)
    }
  }

  private def processByQuarter(spark: SparkSession
                               , mysql_host: String, from_db: String, from_table: String, to_db: String, to_table: String, userName: String, passWd: String
                               , dt: String, dt1: String
                               , saveMode: MySaveMode
                               , dropCols: String): Unit = {

    val dt_start = DateUtil.getQuarterStart(dt)
    val dt_end = DateUtil.getQuarterEnd(dt)

    var frame = spark.sql(
      s"""
         |select *,'${DateFormatUtils.format(new Date(), "yyyy-MM-dd HH:mm:ss")}' as last_update_time from $from_db.$from_table where dt between '$dt_start' and '$dt_end'
         |""".stripMargin)
      .persist(StorageLevel.MEMORY_ONLY_SER_2)

    if (StringUtils.isNotBlank(dropCols)) {

      dropCols.split(",").foreach(col => {
        frame = frame.drop(col)
      })

    }

    frame.show(true)

    SinkUtil.sink_to_mysql(mysql_host, to_db, to_table, userName, passWd, frame, 5, saveMode, dt)
  }


  private def processByDt(spark: SparkSession
                          , mysql_host: String, from_db: String, from_table: String, to_db: String, to_table: String, userName: String, passWd: String
                          , dt: String, dt1: String
                          , saveMode: MySaveMode
                          , dropCols: String): Unit = {
    new IDate {
      override def onDate(dt: String): Unit = {

        var frame = spark.sql(
          s"""
             |select *,'${DateFormatUtils.format(new Date(), "yyyy-MM-dd HH:mm:ss")}' as last_update_time from $from_db.$from_table where dt='$dt'
             |""".stripMargin)
          .persist(StorageLevel.MEMORY_ONLY_SER_2)

        if (StringUtils.isNotBlank(dropCols)) {

          dropCols.split(",").foreach(col => {
            frame = frame.drop(col)
          })

        }

        frame.show(true)

        SinkUtil.sink_to_mysql(mysql_host, to_db, to_table, userName, passWd, frame, 5, saveMode, dt)

        frame.unpersist()
      }
    }.invoke(dt, dt1)
  }

  def defineSaveMode(saveMode: String): interf.MySaveMode.Value = {
    var mode = MySaveMode.Ignore

    if (MySaveMode.OverWriteByDt.toString.equalsIgnoreCase(saveMode)) {

      mode = MySaveMode.OverWriteByDt

    } else if (MySaveMode.OverWriteAllTable.toString.equalsIgnoreCase(saveMode)) {

      mode = MySaveMode.OverWriteAllTable

    } else if (MySaveMode.OverWriteByQuarter.toString.equalsIgnoreCase(saveMode)) {

      mode = MySaveMode.OverWriteByQuarter

    } else if (MySaveMode.OverWriteByMonth.toString.equalsIgnoreCase(saveMode)) {

      mode = MySaveMode.OverWriteByMonth

    } else {
      throw new IllegalArgumentException("未知的SaveMode参数:" + saveMode)
    }

    mode
  }
}

common_es_to_hive

import org.apache.commons.lang3.StringUtils
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.storage.StorageLevel
import org.elasticsearch.hadoop.cfg.ConfigurationOptions

/**
 * 全量抽数支持多表批量导入
 * 支持非默认端口9200
 * 支持非默认_doc
 * 支持历史数据批量抽取
 * 支持三种抽数策略
 * 支持指定存储格式(json,orc,parquet(default))
 * 支持指定分区数
 */
object common_es_to_hive {

  //===========================params define start=====================================
  private var es_host: String = ""
  private var es_indexes = ""

  private var hive_db = ""
  private var hive_tables = ""

  /**
   * "ct_or_ut_time" or ""
   */
  private var create_time = ""
  private var update_time = ""

  private var dt: String = ""
  private var dt1: String = ""
  //===========================params define end=====================================

  private var es_port = "9200"
  private var es_doc = "_doc"
  private var format = "parquet"
  private var partitions = 1


  /**
   * 默认抽数策略为抽新增
   * 实际根据程序传参参数确定策略
   */
  var load_strategy: LoadStrategy.Value = LoadStrategy.NEW

  def main(args: Array[String]): Unit = {

    System.setProperty("HADOOP_USER_NAME", "root")

    val builder = SparkUtils.getBuilder

    //本地或者集群
    if (System.getProperties.getProperty("os.name").contains("Windows")) {

      builder.master("local[*]")

      es_host = s"${ConfigUtils.IOT_ES_HOST}"
      es_indexes = "wfs_order_material_index_2021"

      hive_db = "paascloud"
      hive_tables = "wfs_order_material_index"

      /**
       * 全量:“”,""
       * 增量之新增:create_time,""
       * 增量之新增和变化:crate_time,update_time
       */
      //      create_time = "orderCreateTime"
      create_time = ""
      //      update_time = "orderUpdateTime"
      update_time = ""

      dt = "2020-12-15"
      dt1 = "2020-12-15"


    } else {
      es_host = args(0)
      es_indexes = args(1) //may be xxx_index:_docx,yyy_index:_docy
      hive_db = args(2)
      hive_tables = args(3)

      create_time = args(4)
      update_time = args(5)

      dt = args(6)
      dt1 = args(7)

      if (args.length == 9) {

        format = processFormat(args(8))

      } else if (args.length == 10) {
        format = processFormat(args(8))

        partitions = if (StringUtils.isEmpty(args(9))) partitions else args(9).toInt
      }
    }

    StringUtil.assertNotBlank(es_host, "es_host can not be null")
    StringUtil.assertNotBlank(es_indexes, "es_indexes can not be null")
    StringUtil.assertNotBlank(hive_db, "hive_db can not be null")
    StringUtil.assertNotBlank(hive_tables, "hive_tables can not be null")


    /**
     * 解析端口
     */
    if (es_host.split(":").length >= 2) {
      es_port = es_host.split(":")(1)
      es_host = es_host.split(":")(0)
    }


    /**
     * 参数解析,确定抽数策略
     */
    if (StringUtils.isBlank(create_time) && StringUtils.isBlank(update_time)) {

      load_strategy = LoadStrategy.ALL
      println("1 load_strategy:" + load_strategy)

    } else if (StringUtils.isNotBlank(create_time) && StringUtils.isBlank(update_time)) {

      load_strategy = LoadStrategy.NEW
      println("2 load_strategy:" + load_strategy)

    } else if (StringUtils.isNotBlank(create_time) && StringUtils.isNotBlank(update_time)) {

      load_strategy = LoadStrategy.NEW_AND_CHANGE
      println("3 load_strategy:" + load_strategy)
    } else {
      throw new IllegalArgumentException("create_time 和 update_time参数组合异常")
    }

    val spark: SparkSession = builder
      .config(ConfigurationOptions.ES_NODES, es_host)
      .config(ConfigurationOptions.ES_PORT, es_port)
      .config(ConfigurationOptions.ES_SCROLL_SIZE, 10000)
      .config(ConfigurationOptions.ES_INPUT_USE_SLICED_PARTITIONS, false)
      .config(ConfigurationOptions.ES_BATCH_SIZE_ENTRIES, 20000)
      .config(ConfigurationOptions.ES_BATCH_WRITE_REFRESH, false)
      .config(ConfigurationOptions.ES_MAX_DOCS_PER_PARTITION, 200000)
      .config(ConfigurationOptions.ES_HTTP_TIMEOUT, "5m")
      .config(ConfigurationOptions.ES_SCROLL_KEEPALIVE, "10m")
      .appName("es_to_hive")
      .getOrCreate()

    println("load_strategy:" + load_strategy)

    if (es_indexes.split(",").length > 1 && load_strategy != LoadStrategy.ALL) {
      throw new IllegalArgumentException("非全量导入不支持多表批量操作...")
    }

    new IDate {
      override def onDate(dt: String): Unit = {

        val indexes = es_indexes.split(",")
        val tables = hive_tables.split(",")

        if (indexes.length != tables.length) {
          throw new IllegalArgumentException("索引数量和hive表数量不一致...")
        }

        indexes.zip(tables).foreach(x => {

          var index = x._1 // may be xxx_index:_doc1
          val table = x._2

          println(s"index is $index and table is $table")

          /**
           * 解析不规范的_doc
           */
          if (index.split(":").length >= 2) {
            es_doc = index.split(":")(1)
            index = index.split(":")(0) // must be xxx_index
          }

          println("es_doc:" + es_doc)

          val frame_result = source(spark, dt, index, es_doc, load_strategy).persist(StorageLevel.MEMORY_ONLY_SER_2)

          println("frame_result===>")
          frame_result.show(3, true)

          /**
           * 全量表不进行分区,全表覆盖,只有一份
           */
          val mySaveMode = if (load_strategy == LoadStrategy.ALL) MySaveMode.OverWriteAllTable else MySaveMode.OverWriteByDt

          if (!frame_result.isEmpty) {
            SinkUtil.sink_to_hive(dt, spark, frame_result, hive_db, table, format, mySaveMode, partitions)

          }

          frame_result.unpersist()
        })

      }
    }.invoke(dt, dt1)

    spark.stop()

  }

  private def processFormat(format: String): String = {

    val result = if (StringUtils.isBlank(format)) "parquet" else format

    result match {
      case "orc" | "parquet" =>
      case _ => throw new IllegalArgumentException("format must be one of orc or parquet,not support json csv text")
    }

    result
  }

  def source(spark: SparkSession, dt: String, t_source: String, _doc: String, load_strategy: LoadStrategy.strategy): DataFrame = {

    var frame_result: DataFrame = null

    if (load_strategy == LoadStrategy.NEW) {

      frame_result = SourceUtil.sourceNewFromEs(spark, index = t_source, _doc, create_time, dt)

    } else if (load_strategy == LoadStrategy.ALL) {

      frame_result = SourceUtil.sourceAllFromEs(spark, index = t_source, _doc)

    } else if (load_strategy == LoadStrategy.NEW_AND_CHANGE) {

      frame_result = SourceUtil.sourceNewAndChangeFromEs(spark, index = t_source, _doc, create_time, update_time, dt)

    }

    frame_result.persist(StorageLevel.DISK_ONLY)

    val f1: DataFrame = frame_result
      .withColumn("dt", lit(dt))

    f1
  }
}

hive_to_es

import org.apache.commons.lang3.time.DateFormatUtils
import org.apache.spark.storage.StorageLevel
import org.elasticsearch.hadoop.cfg.ConfigurationOptions
import org.elasticsearch.spark.sql.EsSparkSQL
import org.slf4j.LoggerFactory

import java.util.Date

/**
 * 通用的从数仓同步数据到ElasticSearch的类
 *
 * 增加:在默认的pk上,支持指定pk
 * 增加:支持列裁剪
 */
object hive_to_es {

  private var es_port = "9200"


  private val logger = LoggerFactory.getLogger(hive_to_change_mysql.getClass)

  def main(args: Array[String]): Unit = {
    System.setProperty("HADOOP_USER_NAME", "root")
    var from_db = "temp"
    //    var from_table = "wfs_order_list_index_map"
    //    var from_table = "wfs_order_track_index_map"
    var from_table = "wfs_correlation_info_index_map"

    var es_host = "192.168.33.163"
    //    var to_index = "wfs_order_list_index_temp_2023"
    //    var to_index = "wfs_order_track_index_temp_2023"
    var to_index = "wfs_correlation_info_index_temp_2023"

    var dt = "2023-10-01"
    var dt1 = "2023-11-07"

    var pk = "id"
    var dropCols = "dt,last_update_time"

    //    val builder = SparkSession.builder()

    val builder = SparkUtils.getBuilder

    if (System.getProperties.getProperty("os.name").contains("Windows")) {
      builder.master("local[*]")
    } else {


      from_db = args(0)
      from_table = args(1)
      es_host = args(2)
      to_index = args(3)
      dt = args(4)
      dt1 = args(5)

      pk = "pk"

      if (args.length == 7) {
        pk = args(6)
      } else if (args.length == 8) {
        pk = args(6)
        dropCols = args(7)
      }


    }

    println("mappingId is " + pk)

    /**
     * 解析端口
     */
    if (es_host.split(":").length >= 2) {
      es_port = es_host.split(":")(1)
      es_host = es_host.split(":")(0)
    }


    builder.config("spark.sql.parquet.writeLegacyFormat", "true")
      .enableHiveSupport()
      .appName(from_table + "_to_es")
    val spark = builder.getOrCreate()

    new IDate {
      override def onDate(dt: String): Unit = {
        var frame = spark.sql(
          s"""
             |select *,'${DateFormatUtils.format(new Date(), "yyyy-MM-dd HH:mm:ss")}' as last_update_time from $from_db.$from_table where dt='$dt'
             |""".stripMargin)
          .repartition(3)
          .persist(StorageLevel.MEMORY_ONLY_SER_2)
        frame.show(false)

        if (StringUtil.isNotBlank(dropCols)) {

          dropCols.split(",").foreach(c => {

            frame = frame.drop(c)

          })

        }

        val map = new scala.collection.mutable.HashMap[String, String]
        map += ConfigurationOptions.ES_NODES -> es_host
        map += ConfigurationOptions.ES_PORT -> es_port
        map += ConfigurationOptions.ES_MAPPING_ID -> pk

        EsSparkSQL.saveToEs(frame, s"/${to_index}/_doc", map)

        frame.unpersist()
      }
    }.invoke(dt, dt1)

    spark.stop()
    logger.info("The task execution completed ......")

  }
}

doris_to_hive

import org.apache.spark.SparkException
import org.apache.spark.sql.SparkSession
import org.slf4j.LoggerFactory


/**
 * 用于从doris恢复数据
 * */
object doris_to_hive {
  private val logger = LoggerFactory.getLogger(this.getClass)

  def main(args: Array[String]): Unit = {

    var doris_host = "192.168.20.179:8030"
    var from_db = "test_db" //
    var from_tables = "wfs_dwd_order_info_doris"
    var to_db = "paascloud" //
    var to_tables = "dwd_order_info_doris"
    var userName = "root"
    var passWd = "123456"

    var dt = "2021-11-15"
    var dt1 = "2021-11-15"


    System.setProperty("HADOOP_USER_NAME", "root")
    val builder = SparkUtils.getBuilder


    if (System.getProperties.getProperty("os.name").contains("Windows")) {

      builder.master("local[*]")
    } else {
      try {
        doris_host = args(0)
        from_db = args(1)
        from_tables = args(2)
        to_db = args(3)
        to_tables = args(4)
        userName = args(5)
        passWd = args(6)
        dt = args(7)
        dt1 = args(8)

        StringUtil.assertNotBlank(doris_host, "doris_host can not be null")
        StringUtil.assertNotBlank(from_db, "from_db can not be null")
        StringUtil.assertNotBlank(from_tables, "from_tables can not be null")
        StringUtil.assertNotBlank(to_db, "to_db can not be null")
        StringUtil.assertNotBlank(to_tables, "to_tables can not be null")
        StringUtil.assertNotBlank(userName, "userName can not be null")
        StringUtil.assertNotBlank(passWd, "passWd can not be null")

        if (from_tables.split(",").length != to_tables.split(",").length) {
          throw new IllegalArgumentException("源表和目标表数量不一致...")
        }

      } catch {
        case e: Exception => logger.error("Parameter exception: {} ", e)
          logger.error("doris_to_hive error:", e.toString)
          System.exit(-1)
      }
    }


    try {
      val spark = builder
        .appName(this.getClass.getSimpleName)
        .getOrCreate()

      from_tables.split(",").zip(to_tables.split(",")).foreach(r => {

        process(spark, doris_host, from_db, r._1, to_db, r._2, userName, passWd, dt, dt1)

      })

      spark.stop()
      logger.info("The task execution completed ......")
    } catch {
      case e: SparkException => logger.error("SparkException:DatabasesName={},TableName={},exception info {}", to_db, to_tables, e.getMessage)
      case e: Exception => logger.error("Exception:DatabasesName={},TableName={},exception info {}", to_db, to_tables, e.getMessage)
      case _: java.lang.Exception => logger.error("The spark task failed to write data to mysql database,未知异常")
        System.exit(-1)
    }
  }

  def process(spark: SparkSession, doris_host: String, from_db: String, from_table: String
              , to_db: String, to_table: String, userName: String, passWd: String, dt: String, dt1: String): Unit = {

    new IDate {
      override def onDate(dt: String): Unit = {


        //读取数据
        //        val dorisSparkDF = spark.read.format("doris")
        //          .option("doris.table.identifier", "test_db.wfs_dwd_order_info_doris")
        //          .option("doris.fenodes", s"$doris_host")
        //          .option("user", "root")
        //          .option("password", "123456")
        //          .load()
        //        dorisSparkDF.show()


        val v = s"view_$from_table"

        spark.sql(
          s"""
             |CREATE TEMPORARY VIEW $v
             |USING doris
             |OPTIONS(
             | "table.identifier"='$from_db.$from_table',
             | "fenodes"='$doris_host',
             | "user"='$userName',
             | "password"='$passWd'
             |)
             |""".stripMargin)

        val f1 = spark.sql(
          s"""
             |select * from $v where dt='$dt'
             |
             |""".stripMargin)


        println("f1 show")
        f1.show(false)

        import spark.implicits._
        val f2 = f1.
          as[X]
          .map(e => {

            val life = e.operatelifecycle

            println("life:" + life)

            e

          })

        println("f2 show")
        f2.show()

      }
    }.invoke(dt, dt1)

  }

  private case class X(
                        orderId: String
                        , dt: String
                        , operatelifecycle: String

                      )

}

hive_to_doris

import org.apache.commons.lang3.time.DateFormatUtils
import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession}
import org.apache.spark.storage.StorageLevel
import org.slf4j.LoggerFactory

import java.util.Date

object hive_to_doris {
  private val logger = LoggerFactory.getLogger(this.getClass)

  def main(args: Array[String]): Unit = {

    //    var doris_host = "192.168.20.179:8030"
    var doris_host = "192.168.20.179:9030"
    var from_db = "paascloud"
    var from_tables = "dwd_order_info_doris"
    var to_db = "test_db"
    var to_tables = "wfs_dwd_order_info_doris"
    var userName = "root"
    var passWd = "123456"

    var dt = "2021-11-15"
    var dt1 = "2021-11-15"


    System.setProperty("HADOOP_USER_NAME", "root")
    val builder = SparkUtils.getBuilder


    if (System.getProperties.getProperty("os.name").contains("Windows")) {

      builder.master("local[*]")
    } else {

      doris_host = args(0)
      from_db = args(1)
      from_tables = args(2)
      to_db = args(3)
      to_tables = args(4)
      userName = args(5)
      passWd = args(6)
      dt = args(7)
      dt1 = args(8)


      StringUtil.assertNotBlank(doris_host, "doris_host can not be null")
      StringUtil.assertNotBlank(from_db, "from_db can not be null")
      StringUtil.assertNotBlank(from_tables, "from_tables can not be null")
      StringUtil.assertNotBlank(to_db, "to_db can not be null")
      StringUtil.assertNotBlank(to_tables, "to_tables can not be null")
      StringUtil.assertNotBlank(userName, "userName can not be null")
      StringUtil.assertNotBlank(passWd, "passWd can not be null")

      if (from_tables.split(",").length != to_tables.split(",").length) {
        throw new IllegalArgumentException("源表和目标表数量不一致...")
      }


    }

    val spark = builder
      .appName(this.getClass.getSimpleName)
      .getOrCreate()

    from_tables.split(",").zip(to_tables.split(",")).foreach(r => {

      var hive_table = r._1 //like t_a:x-y-z or t_a
      var dropCols = ""

      if (hive_table.split(":").length >= 2) {
        dropCols = hive_table.split(":")(1)
        hive_table = hive_table.split(":")(0)
      }

      process(spark, doris_host, from_db, hive_table, to_db, r._2, userName, passWd, dt, dt1, dropCols)

    })

    spark.stop()
    logger.info("The task execution completed ......")

  }


  def process(spark: SparkSession, doris_host: String, from_db: String, from_table: String
              , to_db: String, to_table: String, userName: String, passWd: String, dt: String, dt1: String, dropCols: String): Unit = {

    println("dropCols====================>" + dropCols)

    new IDate {
      override def onDate(dt: String): Unit = {

        val frame0: DataFrame = spark.sql(
          s"""
             |select *,'${DateFormatUtils.format(new Date(), "yyyy-MM-dd HH:mm:ss")}' as last_update_time from ${from_db}.$from_table where dt='$dt'
             |""".stripMargin)
          .persist(StorageLevel.MEMORY_ONLY)

        var frame1 = frame0

        dropCols.split("-").foreach(col => {

          println("drop col===================>" + col)

          frame1 = frame1.drop(col)

        })


        /**
         * must,but dont known how to set the value
         * java.sql.BatchUpdateException: errCode = 2, detailMessage = all partitions have no load data
         * maybe the reason is that the doris cluster has only one node!!!
         */
        val r = frame1.coalesce(1)
        

        println("source count:" + r.count())

        r
          .write
          .format("jdbc")
          .mode(SaveMode.Append)
          .option("url", s"jdbc:mysql://$doris_host/$to_db?rewriteBatchedStatements=true")
          .option("user", userName)
          .option("password", passWd)
          .option("dbtable", to_table)
          .option("batchsize", "50000")
          .save()

      }
    }.invoke(dt, dt1)

  }
}


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值