Spark ML 正则化 标准化 归一化 ---- spark 中的 标准化

本文介绍了Spark中的标准化过程,包括StandardScaler类的使用方法,通过计算训练集样本的列汇总统计来标准化特征,去除均值并缩放至单位方差。代码示例和关键参数设置详尽解读,适合理解Spark特征预处理技术。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >


spark 中的标准化

Standardizes

Standardizes features by removing the mean and scaling to unit variance using column summary statistics on the samples in the training set.

The “unit std” is computed using the corrected sample standard deviation, which is computed as the square root of the unbiased sample variance.

代码:

  • https://github.com/apache/spark/blob/v3.1.2/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala

源代码


package org.apache.spark.ml.feature

import org.apache.hadoop.fs.Path

import org.apache.spark.annotation.Since
import org.apache.spark.ml._
import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.stat.Summarizer
import org.apache.spark.ml.util._
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{StructField, StructType}

/**
 * Params for [[StandardScaler]] and [[StandardScalerModel]].
 */
private[feature] trait StandardScalerParams extends Params with HasInputCol with HasOutputCol {

  /**
   * Whether to center the data with mean before scaling.
   * It will build a dense output, so take care when applying to sparse input.
   * Default: false
   * @group param
   */
  val withMean: BooleanParam = new BooleanParam(this, "withMean",
    "Whether to center data with mean")

  /** @group getParam */
  def getWithMean: Boolean = $(withMean)

  /**
   * Whether to scale the data to unit standard deviation.
   * Default: true
   * @group param
   */
  val withStd: BooleanParam = new BooleanParam(this, "withStd",
    "Whether to scale the data to unit standard deviation")

  /** @group getParam */
  def getWithStd: Boolean = $(withStd)

  /** Validates and transforms the input schema. */
  protected def validateAndTransformSchema(schema: StructType): StructType = {
    SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)
    require(!schema.fieldNames.contains($(outputCol)),
      s"Output column ${$(outputCol)} already exists.")
    val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
    StructType(outputFields)
  }

  setDefault(withMean -> false, withStd -> true)
}

/**
 * Standardizes features by removing the mean and scaling to unit variance using column summary
 * statistics on the samples in the training set.
 *
 * The "unit std" is computed using the
 * <a href="https://en.wikipedia.org/wiki/Standard_deviation#Corrected_sample_standard_deviation">
 * corrected sample standard deviation</a>,
 * which is computed as the square root of the unbiased sample variance.
 */
@Since("1.2.0")
class StandardScaler @Since("1.4.0") (
    @Since("1.4.0") override val uid: String)
  extends Estimator[StandardScalerModel] with StandardScalerParams with DefaultParamsWritable {

  @Since("1.2.0")
  def this() = this(Identifiable.randomUID("stdScal"))

  /** @group setParam */
  @Since("1.2.0")
  def setInputCol(value: String): this.type = set(inputCol, value)

  /** @group setParam */
  @Since("1.2.0")
  def setOutputCol(value: String): this.type = set(outputCol, value)

  /** @group setParam */
  @Since("1.4.0")
  def setWithMean(value: Boolean): this.type = set(withMean, value)

  /** @group setParam */
  @Since("1.4.0")
  def setWithStd(value: Boolean): this.type = set(withStd, value)

  @Since("2.0.0")
  override def fit(dataset: Dataset[_]): StandardScalerModel = {
    transformSchema(dataset.schema, logging = true)

    val Row(mean: Vector, std: Vector) = dataset
      .select(Summarizer.metrics("mean", "std").summary(col($(inputCol))).as("summary"))
      .select("summary.mean", "summary.std")
      .first()

    copyValues(new StandardScalerModel(uid, std.compressed, mean.compressed).setParent(this))
  }

  @Since("1.4.0")
  override def transformSchema(schema: StructType): StructType = {
    validateAndTransformSchema(schema)
  }

  @Since("1.4.1")
  override def copy(extra: ParamMap): StandardScaler = defaultCopy(extra)
}

@Since("1.6.0")
object StandardScaler extends DefaultParamsReadable[StandardScaler] {

  @Since("1.6.0")
  override def load(path: String): StandardScaler = super.load(path)
}

/**
 * Model fitted by [[StandardScaler]].
 *
 * @param std Standard deviation of the StandardScalerModel
 * @param mean Mean of the StandardScalerModel
 */
