Skip to content

Commit

Permalink
Merge 252a9e9 into 6759be8
Browse files Browse the repository at this point in the history
  • Loading branch information
RustedBones authored Jun 16, 2023
2 parents 6759be8 + 252a9e9 commit 66ee5eb
Show file tree
Hide file tree
Showing 9 changed files with 199 additions and 150 deletions.
1 change: 0 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,6 @@ lazy val `scio-test`: Project = project
"com.google.http-client" % "google-http-client" % googleHttpClientsVersion,
"com.lihaoyi" %% "fansi" % fansiVersion,
"com.lihaoyi" %% "pprint" % pprintVersion,
"com.softwaremill.magnolia1_2" %% "magnolia" % magnoliaVersion,
"com.spotify" %% "magnolify-guava" % magnolifyVersion,
"com.twitter" %% "chill" % chillVersion,
"commons-io" % "commons-io" % commonsIoVersion,
Expand Down
117 changes: 49 additions & 68 deletions scio-core/src/main/scala/com/spotify/scio/coders/BeamCoders.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,110 +17,91 @@

package com.spotify.scio.coders

import com.spotify.scio.coders.{instances => scio}
import com.spotify.scio.values.{SCollection, SideInput}
import org.apache.beam.sdk.{coders => beam}
import com.spotify.scio.coders.CoderMaterializer.CoderOptions
import com.spotify.scio.values.SCollection
import org.apache.beam.sdk.coders.{Coder => BCoder, NullableCoder, StructuredCoder}
import org.apache.beam.sdk.values.PCollection

import scala.annotation.tailrec
import scala.jdk.CollectionConverters._

