Skip to content

Commit

Permalink
CoderTest changes for v0.13.x (#4806)
Browse files Browse the repository at this point in the history
* Improve coder tests (#4664)

* Refactor coder tests

* Fixes

* fix scala 2.12 errors

* One more fix

* One more fix

* fix

* one more fix

* one more fix

* one more fix

* added file headers

* added more checks on primitives

* Fixing checks failures

* Fixing checks

* Fix tests

* Removed coderShouldThrowOn

* fix

* Added more checks and fixed a few coders

* Update scio-test/src/test/scala/com/spotify/scio/coders/CoderTest.scala

Co-authored-by: Michel Davit <micheld@spotify.com>

* fixes after review

* Fixing WrappedArray issue

* fix format issue

* Ran scalafixAll on scio-core

* modify syntax with custom options

* fix

---------

Co-authored-by: Michel Davit <micheld@spotify.com>

* fix scalafix errors

* fix copyright date

* Update LowPriorityCoderDerivation.scala

---------

Co-authored-by: Michel Davit <micheld@spotify.com>
  • Loading branch information
shnapz and RustedBones authored May 18, 2023
1 parent f139a0f commit 62751fe
Show file tree
Hide file tree
Showing 12 changed files with 1,163 additions and 822 deletions.
366 changes: 3 additions & 363 deletions scio-core/src/main/scala/com/spotify/scio/coders/Coder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,14 @@

package com.spotify.scio.coders

import java.io.{InputStream, OutputStream}
import com.spotify.scio.IsJavaBean
import com.spotify.scio.coders.instances._
import com.spotify.scio.transforms.BaseAsyncLookupDoFn
import org.apache.beam.sdk.coders.Coder.NonDeterministicException
import org.apache.beam.sdk.coders.{Coder => BCoder, CustomCoder, StructuredCoder}
import org.apache.beam.sdk.util.common.ElementByteSizeObserver
import org.apache.beam.sdk.coders.{Coder => BCoder}
import org.apache.beam.sdk.values.KV

import java.util.{Collections, List => JList, Objects}
import java.util.{List => JList}
import scala.annotation.implicitNotFound
import scala.jdk.CollectionConverters._
import scala.collection.compat._
import scala.collection.{mutable => m, BitSet, SortedSet}
import scala.reflect.ClassTag
Expand Down Expand Up @@ -133,362 +129,6 @@ final case class KVCoder[K, V] private (koder: Coder[K], voder: Coder[V]) extend
// GroupByKey aggregate are special because they can't be wrapped
final case class AggregateCoder[T] private (coder: Coder[T]) extends Coder[java.lang.Iterable[T]]

///////////////////////////////////////////////////////////////////////////////
// Materialized beam coders
///////////////////////////////////////////////////////////////////////////////
final private class SingletonCoder[T](
val typeName: String,
supply: () => T
) extends CustomCoder[T] {
@transient private lazy val singleton = supply()

override def toString: String = s"SingletonCoder[$typeName]"

override def equals(obj: Any): Boolean = obj match {
case that: SingletonCoder[_] => typeName == that.typeName
case _ => false
}

override def hashCode(): Int = typeName.hashCode

override def encode(value: T, outStream: OutputStream): Unit = {}
override def decode(inStream: InputStream): T = singleton
override def verifyDeterministic(): Unit = {}
override def consistentWithEquals(): Boolean = true
override def isRegisterByteSizeObserverCheap(value: T): Boolean = true
override def getEncodedElementByteSize(value: T): Long = 0
}

final private class DisjunctionCoder[T, Id](
val typeName: String,
val idCoder: BCoder[Id],
val coders: Map[Id, BCoder[T]],
id: T => Id
) extends CustomCoder[T] {

override def toString: String = {
val body = coders.map { case (id, coder) => s"$id -> $coder" }.mkString(", ")
s"DisjunctionCoder[$typeName]($body)"
}

override def equals(obj: Any): Boolean = obj match {
case that: DisjunctionCoder[_, _] =>
typeName == that.typeName && idCoder == that.idCoder && coders == that.coders
case _ =>
false
}

override def hashCode(): Int = Objects.hash(typeName, coders)

def encode(value: T, os: OutputStream): Unit = {
val i = id(value)
idCoder.encode(i, os)
coders(i).encode(value, os)
}

def decode(is: InputStream): T = {
val i = idCoder.decode(is)
coders(i).decode(is)
}

override def verifyDeterministic(): Unit = {
var cause = Option.empty[NonDeterministicException]
val reasons = List.newBuilder[String]
coders.foreach { case (id, c) =>
try {
c.verifyDeterministic()
} catch {
case e: NonDeterministicException =>
cause = Some(e)
reasons += s"case $id is using non-deterministic $c"
}
}

cause.foreach { e =>
throw new NonDeterministicException(this, reasons.result().asJava, e)
}
}

override def consistentWithEquals(): Boolean =
coders.values.forall(_.consistentWithEquals())

override def structuralValue(value: T): AnyRef =
if (consistentWithEquals()) {
value.asInstanceOf[AnyRef]
} else {
coders(id(value)).structuralValue(value)
}
}

// Coder used internally specifically for Magnolia derived coders.
// It's technically possible to define Product coders only in terms of `Coder.transform`
// This is just faster
final private[scio] class RecordCoder[T](
val typeName: String,
val cs: IndexedSeq[(String, BCoder[Any])],
construct: Seq[Any] => T,
destruct: T => IndexedSeq[Any]
) extends CustomCoder[T] {

override def toString: String = {
val body = cs.map { case (l, c) => s"$l -> $c" }.mkString(", ")
s"RecordCoder[$typeName]($body)"
}

override def equals(obj: Any): Boolean = obj match {
case that: RecordCoder[_] =>
typeName == that.typeName && cs == that.cs
case _ =>
false
}

override def hashCode(): Int = Objects.hash(typeName, cs)

@inline def onErrorMsg[A](msg: => String)(f: => A): A =
try {
f
} catch {
case e: Exception => throw CoderStackTrace.append(e, msg)
}

override def encode(value: T, os: OutputStream): Unit = {
val vs = destruct(value)
var idx = 0
while (idx < cs.length) {
val (l, c) = cs(idx)
val v = vs(idx)
onErrorMsg(
s"Exception while trying to `encode` an instance of $typeName: Can't encode field $l value $v"
) {
c.encode(v, os)
}
idx += 1
}
}

override def decode(is: InputStream): T = {
val vs = Array.ofDim[Any](cs.length)
var idx = 0
while (idx < cs.length) {
val (l, c) = cs(idx)
val v = onErrorMsg(
s"Exception while trying to `decode` an instance of $typeName: Can't decode field $l"
) {
c.decode(is)
}
vs.update(idx, v)
idx += 1
}
construct(vs)
}

// delegate methods for determinism and equality checks

override def verifyDeterministic(): Unit = {
var cause = Option.empty[NonDeterministicException]
val reasons = List.newBuilder[String]
cs.foreach { case (l, c) =>
try {
c.verifyDeterministic()
} catch {
case e: NonDeterministicException =>
cause = Some(e)
reasons += s"field $l is using non-deterministic $c"
}
}

cause.foreach { e =>
throw new NonDeterministicException(this, reasons.result().asJava, e)
}
}

override def consistentWithEquals(): Boolean = cs.forall(_._2.consistentWithEquals())

override def structuralValue(value: T): AnyRef =
if (consistentWithEquals()) {
value.asInstanceOf[AnyRef]
} else {
val svs = Array.ofDim[Any](cs.length)
val vs = destruct(value)
var idx = 0
while (idx < cs.length) {
val (l, c) = cs(idx)
val v = vs(idx)
val sv = onErrorMsg(
s"Exception while trying compute `structuralValue` for field $l with value $v"
) {
c.structuralValue(v)
}
svs.update(idx, sv)
idx += 1
}
// return a scala Seq which defines proper equality for structuralValue comparison
svs.toSeq
}

// delegate methods for byte size estimation
override def isRegisterByteSizeObserverCheap(value: T): Boolean = {
val vs = destruct(value)
var isCheap = true
var idx = 0
while (isCheap && idx < cs.length) {
val (_, c) = cs(idx)
val v = vs(idx)
isCheap = c.isRegisterByteSizeObserverCheap(v)
idx += 1
}
isCheap
}

override def registerByteSizeObserver(value: T, observer: ElementByteSizeObserver): Unit = {
val vs = destruct(value)
var idx = 0
while (idx < cs.length) {
val (_, c) = cs(idx)
val v = vs(idx)
c.registerByteSizeObserver(v, observer)
idx += 1
}
}
}

final private[scio] class TransformCoder[T, U](
val typeName: String,
val bcoder: BCoder[U],
to: T => U,
from: U => T
) extends CustomCoder[T] {

override def toString: String = s"TransformCoder[$typeName]($bcoder)"

override def equals(obj: Any): Boolean = obj match {
case that: TransformCoder[_, _] =>
// Assume that all TransformCoder from typeName using bcoder are equal
typeName == that.typeName && bcoder == that.bcoder
case _ =>
false
}

override def hashCode(): Int = Objects.hash(typeName, bcoder)
override def encode(value: T, os: OutputStream): Unit =
bcoder.encode(to(value), os)

override def encode(value: T, os: OutputStream, context: BCoder.Context): Unit =
bcoder.encode(to(value), os, context)

override def decode(is: InputStream): T =
from(bcoder.decode(is))

override def decode(is: InputStream, context: BCoder.Context): T =
from(bcoder.decode(is, context))

override def verifyDeterministic(): Unit =
bcoder.verifyDeterministic()

// Here we make the assumption that mapping functions are idempotent
override def consistentWithEquals(): Boolean =
bcoder.consistentWithEquals()

override def structuralValue(value: T): AnyRef =
bcoder.structuralValue(to(value))

override def isRegisterByteSizeObserverCheap(value: T): Boolean =
bcoder.isRegisterByteSizeObserverCheap(to(value))

override def registerByteSizeObserver(value: T, observer: ElementByteSizeObserver): Unit =
bcoder.registerByteSizeObserver(to(value), observer)
}

sealed abstract private[scio] class WrappedCoder[T] extends StructuredCoder[T] {
def bcoder: BCoder[T]

override def getCoderArguments: JList[_ <: BCoder[_]] =
Collections.singletonList(bcoder)

override def encode(value: T, os: OutputStream): Unit =
bcoder.encode(value, os)
override def encode(value: T, os: OutputStream, context: BCoder.Context): Unit =
bcoder.encode(value, os, context)
override def decode(is: InputStream): T =
bcoder.decode(is)
override def decode(is: InputStream, context: BCoder.Context): T =
bcoder.decode(is, context)
override def verifyDeterministic(): Unit =
bcoder.verifyDeterministic()
override def consistentWithEquals(): Boolean =
bcoder.consistentWithEquals()
override def structuralValue(value: T): AnyRef =
bcoder.structuralValue(value)
override def isRegisterByteSizeObserverCheap(value: T): Boolean =
bcoder.isRegisterByteSizeObserverCheap(value)
override def registerByteSizeObserver(value: T, observer: ElementByteSizeObserver): Unit =
bcoder.registerByteSizeObserver(value, observer)
}

final private[scio] class RefCoder[T](var bcoder: BCoder[T]) extends WrappedCoder[T] {
def this() = this(null)
override def toString: String = bcoder.toString
}

final private[scio] class LazyCoder[T](val typeName: String, bc: => BCoder[T])
extends WrappedCoder[T] {

@transient override lazy val bcoder: BCoder[T] = bc

override def toString: String = s"LazyCoder[$typeName]"

// stop call stack and only compare on typeName
override def equals(obj: Any): Boolean = obj match {
case that: LazyCoder[_] => typeName == that.typeName
case _ => false
}

override def hashCode(): Int = typeName.hashCode

// stop call stack and not interfere with other result
override def verifyDeterministic(): Unit = {}

// stop call stack and not interfere with other result
override def consistentWithEquals(): Boolean = true

// stop call stack and not interfere with other result
override def isRegisterByteSizeObserverCheap(value: T): Boolean = true
}

// Contains the materialization stack trace to provide a helpful stacktrace if an exception happens
final private[scio] class MaterializedCoder[T](
val bcoder: BCoder[T],
materializationStackTrace: Array[StackTraceElement]
) extends WrappedCoder[T] {

def this(bcoder: BCoder[T]) = this(bcoder, CoderStackTrace.prepare)

override def toString: String = bcoder.toString

@inline private def catching[A](a: => A) =
try {
a
} catch {
case ex: Throwable =>
// prior to scio 0.8, a wrapped exception was thrown. It is no longer the case, as some
// backends (e.g. Flink) use exceptions as a way to signal from the Coder to the layers
// above here; we therefore must alter the type of exceptions passing through this block.
throw CoderStackTrace.append(ex, materializationStackTrace)
}

override def encode(value: T, os: OutputStream): Unit =
catching(super.encode(value, os))

override def encode(value: T, os: OutputStream, context: BCoder.Context): Unit =
catching(super.encode(value, os, context))

override def decode(is: InputStream): T =
catching(super.decode(is))

override def decode(is: InputStream, context: BCoder.Context): T =
catching(super.decode(is, context))
}

/**
* Coder Grammar is used to explicitly specify Coder derivation for types used in pipelines.
*
Expand Down Expand Up @@ -681,7 +321,7 @@ private[coders] object CoderStackTrace {
.currentThread()
.getStackTrace
.dropWhile(!_.getClassName.contains(CoderMaterializer.getClass.getName))
.take(10)
.take(15)

def append[T <: Throwable](
cause: T,
Expand Down
Loading

0 comments on commit 62751fe

Please sign in to comment.