spark ML特征工程离散余弦变换(dct)

本文介绍如何使用Apache Spark ML库中的DCT算法进行离散余弦变换,通过实例数据演示了从数据准备到向量组装、DCT转换、结果展示的全过程。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

简介

离散余弦变换(Discrete Cosine Transform) 是将时域的N维实数序列转换成频域的N维实数序列的过程(有点类似离散傅里叶变换)。(ML中的)DCT类提供了离散余弦变换DCT-II的功能,将离散余弦变换后结果乘以 12√12 得到一个与时域矩阵长度一致的矩阵。输入序列与输出之间是一一对应的。

实战

1.spark工程的pom文件引用

  <properties>
        <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
        <project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
        <scala.version>2.11</scala.version>
        <spark.version>2.3.0</spark.version>
    </properties>
    <dependencies>
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-sql_${scala.version}</artifactId>
            <version>${spark.version}</version>
        </dependency>
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-core_${scala.version}</artifactId>
            <version>${spark.version}</version>
        </dependency>
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-mllib_${scala.version}</artifactId>
            <version>${spark.version}</version>
        </dependency>
    </dependencies>
 

​​​​​2.测试数据准备

        spark ml的基本以dataframe的方式处理数据,可以直接从hive中获取dataframe计算。为了演示方便,手动创建dataframe

val spark: SparkSession = SparkSession.builder().appName("SparkSql").master("local[2]").getOrCreate()
    //准备示例数据,将数据转为dataframe
    import spark.implicits._
    val dataList: List[(Int, Double, Double, Double, Double, Double, Double)] = List(
      (0, 8.9255, -6.7863, 11.9081, 5.093, 11.4607, -9.2834),
      (0, 11.5006, -4.1473, 13.8588, 5.389, 12.3622, 7.0433),
      (0, 8.6093, -2.7457, 12.0805, 7.8928, 10.5825, -9.0837),
      (0, 11.0604, -2.1518, 8.9522, 7.1957, 12.5846, -1.8361),
      (1, 9.8369, -1.4834, 12.8746, 6.6375, 12.2772, 2.4486),
      (1, 11.4763, -2.3182, 12.608, 8.6264, 10.9621, 3.5609),
      (0, 11.8091, -0.0832, 9.3494, 4.2916, 11.1355, -8.0198),
      (0, 13.558, -7.9881, 13.8776, 7.5985, 8.6543, 0.831),
      (0, 16.1071, 2.4426, 13.9307, 5.6327, 8.8014, 6.163),
      (1, 12.5088, 1.9743, 8.896, 5.4508, 13.6043, -16.2859),
      (0, 5.0702, -0.5447, 9.59, 4.2987, 12.391, -18.8687),
      (0, 12.7188, -7.975, 10.3757, 9.0101, 12.857, -12.0852),
      (0, 8.7671, -4.6154, 9.7242, 7.4242, 9.0254, 1.4247),
      (1, 16.3699, 1.5934, 16.7395, 7.333, 12.145, 5.9004),
      (0, 13.808, 5.0514, 17.2611, 8.512, 12.8517, -9.1622),
      (0, 3.9416, 2.6562, 13.3633, 6.8895, 12.2806, -16.162),
      (0, 5.0615, 0.2689, 15.1325, 3.6587, 13.5276, -6.5477),
      (0, 8.4199, -1.8128, 8.1202, 5.3955, 9.7184, -17.839),
      (0, 4.875, 1.2646, 11.919, 8.465, 10.7203, -0.6707),
      (0, 4.409, -0.7863, 15.1828, 8.0631, 11.2831, -0.7356))

    val inputDF: DataFrame = dataList.toDF("target", "feature1", "feature2", "feature3", "feature4", "feature5", "feature6")
    inputDF.show()

3.将为字段转为向量

     pca算法转换的是特征向量,需要将降维的字段转为向量字段(转换向量的字段必须为数字类型)。

    val transCols: Array[String] = Array("feature1", "feature2", "feature3", "feature4", "feature5", "feature6")
    val assembler: VectorAssembler = new VectorAssembler().setInputCols(transCols).setOutputCol("fea_vector")
    val vectorDf: DataFrame = assembler.transform(inputDF)

4.DCT转换

    调用ml 下的进行DCT转换。

    val dct: DCT = new DCT().setInputCol("fea_vector").setOutputCol("fea_pca_vector").setInverse(false)
    val dctDF: DataFrame = dct.transform(vectorDf).drop("fea_vector").drop("fea_vector")

运行结果

