Spark MLlib Params 源码解析

本文详细解析了SparkMLlib中的Params特质,介绍了其在组件参数管理中的作用,包括如何定义、设置、获取和解释参数,以及提供的一些实用方法和示例。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

Spark MLlib Params 源码解析

1. 源码适用场景

Params 是一个 trait,用于组件中需要使用参数的情况。它还提供了一个内部的参数映射来存储附加到实例上的参数值。

2. 源码总结

这段代码定义了一个 Params 特质和它的伴生对象 Params

Params 特质是一个用于接受参数的组件,并提供一个内部参数映射来存储与实例相关联的参数值。它具有以下主要成员:

  • params: Array[Param[_]]:返回按名称排序的所有参数的数组。
  • explainParam(param: Param[_]): String:解释一个参数,返回包含参数名称、文档以及默认值和用户提供的值的字符串。
  • explainParams(): String:解释该实例的所有参数。
  • isSet(param: Param[_]): Boolean:检查参数是否已显式设置。
  • isDefined(param: Param[_]): Boolean:检查参数是否已显式设置或具有默认值。
  • hasParam(paramName: String): Boolean:测试该实例是否包含具有给定名称的参数。
  • getParam(paramName: String): Param[Any]:根据参数名称获取参数。
  • set[T](param: Param[T], value: T): this.type:在嵌入的参数映射中设置参数。
  • get[T](param: Param[T]): Option[T]:可选地返回参数的用户提供的值。
  • clear(param: Param[_]): this.type:清除参数的用户提供的值。
  • getOrDefault[T](param: Param[T]): T:获取参数的值或其默认值。
  • setDefault[T](param: Param[T], value: T): this.type:为参数设置默认值。
  • setDefault(paramPairs: ParamPair[_]*): this.type:为一组参数设置默认值。
  • getDefault[T](param: Param[T]): Option[T]:获取参数的默认值。
  • hasDefault[T](param: Param[T]): Boolean:测试输入参数是否具有默认值。
  • copy(extra: ParamMap): Params:使用相同的 UID 和额外参数创建此实例的副本。
  • extractParamMap(extra: ParamMap): ParamMap:提取嵌入的默认参数值和用户提供的值,并与额外值合并到一个平面参数映射中。

Params 伴生对象包含了一些辅助方法,如 setDefault,用于设置默认参数值。

通过使用 Params 特质和 Params 伴生对象,我们可以在组件中定义、设置和获取参数,并进行参数解释和复制。

3. 用法及示例

获取所有按名称排序的参数:

val params: Array[Param[_]] = this.params

解释单个参数的含义:

val param: Param[_] = ...
val explanation: String = this.explainParam(param)

解释所有参数的含义:

val explanations: String = this.explainParams()

检查参数是否已设置:

val param: Param[_] = ...
val isSet: Boolean = this.isSet(param)

检查参数是否已设置或具有默认值:

val param: Param[_] = ...
val isDefined: Boolean = this.isDefined(param)

检查是否存在具有给定名称的参数:

val paramName: String = ...
val hasParam: Boolean = this.hasParam(paramName)

根据参数名称获取参数:

val paramName: String = ...
val param: Param[Any] = this.getParam(paramName)

设置参数值:

val param: Param[T] = ...
val value: T = ...
this.set(param, value)

获取参数值(如果已设置):

val param: Param[T] = ...
val valueOpt: Option[T] = this.get(param)

清除参数值:

val param: Param[_] = ...
this.clear(param)

获取参数值或默认值:

val param: Param[T] = ...
val value: T = this.getOrDefault(param)

创建带有相同 UID 和额外参数的副本:

val extra: ParamMap = ...
val copy: Params = this.copy(extra)

提取参数映射:

val paramMap: ParamMap = this.extractParamMap()

3. 中文源码

/**
 * :: DeveloperApi ::
 * 该接口用于组件接受参数。这还提供了一个内部参数映射,用于存储附加到实例的参数值。
 */
@DeveloperApi
trait Params extends Identifiable with Serializable {

