简介
离散余弦变换(Discrete Cosine Transform) 是将时域的N维实数序列转换成频域的N维实数序列的过程(有点类似离散傅里叶变换)。(ML中的)DCT类提供了离散余弦变换DCT-II的功能,将离散余弦变换后结果乘以 12√12 得到一个与时域矩阵长度一致的矩阵。输入序列与输出之间是一一对应的。
实战
1.spark工程的pom文件引用
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
<scala.version>2.11</scala.version>
<spark.version>2.3.0</spark.version>
</properties>
<dependencies>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_${scala.version}</artifactId>
<version>${spark.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_${scala.version}</artifactId>
<version>${spark.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_${scala.version}</artifactId>
<version>${spark.version}</version>
</dependency>
</dependencies>
2.测试数据准备
spark ml的基本以dataframe的方式处理数据,可以直接从hive中获取dataframe计算。为了演示方便,手动创建dataframe
val spark: SparkSession = SparkSession.builder().appName("SparkSql").master("local[2]").getOrCreate()
//准备示例数据,将数据转为dataframe
import spark.implicits._
val dataList: List[(Int, Double, Double, Double, Double, Double, Double)] = List(
(0, 8.9255, -6.7863, 11.9081, 5.093, 11.4607, -9.2834),
(0, 11.5006, -4.1473, 13.8588, 5.389, 12.3622, 7.0433),
(0, 8.6093, -2.7457, 12.0805, 7.8928, 10.5825, -9.0837),
(0, 11.0604, -2.1518, 8.9522, 7.1957, 12.5846, -1.8361),
(1, 9.8369, -1.4834, 12.8746, 6.6375, 12.2772, 2.4486),
(1, 11.4763, -2.3182, 12.608, 8.6264, 10.9621, 3.5609),
(0, 11.8091, -0.0832, 9.3494, 4.2916, 11.1355, -8.0198),
(0, 13.558, -7.9881, 13.8776, 7.5985, 8.6543, 0.831),
(0, 16.1071, 2.4426, 13.9307, 5.6327, 8.8014, 6.163),
(1, 12.5088, 1.9743, 8.896, 5.4508, 13.6043, -16.2859),
(0, 5.0702, -0.5447, 9.59, 4.2987, 12.391, -18.8687),
(0, 12.7188, -7.975, 10.3757, 9.0101, 12.857, -12.0852),
(0, 8.7671, -4.6154, 9.7242, 7.4242, 9.0254, 1.4247),
(1, 16.3699, 1.5934, 16.7395, 7.333, 12.145, 5.9004),
(0, 13.808, 5.0514, 17.2611, 8.512, 12.8517, -9.1622),
(0, 3.9416, 2.6562, 13.3633, 6.8895, 12.2806, -16.162),
(0, 5.0615, 0.2689, 15.1325, 3.6587, 13.5276, -6.5477),
(0, 8.4199, -1.8128, 8.1202, 5.3955, 9.7184, -17.839),
(0, 4.875, 1.2646, 11.919, 8.465, 10.7203, -0.6707),
(0, 4.409, -0.7863, 15.1828, 8.0631, 11.2831, -0.7356))
val inputDF: DataFrame = dataList.toDF("target", "feature1", "feature2", "feature3", "feature4", "feature5", "feature6")
inputDF.show()
3.将为字段转为向量
pca算法转换的是特征向量,需要将降维的字段转为向量字段(转换向量的字段必须为数字类型)。
val transCols: Array[String] = Array("feature1", "feature2", "feature3", "feature4", "feature5", "feature6")
val assembler: VectorAssembler = new VectorAssembler().setInputCols(transCols).setOutputCol("fea_vector")
val vectorDf: DataFrame = assembler.transform(inputDF)
4.DCT转换
调用ml 下的进行DCT转换。
val dct: DCT = new DCT().setInputCol("fea_vector").setOutputCol("fea_pca_vector").setInverse(false)
val dctDF: DataFrame = dct.transform(vectorDf).drop("fea_vector").drop("fea_vector")
运行结果
+------+--------+--------+--------+--------+--------+--------+--------------------+
|target|feature1|feature2|feature3|feature4|feature5|feature6| fea_pca_vector|
+------+--------+--------+--------+--------+--------+--------+--------------------+
| 0| 8.9255| -6.7863| 11.9081| 5.093| 11.4607| -9.2834|[8.70287375679244...|
| 0| 11.5006| -4.1473| 13.8588| 5.389| 12.3622| 7.0433|[18.7821158000547...|
| 0| 8.6093| -2.7457| 12.0805| 7.8928| 10.5825| -9.0837|[11.1597527936330...|
| 0| 11.0604| -2.1518| 8.9522| 7.1957| 12.5846| -1.8361|[14.6173300400586...|
| 1| 9.8369| -1.4834| 12.8746| 6.6375| 12.2772| 2.4486|[17.3878662384625...|
| 1| 11.4763| -2.3182| 12.608| 8.6264| 10.9621| 3.5609|[18.3366760903296...|
| 0| 11.8091| -0.0832| 9.3494| 4.2916| 11.1355| -8.0198|[11.6279727579660...|
| 0| 13.558| -7.9881| 13.8776| 7.5985| 8.6543| 0.831|[14.9138407734225...|
| 0| 16.1071| 2.4426| 13.9307| 5.6327| 8.8014| 6.163|[21.6687986370956...|
| 1| 12.5088| 1.9743| 8.896| 5.4508| 13.6043|-16.2859|[10.6749987735362...|
| 0| 5.0702| -0.5447| 9.59| 4.2987| 12.391|-18.8687|[4.87305571912190...|
| 0| 12.7188| -7.975| 10.3757| 9.0101| 12.857|-12.0852|[10.1659539801568...|
| 0| 8.7671| -4.6154| 9.7242| 7.4242| 9.0254| 1.4247|[12.9619648718857...|
| 1| 16.3699| 1.5934| 16.7395| 7.333| 12.145| 5.9004|[24.5280471890174...|
| 0| 13.808| 5.0514| 17.2611| 8.512| 12.8517| -9.1622|[19.7273738917947...|
| 0| 3.9416| 2.6562| 13.3633| 6.8895| 12.2806| -16.162|[9.37713663332256...|
| 0| 5.0615| 0.2689| 15.1325| 3.6587| 13.5276| -6.5477|[12.6971342058618...|
| 0| 8.4199| -1.8128| 8.1202| 5.3955| 9.7184| -17.839|[4.89987763180537...|
| 0| 4.875| 1.2646| 11.919| 8.465| 10.7203| -0.6707|[14.9309463767929...|
| 0| 4.409| -0.7863| 15.1828| 8.0631| 11.2831| -0.7356|[15.2750588608249...|
+------+--------+--------+--------+--------+--------+--------+--------------------+
5.向量列展开
降维后的字段以向量字段的形式存在,为了方便储存及查看,将向量字段转开为多个字段。
//,展开为多个字段
//将保留字段没成改为 Array[Column]
val keepCols: Array[Column] = inputDF.schema.fieldNames.map(colName => $"$colName")
//将vector转为arr的udf
val vecToArray = udf((xs: DenseVector) => xs.toArray)
//引用udf
val arrayCols: Array[Column] = Array(vecToArray($"fea_pca_vector").alias("fea_pca_array"))
val arrayDf: DataFrame = dctDF.select((keepCols ++ arrayCols): _*)
//4.将array拆分为多个字段
val strings: Array[String] = Array.tabulate(6)(i => "dct" + (i + 1))
val extendExprs: Array[Column] = strings.zipWithIndex.map { case (newCols, index) => {
$"fea_pca_array".getItem(index).alias(newCols)
}
}
val pcaTransDf: DataFrame = arrayDf.select((keepCols ++ extendExprs): _*)
pcaTransDf.show(100)
运行结果
+------+--------+--------+--------+--------+--------+--------+------------------+--------------------+-------------------+------------------+-------------------+------------------+
|target|feature1|feature2|feature3|feature4|feature5|feature6| dct1| dct2| dct3| dct4| dct5| dct6|
+------+--------+--------+--------+--------+--------+--------+------------------+--------------------+-------------------+------------------+-------------------+------------------+
| 0| 8.9255| -6.7863| 11.9081| 5.093| 11.4607| -9.2834| 8.702873756792446| 3.7237631760555563| -8.679500000000003|12.100805927981273| 2.105711901788404|13.970876916356612|
| 0| 11.5006| -4.1473| 13.8588| 5.389| 12.3622| 7.0433|18.782115800054765| -2.9886032486179634|-0.3519500000000004| 5.101878885926897| 6.166649357700933|12.129443319694582|
| 0| 8.6093| -2.7457| 12.0805| 7.8928| 10.5825| -9.0837|11.159752793633022| 5.051538627194592|-10.223850000000004| 10.95473050216207| 1.104268992365538|10.420452719913904|
| 0| 11.0604| -2.1518| 8.9522| 7.1957| 12.5846| -1.8361|14.617330040058617| 1.4384507026552864|-3.4618000000000015|10.563996063359104| 1.3009433615649835| 8.922784962071194|
| 1| 9.8369| -1.4834| 12.8746| 6.6375| 12.2772| 2.4486| 17.38786623846258| -0.5654473415764422|-3.6132999999999997| 6.087716857739034| 2.9473731242130397|10.200060077955627|
| 1| 11.4763| -2.3182| 12.608| 8.6264| 10.9621| 3.5609|18.336676090329636|-0.41245152144568986| -3.098600000000001|7.0276268968739855| 5.48015102012101| 8.824900897222047|
| 0| 11.8091| -0.0832| 9.3494| 4.2916| 11.1355| -8.0198|11.627972757966024| 7.233900231086073|-4.9258500000000005| 10.61029141949771|-1.3493541816365344|10.363654100269182|
| 0| 13.558| -7.9881| 13.8776| 7.5985| 8.6543| 0.831|14.913840773422521| 1.2416118237799985|-3.5435500000000015| 9.426575501297737| 9.968731820422297|12.197730370129644|
| 0| 16.1071| 2.4426| 13.9307| 5.6327| 8.8014| 6.163| 21.66879863709569| 4.189595832606229| 1.3533499999999996|3.2679867403341762| 5.584564816303953| 8.709516780476891|
| 1| 12.5088| 1.9743| 8.896| 5.4508| 13.6043|-16.2859| 10.67499877353623| 11.82504335885021| -9.061950000000003|15.096817657208424|-5.9430993334707125|10.972008555925967|
| 0| 5.0702| -0.5447| 9.59| 4.2987| 12.391|-18.8687| 4.873055719121901| 8.859884317393087|-13.843600000000006|12.893828232207328| -6.813425996760614| 11.80898831804594|
| 0| 12.7188| -7.975| 10.3757| 9.0101| 12.857|-12.0852| 10.16595398015684| 5.532065789130647| -9.376100000000001|18.073315118151402| 2.960478975323644| 12.97263583180883|
| 0| 8.7671| -4.6154| 9.7242| 7.4242| 9.0254| 1.4247|12.961964871885746| -1.130454477185389| -3.478300000000001| 7.627384460394446| 5.346321227722854| 7.948660904098555|
| 1| 16.3699| 1.5934| 16.7395| 7.333| 12.145| 5.9004|24.528047189017446| 2.9365384961549568|-0.9010999999999998| 4.741640594421585| 5.44614508926574|11.117915886708863|
| 0| 13.808| 5.0514| 17.2611| 8.512| 12.8517| -9.1622| 19.72737389179479| 10.932876803897779| -10.56365| 8.990198903620916|-1.5551795526026362|11.496055320592678|
| 0| 3.9416| 2.6562| 13.3633| 6.8895| 12.2806| -16.162| 9.377136633322563| 8.249556311095454| -16.2366| 9.493487396104763|-6.3050113497122275|10.543503455211898|
| 0| 5.0615| 0.2689| 15.1325| 3.6587| 13.5276| -6.5477|12.697134205861836| 2.7758697155557206|-10.138700000000002| 5.468118427302028|-2.9698897847114356|13.546276114573358|
| 0| 8.4199| -1.8128| 8.1202| 5.3955| 9.7184| -17.839| 4.899877631805378| 10.343555823319633|-11.467400000000001|14.315389604431543|-3.3817137317243957|10.150944279878784|
| 0| 4.875| 1.2646| 11.919| 8.465| 10.7203| -0.6707|14.930946376792956|-0.25143243623523215| -8.089850000000002| 4.714206309302412|0.17854557074689198| 6.615181334879805|
| 0| 4.409| -0.7863| 15.1828| 8.0631| 11.2831| -0.7356|15.275058860824947| -0.9943926679190486| -9.786250000000004| 4.120980718429373| 1.710602245068483| 9.666562364425225|
+------+--------+--------+--------+--------+--------+--------+------------------+--------------------+-------------------+------------------+-------------------+------------------+
完整代码如下
import org.apache.spark.ml.feature.{DCT, VectorAssembler}
import org.apache.spark.ml.linalg.DenseVector
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.{Column, DataFrame, SparkSession}
/**
* author :sunyiyuan
* date :Created in 2019/4/17 16:44
* description:${description}
* modified By:
*/
object DCTDemo {
def main(args: Array[String]): Unit = {
val spark: SparkSession = SparkSession.builder().appName("SparkSql").master("local[2]").getOrCreate()
//准备示例数据,将数据转为dataframe
import spark.implicits._
val dataList: List[(Int, Double, Double, Double, Double, Double, Double)] = List(
(0, 8.9255, -6.7863, 11.9081, 5.093, 11.4607, -9.2834),
(0, 11.5006, -4.1473, 13.8588, 5.389, 12.3622, 7.0433),
(0, 8.6093, -2.7457, 12.0805, 7.8928, 10.5825, -9.0837),
(0, 11.0604, -2.1518, 8.9522, 7.1957, 12.5846, -1.8361),
(1, 9.8369, -1.4834, 12.8746, 6.6375, 12.2772, 2.4486),
(1, 11.4763, -2.3182, 12.608, 8.6264, 10.9621, 3.5609),
(0, 11.8091, -0.0832, 9.3494, 4.2916, 11.1355, -8.0198),
(0, 13.558, -7.9881, 13.8776, 7.5985, 8.6543, 0.831),
(0, 16.1071, 2.4426, 13.9307, 5.6327, 8.8014, 6.163),
(1, 12.5088, 1.9743, 8.896, 5.4508, 13.6043, -16.2859),
(0, 5.0702, -0.5447, 9.59, 4.2987, 12.391, -18.8687),
(0, 12.7188, -7.975, 10.3757, 9.0101, 12.857, -12.0852),
(0, 8.7671, -4.6154, 9.7242, 7.4242, 9.0254, 1.4247),
(1, 16.3699, 1.5934, 16.7395, 7.333, 12.145, 5.9004),
(0, 13.808, 5.0514, 17.2611, 8.512, 12.8517, -9.1622),
(0, 3.9416, 2.6562, 13.3633, 6.8895, 12.2806, -16.162),
(0, 5.0615, 0.2689, 15.1325, 3.6587, 13.5276, -6.5477),
(0, 8.4199, -1.8128, 8.1202, 5.3955, 9.7184, -17.839),
(0, 4.875, 1.2646, 11.919, 8.465, 10.7203, -0.6707),
(0, 4.409, -0.7863, 15.1828, 8.0631, 11.2831, -0.7356))
val inputDF: DataFrame = dataList.toDF("target", "feature1", "feature2", "feature3", "feature4", "feature5", "feature6")
inputDF.show()
//将需要转换的列合并为向量列
val transCols: Array[String] = Array("feature1", "feature2", "feature3", "feature4", "feature5", "feature6")
val assembler: VectorAssembler = new VectorAssembler().setInputCols(transCols).setOutputCol("fea_vector")
val vectorDf: DataFrame = assembler.transform(inputDF)
//调用ml包中的PCA()
val dct: DCT = new DCT().setInputCol("fea_vector").setOutputCol("fea_pca_vector").setInverse(false)
val dctDF: DataFrame = dct.transform(vectorDf).drop("fea_vector").drop("fea_vector")
dctDF.show(100)
//展开为多个字段
//将保留字段没成改为 Array[Column]
val keepCols: Array[Column] = inputDF.schema.fieldNames.map(colName => $"$colName")
//将vector转为arr的udf
val vecToArray = udf((xs: DenseVector) => xs.toArray)
//引用udf
val arrayCols: Array[Column] = Array(vecToArray($"fea_pca_vector").alias("fea_pca_array"))
val arrayDf: DataFrame = dctDF.select((keepCols ++ arrayCols): _*)
//4.将array拆分为多个字段
val strings: Array[String] = Array.tabulate(6)(i => "dct" + (i + 1))
val extendExprs: Array[Column] = strings.zipWithIndex.map { case (newCols, index) => {
$"fea_pca_array".getItem(index).alias(newCols)
}
}
val pcaTransDf: DataFrame = arrayDf.select((keepCols ++ extendExprs): _*)
pcaTransDf.show(100)
}
}