/** Utility for extracting [[Coder]]s from Scio types. */
private[scio] object BeamCoders {
@tailrec
private def unwrap[T](coder: beam.Coder[T]): beam.Coder[T] =
private def unwrap[T](options: CoderOptions, coder: BCoder[T]): BCoder[T] =
coder match {
case c: WrappedCoder[T] => unwrap(c.bcoder)
case c: beam.NullableCoder[T] => c.getValueCoder
case _ => coder
case c: MaterializedCoder[T] => unwrap(options, c.bcoder)
case c: NullableCoder[T] if options.nullableCoders => c.getValueCoder
case _ => coder
}

@inline
private def coderElement[T](productCoder: RecordCoder[_])(n: Int): beam.Coder[T] =
productCoder.cs(n)._2.asInstanceOf[beam.Coder[T]]

/** Get coder from an `PCollection[T]`. */
def getCoder[T](coll: PCollection[T]): Coder[T] = Coder.beam(unwrap(coll.getCoder))
def getCoder[T](coll: PCollection[T]): Coder[T] = {
val options = CoderOptions(coll.getPipeline.getOptions)
Coder.beam(unwrap(options, coll.getCoder))
}

/** Get coder from an `SCollection[T]`. */
def getCoder[T](coll: SCollection[T]): Coder[T] = getCoder(coll.internal)

/** Get key-value coders from an `SCollection[(K, V)]`. */
def getTupleCoders[K, V](coll: SCollection[(K, V)]): (Coder[K], Coder[V]) = {
val options = CoderOptions(coll.context.options)
val coder = coll.internal.getCoder
val (k, v) = unwrap(coder) match {
case c: scio.Tuple2Coder[K, V] =>
(c.ac, c.bc)
case c: RecordCoder[(K, V)] =>
val ac = coderElement[K](c)(0)
val bc = coderElement[V](c)(1)
(ac, bc)
case _ =>
Some(unwrap(options, coder))
.collect { case c: StructuredCoder[_] => c }
.map(_.getComponents.asScala.toList)
.collect { case (c1: BCoder[K]) :: (c2: BCoder[V]) :: Nil =>
val k = Coder.beam(unwrap(options, c1))
val v = Coder.beam(unwrap(options, c2))
k -> v
}
.getOrElse {
throw new IllegalArgumentException(
s"Failed to extract key-value coders from Coder[(K, V)]: $coder"
)
}
(Coder.beam(unwrap(k)), Coder.beam(unwrap(v)))
}
}

def getTuple3Coders[A, B, C](coll: SCollection[(A, B, C)]): (Coder[A], Coder[B], Coder[C]) = {
val options = CoderOptions(coll.context.options)
val coder = coll.internal.getCoder
val (a, b, c) = unwrap(coder) match {
case c: scio.Tuple3Coder[A, B, C] => (c.ac, c.bc, c.cc)
case c: RecordCoder[(A, B, C)] =>
val ac = coderElement[A](c)(0)
val bc = coderElement[B](c)(1)
val cc = coderElement[C](c)(2)
(ac, bc, cc)
case _ =>
Some(unwrap(options, coder))
.collect { case c: StructuredCoder[_] => c }
.map(_.getComponents.asScala.toList)
.collect { case (c1: BCoder[A]) :: (c2: BCoder[B]) :: (c3: BCoder[C]) :: Nil =>
val a = Coder.beam(unwrap(options, c1))
val b = Coder.beam(unwrap(options, c2))
val c = Coder.beam(unwrap(options, c3))
(a, b, c)
}
.getOrElse {
throw new IllegalArgumentException(
s"Failed to extract tupled coders from Coder[(A, B, C)]: $coder"
)
}
(Coder.beam(unwrap(a)), Coder.beam(unwrap(b)), Coder.beam(unwrap(c)))
}
}

def getTuple4Coders[A, B, C, D](
coll: SCollection[(A, B, C, D)]
): (Coder[A], Coder[B], Coder[C], Coder[D]) = {
val options = CoderOptions(coll.context.options)
val coder = coll.internal.getCoder
val (a, b, c, d) = unwrap(coder) match {
case c: scio.Tuple4Coder[A, B, C, D] => (c.ac, c.bc, c.cc, c.dc)
case c: RecordCoder[(A, B, C, D)] =>
val ac = coderElement[A](c)(0)
val bc = coderElement[B](c)(1)
val cc = coderElement[C](c)(2)
val dc = coderElement[D](c)(3)
(ac, bc, cc, dc)
case _ =>
Some(unwrap(options, coder))
.collect { case c: StructuredCoder[_] => c }
.map(_.getComponents.asScala.toList)
.collect {
case (c1: BCoder[A]) :: (c2: BCoder[B]) :: (c3: BCoder[C]) :: (c4: BCoder[D]) :: Nil =>
val a = Coder.beam(unwrap(options, c1))
val b = Coder.beam(unwrap(options, c2))
val c = Coder.beam(unwrap(options, c3))
val d = Coder.beam(unwrap(options, c4))
(a, b, c, d)
}
.getOrElse {
throw new IllegalArgumentException(
s"Failed to extract tupled coders from Coder[(A, B, C, D)]: $coder"
)
}
(Coder.beam(unwrap(a)), Coder.beam(unwrap(b)), Coder.beam(unwrap(c)), Coder.beam(unwrap(d)))
}

private def getIterableV[V](coder: beam.Coder[Iterable[V]]): beam.Coder[V] =
unwrap(coder) match {
case c: scio.BaseSeqLikeCoder[Iterable, V] @unchecked => c.elemCoder
case _ =>
throw new IllegalArgumentException(
s"Failed to extract value coder from Coder[Iterable[V]]: $coder"
)
}

/** Get key-value coders from a `SideInput[Map[K, Iterable[V]]]`. */
def getMultiMapKV[K, V](si: SideInput[Map[K, Iterable[V]]]): (Coder[K], Coder[V]) = {
val coder = si.view.getPCollection.getCoder
val (k, v) = unwrap(coder) match {
// Beam's `View.asMultiMap`
case c: beam.KvCoder[K, V] @unchecked => (c.getKeyCoder, c.getValueCoder)
// `asMapSingletonSideInput`
case c: scio.MapCoder[K, Iterable[V]] @unchecked => (c.kc, getIterableV(c.vc))
case _ =>
throw new IllegalArgumentException(
s"Failed to extract key-value coders from Coder[Map[K, Iterable[V]]: $coder"
)
}
(Coder.beam(unwrap(k)), Coder.beam(unwrap(v)))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ case class ApproximateUniqueCounter[T](sampleSize: Int) extends ApproxDistinctCo
.asInstanceOf[SCollection[Long]]

override def estimateDistinctCountPerKey[K](in: SCollection[(K, T)]): SCollection[(K, Long)] = {
implicit val (keyCoder, _): (Coder[K], Coder[T]) = BeamCoders.getTupleCoders(in)
implicit val keyCoder: Coder[K] = BeamCoders.getTupleCoders(in)._1
in.toKV
.applyTransform(beam.ApproximateUnique.perKey[K, T](sampleSize))
.map(klToTuple)
Expand All @@ -83,7 +83,7 @@ case class ApproximateUniqueCounterByError[T](maximumEstimationError: Double = 0
.asInstanceOf[SCollection[Long]]

override def estimateDistinctCountPerKey[K](in: SCollection[(K, T)]): SCollection[(K, Long)] = {
implicit val (keyCoder, _): (Coder[K], Coder[T]) = BeamCoders.getTupleCoders(in)
implicit val keyCoder: Coder[K] = BeamCoders.getTupleCoders(in)._1
in.toKV
.applyTransform(beam.ApproximateUnique.perKey[K, T](maximumEstimationError))
.map(klToTuple)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,10 @@ class PairHashSCollectionFunctions[K, V](val self: SCollection[(K, V)]) {
*/
def hashJoin[W](
rhs: SCollection[(K, W)]
): SCollection[(K, (V, W))] =
): SCollection[(K, (V, W))] = {
implicit val wCoder = BeamCoders.getTupleCoders(rhs)._2
hashJoin(rhs.asMultiMapSingletonSideInput)
}

/**
* Perform an inner join with a MultiMap `SideInput[Map[K, Iterable[V]]`
Expand All @@ -56,20 +58,17 @@ class PairHashSCollectionFunctions[K, V](val self: SCollection[(K, V)]) {
*
* @group join
*/
def hashJoin[W](
def hashJoin[W: Coder](
sideInput: SideInput[Map[K, Iterable[W]]]
): SCollection[(K, (V, W))] = {
implicit val wCoder = BeamCoders.getMultiMapKV(sideInput)._2
self.transform { in =>
in.withSideInputs(sideInput)
.flatMap[(K, (V, W))] { (kv, sideInputCtx) =>
sideInputCtx(sideInput)
.getOrElse(kv._1, Iterable.empty[W])
.iterator
.map(w => (kv._1, (kv._2, w)))
}
.toSCollection
}
): SCollection[(K, (V, W))] = self.transform { in =>
in.withSideInputs(sideInput)
.flatMap[(K, (V, W))] { (kv, sideInputCtx) =>
sideInputCtx(sideInput)
.getOrElse(kv._1, Iterable.empty[W])
.iterator
.map(w => (kv._1, (kv._2, w)))
}
.toSCollection
}

/**
Expand All @@ -87,8 +86,10 @@ class PairHashSCollectionFunctions[K, V](val self: SCollection[(K, V)]) {
*/
def hashLeftOuterJoin[W](
rhs: SCollection[(K, W)]
): SCollection[(K, (V, Option[W]))] =
): SCollection[(K, (V, Option[W]))] = {
implicit val wCoder: Coder[W] = BeamCoders.getTupleCoders(rhs)._2
hashLeftOuterJoin(rhs.asMultiMapSingletonSideInput)
}

/**
* Perform a left outer join with a MultiMap `SideInput[Map[K, Iterable[V]]`
Expand All @@ -101,19 +102,16 @@ class PairHashSCollectionFunctions[K, V](val self: SCollection[(K, V)]) {
* }}}
* @group join
*/
def hashLeftOuterJoin[W](
def hashLeftOuterJoin[W: Coder](
sideInput: SideInput[Map[K, Iterable[W]]]
): SCollection[(K, (V, Option[W]))] = {
implicit val wCoder = BeamCoders.getMultiMapKV(sideInput)._2
self.transform { in =>
in.withSideInputs(sideInput)
.flatMap[(K, (V, Option[W]))] { case ((k, v), sideInputCtx) =>
val rhsSideMap = sideInputCtx(sideInput)
if (rhsSideMap.contains(k)) rhsSideMap(k).iterator.map(w => (k, (v, Some(w))))
else Iterator((k, (v, None)))
}
.toSCollection
}
): SCollection[(K, (V, Option[W]))] = self.transform { in =>
in.withSideInputs(sideInput)
.flatMap[(K, (V, Option[W]))] { case ((k, v), sideInputCtx) =>
val rhsSideMap = sideInputCtx(sideInput)
if (rhsSideMap.contains(k)) rhsSideMap(k).iterator.map(w => (k, (v, Some(w))))
else Iterator((k, (v, None)))
}
.toSCollection
}

/**
Expand All @@ -124,8 +122,10 @@ class PairHashSCollectionFunctions[K, V](val self: SCollection[(K, V)]) {
*/
def hashFullOuterJoin[W](
rhs: SCollection[(K, W)]
): SCollection[(K, (Option[V], Option[W]))] =
): SCollection[(K, (Option[V], Option[W]))] = {
implicit val wCoder: Coder[W] = BeamCoders.getTupleCoders(rhs)._2
hashFullOuterJoin(rhs.asMultiMapSingletonSideInput)
}

/**
* Perform a full outer join with a `SideInput[Map[K, Iterable[W]]]`.
Expand All @@ -139,38 +139,35 @@ class PairHashSCollectionFunctions[K, V](val self: SCollection[(K, V)]) {
*
* @group join
*/
def hashFullOuterJoin[W](
def hashFullOuterJoin[W: Coder](
sideInput: SideInput[Map[K, Iterable[W]]]
): SCollection[(K, (Option[V], Option[W]))] = {
implicit val wCoder = BeamCoders.getMultiMapKV(sideInput)._2
self.transform { in =>
val leftHashed = in
.withSideInputs(sideInput)
.flatMap { case ((k, v), sideInputCtx) =>
val rhsSideMap = sideInputCtx(sideInput)
if (rhsSideMap.contains(k)) {
rhsSideMap(k).iterator
.map[(K, (Option[V], Option[W]), Boolean)](w => (k, (Some(v), Some(w)), true))
} else {
Iterator((k, (Some(v), None), false))
}
}
.toSCollection

val rightHashed = leftHashed
.filter(_._3)
.map(_._1)
.aggregate(Set.empty[K])(_ + _, _ ++ _)
.withSideInputs(sideInput)
.flatMap { (mk, sideInputCtx) =>
val m = sideInputCtx(sideInput)
(m.keySet diff mk)
.flatMap(k => m(k).iterator.map[(K, (Option[V], Option[W]))](w => (k, (None, Some(w)))))
): SCollection[(K, (Option[V], Option[W]))] = self.transform { in =>
val leftHashed = in
.withSideInputs(sideInput)
.flatMap { case ((k, v), sideInputCtx) =>
val rhsSideMap = sideInputCtx(sideInput)
if (rhsSideMap.contains(k)) {
rhsSideMap(k).iterator
.map[(K, (Option[V], Option[W]), Boolean)](w => (k, (Some(v), Some(w)), true))
} else {
Iterator((k, (Some(v), None), false))
}
.toSCollection
}
.toSCollection

val rightHashed = leftHashed
.filter(_._3)
.map(_._1)
.aggregate(Set.empty[K])(_ + _, _ ++ _)
.withSideInputs(sideInput)
.flatMap { (mk, sideInputCtx) =>
val m = sideInputCtx(sideInput)
(m.keySet diff mk)
.flatMap(k => m(k).iterator.map[(K, (Option[V], Option[W]))](w => (k, (None, Some(w)))))
}
.toSCollection

leftHashed.map(x => (x._1, x._2)) ++ rightHashed
}
leftHashed.map(x => (x._1, x._2)) ++ rightHashed
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ case class SketchHllPlusPlus[T](p: Int, sp: Int) extends ApproxDistinctCounter[T
override def estimateDistinctCountPerKey[K](
in: SCollection[(K, T)]
): SCollection[(K, Long)] = {
implicit val (keyCoder, _): (Coder[K], Coder[T]) = BeamCoders.getTupleCoders(in)
implicit val keyCoder: Coder[K] = BeamCoders.getTupleCoders(in)._1

in.toKV
.applyTransform(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import com.google.bigtable.v2.Mutation
import com.google.bigtable.v2.Mutation.MutationCase
import com.google.protobuf.ByteString
import com.spotify.scio.values.SCollection
import com.spotify.scio.coders.Coder

import org.scalatest.matchers.{MatchResult, Matcher}

Expand All @@ -34,10 +33,6 @@ trait BigtableMatchers extends SCollectionMatchers {
type BTRow = (ByteString, Iterable[Mutation])
type BTCollection = SCollection[BTRow]

// Needed because scalac is an idiot
implicit def btCollCoder: Coder[BTRow] =
Coder.gen[(ByteString, Iterable[Mutation])]

/** Provide an implicit BT serializer for common cell value type String. */
implicit def stringBTSerializer(s: String): ByteString =
ByteString.copyFromUtf8(s)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,22 @@ class PairHashSCollectionFunctionsTest extends PipelineSpec {
runWithContext { sc =>
val p1 = sc.parallelize(Seq(("a", 1), ("a", 2), ("b", 3)))
val p2 = sc.parallelize[(String, Int)](Map.empty).asMultiMapSideInput
val p = p1.hashJoin(p2)
p should containInAnyOrder(Seq.empty[(String, (Int, Int))])
}
}

it should "support hashJoin() with transformed .asMultiMapSideInput" in {
runWithContext { sc =>
val p1 = sc.parallelize(Seq(("a", 1), ("a", 2), ("b", 3)))
val p2 = sc
.parallelize(Seq(("a", "11"), ("b", "12"), ("b", "13")))
.asMultiMapSideInput
.map(_.map { case (k, vs) => k -> vs.map(_.toInt) })

val p = p1.hashJoin(p2)
p should
containInAnyOrder(Seq.empty[(String, (Int, Int))])
containInAnyOrder(Seq(("a", (1, 11)), ("a", (2, 11)), ("b", (3, 12)), ("b", (3, 13))))
}
}

Expand Down
Loading

0 comments on commit 66ee5eb

Please sign in to comment.