【spark原理系列】Spark SQL ExpressionEncoder原理用法示例源码分析

本文详细解析了SparkSQL中的ExpressionEncoder,介绍了其适用场景、主要方法、用法示例以及源码实现。重点讨论了如何在自定义对象与SparkSQLDataFrame/Dataset交互时进行有效编码和解码。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

Spark SQL ExpressionEncoder源码分析

1.适用场景

ExpressionEncoder是Spark中用于将JVM对象编码为Spark SQL行的泛型编码器。它主要用于以下场景:

  • 在Spark SQL中将自定义对象转换为DataFrame或Dataset时,可以使用ExpressionEncoder来指定对象的编码和解码方式。
  • 当需要在Spark SQL查询过程中直接对对象进行序列化和反序列化操作时,可以使用ExpressionEncoder进行对象的编码和解码。

2.方法总结归纳

ExpressionEncoder的主要方法包括:

  • resolveAndBind(attrs: Seq[Attribute], analyzer: Analyzer): ExpressionEncoder[T]

    • 功能:将deserializer与给定的schema进行解析和绑定,并返回一个新的ExpressionEncoder。
    • 参数:
      • attrs: 给定的属性序列,默认为schema的属性。
      • analyzer: 分析器,默认为SimpleAnalyzer。
    • 返回值:解析和绑定后的ExpressionEncoder。
  • namedExpressions: Seq[NamedExpression]

    • 功能:返回表示该对象序列化形式的一组具有唯一ID的NamedExpression。
    • 返回值:NamedExpression的序列。
  • toRow(t: T): InternalRow

    • 功能:将对象t编码为Spark SQL行的形式。
    • 参数:待编码的对象t
    • 返回值:编码后的InternalRow。
  • fromRow(row: InternalRow): T

    • 功能:从给定的行中提取所需的值,并返回类型为T的对象。
    • 参数:给定的行row。
    • 返回值:解码后的对象。
  • assertUnresolved(): Unit

    • 功能:检查在后续编码器组合的位置上是否已经进行了解析过程。
    • 无返回值。

3.用法及示例

ExpressionEncoder的主要用法是定义和操作对象的编码和解码方式。以下是一个简单的使用示例:

import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.types._
import org.apache.spark.sql.catalyst.expressions._

case class Person(name: String, age: Int)

// 定义Person对象的schema
val schema = StructType(Seq(
  StructField("name", StringType),
  StructField("age", IntegerType)
))

// 定义Person对象的编码器
val encoder = ExpressionEncoder[Person](
  schema,
  flat = true,
  Seq(
    BoundReference(0, StringType, nullable = false),
    BoundReference(1, IntegerType, nullable = false)
  ),
  CreateNamedStruct(Seq(
    Literal("name"), BoundReference(0, StringType, nullable = false),
    Literal("age"), BoundReference(1, IntegerType, nullable = false)
  ))
)

// 将对象编码为行
val person = Person("Alice", 25)
val row = encoder.toRow(person)

// 从行中解码出对象
val decodedPerson = encoder.fromRow(row)
println(decodedPerson)  // 输出: Person(Alice,25)

4.中文源码

/**
 * 一个工厂,用于构造编码器,将对象和原始类型转换为使用 catalyst 表达式和代码生成的内部行格式。
 * 默认情况下,当生成对象时,用于从输入行中检索值的表达式将被创建为:
 * - 类会通过 [[UnresolvedAttribute]] 表达式和 [[UnresolvedExtractValue]] 表达式按名称提取其子字段。
 * - 元组将通过 [[BoundReference]] 表达式按位置提取其子字段。
 * - 原始类型将从具有默认名称“value”的第一个序数提取其值。
 */