  /**
   * 返回按名称排序的所有参数。默认实现使用 Java 反射列出所有没有参数并返回 [[Param]] 的公共方法。
   *
   * @note 开发者不应在构造函数中使用此方法,因为我们无法保证此变量在其他参数之前被初始化。
   */
  lazy val params: Array[Param[_]] = {
    val methods = this.getClass.getMethods
    methods.filter { m =>
        Modifier.isPublic(m.getModifiers) &&
          classOf[Param[_]].isAssignableFrom(m.getReturnType) &&
          m.getParameterTypes.isEmpty
      }.sortBy(_.getName)
      .map(m => m.invoke(this).asInstanceOf[Param[_]])
  }

  /**
   * 解释一个参数。
   * @param param 输入参数,必须属于该实例。
   * @return 包含输入参数名称、文档以及可选的默认值和用户提供的值的字符串
   */
  def explainParam(param: Param[_]): String = {
    shouldOwn(param)
    val valueStr = if (isDefined(param)) {
      val defaultValueStr = getDefault(param).map("默认值: " + _)
      val currentValueStr = get(param).map("当前值: " + _)
      (defaultValueStr ++ currentValueStr).mkString("(", ", ", ")")
    } else {
      "(未定义)"
    }
    s"${param.name}: ${param.doc} $valueStr"
  }

  /**
   * 解释该实例的所有参数。参见 `explainParam()`。
   */
  def explainParams(): String = {
    params.map(explainParam).mkString("\n")
  }

  /** 检查参数是否已显式设置。 */
  final def isSet(param: Param[_]): Boolean = {
    shouldOwn(param)
    paramMap.contains(param)
  }

  /** 检查参数是否已显式设置或具有默认值。 */
  final def isDefined(param: Param[_]): Boolean = {
    shouldOwn(param)
    defaultParamMap.contains(param) || paramMap.contains(param)
  }

  /** 测试该实例是否包含具有给定名称的参数。 */
  def hasParam(paramName: String): Boolean = {
    params.exists(_.name == paramName)
  }

  /** 根据参数名称获取参数。 */
  def getParam(paramName: String): Param[Any] = {
    params.find(_.name == paramName).getOrElse {
      throw new NoSuchElementException(s"参数 $paramName 不存在。")
    }.asInstanceOf[Param[Any]]
  }

  /**
   * 在嵌入的参数映射中设置参数。
   */
  final def set[T](param: Param[T], value: T): this.type = {
    set(param -> value)
  }

  /**
   * 在嵌入的参数映射中设置参数(按名称)。
   */
  protected final def set(param: String, value: Any): this.type = {
    set(getParam(param), value)
  }

  /**
   * 在嵌入的参数映射中设置参数。
   */
  protected final def set(paramPair: ParamPair[_]): this.type = {
    shouldOwn(paramPair.param)
    paramMap.put(paramPair)
    this
  }

  /**
   * 可选地返回参数的用户提供的值。
   */
  final def get[T](param: Param[T]): Option[T] = {
    shouldOwn(param)
    paramMap.get(param)
  }

  /**
   * 清除输入参数的用户提供的值。
   */
  final def clear(param: Param[_]): this.type = {
    shouldOwn(param)
    paramMap.remove(param)
    this
  }

  /**
   * 获取嵌入的参数映射中的参数值或其默认值。如果都未设置,则抛出异常。
   */
  final def getOrDefault[T](param: Param[T]): T = {
    shouldOwn(param)
    get(param).orElse(getDefault(param)).getOrElse(
      throw new NoSuchElementException(s"未找到参数 ${param.name} 的默认值"))
  }

  /**
   * `getOrDefault()` 的别名。
   */
  protected final def $[T](param: Param[T]): T = getOrDefault(param)

  /**
   * 为参数设置默认值。
   * @param param 要设置默认值的参数。确保在调用此方法之前初始化该参数。
   * @param value 默认值
   */
  protected final def setDefault[T](param: Param[T], value: T): this.type = {
    defaultParamMap.put(param -> value)
    this
  }

