import java.text.SimpleDateFormat
import java.util.Properties
import org.apache.spark.sql.SparkSession
import scala.collection.mutable.ArrayBuffer
object MultiplePartitionsMysql {
var spark = SparkSession
.builder()
.appName(this.getClass.getSimpleName.filter(!_.equals("$")))
//.enableHiveSupport()
.master("local[*]")
.config("hive.exec.dynamic.partition", "true")
.config("hive.exec.dynamic.partition.mode", "nonstrict")
.config("hive.exec.max.dynamic.partitions", "10000")
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
.getOrCreate()
val youqun = new Properties()
youqun.put("user", "root")
youqun.put("password", "123456")
youqun.put("driver", "com.mysql.jdbc.Driver")
youqun.put("url", "jdbc:mysql://localhost:3306/wang?characterEncoding=utf8&useSSL=false")
val partitions = 3
def main(args: Array[String]): Unit = {
spark.sparkContext.setLogLevel("ERROR")
val resultframe = readTables(youqun.getProperty("url"), "csv", youqun, 3, "load_date", "date").cache()
//val resultframe = readTables(youqun.getProperty("url"), "csv", youqun, 4, "sequence", "long")
//resultframe.show(10, false)
println(resultframe.count())
println(resultframe.rdd.getNumPartitions)
resultframe.rdd.glom().foreach(part => println(part.getClass.getName + "=======" + part.length))
}
/**
* @param url 指定url
* @param table 指定表名
* @param properties 连接rds设置参数
* @param partitions 自定义分区个数
* @param column 指定多分区划分字段
* @param columntype 多分区划分字段类型
* @return
*/
def readTables(url: String, table: String, properties: Properties, partitions: Int, column: String, columntype: String) = {
var dataFrame = spark.emptyDataFrame
//根据 指定分区字段类型判断 多分区读取数据方式
if (columntype.toLowerCase() == "long") {
val array = getMaxMin(table, properties, column)
val minNum = array(0).toLong
val maxNum = array(1).toLong
dataFrame = spark.read.jdbc(url, table, column, minNum, maxNum, partitions, properties)
//如果分区字段是时间格式,根据字段划分分区区间
} else if (columntype.toLowerCase() == "date") {
val array = getMaxMin(table, properties, column)
val arraypartition = generateArray(array, partitions, column)
dataFrame = spark.read.jdbc(url, table, arraypartition, properties)
}
dataFrame
}
def getMaxMin(table: String, properties: Properties, column: String) = {
val arrays = ArrayBuffer[String]()
val array = spark.read.jdbc(youqun.getProperty("url"), table, youqun).selectExpr(s"min(${column}) as minNum", s"max(${column}) as maxNum").collect()
if (array.length == 1) {
arrays.append((array(0)(0).toString))
arrays.append((array(0)(1).toString))
}
arrays.toArray
}
//根据最小最大值时间范围,按照指定分区个数切分成时间分区
def generateArray(minmaxNum: Array[String], partition: Int, colum: String): Array[String] = {
val array = ArrayBuffer[(String, String)]()
var resultArray = Array[String]()
//根据常见的时间格式进行调整
if (minmaxNum(0).contains("-")) {
val dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss")
var minTime = dateFormat.parse(minmaxNum(0)).getTime()
val maxTime = dateFormat.parse(minmaxNum(1)).getTime()
val subNum = (maxTime - minTime) / partition.toLong
var midNum = minTime
for (i <- 0 to partition - 1) {
minTime = midNum
midNum = midNum + subNum
if (i == partition - 1) {
array.append(dateFormat.format(minTime) -> dateFormat.format(maxTime))
} else {
array.append(dateFormat.format(minTime) -> dateFormat.format(midNum))
}
}
} else if (minmaxNum(0).contains("/")) {
val dateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss")
var minTime = dateFormat.parse(minmaxNum(0)).getTime()
val maxTime = dateFormat.parse(minmaxNum(1)).getTime()
val subNum = (maxTime - minTime) / partition.toLong
var midNum = minTime
for (i <- 0 to partition - 1) {
minTime = midNum
midNum = midNum + subNum
if (i == partition - 1) {
array.append(dateFormat.format(minTime) -> dateFormat.format(maxTime))
} else {
array.append(dateFormat.format(minTime) -> dateFormat.format(midNum))
}
}
} else {
val dateFormat = new SimpleDateFormat("yyyyMMdd HH:mm:ss")
var minTime = dateFormat.parse(minmaxNum(0)).getTime()
val maxTime = dateFormat.parse(minmaxNum(1)).getTime()
val subNum = (maxTime - minTime) / partition.toLong
var midNum = minTime
for (i <- 0 to partition - 1) {
minTime = midNum
midNum = midNum + subNum
if (i == partition - 1) {
array.append(dateFormat.format(minTime) -> dateFormat.format(maxTime))
} else {
array.append(dateFormat.format(minTime) -> dateFormat.format(midNum))
}
}
}
//根据时间划分区间,并且左闭右开,避免数据时间点重叠
resultArray = array.toArray.map {
case (start, end) => s"'${start}'<= ${colum} and ${colum} < '${end}'"
}
//将最后一个时间区间范围改为闭区间
resultArray.update(resultArray.size - 1, resultArray.last.replaceAll(s"${colum} <", s"${colum} <="))
println(resultArray.mkString(" " + "\n"))
resultArray
}
}
sparksql 多分区读RDS的两种方式(mysql 为例)
最新推荐文章于 2022-09-25 15:15:06 发布
本文介绍如何使用 Apache Spark 对 MySQL 数据库中的表进行分区读取,通过自定义分区数量和依据不同的数据类型(如 Long 和 Date)来提高数据处理效率。
3270