object ExpressionEncoder {
  def apply[T : TypeTag](): ExpressionEncoder[T] = {
    // 将不可序列化的 TypeTag 转换为 StructType 和 ClassTag。
    val mirror = ScalaReflection.mirror
    val tpe = typeTag[T].in(mirror).tpe

    if (ScalaReflection.optionOfProductType(tpe)) {
      throw new UnsupportedOperationException(
        "无法为 Product 类型的 Option 创建编码器,因为 Product 类型表示为行,并且整个行不能在 Spark SQL 中为 null,像普通数据库一样。" +
          "如果确实想要顶级 null Product 对象,可以将类型包装在 Tuple1 中,例如,不要创建 `Dataset[Option[MyClass]]`,而是可以执行类似这样的操作 " +
          "`val ds: Dataset[Tuple1[MyClass]] = Seq(Tuple1(MyClass(...)), Tuple1(null)).toDS`")
    }

    val cls = mirror.runtimeClass(tpe)
    val flat = !ScalaReflection.definedByConstructorParams(tpe)

    val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = !cls.isPrimitive)
    val nullSafeInput = if (flat) {
      inputObject
    } else {
      // 对于 Product 类型的输入对象,如果为 null,则无法将其编码为行,因为 Spark SQL 不允许顶级行为 null,只有列可以为 null。
      AssertNotNull(inputObject, Seq("top level Product input object"))
    }
    val serializer = ScalaReflection.serializerFor[T](nullSafeInput)
    val deserializer = ScalaReflection.deserializerFor[T]

    val schema = serializer.dataType

    new ExpressionEncoder[T](
      schema,
      flat,
      serializer.flatten,
      deserializer,
      ClassTag[T](cls))
  }

  // TODO: 改进 Java bean 编码器的错误消息。
  def javaBean[T](beanClass: Class[T]): ExpressionEncoder[T] = {
    val schema = JavaTypeInference.inferDataType(beanClass)._1
    assert(schema.isInstanceOf[StructType])

    val serializer = JavaTypeInference.serializerFor(beanClass)
    val deserializer = JavaTypeInference.deserializerFor(beanClass)

    new ExpressionEncoder[T](
      schema.asInstanceOf[StructType],
      flat = false,
      serializer.flatten,
      deserializer,
      ClassTag[T](beanClass))
  }

  /**
   * 给定 N 个编码器的集合,构造一个新的编码器,将对象作为 N 元组中的项生成。
   * 注意,这些编码器应该是未解析的,以保留有关名称/位置绑定的信息。
   */
  def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = {
    encoders.foreach(_.assertUnresolved())

    val schema = StructType(encoders.zipWithIndex.map {
      case (e, i) =>
        val (dataType, nullable) = if (e.flat) {
          e.schema.head.dataType -> e.schema.head.nullable
        } else {
          e.schema -> true
        }
        StructField(s"_${i + 1}", dataType, nullable)
    })

    val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")

    val serializer = encoders.zipWithIndex.map { case (enc, index) =>
  val originalInputObject = enc.serializer.head.collect { case b: BoundReference => b }.head
  val newInputObject = Invoke(
    BoundReference(0, ObjectType(cls), nullable = true),
    s"_${index + 1}",
    originalInputObject.dataType)

  val newSerializer = enc.serializer.map(_.transformUp {
    case b: BoundReference if b == originalInputObject => newInputObject
  })

  val serializerExpr = if (enc.flat) {
    newSerializer.head
  } else {
    // 对于非扁平编码器,将输入对象合并为元组编码器后,输入对象不再是顶级对象,因此可以为空。因此,我们应该使用“If”和空值检查来正确处理空值情况下的“CreateStruct”。例如,对于 Encoder[(Int, String)],序列化表达式将创建2个列,并且无法处理输入元组为空的情况。这不是一个问题,因为有一个检查来确保输入对象不会为空。然而,如果使用此编码器来创建更大的元组编码器,则原始输入对象变为新输入元组的字段,并且可能为空。因此,我们不应该直接在此处创建结构体,而应添加空值/None检查,如果空值/None检查失败,则返回一个空结构体。
    val struct = CreateStruct(newSerializer)
    val nullCheck = Or(
      IsNull(newInputObject),
      Invoke(Literal.fromObject(None), "equals", BooleanType, newInputObject :: Nil))
    If(nullCheck, Literal.create(null, struct.dataType), struct)
  }
  Alias(serializerExpr, s"_${index + 1}")()
}

val childrenDeserializers = encoders.zipWithIndex.map { case (enc, index) =>
  if (enc.flat) {
    enc.deserializer.transform {
      case g: GetColumnByOrdinal => g.copy(ordinal = index)
    }
  } else {
    val input = GetColumnByOrdinal(index, enc.schema)
    val deserialized = enc.deserializer.transformUp {
      case UnresolvedAttribute(nameParts) =>
        assert(nameParts.length == 1)
        UnresolvedExtractValue(input, Literal(nameParts.head))
      case GetColumnByOrdinal(ordinal, _) => GetStructField(input, ordinal)
    }
    If(IsNull(input), Literal.create(null, deserialized.dataType), deserialized)
  }
}

val deserializer =
  NewInstance(cls, childrenDeserializers, ObjectType(cls), propagateNull = false)

new ExpressionEncoder[Any](
  schema,
  flat = false,
  serializer,
  deserializer,
  ClassTag(cls))
 }

  // Tuple1
  def tuple[T](e: ExpressionEncoder[T]): ExpressionEncoder[Tuple1[T]] =
    tuple(Seq(e)).asInstanceOf[ExpressionEncoder[Tuple1[T]]]

  def tuple[T1, T2](
      e1: ExpressionEncoder[T1],
      e2: ExpressionEncoder[T2]): ExpressionEncoder[(T1, T2)] =
    tuple(Seq(e1, e2)).asInstanceOf[ExpressionEncoder[(T1, T2)]]

  def tuple[T1, T2, T3](
      e1: ExpressionEncoder[T1],
      e2: ExpressionEncoder[T2],
      e3: ExpressionEncoder[T3]): ExpressionEncoder[(T1, T2, T3)] =
    tuple(Seq(e1, e2, e3)).asInstanceOf[ExpressionEncoder[(T1, T2, T3)]]

  def tuple[T1, T2, T3, T4](
      e1: ExpressionEncoder[T1],
      e2: ExpressionEncoder[T2],
      e3: ExpressionEncoder[T3],
      e4: ExpressionEncoder[T4]): ExpressionEncoder[(T1, T2, T3, T4)] =
    tuple(Seq(e1, e2, e3, e4)).asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4)]]

  def tuple[T1, T2, T3, T4, T5](
      e1: ExpressionEncoder[T1],
      e2: ExpressionEncoder[T2],
      e3: ExpressionEncoder[T3],
      e4: ExpressionEncoder[T4],
      e5: ExpressionEncoder[T5]): ExpressionEncoder[(T1, T2, T3, T4, T5)] =
    tuple(Seq(e1, e2, e3, e4, e5)).asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4, T5)]]
}