  /**
   * 为一组参数设置默认值。
   *
   * 注意:Java 开发者应该使用单参数的 `setDefault`。在 Scala 编译器中,使用 varargs 注释可能会导致编译错误,这是由于 Scala 编译器的错误引起的。
   * 参见 SPARK-9268。
   *
   * @param paramPairs 指定要设置其默认值的参数及其默认值的一组参数对。确保在调用此方法之前初始化这些参数。
   */
  protected final def setDefault(paramPairs: ParamPair[_]*): this.type = {
    paramPairs.foreach { p =>
      setDefault(p.param.asInstanceOf[Param[Any]], p.value)
    }
    this
  }

  /**
   * 获取参数的默认值。
   */
  final def getDefault[T](param: Param[T]): Option[T] = {
    shouldOwn(param)
    defaultParamMap.get(param)
  }

  /**
   * 测试输入参数是否具有默认值。
   */
  final def hasDefault[T](param: Param[T]): Boolean = {
    shouldOwn(param)
    defaultParamMap.contains(param)
  }

  /**
   * 使用相同的 UID 和一些额外参数创建此实例的副本。
   * 子类应实现此方法并正确设置返回类型。参见 `defaultCopy()`。
   */
  def copy(extra: ParamMap): Params

  /**
   * 带有额外参数的复制的默认实现。
   * 尝试使用相同的 UID 创建一个新实例。
   * 然后将嵌入和额外参数复制到新实例,并返回该新实例。
   */
  protected final def defaultCopy[T <: Params](extra: ParamMap): T = {
    val that = this.getClass.getConstructor(classOf[String]).newInstance(uid)
    copyValues(that, extra).asInstanceOf[T]
  }

  /**
   * 提取嵌入的默认参数值和用户提供的值,然后将它们与输入中的额外值合并到一个平面参数映射中,
   * 其中后面的值(如果存在冲突)将被使用,即:默认参数值 < 用户提供的值 < 额外值。
   */
  final def extractParamMap(extra: ParamMap): ParamMap = {
    defaultParamMap ++ paramMap ++ extra
  }

  /**
   * 不带额外值的 `extractParamMap`。
   */
  final def extractParamMap(): ParamMap = {
    extractParamMap(ParamMap.empty)
  }

  /** 用于用户提供的值的内部参数映射。 */
  private[ml] val paramMap: ParamMap = ParamMap.empty

  /** 用于默认值的内部参数映射。 */
  private[ml] val defaultParamMap: ParamMap = ParamMap.empty

  /** 验证输入参数是否属于此实例。 */
  private def shouldOwn(param: Param[_]): Unit = {
    require(param.parent == uid && hasParam(param.name), s"参数 $param 不属于 $this.")
  }

  /**
   * 将参数值从此实例复制到另一个实例,以便它们共享的参数。
   *
   * 这会分别处理默认参数和显式设置的参数。
   * 默认参数从 `defaultParamMap` 复制并到 `paramMap`,而显式设置的参数则从 `paramMap` 复制并到 `defaultParamMap`。
   * 警告:这假设此 [[Params]] 实例和目标实例共享相同的默认参数集。
   *
   * @param to 目标实例,应与此源实例共享相同的默认参数集
   * @param extra 要复制到目标的 `paramMap` 的额外参数
   * @return 复制了参数值的目标实例
   */
  protected def copyValues[T <: Params](to: T, extra: ParamMap = ParamMap.empty): T = {
    val map = paramMap ++ extra
    params.foreach { param =>
      // 复制默认参数
      if (defaultParamMap.contains(param) && to.hasParam(param.name)) {
        to.defaultParamMap.put(to.getParam(param.name), defaultParamMap(param))
      }
      // 复制显式设置的参数
      if (map.contains(param) && to.hasParam(param.name)) {
        to.set(param.name, map(param))
      }
    }
    to
  }
}

private[ml] object Params {
  /**
   * 为 `Params` 设置默认参数值。
   */
  private[ml] final def setDefault[T](params: Params, param: Param[T], value: T): Unit = {
    params.defaultParamMap.put(param -> value)
  }
}

5. 官方链接

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

BigDataMLApplication

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值