Spark MLlib Param 源码解析
1. 适用场景
Param 类是 Spark MLlib 中用于定义算法参数的通用类。它可以用于各种机器学习算法中,包括分类、回归、聚类等。通过 Param 类,开发者可以定义自己的参数,并对参数值进行验证,以确保参数值的有效性。
2. 主要用法及示例
- 创建一个 Param 对象,并设置默认值:
val param = new Param[Double]("myParent", "myName", "这是一个参数的文档", (value: Double) => value > 0.0)
param.w(10.0) // 使用默认值创建一个 ParamPair 对象
- 验证参数值是否有效:
val isValid = param.isValid(10.0) // 检查参数值是否有效
param.validate(10.0) // 如果参数值无效,将抛出 IllegalArgumentException 异常
- 将参数值编码为 JSON 字符串:
val json = param.jsonEncode(10.0) // 将参数值编码为 JSON 字符串
- 从 JSON 字符串中解码参数值:
val value = param.jsonDecode(json) // 从 JSON 字符串中解码参数值
3. 源码分析
Param
类是一个带有自包含文档和可选默认值的参数。它具有以下主要成员:
parent
:父对象的标识。name
:参数名称。doc
:参数的文档。isValid
:用于验证参数值是否有效的方法。
Param
类还提供了一些构造函数,可以根据不同的参数来创建实例。它还提供了一些方法,如 w
和 ->
,用于创建参数对。
该类还实现了一些其他方法,包括:
validate(value: T): Unit
:验证给定的值是否对该参数有效。jsonEncode(value: T): String
:将参数值编码为 JSON 字符串。jsonDecode(json: String): T
:从 JSON 字符串解码参数值。
Param
类还重写了 toString
、hashCode
和 equals
方法,以便在比较和打印参数时能够正确工作。
Param
伴生对象包含了一个 jsonDecode[T](json: String): T
方法,用于从 JSON 字符串解码参数值。
4. 中文源码
/**
* :: DeveloperApi ::
* Param 是一个带有自包含文档和可选默认值的参数。原始类型的参数应该使用专门的版本,
* 这对于 Java 用户更友好。
*
* @param parent 父对象
* @param name 参数名称
* @param doc 文档
* @param isValid 可选的验证方法,用于判断值是否有效。
* 请参考 [[ParamValidators]] 获取常见验证函数的工厂方法。
* @tparam T 参数值的类型
*/
@DeveloperApi
class Param[T](val parent: String, val name: String, val doc: String, val isValid: T => Boolean)
extends Serializable {
def this(parent: Identifiable, name: String, doc: String, isValid: T => Boolean) =
this(parent.uid, name, doc, isValid)
def this(parent: String, name: String, doc: String) =
this(parent, name, doc, ParamValidators.alwaysTrue[T])
def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)
/**
* 验证给定的值是否对该参数有效。
*
* 注意:涉及多个参数和输入/输出列之间交互的参数检查应该在 [[org.apache.spark.ml.PipelineStage.transformSchema()]] 中实现。
*
* DEVELOPERS:此方法仅由 [[ParamPair]] 调用,这意味着所有参数都应通过 [[ParamPair]] 指定。
*
* @throws IllegalArgumentException 如果值无效
*/
private[param] def validate(value: T): Unit = {
if (!isValid(value)) {
val valueToString = value match {
case v: Array[_] => v.mkString("[", ",", "]")
case _ => value.toString
}
throw new IllegalArgumentException(
s"$parent parameter $name given invalid value $valueToString.")
}
}
/** 使用给定值创建一个参数对(用于 Java)。 */
def w(value: T): ParamPair[T] = this -> value
/** 使用给定值创建一个参数对(用于 Scala)。 */
// scalastyle:off
def ->(value: T): ParamPair[T] = ParamPair(this, value)
// scalastyle:on
/** 将参数值编码为 JSON,可以由 `jsonDecode()` 解码。 */
def jsonEncode(value: T): String = {
value match {
case x: String =>
compact(render(JString(x)))
case v: Vector =>
JsonVectorConverter.toJson(v)
case m: Matrix =>
JsonMatrixConverter.toJson(m)
case _ =>
throw new NotImplementedError(
"The default jsonEncode only supports string, vector and matrix. " +
s"${this.getClass.getName} must override jsonEncode for ${value.getClass.getName}.")
}
}
/** 从 JSON 解码参数值。 */
def jsonDecode(json: String): T = Param.jsonDecode[T](json)
private[this] val stringRepresentation = s"${parent}__$name"
override final def toString: String = stringRepresentation
override final def hashCode: Int = toString.##
override final def equals(obj: Any): Boolean = {
obj match {
case p: Param[_] => (p.parent == parent) && (p.name == name)
case _ => false
}
}
}
private[ml] object Param {
/** 从 JSON 解码参数值。 */
def jsonDecode[T](json: String): T = {
val jValue = parse(json)
jValue match {
case JString(x) =>
x.asInstanceOf[T]
case JObject(v) =>
val keys = v.map(_._1)
if (keys.contains("class")) {
implicit val formats = DefaultFormats
val className = (jValue \ "class").extract[String]
className match {
case JsonMatrixConverter.className =>
val checkFields = Array("numRows", "numCols", "values", "isTransposed", "type")
require(checkFields.forall(keys.contains), s"Expect a JSON serialized Matrix" +
s" but cannot find fields ${checkFields.mkString(", ")} in $json.")
JsonMatrixConverter.fromJson(json).asInstanceOf[T]
case s => throw new SparkException(s"unrecognized class $s in $json")
}
} else {
// "class" info in JSON was added in Spark 2.3(SPARK-22289). JSON support for Vector was
// implemented before that and does not have "class" attribute.
require(keys.contains("type") && keys.contains("values"), s"Expect a JSON serialized" +
s" vector/matrix but cannot find fields 'type' and 'values' in $json.")
JsonVectorConverter.fromJson(json).asInstanceOf[T]
}
case _ =>
throw new NotImplementedError(
"The default jsonDecode only supports string, vector and matrix. " +
s"${this.getClass.getName} must override jsonDecode to support its value type.")
}
}
}