【Spark ML系列】Binarizer场景用法示例源码分析

本文详细解析了SparkMLlib库中的Binarizer类,介绍了其在二值化连续特征、设置阈值和默认阈值的应用,提供了相应的代码示例和源码解读。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

Spark MLlib中的Binarizer源码分析

1. 源码适用场景

Binarizer 是 Spark MLlib 中的一个类,用于将连续特征二值化(binarize),根据给定的阈值将连续特征转换为二进制特征。

2. 多种主要用法及其代码示例

设置阈值进行二值化

import org.apache.spark.ml.feature.Binarizer

val binarizer = new Binarizer()
  .setInputCol("continuousFeature")
  .setOutputCol("binaryFeature")
  .setThreshold(0.5)

val binaryData = binarizer.transform(data)

上述代码通过创建一个 Binarizer 对象,设置输入列名、输出列名和阈值,然后使用 transform 方法将数据集进行二值化操作。

使用默认阈值二值化

import org.apache.spark.ml.feature.Binarizer

val binarizer = new Binarizer()
  .setInputCol("continuousFeature")
  .setOutputCol("binaryFeature")

val binaryData = binarizer.transform(data)

上述代码使用默认阈值(0.0)进行二值化操作,即大于阈值的连续特征值为 1.0,小于等于阈值的连续特征值为 0.0。

3. 中文源码

​```scala
/**
 * 根据阈值将连续特征的列二值化。
 */
@Since("1.4.0")
final class Binarizer @Since("1.4.0") (@Since("1.4.0") override val uid: String)
  extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable {

  @Since("1.4.0")
  def this() = this(Identifiable.randomUID("binarizer"))

  /**
   * 用于二值化连续特征的阈值参数。
   * 大于阈值的特征将被二值化为 1.0。
   * 小于等于阈值的特征将被二值化为 0.0。
   * 默认值为 0.0
   * @group param
   */
  @Since("1.4.0")
  val threshold: DoubleParam =
    new DoubleParam(this, "threshold", "用于二值化连续特征的阈值参数")

  /** @group getParam */
  @Since("1.4.0")
  def getThreshold: Double = $(threshold)

  /** @group setParam */
  @Since("1.4.0")
  def setThreshold(value: Double): this.type = set(threshold, value)

  setDefault(threshold -> 0.0)

  /** @group setParam */
  @Since("1.4.0")
  def setInputCol(value: String): this.type = set(inputCol, value)

  /** @group setParam */
  @Since("1.4.0")
  def setOutputCol(value: String): this.type = set(outputCol, value)

  @Since("2.0.0")
  override def transform(dataset: Dataset[_]): DataFrame = {
    val outputSchema = transformSchema(dataset.schema, logging = true)
    val schema = dataset.schema
    val inputType = schema($(inputCol)).dataType
    val td = $(threshold)

    val binarizerDouble = udf { in: Double => if (in > td) 1.0 else 0.0 }
    val binarizerVector = udf { data: Vector =>
      val indices = mutable.ArrayBuilder.make[Int]
      val values = mutable.ArrayBuilder.make[Double]

      data.foreachActive { (index, value) =>
        if (value > td) {
          indices += index
          values +=  1.0
        }
      }

      Vectors.sparse(data.size, indices.result(), values.result()).compressed
    }

    val metadata = outputSchema($(outputCol)).metadata

    inputType match {
      case _:NumericType =>
        dataset.select(col("*"), binarizerDouble(col($(inputCol)).cast("double")).as($(outputCol), metadata))
      case _: VectorUDT =>
        dataset.select(col("*"), binarizerVector(col($(inputCol))).as($(outputCol), metadata))
    }
  }

  @Since("1.4.0")
  override def transformSchema(schema: StructType): StructType = {
    val inputType = schema($(inputCol)).dataType
    val outputColName = $(outputCol)

    val outCol: StructField = inputType match {
      case _: NumericType =>
        BinaryAttribute.defaultAttr.withName(outputColName).toStructField()
      case _: VectorUDT =>
        StructField(outputColName, new VectorUDT)
      case _ =>
        throw new IllegalArgumentException(s"不支持的数据类型 $inputType。")
    }

    if (schema.fieldNames.contains(outputColName)) {
      throw new IllegalArgumentException(s"输出列 $outputColName 已经存在。")
    }
    StructType(schema.fields :+ outCol)
  }

  @Since("1.4.1")
  override def copy(extra: ParamMap): Binarizer = defaultCopy(extra)
}

@Since("1.6.0")
object Binarizer extends DefaultParamsReadable[Binarizer] {

  @Since("1.6.0")
  override def load(path: String): Binarizer = super.load(path)
}
​```

4. 官方链接

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

BigDataMLApplication

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值