1. 适用场景
Bucketizer将连续特征的列映射到特征桶的列。
该源码适用于以下场景:
- 将连续特征划分为离散的桶,以进行进一步的处理或建模。
- 处理包含连续特征的数据集,并将其转换为具有桶特征的新数据集。
- 可以同时映射多个列,适用于批量处理多个连续特征的情况。
2. 多种主要用法及其代码示例
使用单个列进行映射
import org.apache.spark.ml.feature.Bucketizer
// 准备输入数据集
val data = Array(-0.5, 0.1, 1.5, 2.0)
val dataFrame = spark.createDataFrame(data.map(Tuple1.apply)).toDF("features")
// 设置桶划分点
val splits = Array(-Double.PositiveInfinity, 0.0, 1.0, Double.PositiveInfinity)
// 创建Bucketizer对象并设置参数
val bucketizer = new Bucketizer()
.setInputCol("features")
.setOutputCol("bucketedFeatures")
.setSplits(splits)
// 应用Bucketizer进行转换
val bucketedData = bucketizer.transform(dataFrame)
bucketedData.show()
输出结果:
+--------+----------------+
|features|bucketedFeatures|
+--------+----------------+
| -0.5| 0.0|
| 0.1| 1.0|
| 1.5| 2.0|
| 2.0| 2.0|
+--------+----------------+
使用多个列进行映射
import org.apache.spark.ml.feature.Bucketizer
// 准备输入数据集
val data = Array((-0.5, 10.0), (0.1, 20.0), (1.5, 30.0), (2.0, 40.0))
val dataFrame = spark.createDataFrame(data).toDF("features1", "features2")
// 设置桶划分点
val splitsArray = Array(
Array(-Double.PositiveInfinity, 0.0, Double.PositiveInfinity),
Array(0.0, 1.0, 2.0, Double.PositiveInfinity)
)
// 创建Bucketizer对象并设置参数
val bucketizer = new Bucketizer()
.setInputCols(Array("features1", "features2"))
.setOutputCols(Array("bucketedFeatures1", "bucketedFeatures2"))
.setSplitsArray(splitsArray)
// 应用Bucketizer进行转换
val bucketedData = bucketizer.transform(dataFrame)
bucketedData.show()
输出结果:
+---------+---------+-------------------+-------------------+
|features1|features2|bucketedFeatures1 |bucketedFeatures2 |
+---------+---------+-------------------+-------------------+
|-0.5 |10.0 |0.0 |1.0 |
|0.1 |20.0 |1.0 |2.0 |
|1.5 |30.0 |2.0 |3.0 |
|2.0 |40.0 |2.0 |4.0 |
+---------+---------+-------------------+-------------------+
3. 中文源码
/**
* `Bucketizer`将连续特征的列映射到特征桶的列。
*
* 自2.3.0版本起,
* `Bucketizer`可以通过设置`inputCols`参数一次性映射多个列。注意,当`inputCol`和`inputCols`参数都设置时,将抛出异常。
* `splits`参数仅用于单列使用,`splitsArray`用于多列。
*/
@Since("1.4.0")
final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Model[Bucketizer] with HasHandleInvalid with HasInputCol with HasOutputCol
with HasInputCols with HasOutputCols with DefaultParamsWritable {
@Since("1.4.0")
def this() = this(Identifiable.randomUID("bucketizer"))
/**
* 将连续特征映射为桶的参数。对于n+1个分割点,有n个桶。
* 由分割点x、y定义的桶包含范围[x,y),最后一个桶也包括y。分割点应具有大于或等于3个且严格递增的长度。
* 需要明确提供-inf、inf处的值以覆盖所有Double值;否则,超出指定分割点范围的值将被视为错误。
*
* 另请参见[[handleInvalid]],它可以选择为NaN值创建额外的桶。
*
* @group param
*/
@Since("1.4.0")
val splits: DoubleArrayParam = new DoubleArrayParam(this, "splits",
"用于将连续特征映射到桶中的分割点。对于n+1个分割点,有n个桶。由分割点x、y定义的桶包含范围[x,y),最后一个桶也包括y。" +
"分割点应具有长度>=3且严格递增。需要明确提供-inf、inf处的值以覆盖所有Double值;否则,超出指定分割点范围的值将被视为错误。",
Bucketizer.checkSplits)
/** @group getParam */
@Since("1.4.0")
def getSplits: Array[Double] = $(splits)
/** @group setParam */
@Since("1.4.0")
def setSplits(value: Array[Double]): this.type = set(splits, value)
/** @group setParam */
@Since("1.4.0")
def setInputCol(value: String): this.type = set(inputCol, value)
/** @group setParam */
@Since("1.4.0")
def setOutputCol(value: String): this.type = set(outputCol, value)
/**
* 参数用于处理无效条目的方式。选项有'skip'(过滤掉具有无效值的行),
* 'error'(抛出错误)或'keep'(将无效值保留在特殊的附加桶中)。
* 注意,在多列情况下,无效处理适用于所有列。对于'error',如果任何列中存在无效值,它将抛出错误;
* 对于'skip',如果任何列中存在无效值,它将跳过具有任何无效值的行等等。
* 默认值:"error"
* @group param
*/
@Since("2.1.0")
override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
"如何处理无效条目。选项为skip(过滤掉具有无效值的行),error(抛出错误)或keep(将无效值保留在特殊的附加桶中)。",
ParamValidators.inArray(Bucketizer.supportedHandleInvalids))
/** @group setParam */
@Since("2.1.0")
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
setDefault(handleInvalid, Bucketizer.ERROR_INVALID)
/**
* 用于指定多个分割参数的参数。此数组中的每个元素都可以用于将连续特征映射到桶中。
*
* @group param
*/
@Since("2.3.0")
val splitsArray: DoubleArrayArrayParam = new DoubleArrayArrayParam(this, "splitsArray",
"映射连续特征到多列桶的分割点数组。对于每个输入列,n+1个分割点将产生n个桶。由分割点x、y定义的桶包含范围[x,y)," +
"最后一个桶也包括y。分割点应具有长度>=3且严格递增。需要明确提供-inf、inf处的值以覆盖所有Double值;" +
"否则,超出指定分割点范围的值将被视为错误。",
Bucketizer.checkSplitsArray)
/** @group getParam */
@Since("2.3.0")
def getSplitsArray: Array[Array[Double]] = $(splitsArray)
/** @group setParam */
@Since("2.3.0")
def setSplitsArray(value: Array[Array[Double]]): this.type = set(splitsArray, value)
/** @group setParam */
@Since("2.3.0")
def setInputCols(value: Array[String]): this.type = set(inputCols, value)
/** @group setParam */
@Since("2.3.0")
def setOutputCols(value: Array[String]): this.type = set(outputCols, value)
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
val transformedSchema = transformSchema(dataset.schema)
val (inputColumns, outputColumns) = if (isSet(inputCols)) {
($(inputCols).toSeq, $(outputCols).toSeq)
} else {
(Seq($(inputCol)), Seq($(outputCol)))
}
val (filteredDataset, keepInvalid) = {
if (getHandleInvalid == Bucketizer.SKIP_INVALID) {
// 如果设置了“skip” NaN选项,则过滤掉数据集中的NaN值
(dataset.na.drop(inputColumns).toDF(), false)
} else {
(dataset.toDF(), getHandleInvalid == Bucketizer.KEEP_INVALID)
}
}
val seqOfSplits = if (isSet(inputCols)) {
$(splitsArray).toSeq
} else {
Seq($(splits))
}
val bucketizers: Seq[UserDefinedFunction] = seqOfSplits.zipWithIndex.map { case (splits, idx) =>
udf { (feature: Double) =>
Bucketizer.binarySearchForBuckets(splits, feature, keepInvalid)
}.withName(s"bucketizer_$idx")
}
val newCols = inputColumns.zipWithIndex.map { case (inputCol, idx) =>
bucketizers(idx)(filteredDataset(inputCol).cast(DoubleType))
}
val metadata = outputColumns.map { col =>
transformedSchema(col).metadata
}
filteredDataset.withColumns(outputColumns, newCols, metadata)
}
private def prepOutputField(splits: Array[Double], outputCol: String): StructField = {
val buckets = splits.sliding(2).map(bucket => bucket.mkString(", ")).toArray
val attr = new NominalAttribute(name = Some(outputCol), isOrdinal = Some(true),
values = Some(buckets))
attr.toStructField()
}
@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
ParamValidators.checkSingleVsMultiColumnParams(this, Seq(outputCol, splits),
Seq(outputCols, splitsArray))
if (isSet(inputCols)) {
require(getInputCols.length == getOutputCols.length &&
getInputCols.length == getSplitsArray.length, s"Bucketizer $this has mismatched Params " +
s"for multi-column transform. Params (inputCols, outputCols, splitsArray) should have " +
s"equal lengths, but they have different lengths: " +
s"(${getInputCols.length}, ${getOutputCols.length}, ${getSplitsArray.length}).")
var transformedSchema = schema
$(inputCols).zip($(outputCols)).zipWithIndex.foreach { case ((inputCol, outputCol), idx) =>
SchemaUtils.checkNumericType(transformedSchema, inputCol)
transformedSchema = SchemaUtils.appendColumn(transformedSchema,
prepOutputField($(splitsArray)(idx), outputCol))
}
transformedSchema
} else {
SchemaUtils.checkNumericType(schema, $(inputCol))
SchemaUtils.appendColumn(schema, prepOutputField($(splits), $(outputCol)))
}
}
@Since("1.4.1")
override def copy(extra: ParamMap): Bucketizer = {
defaultCopy[Bucketizer](extra).setParent(parent)
}
}
@Since("1.6.0")
object Bucketizer extends DefaultParamsReadable[Bucketizer] {
private[feature] val SKIP_INVALID: String = "skip"
private[feature] val ERROR_INVALID: String = "error"
private[feature] val KEEP_INVALID: String = "keep"
private[feature] val supportedHandleInvalids: Array[String] =
Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID)
/**
* 要求分割点的长度大于等于3,并且严格递增。
* 不应接受NaN分割点。
*/
private[feature] def checkSplits(splits: Array[Double]): Boolean = {
if (splits.length < 3) {
false
} else {
var i = 0
val n = splits.length - 1
while (i < n) {
if (splits(i) >= splits(i + 1) || splits(i).isNaN) return false
i += 1
}
!splits(n).isNaN
}
}
/**
* 检查分割点数组中的每个分割点。
*/
private[feature] def checkSplitsArray(splitsArray: Array[Array[Double]]): Boolean = {
splitsArray.forall(checkSplits(_))
}
/**
* 在多个桶中进行二分搜索以将每个数据点放置到相应的桶中。
* @param splits 分割点数组
* @param feature 数据点
* @param keepInvalid NaN标志。
* 设置为"true"以为NaN值创建一个额外的桶;
* 设置为"false"以报告NaN值的错误
* @return 每个数据点的桶
* @throws SparkException 如果特征值<分割点.head或>分割点.last
*/
private[feature] def binarySearchForBuckets(
splits: Array[Double],
feature: Double,
keepInvalid: Boolean): Double = {
if (feature.isNaN) {
if (keepInvalid) {
splits.length - 1
} else {
throw new SparkException("Bucketizer遇到NaN值。要处理或跳过NaN值,请尝试设置Bucketizer.handleInvalid。")
}
} else if (feature == splits.last) {
splits.length - 2
} else {
val idx = ju.Arrays.binarySearch(splits, feature)
if (idx >= 0) {
idx
} else {
val insertPos = -idx - 1
if (insertPos == 0 || insertPos == splits.length) {
throw new SparkException(s"特征值 $feature 超出Bucketizer范围[${splits.head}, ${splits.last}]。" +
s"请检查您的特征值或放宽下限/上限约束。")
} else {
insertPos - 1
}
}
}
}
@Since("1.6.0")
override def load(path: String): Bucketizer = super.load(path)
}
```

1600

被折叠的 条评论
为什么被折叠?