/**
 * 用于JVM对象的通用编码器,使用Catalyst表达式作为“序列化器”和“反序列化器”。
 *
 * @param schema 将 `T` 转换为 Spark SQL 行后的模式。
 * @param serializer 一组表达式,每个顶级字段对应一个表达式,可用于将原始对象中的值提取到 [[InternalRow]] 中。
 * @param deserializer 给定 [[InternalRow]],用于构造一个对象的表达式。
 * @param clsTag `T` 的类标签。
 */
case class ExpressionEncoder[T](
    schema: StructType,
    flat: Boolean,
    serializer: Seq[Expression],
    deserializer: Expression,
    clsTag: ClassTag[T])
  extends Encoder[T] {

  if (flat) require(serializer.size == 1)

  // 序列化器表达式用于将对象编码为行,而对象通常是在运算符内部生成的中间值,并非来自子运算符的输出。
  // 这与普通表达式非常不同,“AttributeReference”在此处无效(中间值不是属性)。
  // 我们假设所有序列化器表达式都使用相同的“BoundReference”引用对象,并且如果它们没有这样做,则抛出异常。
  assert(serializer.forall(_.references.isEmpty), "序列化器不能引用任何属性。")
  assert(serializer.flatMap { ser =>
    val boundRefs = ser.collect { case b: BoundReference => b }
    assert(boundRefs.nonEmpty, "每个序列化器表达式应至少包含一个`BoundReference`。")
    boundRefs
  }.distinct.length <= 1, "所有序列化器表达式必须使用相同的 BoundReference。")

  /**
   * 返回此编码器的新副本,其中“反序列化器”已解析并绑定到给定的模式。
   *
   * 注意,理想情况下,编码器用作serde表达式的容器,解析和绑定工作应在查询框架内部发生。
   * 但在某些情况下,我们需要将编码器直接用作函数进行序列化(例如Dataset.collect),
   * 然后我们可以使用此方法在查询框架外部进行解析和绑定。
   */
  def resolveAndBind(
      attrs: Seq[Attribute] = schema.toAttributes,
      analyzer: Analyzer = SimpleAnalyzer): ExpressionEncoder[T] = {
    val dummyPlan = CatalystSerde.deserialize(LocalRelation(attrs))(this)
    val analyzedPlan = analyzer.execute(dummyPlan)
    analyzer.checkAnalysis(analyzedPlan)
    val resolved = SimplifyCasts(analyzedPlan).asInstanceOf[DeserializeToObject].deserializer
    val bound = BindReferences.bindReference(resolved, attrs)
    copy(deserializer = bound)
  }

  @transient
  private lazy val extractProjection = GenerateUnsafeProjection.generate(serializer)

  @transient
  private lazy val inputRow = new GenericInternalRow(1)

  @transient
  private lazy val constructProjection = GenerateSafeProjection.generate(deserializer :: Nil)

  /**
   * 返回表示该对象序列化形式的一组具有唯一标识符的 [[NamedExpression]]。
   */
  def namedExpressions: Seq[NamedExpression] = schema.map(_.name).zip(serializer).map {
    case (_, ne: NamedExpression) => ne.newInstance()
    case (name, e) => Alias(e, name)()
  }

  /**
   * 将 `t` 的编码版本作为 Spark SQL 行返回。
   * 注意,允许多次调用 toRow 返回相同的实际 [[InternalRow]] 对象。
   * 因此,如果需要,在进行另一次调用之前,调用方应复制结果。
   */
  def toRow(t: T): InternalRow = try {
    inputRow(0) = t
    extractProjection(inputRow)
  } catch {
    case e: Exception =>
      throw new RuntimeException(
        s"编码时出错:$e\n${serializer.map(_.simpleString).mkString("\n")}", e)
  }

/**
 * 将 `t` 的编码版本作为 Spark SQL 行返回。
 * 注意,允许多次调用 toRow 返回相同的实际 [[InternalRow]] 对象。
 * 因此,如果需要,在进行另一次调用之前,调用方应复制结果。
 */
def toRow(t: T): InternalRow = try {
  inputRow(0) = t
  extractProjection(inputRow)
} catch {
  case e: Exception =>
    throw new RuntimeException(
      s"编码时出错: $e\n${serializer.map(_.simpleString).mkString("\n")}", e)
}

/**
 * 返回一个类型为 `T` 的对象,从提供的行中提取所需的值。
 * 注意,在调用此函数之前,必须将编码器解析和绑定到特定模式。
 */
def fromRow(row: InternalRow): T = try {
  constructProjection(row).get(0, ObjectType(clsTag.runtimeClass)).asInstanceOf[T]
} catch {
  case e: Exception =>
    throw new RuntimeException(s"解码时出错: $e\n${deserializer.simpleString}", e)
}

/**
 * 解析给定模式的过程会丢弃按序号而不是按名称绑定给定字段的信息。
 * 此方法检查在我们计划稍后组合编码器的位置上是否已经执行了此过程。
 */
def assertUnresolved(): Unit = {
  (deserializer +:  serializer).foreach(_.foreach {
    case a: AttributeReference if a.name != "loopVar" =>
      sys.error(s"预期未解析的编码器,但找到了 $a。")
    case _ =>
  })
}

protected val attrs = serializer.flatMap(_.collect {
  case _: UnresolvedAttribute => ""
  case a: Attribute => s"#${a.exprId}"
  case b: BoundReference => s"[${b.ordinal}]"
})

protected val schemaString =
  schema
    .zip(attrs)
    .map { case(f, a) => s"${f.name}$a: ${f.dataType.simpleString}"}.mkString(", ")

override def toString: String = s"类[$schemaString]"
}

5.官方链接

ExpressionEncoder的官方链接:ExpressionEncoder Scala Doc

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

BigDataMLApplication

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值