【Spark ML系列】LinearSVC原理源码继承关系分析
文章目录
- 【Spark ML系列】LinearSVC原理源码继承关系分析
- 一、class LinearSVC
- 二、LinearSVCModel
- 三、LinearSVC原理
- 四、"LinearSVC" 和 "Linear SVM" 一样吗?
一、class LinearSVC
class LinearSVC @Since("2.2.0") (
@Since("2.2.0") override val uid: String)
extends Classifier[Vector, LinearSVC, LinearSVCModel]
with LinearSVCParams with DefaultParamsWritable {
1. extends Classifier(通用类)
abstract class Classifier[
FeaturesType,
E <: Classifier[FeaturesType, E, M],
M <: ClassificationModel[FeaturesType, M]]
extends Predictor[FeaturesType, E, M]
with ClassifierParams {
Classifier 是一个抽象类,继承自 Predictor 和 ClassifierParams。
它定义了分类器的基本行为,并提供了一些公共方法,如获取类别数量和转换数据集。
实现类

/**
* 单标签二分类或多分类。
* 类别被索引为 {0, 1, ..., numClasses - 1}。
*
* @tparam FeaturesType 输入特征的类型,例如 `Vector`
* @tparam E 具体的估计器类型
* @tparam M 具体的模型类型
*/
abstract class Classifier[
FeaturesType,
E <: Classifier[FeaturesType, E, M],
M <: ClassificationModel[FeaturesType, M]]
extends Predictor[FeaturesType, E, M] with ClassifierParams {
/**
* 获取类别数量。首先查找列元数据中的类别数值,如果缺失,则假设类别被索引为 0,1,...,numClasses-1,
* 并通过找到最大标签值来计算类别数。
*
* 类别验证(确保所有类别都是整数且 >= 0)需要在其他地方处理,比如在 `extractLabeledPoints()` 中。
*
* @param dataset 包含列 [[labelCol]] 的数据集
* @param maxNumClasses 从数据中推断时允许的最大类别数。如果元数据中指定了 numClasses,则忽略 maxNumClasses。
* @return 类别数量
* @throws IllegalArgumentException 如果元数据未指定 numClasses,并且实际的 numClasses 超过了 maxNumClasses
*/
protected def getNumClasses(dataset: Dataset[_], maxNumClasses: Int = 100): Int = {
DatasetUtils.getNumClasses(dataset, $(labelCol), maxNumClasses)
}
/** @group setParam */
def setRawPredictionCol(value: String): E = set(rawPredictionCol, value).asInstanceOf[E]
// TODO: defaultEvaluator (follow-up PR)
}
1.1 with ClassifierParams(所有分类器)
/** 分类器的参数。 */
private[spark] trait ClassifierParams
extends PredictorParams
with HasRawPredictionCol {
override protected def validateAndTransformSchema(
schema: StructType,
fitting: Boolean,
featuresDataType: DataType): StructType = {
val parentSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType)
SchemaUtils.appendColumn(parentSchema, $(rawPredictionCol), new VectorUDT)
}
}
ClassifierParams 是分类器的参数特质,继承自 PredictorParams 和 HasRawPredictionCol。它重写了 validateAndTransformSchema 方法来验证并转换输入数据集的模式。
1.1.1 extends PredictorParams(预测回归和分类的参数)
PredictorParams 是预测器的参数特质,继承自 Params,并包含了标签列、特征列和预测列的参数,用于预测(回归和分类)的参数特质
/**
* (private[ml]) 用于预测(回归和分类)的参数特质。
*/
private[ml] trait PredictorParams extends Params
with HasLabelCol with HasFeaturesCol with HasPredictionCol {
/**
* 使用提供的参数映射验证和转换输入模式。
*
* @param schema 输入模式
* @param fitting 是否为拟合过程
* @param featuresDataType FeaturesType 的 SQL 数据类型。
* 例如,对于向量特征使用 `VectorUDT`。
* @return 输出模式
*/
protected def validateAndTransformSchema(
schema: StructType,
fitting: Boolean,
featuresDataType: DataType): StructType = {
// TODO: 支持将 Array[Double] 和 Array[Float] 转换为 Vector,当 FeaturesType = Vector 时
SchemaUtils.checkColumnType(schema, $(featuresCol), featuresDataType)
if (fitting) {
SchemaUtils.checkNumericType(schema, $(labelCol))
this match {
case p: HasWeightCol =>
if (isDefined(p.weightCol) && $(p.weightCol).nonEmpty) {
SchemaUtils.checkNumericType(schema, $(p.weightCol))
}
case _ =>
}
}
SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType)
}
}
1.1.2 with HasRawPredictionCol
1.2 extends Predictor
abstract class Predictor[
FeaturesType,
Learner <: Predictor[FeaturesType, Learner, M],
M <: PredictionModel[FeaturesType, M]]
extends Estimator[M]
with PredictorParams {
Predictor 是一个抽象类,继承自 Estimator 和 PredictorParams。它定义了预测器的基本行为,并提供了一些公共方法,如设置标签列、特征列和预测列、拟合模型等.
实现类


/**
* 预测问题(回归和分类)的抽象类。它接受所有数值类型的标签,并在 `fit()` 中自动将其转换为 DoubleType。
* 如果该预测器支持权重,则它接受所有数值类型的权重,将在 `fit()` 中自动转换为 DoubleType。
*
* @tparam FeaturesType 特征的类型
* 例如,对于向量特征使用 `VectorUDT`。
* @tparam Learner 该类的具体实现。如果您继承了此类型,请使用此类型参数来指定具体类型。
* @tparam M 该类的具体实现,继承自 [[PredictionModel]]。如果您继承了此类型,请使用此类型参数来指定相应模型的具体类型。
*/
abstract class Predictor[
FeaturesType,
Learner <: Predictor[FeaturesType, Learner, M],
M <: PredictionModel[FeaturesType, M]]
extends Estimator[M] with PredictorParams {
/** @group setParam */
def setLabelCol(value: String): Learner = set(labelCol, value).asInstanceOf[Learner]
/** @group setParam */
def setFeaturesCol(value: String): Learner = set(featuresCol, value).asInstanceOf[Learner]
/** @group setParam */
def setPredictionCol(value: String): Learner = set(predictionCol, value).asInstanceOf[Learner]
override def fit(dataset: Dataset[_]): M = {
// 处理一些内容,例如模式验证。
// 开发者只需要实现 train() 方法。
transformSchema(dataset.schema, logging = true)
// 将 LabelCol 转换为 DoubleType 并保留元数据。
val labelMeta = dataset.schema($(labelCol)).metadata
val labelCasted = dataset.withColumn($(labelCol), col($(labelCol)).cast(DoubleType), labelMeta)
// 将 WeightCol 转换为 DoubleType 并保留元数据。
val casted = this match {
case p: HasWeightCol =>
if (isDefined(p.weightCol) && $(p.weightCol).nonEmpty) {
val weightMeta = dataset.schema($(p.weightCol)).metadata
labelCasted.withColumn($(p.weightCol), col($(p.weightCol)).cast(DoubleType), weightMeta)
} else {
labelCasted
}
case _ => labelCasted
}
copyValues(train(casted).setParent(this))
}
override def copy(extra: ParamMap): Learner
/**
* 使用给定的数据集和参数训练模型。
* 开发者可以实现此方法来替代 `fit()`,以避免处理模式验证并将参数复制到模型中。
*
* @param dataset 训练数据集
* @return 拟合的模型
*/
protected def train(dataset: Dataset[_]): M
/**
* 返回与 FeaturesType 类型参数对应的 SQL 数据类型。
*
* 这用于 `validateAndTransformSchema()`。
* 这个解决方案是因为 Scala 和 Java 在 SQL 上有不同的 API。
*
* 默认值为 VectorUDT,但如果 FeaturesType 不是向量,则可能会被重写。
*/
private[ml] def featuresDataType: DataType = new VectorUDT
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema, fitting = true, featuresDataType)
}
}
1.2.1 extends Estimator[M]
抽象类 Estimator,用于将模型拟合到数据上。它是 Spark ML 中的一个基本概念,用于表示机器学习算法中的训练过程。
Estimator 类具有以下功能:
- 定义了三个 fit 方法,用于将模型拟合到输入数据上。其中,第一个 fit 方法使用可选参数对单个模型进行拟合,第二个 fit 方法使用提供的参数映射对单个模型进行拟合,第三个 fit 方法用于拟合多个模型。
- fit 方法根据传入的参数进行模型拟合,并返回拟合后的模型。
- fit 方法可以重写,以实现特定的算法优化。
- copy 方法用于复制 Estimator 对象,并在复制对象中添加额外的参数。
Estimator 类为具体的机器学习算法提供了一个统一的接口,使得用户可以方便地使用和扩展。用户可以继承 Estimator 类并实现自己的算法逻辑。
/**
* 用于将模型拟合到数据的估计器的抽象类。
*/
abstract class Estimator[M <: Model[M]] extends PipelineStage {
/**
* 使用可选参数将单个模型拟合到输入数据中。
*
* @param dataset 输入数据集
* @param firstParamPair 第一个参数对,覆盖嵌入参数
* @param otherParamPairs 其他参数对。这些值会覆盖此估计器的嵌入ParamMap中指定的任何值。
* @return 拟合的模型
*/
@Since("2.0.0")
@varargs
def fit(dataset: Dataset[_], firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): M = {
val map = new ParamMap()
.put(firstParamPair)
.put(otherParamPairs: _*)
fit(dataset, map)
}
/**
* 使用提供的参数映射将单个模型拟合到输入数据中。
*
* @param dataset 输入数据集
* @param paramMap 参数映射。
* 这些值会覆盖此估计器的嵌入ParamMap中指定的任何值。
* @return 拟合的模型
*/
@Since("2.0.0")
def fit(dataset: Dataset[_], paramMap: ParamMap): M = {
copy(paramMap).fit(dataset)
}
/**
* 将模型拟合到输入数据中。
*/
@Since("2.0.0")
def fit(dataset: Dataset[_]): M
/**
* 使用多个参数集将多个模型拟合到输入数据中。
* 默认实现在每个参数映射上使用for循环。
* 子类可以重写此方法以优化多模型训练。
*
* @param dataset 输入数据集
* @param paramMaps 参数映射的数组。
* 这些值会覆盖此估计器的嵌入ParamMap中指定的任何值。
* @return 拟合的模型,与输入参数映射相匹配
*/
@Since("2.0.0")
def fit(dataset: Dataset[_], paramMaps: Seq[ParamMap]): Seq[M] = {
paramMaps.map(fit(dataset, _))
}
override def copy(extra: ParamMap): Estimator[M]
}
1.2.2 with PredictorParams
1.2.3 M <: PredictionModel
abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, M]]
extends Model[M]
with PredictorParams {
PredictionModel 是预测模型的抽象类,继承自 Model 和 PredictorParams。
它提供了一些通用的方法,如设置特征列和预测列、获取特征数、进行转换操作等。
/**
* 预测任务(回归和分类)的模型抽象类。
*
* @tparam FeaturesType 特征的类型
* 例如,对于向量特征使用 `VectorUDT`。
* @tparam M 具体实现的 [[PredictionModel]]。如果您继承了此类型,请使用此类型参数来指定相应模型的具体类型。
*/
abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, M]]
extends Model[M] with PredictorParams {
/** @group setParam */
def setFeaturesCol(value: String): M = set(featuresCol, value).asInstanceOf[M]
/** @group setParam */
def setPredictionCol(value: String): M = set(predictionCol, value).asInstanceOf[M]
/** 返回模型训练时使用的特征数。如果未知,则返回 -1。*/
@Since("1.6.0")
def numFeatures: Int = -1
/**
* 返回与 FeaturesType 类型参数对应的 SQL 数据类型。
*
* 这用于 `validateAndTransformSchema()`。
* 这个解决方案是因为 Scala 和 Java 在 SQL 上有不同的 API。
*
* 默认值为 VectorUDT,但如果 FeaturesType 不是向量,则可能会被重写。
*/
protected def featuresDataType: DataType = new VectorUDT
override def transformSchema(schema: StructType): StructType = {
var outputSchema = validateAndTransformSchema(schema, fitting = false, featuresDataType)
if ($(predictionCol).nonEmpty) {
outputSchema = SchemaUtils.updateNumeric(outputSchema, $(predictionCol))
}
outputSchema
}
/**
* 通过读取 [[featuresCol]],调用 `predict` 方法,并将预测结果存储为新列 [[predictionCol]] 来转换数据集。
*
* @param dataset 输入数据集
* @return 具有类型为 `Double` 的 [[predictionCol]] 的转换后的数据集
*/
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
if ($(predictionCol).nonEmpty) {
transformImpl(dataset)
} else {
this.logWarning(s"$uid: Predictor.transform() 不执行任何操作,因为未设置输出列。")
dataset.toDF
}
}
protected def transformImpl(dataset: Dataset[_]): DataFrame = {
val outputSchema = transformSchema(dataset.schema, logging = true)
val predictUDF = udf { features: Any =>
predict(features.asInstanceOf[FeaturesType])
}
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))),
outputSchema($(predictionCol)).metadata)
}
/**
* 针对给定的特征预测标签。
* 此方法用于实现 `transform()` 并输出 [[predictionCol]]。
*/
@Since("2.4.0")
def predict(features: FeaturesType): Double
}
1.3 M <: ClassificationModel
abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[FeaturesType, M]]
extends PredictionModel[FeaturesType, M]
with ClassifierParams {
ClassificationModel 是分类模型的抽象类,继承自 PredictionModel 和 ClassifierParams。它提供了一些通用的方法,如设置原始预测列和预测列、获取类别数量、进行转换操作等。
1.3.1 extends PredictionModel
1.3.2 with ClassifierParams
实现类

源码
/**
* 由 [[Classifier]] 生成的模型。
* 类别被索引为 {0, 1, ..., numClasses - 1}。
*
* @tparam FeaturesType 输入特征的类型,例如 `Vector`
* @tparam M 具体的模型类型
*/
abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[FeaturesType, M]]
extends PredictionModel[FeaturesType, M] with ClassifierParams {
/** @group setParam */
def setRawPredictionCol(value: String): M = set(rawPredictionCol, value).asInstanceOf[M]
/** 类别数(标签可以取的值的数量)。*/
def numClasses: Int
override def transformSchema(schema: StructType): StructType = {
var outputSchema = super.transformSchema(schema)
if ($(predictionCol).nonEmpty) {
outputSchema = SchemaUtils.updateNumValues(schema,
$(predictionCol), numClasses)
}
if ($(rawPredictionCol).nonEmpty) {
outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema,
$(rawPredictionCol), numClasses)
}
outputSchema
}
/**
* 通过读取 [[featuresCol]] 并根据参数指定的方式添加新列进行数据集转换:
* - 将预测标签作为类型为 `Double` 的 [[predictionCol]]
* - 将原始预测值(置信度)作为类型为 `Vector` 的 [[rawPredictionCol]]
*
* @param dataset 输入数据集
* @return 转换后的数据集
*/
override def transform(dataset: Dataset[_]): DataFrame = {
val outputSchema = transformSchema(dataset.schema, logging = true)
// 只输出选定的列。
// 这里稍微有些复杂,因为它尝试避免重复计算。
var outputData = dataset
var numColsOutput = 0
if (getRawPredictionCol != "") {
val predictRawUDF = udf { features: Any =>
predictRaw(features.asInstanceOf[FeaturesType])
}
outputData = outputData.withColumn(getRawPredictionCol, predictRawUDF(col(getFeaturesCol)),
outputSchema($(rawPredictionCol)).metadata)
numColsOutput += 1
}
if (getPredictionCol != "") {
val predCol = if (getRawPredictionCol != "") {
udf(raw2prediction _).apply(col(getRawPredictionCol))
} else {
val predictUDF = udf { features: Any =>
predict(features.asInstanceOf[FeaturesType])
}
predictUDF(col(getFeaturesCol))
}
outputData = outputData.withColumn(getPredictionCol, predCol,
outputSchema($(predictionCol)).metadata)
numColsOutput += 1
}
if (numColsOutput == 0) {
logWarning(s"$uid: ClassificationModel.transform() 未执行任何操作,因为未设置输出列。")
}
outputData.toDF
}
final override def transformImpl(dataset: Dataset[_]): DataFrame =
throw new UnsupportedOperationException(s"不支持在 $getClass 中调用 transformImpl 方法")
/**
* 针对给定的特征预测标签。
* 这个方法用于实现 `transform()` 并输出 [[predictionCol]]。
*
* 对于分类,默认实现是从 `predictRaw()` 中选择最大值的索引作为预测结果。
*/
override def predict(features: FeaturesType): Double = {
raw2prediction(predictRaw(features))
}
/**
* 针对每个可能的类别进行原始预测。
* "原始" 预测的含义在不同的算法之间可能有所不同,但它直观地给出了对每个可能类别的置信度(较大的值表示更高的置信度)。
* 此内部方法用于实现 `transform()` 并输出 [[rawPredictionCol]]。
*
* @return 向量,其中第 i 个元素是类别 i 的原始预测值。
* 这些原始预测值可以是任意实数,其中较大的值表示对该类别的更高置信度。
*/
@Since("3.0.0")
def predictRaw(features: FeaturesType): Vector
/**
* 根据给定的原始预测向量选择预测的标签。
* 可以重写此方法以支持偏好特定标签的阈值。
* @return 预测的标签
*/
protected def raw2prediction(rawPrediction: Vector): Double = rawPrediction.argmax
/**
* 如果已设置原始预测列和预测列,则此方法返回当前模型,
* 否则会为它们生成新列,并将它们设置为当前模型的列。
*/
private[classification] def findSummaryModel():
(ClassificationModel[FeaturesType, M], String, String) = {
val model = if ($(rawPredictionCol).isEmpty && $(predictionCol).isEmpty) {
copy(ParamMap.empty)
.setRawPredictionCol("rawPrediction_" + java.util.UUID.randomUUID.toString)
.setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString)
} else if ($(rawPredictionCol).isEmpty) {
copy(ParamMap.empty).setRawPredictionCol("rawPrediction_" +
java.util.UUID.randomUUID.toString)
} else if ($(predictionCol).isEmpty) {
copy(ParamMap.empty).setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString)
} else {
this
}
(model, model.getRawPredictionCol, model.getPredictionCol)
}
}
2.with LinearSVCParams
private[classification] trait LinearSVCParams
extends ClassifierParams
with HasRegParam with HasMaxIter with HasFitIntercept with HasTol with HasStandardization with HasWeightCol
with HasAggregationDepth with HasThreshold with HasMaxBlockSizeInMB
这段代码定义了Apache Spark中线性支持向量机(Linear Support Vector Classification, LinearSVC)分类器的参数。使用这些参数, 可以在训练和调整线性支持向量机模型时对其进行配置。
在这段代码中,我们定义了一个名为LinearSVCParams的特质,它包含了线性支持向量机分类器的一些参数。这些参数包括正则化参数、最大迭代次数、是否拟合截距、收敛阈值、是否进行标准化、权重列等。
特别要注意的是,在LinearSVCParams特质中,我们为二分类预测中的阈值参数添加了注释。对于线性支持向量机(LinearSVC),该阈值应用于原始预测值(rawPrediction),而不是概率。该阈值可以是任意实数,其中正无穷将使所有预测为0.0,负无穷将使所有预测为1.0。默认情况下,阈值为0.0。
此外,我们还设置了一些参数的默认值,例如正则化参数为0.0、最大迭代次数为100、是否拟合截距为true、收敛阈值为1E-6、是否进行标准化为true、聚合深度为2、最大块大小为0.0。
2.1 extends ClassifierParams
/** 线性支持向量机分类器的参数。*/
private[classification] trait LinearSVCParams extends ClassifierParams with HasRegParam
with HasMaxIter with HasFitIntercept with HasTol with HasStandardization with HasWeightCol
with HasAggregationDepth with HasThreshold with HasMaxBlockSizeInMB {
/**
* 二分类预测中的阈值参数。
* 对于线性支持向量机(LinearSVC),该阈值应用于原始预测值(rawPrediction),而不是概率。
* 该阈值可以是任何实数,其中正无穷将使所有预测为0.0,
* 负无穷将使所有预测为1.0。
* 默认值为0.0。
*
* @group param
*/
final override val threshold: DoubleParam = new DoubleParam(this, "threshold",
"应用于原始预测值的二分类预测中的阈值")
setDefault(regParam -> 0.0, maxIter -> 100, fitIntercept -> true, tol -> 1E-6,
standardization -> true, threshold -> 0.0, aggregationDepth -> 2, maxBlockSizeInMB -> 0.0)
}
二、LinearSVCModel
class LinearSVCModel private[classification] (
@Since("2.2.0") override val uid: String,
@Since("2.2.0") val coefficients: Vector,
@Since("2.2.0") val intercept: Double)
extends ClassificationModel[Vector, LinearSVCModel]
with LinearSVCParams
with MLWritable
with HasTrainingSummary[LinearSVCTrainingSummary] {
1. extends ClassificationModel
2. with LinearSVCParams
注意:核心逻辑在此,以上全是陪衬(其他模型共用接口或类)。
以下代码定义了 LinearSVCModel 类及其相关辅助类和特质。它们实现了线性支持向量机模型的训练、预测和评估等功能,并提供了一些方法来处理模型的保存和加载。
/**
* 由 [[LinearSVC]] 训练的线性支持向量机模型
*/
@Since("2.2.0")
class LinearSVCModel private[classification] (
@Since("2.2.0") override val uid: String,
@Since("2.2.0") val coefficients: Vector,
@Since("2.2.0") val intercept: Double)
extends ClassificationModel[Vector, LinearSVCModel]
with LinearSVCParams with MLWritable with HasTrainingSummary[LinearSVCTrainingSummary] {
@Since("2.2.0")
override val numClasses: Int = 2
@Since("2.2.0")
override val numFeatures: Int = coefficients.size
@Since("2.2.0")
def setThreshold(value: Double): this.type = set(threshold, value)
private val margin: Vector => Double = (features) => {
BLAS.dot(features, coefficients) + intercept
}
/**
* 获取训练集上模型的摘要。如果 `hasSummary` 为 false,则抛出异常
*/
@Since("3.1.0")
override def summary: LinearSVCTrainingSummary = super.summary
/**
* 在测试数据集上评估模型。
*
* @param dataset 要在其上评估模型的测试数据集。
*/
@Since("3.1.0")
def evaluate(dataset: Dataset[_]): LinearSVCSummary = {
val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol)
// 处理可能缺失或无效的 rawPrediction 或 prediction 列
val (summaryModel, rawPrediction, predictionColName) = findSummaryModel()
new LinearSVCSummaryImpl(summaryModel.transform(dataset),
rawPrediction, predictionColName, $(labelCol), weightColName)
}
override def predict(features: Vector): Double = {
if (margin(features) > $(threshold)) 1.0 else 0.0
}
@Since("3.0.0")
override def predictRaw(features: Vector): Vector = {
val m = margin(features)
Vectors.dense(-m, m)
}
override protected def raw2prediction(rawPrediction: Vector): Double = {
if (rawPrediction(1) > $(threshold)) 1.0 else 0.0
}
@Since("2.2.0")
override def copy(extra: ParamMap): LinearSVCModel = {
copyValues(new LinearSVCModel(uid, coefficients, intercept), extra).setParent(parent)
}
@Since("2.2.0")
override def write: MLWriter = new LinearSVCModel.LinearSVCWriter(this)
@Since("3.0.0")
override def toString: String = {
s"LinearSVCModel: uid=$uid, numClasses=$numClasses, numFeatures=$numFeatures"
}
}
/**
* `LinearSVCModel` 的可读取器
*/
@Since("2.2.0")
object LinearSVCModel extends MLReadable[LinearSVCModel] {
@Since("2.2.0")
override def read: MLReader[LinearSVCModel] = new LinearSVCReader
@Since("2.2.0")
override def load(path: String): LinearSVCModel = super.load(path)
/** [[LinearSVCModel]] 的 [[MLWriter]] 实例 */
private[LinearSVCModel]
class LinearSVCWriter(instance: LinearSVCModel)
extends MLWriter with Logging {
private case class Data(coefficients: Vector, intercept: Double)
override protected def saveImpl(path: String): Unit = {
// 保存元数据和参数
DefaultParamsWriter.saveMetadata(instance, path, sc)
val data = Data(instance.coefficients, instance.intercept)
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}
private class LinearSVCReader extends MLReader[LinearSVCModel] {
/** 加载模型时与元数据进行校验 */
private val className = classOf[LinearSVCModel].getName
override def load(path: String): LinearSVCModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.format("parquet").load(dataPath)
val Row(coefficients: Vector, intercept: Double) =
data.select("coefficients", "intercept").head()
val model = new LinearSVCModel(metadata.uid, coefficients, intercept)
metadata.getAndSetParams(model)
model
}
}
}
/**
* 线性支持向量机结果的抽象类
*/
sealed trait LinearSVCSummary extends BinaryClassificationSummary
/**
* 线性支持向量机训练结果的抽象类
*/
sealed trait LinearSVCTrainingSummary extends LinearSVCSummary with TrainingSummary
/**
* 给定模型的线性支持向量机结果
*
* @param predictions 模型 `transform` 方法输出的 DataFrame
* @param scoreCol "predictions" 中给出每个实例的原始预测值的字段
* @param predictionCol "predictions" 中给出数据实例的预测值的字段,类型为 double
* @param labelCol "predictions" 中给出每个实例的真实标签的字段
* @param weightCol "predictions" 中给出每个实例的权重的字段
*/
private class LinearSVCSummaryImpl(
@transient override val predictions: DataFrame,
override val scoreCol: String,
override val predictionCol: String,
override val labelCol: String,
override val weightCol: String)
extends LinearSVCSummary
/**
* 线性支持向量机训练结果
*
* @param predictions 模型 `transform` 方法输出的 DataFrame
* @param scoreCol "predictions" 中给出每个实例的原始预测值的字段
* @param predictionCol "predictions" 中给出数据实例的预测值的字段,类型为 double
* @param labelCol "predictions" 中给出每个实例的真实标签的字段
* @param weightCol "predictions" 中给出每个实例的权重的字段
* @param objectiveHistory 每次迭代的目标函数(经过缩放的损失 + 正则化项)
*/
private class LinearSVCTrainingSummaryImpl(
predictions: DataFrame,
scoreCol: String,
predictionCol: String,
labelCol: String,
weightCol: String,
override val objectiveHistory: Array[Double])
extends LinearSVCSummaryImpl(
predictions, scoreCol, predictionCol, labelCol, weightCol)
with LinearSVCTrainingSummary
三、LinearSVC原理
线性支持向量机(Linear Support Vector Machine,简称 Linear SVM)是一种经典的二分类算法,它基于支持向量机(SVM)算法并使用线性核函数。
线性SVM的原理如下:
-
数据预处理:首先对输入数据进行预处理,包括特征缩放和特征选择等操作。这可以提高算法的性能和收敛速度。
-
定义目标变量和特征变量:将待分类的样本数据分为两类,分别标记为正例和负例。同时,确定用于分类的特征变量。
-
寻找最佳超平面:线性SVM的目标是在特征空间中找到一个最佳的超平面,将正例和负例分开。这个超平面被称为决策边界。
-
定义优化问题:线性SVM的优化问题是通过最大化间隔来找到最佳超平面。间隔指的是从训练样本到超平面的最小距离。最大化间隔可以增加模型的泛化能力。
-
解决优化问题:将优化问题转化为凸优化问题,并使用二次规划等方法求解。通过求解这个优化问题,可以得到最佳的超平面参数。
-
预测新样本:在训练完成后,可以使用训练得到的超平面对新样本进行分类预测。根据样本点在超平面的位置,判断其属于正例还是负例。
线性SVM的优点包括:
- 线性SVM在高维空间中表现良好,并且可以处理大规模数据集。
- 通过最大化间隔,线性SVM能够提高模型的泛化能力,降低过拟合风险。
- 线性SVM对于异常值和噪声具有较好的鲁棒性。
然而,
线性SVM也存在一些限制:
- 线性SVM只适用于线性可分的数据集。当数据集不能被一个超平面完全分开时,线性SVM无法很好地工作。
- 线性SVM对于处理大量特征的数据集可能会出现计算复杂度高的问题。这时需要使用特征选择等方法来减少特征数量。
总的来说,线性SVM是一种强大的二分类算法,尤其适用于线性可分的问题。它通过寻找最佳超平面,实现了高性能和泛化能力的平衡。
四、“LinearSVC” 和 “Linear SVM” 一样吗?
“LinearSVC” 和 “Linear SVM” 是指同一个算法,即线性支持向量机(Linear Support Vector Machine)。
在Scikit-learn库中,“LinearSVC” 是用于实现线性支持向量机分类器的类名。它使用线性核函数,并采用一对多(One-vs-Rest)策略处理多类分类问题。这个类提供了一些参数和方法,用于调整模型的超参数和进行预测。
而 “Linear SVM” 是对线性支持向量机算法的一种常见命名方式。它强调了算法的核心思想和特点:使用线性核函数,在高维空间中找到一个最佳的超平面来分割数据。
因此,可以认为 “LinearSVC” 和 “Linear SVM” 是等价的,都指代了基于线性核函数的支持向量机算法。它们都适用于处理线性可分的分类问题,并具有相似的原理和功能。
本文围绕Spark ML中的LinearSVC展开,分析了其源码继承关系,包括class LinearSVC和LinearSVCModel的继承结构。阐述了LinearSVC原理,如数据预处理、寻找最佳超平面等,还指出其优缺点。最后说明“LinearSVC”和“Linear SVM”本质相同,都用于处理线性可分分类问题。
1275

被折叠的 条评论
为什么被折叠?



