Spark MLlib中的Binarizer源码分析
1. 源码适用场景
Binarizer
是 Spark MLlib 中的一个类,用于将连续特征二值化(binarize),根据给定的阈值将连续特征转换为二进制特征。
2. 多种主要用法及其代码示例
设置阈值进行二值化
import org.apache.spark.ml.feature.Binarizer
val binarizer = new Binarizer()
.setInputCol("continuousFeature")
.setOutputCol("binaryFeature")
.setThreshold(0.5)
val binaryData = binarizer.transform(data)
上述代码通过创建一个 Binarizer
对象,设置输入列名、输出列名和阈值,然后使用 transform
方法将数据集进行二值化操作。
使用默认阈值二值化
import org.apache.spark.ml.feature.Binarizer
val binarizer = new Binarizer()
.setInputCol("continuousFeature")
.setOutputCol("binaryFeature")
val binaryData = binarizer.transform(data)
上述代码使用默认阈值(0.0)进行二值化操作,即大于阈值的连续特征值为 1.0,小于等于阈值的连续特征值为 0.0。
3. 中文源码
```scala
/**
* 根据阈值将连续特征的列二值化。
*/
@Since("1.4.0")
final class Binarizer @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable {
@Since("1.4.0")
def this() = this(Identifiable.randomUID("binarizer"))
/**
* 用于二值化连续特征的阈值参数。
* 大于阈值的特征将被二值化为 1.0。
* 小于等于阈值的特征将被二值化为 0.0。
* 默认值为 0.0
* @group param
*/
@Since("1.4.0")
val threshold: DoubleParam =
new DoubleParam(this, "threshold", "用于二值化连续特征的阈值参数")
/** @group getParam */
@Since("1.4.0")
def getThreshold: Double = $(threshold)
/** @group setParam */
@Since("1.4.0")
def setThreshold(value: Double): this.type = set(threshold, value)
setDefault(threshold -> 0.0)
/** @group setParam */
@Since("1.4.0")
def setInputCol(value: String): this.type = set(inputCol, value)
/** @group setParam */
@Since("1.4.0")
def setOutputCol(value: String): this.type = set(outputCol, value)
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
val outputSchema = transformSchema(dataset.schema, logging = true)
val schema = dataset.schema
val inputType = schema($(inputCol)).dataType
val td = $(threshold)
val binarizerDouble = udf { in: Double => if (in > td) 1.0 else 0.0 }
val binarizerVector = udf { data: Vector =>
val indices = mutable.ArrayBuilder.make[Int]
val values = mutable.ArrayBuilder.make[Double]
data.foreachActive { (index, value) =>
if (value > td) {
indices += index
values += 1.0
}
}
Vectors.sparse(data.size, indices.result(), values.result()).compressed
}
val metadata = outputSchema($(outputCol)).metadata
inputType match {
case _:NumericType =>
dataset.select(col("*"), binarizerDouble(col($(inputCol)).cast("double")).as($(outputCol), metadata))
case _: VectorUDT =>
dataset.select(col("*"), binarizerVector(col($(inputCol))).as($(outputCol), metadata))
}
}
@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
val inputType = schema($(inputCol)).dataType
val outputColName = $(outputCol)
val outCol: StructField = inputType match {
case _: NumericType =>
BinaryAttribute.defaultAttr.withName(outputColName).toStructField()
case _: VectorUDT =>
StructField(outputColName, new VectorUDT)
case _ =>
throw new IllegalArgumentException(s"不支持的数据类型 $inputType。")
}
if (schema.fieldNames.contains(outputColName)) {
throw new IllegalArgumentException(s"输出列 $outputColName 已经存在。")
}
StructType(schema.fields :+ outCol)
}
@Since("1.4.1")
override def copy(extra: ParamMap): Binarizer = defaultCopy(extra)
}
@Since("1.6.0")
object Binarizer extends DefaultParamsReadable[Binarizer] {
@Since("1.6.0")
override def load(path: String): Binarizer = super.load(path)
}
```