Spark MLlib Params 源码解析
1. 源码适用场景
Params 是一个 trait,用于组件中需要使用参数的情况。它还提供了一个内部的参数映射来存储附加到实例上的参数值。
2. 源码总结
这段代码定义了一个 Params
特质和它的伴生对象 Params
。
Params
特质是一个用于接受参数的组件,并提供一个内部参数映射来存储与实例相关联的参数值。它具有以下主要成员:
params: Array[Param[_]]
:返回按名称排序的所有参数的数组。explainParam(param: Param[_]): String
:解释一个参数,返回包含参数名称、文档以及默认值和用户提供的值的字符串。explainParams(): String
:解释该实例的所有参数。isSet(param: Param[_]): Boolean
:检查参数是否已显式设置。isDefined(param: Param[_]): Boolean
:检查参数是否已显式设置或具有默认值。hasParam(paramName: String): Boolean
:测试该实例是否包含具有给定名称的参数。getParam(paramName: String): Param[Any]
:根据参数名称获取参数。set[T](param: Param[T], value: T): this.type
:在嵌入的参数映射中设置参数。get[T](param: Param[T]): Option[T]
:可选地返回参数的用户提供的值。clear(param: Param[_]): this.type
:清除参数的用户提供的值。getOrDefault[T](param: Param[T]): T
:获取参数的值或其默认值。setDefault[T](param: Param[T], value: T): this.type
:为参数设置默认值。setDefault(paramPairs: ParamPair[_]*): this.type
:为一组参数设置默认值。getDefault[T](param: Param[T]): Option[T]
:获取参数的默认值。hasDefault[T](param: Param[T]): Boolean
:测试输入参数是否具有默认值。copy(extra: ParamMap): Params
:使用相同的 UID 和额外参数创建此实例的副本。extractParamMap(extra: ParamMap): ParamMap
:提取嵌入的默认参数值和用户提供的值,并与额外值合并到一个平面参数映射中。
Params
伴生对象包含了一些辅助方法,如 setDefault
,用于设置默认参数值。
通过使用 Params
特质和 Params
伴生对象,我们可以在组件中定义、设置和获取参数,并进行参数解释和复制。
3. 用法及示例
获取所有按名称排序的参数:
val params: Array[Param[_]] = this.params
解释单个参数的含义:
val param: Param[_] = ...
val explanation: String = this.explainParam(param)
解释所有参数的含义:
val explanations: String = this.explainParams()
检查参数是否已设置:
val param: Param[_] = ...
val isSet: Boolean = this.isSet(param)
检查参数是否已设置或具有默认值:
val param: Param[_] = ...
val isDefined: Boolean = this.isDefined(param)
检查是否存在具有给定名称的参数:
val paramName: String = ...
val hasParam: Boolean = this.hasParam(paramName)
根据参数名称获取参数:
val paramName: String = ...
val param: Param[Any] = this.getParam(paramName)
设置参数值:
val param: Param[T] = ...
val value: T = ...
this.set(param, value)
获取参数值(如果已设置):
val param: Param[T] = ...
val valueOpt: Option[T] = this.get(param)
清除参数值:
val param: Param[_] = ...
this.clear(param)
获取参数值或默认值:
val param: Param[T] = ...
val value: T = this.getOrDefault(param)
创建带有相同 UID 和额外参数的副本:
val extra: ParamMap = ...
val copy: Params = this.copy(extra)
提取参数映射:
val paramMap: ParamMap = this.extractParamMap()
3. 中文源码
/**
* :: DeveloperApi ::
* 该接口用于组件接受参数。这还提供了一个内部参数映射,用于存储附加到实例的参数值。
*/
@DeveloperApi
trait Params extends Identifiable with Serializable {
/**
* 返回按名称排序的所有参数。默认实现使用 Java 反射列出所有没有参数并返回 [[Param]] 的公共方法。
*
* @note 开发者不应在构造函数中使用此方法,因为我们无法保证此变量在其他参数之前被初始化。
*/
lazy val params: Array[Param[_]] = {
val methods = this.getClass.getMethods
methods.filter { m =>
Modifier.isPublic(m.getModifiers) &&
classOf[Param[_]].isAssignableFrom(m.getReturnType) &&
m.getParameterTypes.isEmpty
}.sortBy(_.getName)
.map(m => m.invoke(this).asInstanceOf[Param[_]])
}
/**
* 解释一个参数。
* @param param 输入参数,必须属于该实例。
* @return 包含输入参数名称、文档以及可选的默认值和用户提供的值的字符串
*/
def explainParam(param: Param[_]): String = {
shouldOwn(param)
val valueStr = if (isDefined(param)) {
val defaultValueStr = getDefault(param).map("默认值: " + _)
val currentValueStr = get(param).map("当前值: " + _)
(defaultValueStr ++ currentValueStr).mkString("(", ", ", ")")
} else {
"(未定义)"
}
s"${param.name}: ${param.doc} $valueStr"
}
/**
* 解释该实例的所有参数。参见 `explainParam()`。
*/
def explainParams(): String = {
params.map(explainParam).mkString("\n")
}
/** 检查参数是否已显式设置。 */
final def isSet(param: Param[_]): Boolean = {
shouldOwn(param)
paramMap.contains(param)
}
/** 检查参数是否已显式设置或具有默认值。 */
final def isDefined(param: Param[_]): Boolean = {
shouldOwn(param)
defaultParamMap.contains(param) || paramMap.contains(param)
}
/** 测试该实例是否包含具有给定名称的参数。 */
def hasParam(paramName: String): Boolean = {
params.exists(_.name == paramName)
}
/** 根据参数名称获取参数。 */
def getParam(paramName: String): Param[Any] = {
params.find(_.name == paramName).getOrElse {
throw new NoSuchElementException(s"参数 $paramName 不存在。")
}.asInstanceOf[Param[Any]]
}
/**
* 在嵌入的参数映射中设置参数。
*/
final def set[T](param: Param[T], value: T): this.type = {
set(param -> value)
}
/**
* 在嵌入的参数映射中设置参数(按名称)。
*/
protected final def set(param: String, value: Any): this.type = {
set(getParam(param), value)
}
/**
* 在嵌入的参数映射中设置参数。
*/
protected final def set(paramPair: ParamPair[_]): this.type = {
shouldOwn(paramPair.param)
paramMap.put(paramPair)
this
}
/**
* 可选地返回参数的用户提供的值。
*/
final def get[T](param: Param[T]): Option[T] = {
shouldOwn(param)
paramMap.get(param)
}
/**
* 清除输入参数的用户提供的值。
*/
final def clear(param: Param[_]): this.type = {
shouldOwn(param)
paramMap.remove(param)
this
}
/**
* 获取嵌入的参数映射中的参数值或其默认值。如果都未设置,则抛出异常。
*/
final def getOrDefault[T](param: Param[T]): T = {
shouldOwn(param)
get(param).orElse(getDefault(param)).getOrElse(
throw new NoSuchElementException(s"未找到参数 ${param.name} 的默认值"))
}
/**
* `getOrDefault()` 的别名。
*/
protected final def $[T](param: Param[T]): T = getOrDefault(param)
/**
* 为参数设置默认值。
* @param param 要设置默认值的参数。确保在调用此方法之前初始化该参数。
* @param value 默认值
*/
protected final def setDefault[T](param: Param[T], value: T): this.type = {
defaultParamMap.put(param -> value)
this
}
/**
* 为一组参数设置默认值。
*
* 注意:Java 开发者应该使用单参数的 `setDefault`。在 Scala 编译器中,使用 varargs 注释可能会导致编译错误,这是由于 Scala 编译器的错误引起的。
* 参见 SPARK-9268。
*
* @param paramPairs 指定要设置其默认值的参数及其默认值的一组参数对。确保在调用此方法之前初始化这些参数。
*/
protected final def setDefault(paramPairs: ParamPair[_]*): this.type = {
paramPairs.foreach { p =>
setDefault(p.param.asInstanceOf[Param[Any]], p.value)
}
this
}
/**
* 获取参数的默认值。
*/
final def getDefault[T](param: Param[T]): Option[T] = {
shouldOwn(param)
defaultParamMap.get(param)
}
/**
* 测试输入参数是否具有默认值。
*/
final def hasDefault[T](param: Param[T]): Boolean = {
shouldOwn(param)
defaultParamMap.contains(param)
}
/**
* 使用相同的 UID 和一些额外参数创建此实例的副本。
* 子类应实现此方法并正确设置返回类型。参见 `defaultCopy()`。
*/
def copy(extra: ParamMap): Params
/**
* 带有额外参数的复制的默认实现。
* 尝试使用相同的 UID 创建一个新实例。
* 然后将嵌入和额外参数复制到新实例,并返回该新实例。
*/
protected final def defaultCopy[T <: Params](extra: ParamMap): T = {
val that = this.getClass.getConstructor(classOf[String]).newInstance(uid)
copyValues(that, extra).asInstanceOf[T]
}
/**
* 提取嵌入的默认参数值和用户提供的值,然后将它们与输入中的额外值合并到一个平面参数映射中,
* 其中后面的值(如果存在冲突)将被使用,即:默认参数值 < 用户提供的值 < 额外值。
*/
final def extractParamMap(extra: ParamMap): ParamMap = {
defaultParamMap ++ paramMap ++ extra
}
/**
* 不带额外值的 `extractParamMap`。
*/
final def extractParamMap(): ParamMap = {
extractParamMap(ParamMap.empty)
}
/** 用于用户提供的值的内部参数映射。 */
private[ml] val paramMap: ParamMap = ParamMap.empty
/** 用于默认值的内部参数映射。 */
private[ml] val defaultParamMap: ParamMap = ParamMap.empty
/** 验证输入参数是否属于此实例。 */
private def shouldOwn(param: Param[_]): Unit = {
require(param.parent == uid && hasParam(param.name), s"参数 $param 不属于 $this.")
}
/**
* 将参数值从此实例复制到另一个实例,以便它们共享的参数。
*
* 这会分别处理默认参数和显式设置的参数。
* 默认参数从 `defaultParamMap` 复制并到 `paramMap`,而显式设置的参数则从 `paramMap` 复制并到 `defaultParamMap`。
* 警告:这假设此 [[Params]] 实例和目标实例共享相同的默认参数集。
*
* @param to 目标实例,应与此源实例共享相同的默认参数集
* @param extra 要复制到目标的 `paramMap` 的额外参数
* @return 复制了参数值的目标实例
*/
protected def copyValues[T <: Params](to: T, extra: ParamMap = ParamMap.empty): T = {
val map = paramMap ++ extra
params.foreach { param =>
// 复制默认参数
if (defaultParamMap.contains(param) && to.hasParam(param.name)) {
to.defaultParamMap.put(to.getParam(param.name), defaultParamMap(param))
}
// 复制显式设置的参数
if (map.contains(param) && to.hasParam(param.name)) {
to.set(param.name, map(param))
}
}
to
}
}
private[ml] object Params {
/**
* 为 `Params` 设置默认参数值。
*/
private[ml] final def setDefault[T](params: Params, param: Param[T], value: T): Unit = {
params.defaultParamMap.put(param -> value)
}
}