SPARK ML MODEL源码分析
目录
源码
/**
* :: DeveloperApi ::
* A fitted model, i.e., a [[Transformer]] produced by an [[Estimator]].
*
* @tparam M model type
*/
@DeveloperApi
abstract class Model[M <: Model[M]] extends Transformer {
/**
* The parent estimator that produced this model.
* @note For ensembles' component Models, this value can be null.
*/
@transient var parent: Estimator[M] = _
/**
* Sets the parent of this model (Java API).
*/
def setParent(parent: Estimator[M]): M = {
this.parent = parent
this.asInstanceOf[M]
}
/** Indicates whether this [[Model]] has a corresponding parent. */
def hasParent: Boolean = parent != null
override def copy(extra: ParamMap): M
}
正文
1. 源码加上中文注释
/**
* :: DeveloperApi ::
* 已拟合的模型,即由估计器(Estimator)产生的转换器(Transformer)。
*
* @tparam M 模型类型
*/
@DeveloperApi
abstract class Model[M <: Model[M]] extends Transformer {
/**
* 生成该模型的父级估计器。
* @note 对于集成模型的组件模型,该值可能为 null。
*/
@transient var parent: Estimator[M] = _
/**
* 设置该模型的父级估计器(Java API)。
*/
def setParent(parent: Estimator[M]): M = {
this.parent = parent
this.asInstanceOf[M]
}
/** 表示该 [[Model]] 是否有对应的父级估计器。 */
def hasParent: Boolean = parent != null
override def copy(extra: ParamMap): M
}
2. 多种主要用法及其代码示例
-
创建一个模型对象并设置父级估计器:
val model = new MyModel() val estimator = new MyEstimator() model.setParent(estimator)
-
判断模型是否有父级估计器:
val hasParent = model.hasParent
-
复制模型并设置额外的参数:
val copiedModel = model.copy(new ParamMap())
3. 源码适用场景
该抽象类是 Spark ML 中模型的基类,所有具体的模型都需要继承自它。它提供了设置和获取父级估计器的方法,并定义了一个抽象的复制方法。适用于以下场景:
- 定义自定义模型时,可以继承该类并实现相应的方法。
- 通过设置父级估计器,可以追踪模型的来源和关系。