@Since("1.2.0")
class StandardScalerModel private[ml] (
    @Since("1.4.0") override val uid: String,
    @Since("2.0.0") val std: Vector,
    @Since("2.0.0") val mean: Vector)
  extends Model[StandardScalerModel] with StandardScalerParams with MLWritable {

  import StandardScalerModel._

  /** @group setParam */
  @Since("1.2.0")
  def setInputCol(value: String): this.type = set(inputCol, value)

  /** @group setParam */
  @Since("1.2.0")
  def setOutputCol(value: String): this.type = set(outputCol, value)

  @Since("2.0.0")
  override def transform(dataset: Dataset[_]): DataFrame = {
    val outputSchema = transformSchema(dataset.schema, logging = true)
    val shift = if ($(withMean)) mean.toArray else Array.emptyDoubleArray
    val scale = if ($(withStd)) {
      std.toArray.map { v => if (v == 0) 0.0 else 1.0 / v }
    } else Array.emptyDoubleArray

    val func = getTransformFunc(shift, scale, $(withMean), $(withStd))
    val transformer = udf(func)

    dataset.withColumn($(outputCol), transformer(col($(inputCol))),
      outputSchema($(outputCol)).metadata)
  }

  @Since("1.4.0")
  override def transformSchema(schema: StructType): StructType = {
    var outputSchema = validateAndTransformSchema(schema)
    if ($(outputCol).nonEmpty) {
      outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema,
        $(outputCol), mean.size)
    }
    outputSchema
  }

  @Since("1.4.1")
  override def copy(extra: ParamMap): StandardScalerModel = {
    val copied = new StandardScalerModel(uid, std, mean)
    copyValues(copied, extra).setParent(parent)
  }

  @Since("1.6.0")
  override def write: MLWriter = new StandardScalerModelWriter(this)

  @Since("3.0.0")
  override def toString: String = {
    s"StandardScalerModel: uid=$uid, numFeatures=${mean.size}, withMean=${$(withMean)}, " +
      s"withStd=${$(withStd)}"
  }
}

@Since("1.6.0")
object StandardScalerModel extends MLReadable[StandardScalerModel] {

  private[StandardScalerModel]
  class StandardScalerModelWriter(instance: StandardScalerModel) extends MLWriter {

    private case class Data(std: Vector, mean: Vector)

    override protected def saveImpl(path: String): Unit = {
      DefaultParamsWriter.saveMetadata(instance, path, sc)
      val data = Data(instance.std, instance.mean)
      val dataPath = new Path(path, "data").toString
      sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
    }
  }

  private class StandardScalerModelReader extends MLReader[StandardScalerModel] {

    private val className = classOf[StandardScalerModel].getName

    override def load(path: String): StandardScalerModel = {
      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
      val dataPath = new Path(path, "data").toString
      val data = sparkSession.read.parquet(dataPath)
      val Row(std: Vector, mean: Vector) = MLUtils.convertVectorColumnsToML(data, "std", "mean")
        .select("std", "mean")
        .head()
      val model = new StandardScalerModel(metadata.uid, std, mean)
      metadata.getAndSetParams(model)
      model
    }
  }

  @Since("1.6.0")
  override def read: MLReader[StandardScalerModel] = new StandardScalerModelReader

  @Since("1.6.0")
  override def load(path: String): StandardScalerModel = super.load(path)

  private[spark] def transformWithBoth(
      shift: Array[Double],
      scale: Array[Double],
      values: Array[Double]): Array[Double] = {
    var i = 0
    while (i < values.length) {
      values(i) = (values(i) - shift(i)) * scale(i)
      i += 1
    }
    values
  }

  private[spark] def transformWithShift(
      shift: Array[Double],
      values: Array[Double]): Array[Double] = {
    var i = 0
    while (i < values.length) {
      values(i) -= shift(i)
      i += 1
    }
    values
  }

  private[spark] def transformDenseWithScale(
      scale: Array[Double],
      values: Array[Double]): Array[Double] = {
    var i = 0
    while (i < values.length) {
      values(i) *= scale(i)
      i += 1
    }
    values
  }

  private[spark] def transformSparseWithScale(
      scale: Array[Double],
      indices: Array[Int],
      values: Array[Double]): Array[Double] = {
    var i = 0
    while (i < values.length) {
      values(i) *= scale(indices(i))
      i += 1
    }
    values
  }

  private[spark] def getTransformFunc(
      shift: Array[Double],
      scale: Array[Double],
      withShift: Boolean,
      withScale: Boolean): Vector => Vector = {
    (withShift, withScale) match {
      case (true, true) =>
        vector: Vector =>
          val values = vector match {
            case d: DenseVector => d.values.clone()
            case v: Vector => v.toArray
          }
          val newValues = transformWithBoth(shift, scale, values)
          Vectors.dense(newValues)

      case (true, false) =>
        vector: Vector =>
          val values = vector match {
            case d: DenseVector => d.values.clone()
            case v: Vector => v.toArray
          }
          val newValues = transformWithShift(shift, values)
          Vectors.dense(newValues)

      case (false, true) =>
        vector: Vector =>
          vector match {
            case DenseVector(values) =>
              val newValues = transformDenseWithScale(scale, values.clone())
              Vectors.dense(newValues)
            case SparseVector(size, indices, values) =>
              val newValues = transformSparseWithScale(scale, indices, values.clone())
              Vectors.sparse(size, indices, newValues)
            case v =>
              throw new IllegalArgumentException(s"Unknown vector type ${v.getClass}.")
          }

      case (false, false) =>
        vector: Vector => vector
    }
  }
}

参考文献

系列文章:

spark 中的 特征相关内容处理的文档

  • http://spark.apache.org/docs/latest/api/scala/org/apache/spark/ml/feature/index.html

概念简介

  • https://blog.youkuaiyun.com/u014381464/article/details/81101551

参考:

  • https://segmentfault.com/a/1190000014042959
  • https://www.cnblogs.com/nucdy/p/7994542.html
  • https://blog.youkuaiyun.com/weixin_34117522/article/details/88875270
  • https://blog.youkuaiyun.com/xuejianbest/article/details/85779029
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

shiter

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值