Spark MLlib Param 源码解析

本文详细解析了SparkMLlib中的Param类,介绍了其在机器学习算法中的应用、如何创建和验证参数、以及源码中的关键实现,包括JSON编码和解码功能。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

Spark MLlib Param 源码解析

1. 适用场景

Param 类是 Spark MLlib 中用于定义算法参数的通用类。它可以用于各种机器学习算法中,包括分类、回归、聚类等。通过 Param 类,开发者可以定义自己的参数,并对参数值进行验证,以确保参数值的有效性。

2. 主要用法及示例

  • 创建一个 Param 对象,并设置默认值:
val param = new Param[Double]("myParent", "myName", "这是一个参数的文档", (value: Double) => value > 0.0)
param.w(10.0)  // 使用默认值创建一个 ParamPair 对象
  • 验证参数值是否有效:
val isValid = param.isValid(10.0)  // 检查参数值是否有效
param.validate(10.0)  // 如果参数值无效,将抛出 IllegalArgumentException 异常
  • 将参数值编码为 JSON 字符串:
val json = param.jsonEncode(10.0)  // 将参数值编码为 JSON 字符串
  • 从 JSON 字符串中解码参数值:
val value = param.jsonDecode(json)  // 从 JSON 字符串中解码参数值

3. 源码分析

Param 类是一个带有自包含文档和可选默认值的参数。它具有以下主要成员:

  • parent:父对象的标识。
  • name:参数名称。
  • doc:参数的文档。
  • isValid:用于验证参数值是否有效的方法。

Param 类还提供了一些构造函数,可以根据不同的参数来创建实例。它还提供了一些方法,如 w->,用于创建参数对。

该类还实现了一些其他方法,包括:

  • validate(value: T): Unit:验证给定的值是否对该参数有效。
  • jsonEncode(value: T): String:将参数值编码为 JSON 字符串。
  • jsonDecode(json: String): T:从 JSON 字符串解码参数值。

Param 类还重写了 toStringhashCodeequals 方法,以便在比较和打印参数时能够正确工作。

Param 伴生对象包含了一个 jsonDecode[T](json: String): T 方法,用于从 JSON 字符串解码参数值。

4. 中文源码

/**
 * :: DeveloperApi ::
 * Param 是一个带有自包含文档和可选默认值的参数。原始类型的参数应该使用专门的版本,
 * 这对于 Java 用户更友好。
 *
 * @param parent 父对象
 * @param name 参数名称
 * @param doc 文档
 * @param isValid 可选的验证方法,用于判断值是否有效。
 *                请参考 [[ParamValidators]] 获取常见验证函数的工厂方法。
 * @tparam T 参数值的类型
 */
@DeveloperApi
class Param[T](val parent: String, val name: String, val doc: String, val isValid: T => Boolean)
  extends Serializable {

  def this(parent: Identifiable, name: String, doc: String, isValid: T => Boolean) =
    this(parent.uid, name, doc, isValid)

  def this(parent: String, name: String, doc: String) =
    this(parent, name, doc, ParamValidators.alwaysTrue[T])

  def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)

  /**
   * 验证给定的值是否对该参数有效。
   *
   * 注意:涉及多个参数和输入/输出列之间交互的参数检查应该在 [[org.apache.spark.ml.PipelineStage.transformSchema()]] 中实现。
   *
   * DEVELOPERS:此方法仅由 [[ParamPair]] 调用,这意味着所有参数都应通过 [[ParamPair]] 指定。
   *
   * @throws IllegalArgumentException 如果值无效
   */
  private[param] def validate(value: T): Unit = {
    if (!isValid(value)) {
      val valueToString = value match {
        case v: Array[_] => v.mkString("[", ",", "]")
        case _ => value.toString
      }
      throw new IllegalArgumentException(
        s"$parent parameter $name given invalid value $valueToString.")
    }
  }

  /** 使用给定值创建一个参数对(用于 Java)。 */
  def w(value: T): ParamPair[T] = this -> value

  /** 使用给定值创建一个参数对(用于 Scala)。 */
  // scalastyle:off
  def ->(value: T): ParamPair[T] = ParamPair(this, value)
  // scalastyle:on

  /** 将参数值编码为 JSON,可以由 `jsonDecode()` 解码。 */
  def jsonEncode(value: T): String = {
    value match {
      case x: String =>
        compact(render(JString(x)))
      case v: Vector =>
        JsonVectorConverter.toJson(v)
      case m: Matrix =>
        JsonMatrixConverter.toJson(m)
      case _ =>
        throw new NotImplementedError(
          "The default jsonEncode only supports string, vector and matrix. " +
            s"${this.getClass.getName} must override jsonEncode for ${value.getClass.getName}.")
    }
  }

  /** 从 JSON 解码参数值。 */
  def jsonDecode(json: String): T = Param.jsonDecode[T](json)

  private[this] val stringRepresentation = s"${parent}__$name"

  override final def toString: String = stringRepresentation

  override final def hashCode: Int = toString.##

  override final def equals(obj: Any): Boolean = {
    obj match {
      case p: Param[_] => (p.parent == parent) && (p.name == name)
      case _ => false
    }
  }
}

private[ml] object Param {

  /** 从 JSON 解码参数值。 */
  def jsonDecode[T](json: String): T = {
    val jValue = parse(json)
    jValue match {
      case JString(x) =>
        x.asInstanceOf[T]
      case JObject(v) =>
        val keys = v.map(_._1)
        if (keys.contains("class")) {
          implicit val formats = DefaultFormats
          val className = (jValue \ "class").extract[String]
          className match {
            case JsonMatrixConverter.className =>
              val checkFields = Array("numRows", "numCols", "values", "isTransposed", "type")
              require(checkFields.forall(keys.contains), s"Expect a JSON serialized Matrix" +
                s" but cannot find fields ${checkFields.mkString(", ")} in $json.")
              JsonMatrixConverter.fromJson(json).asInstanceOf[T]

            case s => throw new SparkException(s"unrecognized class $s in $json")
          }
        } else {
          // "class" info in JSON was added in Spark 2.3(SPARK-22289). JSON support for Vector was
          // implemented before that and does not have "class" attribute.
          require(keys.contains("type") && keys.contains("values"), s"Expect a JSON serialized" +
            s" vector/matrix but cannot find fields 'type' and 'values' in $json.")
          JsonVectorConverter.fromJson(json).asInstanceOf[T]
        }

      case _ =>
        throw new NotImplementedError(
          "The default jsonDecode only supports string, vector and matrix. " +
            s"${this.getClass.getName} must override jsonDecode to support its value type.")
    }
  }
}

5. 官方链接

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

BigDataMLApplication

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值