+------+--------+--------+--------+--------+--------+--------+--------------------+
|target|feature1|feature2|feature3|feature4|feature5|feature6|      fea_pca_vector|
+------+--------+--------+--------+--------+--------+--------+--------------------+
|     0|  8.9255| -6.7863| 11.9081|   5.093| 11.4607| -9.2834|[8.70287375679244...|
|     0| 11.5006| -4.1473| 13.8588|   5.389| 12.3622|  7.0433|[18.7821158000547...|
|     0|  8.6093| -2.7457| 12.0805|  7.8928| 10.5825| -9.0837|[11.1597527936330...|
|     0| 11.0604| -2.1518|  8.9522|  7.1957| 12.5846| -1.8361|[14.6173300400586...|
|     1|  9.8369| -1.4834| 12.8746|  6.6375| 12.2772|  2.4486|[17.3878662384625...|
|     1| 11.4763| -2.3182|  12.608|  8.6264| 10.9621|  3.5609|[18.3366760903296...|
|     0| 11.8091| -0.0832|  9.3494|  4.2916| 11.1355| -8.0198|[11.6279727579660...|
|     0|  13.558| -7.9881| 13.8776|  7.5985|  8.6543|   0.831|[14.9138407734225...|
|     0| 16.1071|  2.4426| 13.9307|  5.6327|  8.8014|   6.163|[21.6687986370956...|
|     1| 12.5088|  1.9743|   8.896|  5.4508| 13.6043|-16.2859|[10.6749987735362...|
|     0|  5.0702| -0.5447|    9.59|  4.2987|  12.391|-18.8687|[4.87305571912190...|
|     0| 12.7188|  -7.975| 10.3757|  9.0101|  12.857|-12.0852|[10.1659539801568...|
|     0|  8.7671| -4.6154|  9.7242|  7.4242|  9.0254|  1.4247|[12.9619648718857...|
|     1| 16.3699|  1.5934| 16.7395|   7.333|  12.145|  5.9004|[24.5280471890174...|
|     0|  13.808|  5.0514| 17.2611|   8.512| 12.8517| -9.1622|[19.7273738917947...|
|     0|  3.9416|  2.6562| 13.3633|  6.8895| 12.2806| -16.162|[9.37713663332256...|
|     0|  5.0615|  0.2689| 15.1325|  3.6587| 13.5276| -6.5477|[12.6971342058618...|
|     0|  8.4199| -1.8128|  8.1202|  5.3955|  9.7184| -17.839|[4.89987763180537...|
|     0|   4.875|  1.2646|  11.919|   8.465| 10.7203| -0.6707|[14.9309463767929...|
|     0|   4.409| -0.7863| 15.1828|  8.0631| 11.2831| -0.7356|[15.2750588608249...|
+------+--------+--------+--------+--------+--------+--------+--------------------+

5.向量列展开

      降维后的字段以向量字段的形式存在,为了方便储存及查看,将向量字段转开为多个字段。

    //,展开为多个字段
    //将保留字段没成改为 Array[Column]
    val keepCols: Array[Column] = inputDF.schema.fieldNames.map(colName => $"$colName")
    //将vector转为arr的udf
    val vecToArray = udf((xs: DenseVector) => xs.toArray)
    //引用udf
    val arrayCols: Array[Column] = Array(vecToArray($"fea_pca_vector").alias("fea_pca_array"))
    val arrayDf: DataFrame = dctDF.select((keepCols ++ arrayCols): _*)

    //4.将array拆分为多个字段
    val strings: Array[String] = Array.tabulate(6)(i => "dct" + (i + 1))
    val extendExprs: Array[Column] = strings.zipWithIndex.map { case (newCols, index) => {
      $"fea_pca_array".getItem(index).alias(newCols)
    }
    }
    val pcaTransDf: DataFrame = arrayDf.select((keepCols ++ extendExprs): _*)
    pcaTransDf.show(100)

运行结果

+------+--------+--------+--------+--------+--------+--------+------------------+--------------------+-------------------+------------------+-------------------+------------------+
|target|feature1|feature2|feature3|feature4|feature5|feature6|              dct1|                dct2|               dct3|              dct4|               dct5|              dct6|
+------+--------+--------+--------+--------+--------+--------+------------------+--------------------+-------------------+------------------+-------------------+------------------+
|     0|  8.9255| -6.7863| 11.9081|   5.093| 11.4607| -9.2834| 8.702873756792446|  3.7237631760555563| -8.679500000000003|12.100805927981273|  2.105711901788404|13.970876916356612|
|     0| 11.5006| -4.1473| 13.8588|   5.389| 12.3622|  7.0433|18.782115800054765| -2.9886032486179634|-0.3519500000000004| 5.101878885926897|  6.166649357700933|12.129443319694582|
|     0|  8.6093| -2.7457| 12.0805|  7.8928| 10.5825| -9.0837|11.159752793633022|   5.051538627194592|-10.223850000000004| 10.95473050216207|  1.104268992365538|10.420452719913904|
|     0| 11.0604| -2.1518|  8.9522|  7.1957| 12.5846| -1.8361|14.617330040058617|  1.4384507026552864|-3.4618000000000015|10.563996063359104| 1.3009433615649835| 8.922784962071194|
|     1|  9.8369| -1.4834| 12.8746|  6.6375| 12.2772|  2.4486| 17.38786623846258| -0.5654473415764422|-3.6132999999999997| 6.087716857739034| 2.9473731242130397|10.200060077955627|
|     1| 11.4763| -2.3182|  12.608|  8.6264| 10.9621|  3.5609|18.336676090329636|-0.41245152144568986| -3.098600000000001|7.0276268968739855|   5.48015102012101| 8.824900897222047|
|     0| 11.8091| -0.0832|  9.3494|  4.2916| 11.1355| -8.0198|11.627972757966024|   7.233900231086073|-4.9258500000000005| 10.61029141949771|-1.3493541816365344|10.363654100269182|
|     0|  13.558| -7.9881| 13.8776|  7.5985|  8.6543|   0.831|14.913840773422521|  1.2416118237799985|-3.5435500000000015| 9.426575501297737|  9.968731820422297|12.197730370129644|
|     0| 16.1071|  2.4426| 13.9307|  5.6327|  8.8014|   6.163| 21.66879863709569|   4.189595832606229| 1.3533499999999996|3.2679867403341762|  5.584564816303953| 8.709516780476891|
|     1| 12.5088|  1.9743|   8.896|  5.4508| 13.6043|-16.2859| 10.67499877353623|   11.82504335885021| -9.061950000000003|15.096817657208424|-5.9430993334707125|10.972008555925967|
|     0|  5.0702| -0.5447|    9.59|  4.2987|  12.391|-18.8687| 4.873055719121901|   8.859884317393087|-13.843600000000006|12.893828232207328| -6.813425996760614| 11.80898831804594|
|     0| 12.7188|  -7.975| 10.3757|  9.0101|  12.857|-12.0852| 10.16595398015684|   5.532065789130647| -9.376100000000001|18.073315118151402|  2.960478975323644| 12.97263583180883|
|     0|  8.7671| -4.6154|  9.7242|  7.4242|  9.0254|  1.4247|12.961964871885746|  -1.130454477185389| -3.478300000000001| 7.627384460394446|  5.346321227722854| 7.948660904098555|
|     1| 16.3699|  1.5934| 16.7395|   7.333|  12.145|  5.9004|24.528047189017446|  2.9365384961549568|-0.9010999999999998| 4.741640594421585|   5.44614508926574|11.117915886708863|
|     0|  13.808|  5.0514| 17.2611|   8.512| 12.8517| -9.1622| 19.72737389179479|  10.932876803897779|          -10.56365| 8.990198903620916|-1.5551795526026362|11.496055320592678|
|     0|  3.9416|  2.6562| 13.3633|  6.8895| 12.2806| -16.162| 9.377136633322563|   8.249556311095454|           -16.2366| 9.493487396104763|-6.3050113497122275|10.543503455211898|
|     0|  5.0615|  0.2689| 15.1325|  3.6587| 13.5276| -6.5477|12.697134205861836|  2.7758697155557206|-10.138700000000002| 5.468118427302028|-2.9698897847114356|13.546276114573358|
|     0|  8.4199| -1.8128|  8.1202|  5.3955|  9.7184| -17.839| 4.899877631805378|  10.343555823319633|-11.467400000000001|14.315389604431543|-3.3817137317243957|10.150944279878784|
|     0|   4.875|  1.2646|  11.919|   8.465| 10.7203| -0.6707|14.930946376792956|-0.25143243623523215| -8.089850000000002| 4.714206309302412|0.17854557074689198| 6.615181334879805|
|     0|   4.409| -0.7863| 15.1828|  8.0631| 11.2831| -0.7356|15.275058860824947| -0.9943926679190486| -9.786250000000004| 4.120980718429373|  1.710602245068483| 9.666562364425225|
+------+--------+--------+--------+--------+--------+--------+------------------+--------------------+-------------------+------------------+-------------------+------------------+

完整代码如下

import org.apache.spark.ml.feature.{DCT, VectorAssembler}
import org.apache.spark.ml.linalg.DenseVector
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.{Column, DataFrame, SparkSession}

/**
  * author     :sunyiyuan
  * date       :Created in 2019/4/17 16:44
  * description:${description}
  * modified By:
  */
object DCTDemo {
  def main(args: Array[String]): Unit = {
    val spark: SparkSession = SparkSession.builder().appName("SparkSql").master("local[2]").getOrCreate()

    //准备示例数据,将数据转为dataframe
    import spark.implicits._
    val dataList: List[(Int, Double, Double, Double, Double, Double, Double)] = List(
      (0, 8.9255, -6.7863, 11.9081, 5.093, 11.4607, -9.2834),
      (0, 11.5006, -4.1473, 13.8588, 5.389, 12.3622, 7.0433),
      (0, 8.6093, -2.7457, 12.0805, 7.8928, 10.5825, -9.0837),
      (0, 11.0604, -2.1518, 8.9522, 7.1957, 12.5846, -1.8361),
      (1, 9.8369, -1.4834, 12.8746, 6.6375, 12.2772, 2.4486),
      (1, 11.4763, -2.3182, 12.608, 8.6264, 10.9621, 3.5609),
      (0, 11.8091, -0.0832, 9.3494, 4.2916, 11.1355, -8.0198),
      (0, 13.558, -7.9881, 13.8776, 7.5985, 8.6543, 0.831),
      (0, 16.1071, 2.4426, 13.9307, 5.6327, 8.8014, 6.163),
      (1, 12.5088, 1.9743, 8.896, 5.4508, 13.6043, -16.2859),
      (0, 5.0702, -0.5447, 9.59, 4.2987, 12.391, -18.8687),
      (0, 12.7188, -7.975, 10.3757, 9.0101, 12.857, -12.0852),
      (0, 8.7671, -4.6154, 9.7242, 7.4242, 9.0254, 1.4247),
      (1, 16.3699, 1.5934, 16.7395, 7.333, 12.145, 5.9004),
      (0, 13.808, 5.0514, 17.2611, 8.512, 12.8517, -9.1622),
      (0, 3.9416, 2.6562, 13.3633, 6.8895, 12.2806, -16.162),
      (0, 5.0615, 0.2689, 15.1325, 3.6587, 13.5276, -6.5477),
      (0, 8.4199, -1.8128, 8.1202, 5.3955, 9.7184, -17.839),
      (0, 4.875, 1.2646, 11.919, 8.465, 10.7203, -0.6707),
      (0, 4.409, -0.7863, 15.1828, 8.0631, 11.2831, -0.7356))

    val inputDF: DataFrame = dataList.toDF("target", "feature1", "feature2", "feature3", "feature4", "feature5", "feature6")
    inputDF.show()
    //将需要转换的列合并为向量列
    val transCols: Array[String] = Array("feature1", "feature2", "feature3", "feature4", "feature5", "feature6")
    val assembler: VectorAssembler = new VectorAssembler().setInputCols(transCols).setOutputCol("fea_vector")
    val vectorDf: DataFrame = assembler.transform(inputDF)
    //调用ml包中的PCA()

    val dct: DCT = new DCT().setInputCol("fea_vector").setOutputCol("fea_pca_vector").setInverse(false)
    val dctDF: DataFrame = dct.transform(vectorDf).drop("fea_vector").drop("fea_vector")
    dctDF.show(100)
    //展开为多个字段
    //将保留字段没成改为 Array[Column]
    val keepCols: Array[Column] = inputDF.schema.fieldNames.map(colName => $"$colName")
    //将vector转为arr的udf
    val vecToArray = udf((xs: DenseVector) => xs.toArray)
    //引用udf
    val arrayCols: Array[Column] = Array(vecToArray($"fea_pca_vector").alias("fea_pca_array"))
    val arrayDf: DataFrame = dctDF.select((keepCols ++ arrayCols): _*)

    //4.将array拆分为多个字段
    val strings: Array[String] = Array.tabulate(6)(i => "dct" + (i + 1))
    val extendExprs: Array[Column] = strings.zipWithIndex.map { case (newCols, index) => {
      $"fea_pca_array".getItem(index).alias(newCols)
    }
    }
    val pcaTransDf: DataFrame = arrayDf.select((keepCols ++ extendExprs): _*)
    pcaTransDf.show(100)
  }
}
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值