diff --git a/build.sbt b/build.sbt index 41a63f9c57..e0b98ebff7 100644 --- a/build.sbt +++ b/build.sbt @@ -611,7 +611,6 @@ lazy val `scio-test`: Project = project "com.google.http-client" % "google-http-client" % googleHttpClientsVersion, "com.lihaoyi" %% "fansi" % fansiVersion, "com.lihaoyi" %% "pprint" % pprintVersion, - "com.softwaremill.magnolia1_2" %% "magnolia" % magnoliaVersion, "com.spotify" %% "magnolify-guava" % magnolifyVersion, "com.twitter" %% "chill" % chillVersion, "commons-io" % "commons-io" % commonsIoVersion, diff --git a/scio-core/src/main/scala/com/spotify/scio/coders/BeamCoders.scala b/scio-core/src/main/scala/com/spotify/scio/coders/BeamCoders.scala index b26a86934f..3cfd7369d7 100644 --- a/scio-core/src/main/scala/com/spotify/scio/coders/BeamCoders.scala +++ b/scio-core/src/main/scala/com/spotify/scio/coders/BeamCoders.scala @@ -17,110 +17,91 @@ package com.spotify.scio.coders -import com.spotify.scio.coders.{instances => scio} -import com.spotify.scio.values.{SCollection, SideInput} -import org.apache.beam.sdk.{coders => beam} +import com.spotify.scio.coders.CoderMaterializer.CoderOptions +import com.spotify.scio.values.SCollection +import org.apache.beam.sdk.coders.{Coder => BCoder, NullableCoder, StructuredCoder} import org.apache.beam.sdk.values.PCollection import scala.annotation.tailrec +import scala.jdk.CollectionConverters._ /** Utility for extracting [[Coder]]s from Scio types. */ private[scio] object BeamCoders { @tailrec - private def unwrap[T](coder: beam.Coder[T]): beam.Coder[T] = + private def unwrap[T](options: CoderOptions, coder: BCoder[T]): BCoder[T] = coder match { - case c: WrappedCoder[T] => unwrap(c.bcoder) - case c: beam.NullableCoder[T] => c.getValueCoder - case _ => coder + case c: MaterializedCoder[T] => unwrap(options, c.bcoder) + case c: NullableCoder[T] if options.nullableCoders => c.getValueCoder + case _ => coder } - @inline - private def coderElement[T](productCoder: RecordCoder[_])(n: Int): beam.Coder[T] = - productCoder.cs(n)._2.asInstanceOf[beam.Coder[T]] - /** Get coder from an `PCollection[T]`. */ - def getCoder[T](coll: PCollection[T]): Coder[T] = Coder.beam(unwrap(coll.getCoder)) + def getCoder[T](coll: PCollection[T]): Coder[T] = { + val options = CoderOptions(coll.getPipeline.getOptions) + Coder.beam(unwrap(options, coll.getCoder)) + } /** Get coder from an `SCollection[T]`. */ def getCoder[T](coll: SCollection[T]): Coder[T] = getCoder(coll.internal) /** Get key-value coders from an `SCollection[(K, V)]`. */ def getTupleCoders[K, V](coll: SCollection[(K, V)]): (Coder[K], Coder[V]) = { + val options = CoderOptions(coll.context.options) val coder = coll.internal.getCoder - val (k, v) = unwrap(coder) match { - case c: scio.Tuple2Coder[K, V] => - (c.ac, c.bc) - case c: RecordCoder[(K, V)] => - val ac = coderElement[K](c)(0) - val bc = coderElement[V](c)(1) - (ac, bc) - case _ => + Some(unwrap(options, coder)) + .collect { case c: StructuredCoder[_] => c } + .map(_.getComponents.asScala.toList) + .collect { case (c1: BCoder[K]) :: (c2: BCoder[V]) :: Nil => + val k = Coder.beam(unwrap(options, c1)) + val v = Coder.beam(unwrap(options, c2)) + k -> v + } + .getOrElse { throw new IllegalArgumentException( s"Failed to extract key-value coders from Coder[(K, V)]: $coder" ) - } - (Coder.beam(unwrap(k)), Coder.beam(unwrap(v))) + } } def getTuple3Coders[A, B, C](coll: SCollection[(A, B, C)]): (Coder[A], Coder[B], Coder[C]) = { + val options = CoderOptions(coll.context.options) val coder = coll.internal.getCoder - val (a, b, c) = unwrap(coder) match { - case c: scio.Tuple3Coder[A, B, C] => (c.ac, c.bc, c.cc) - case c: RecordCoder[(A, B, C)] => - val ac = coderElement[A](c)(0) - val bc = coderElement[B](c)(1) - val cc = coderElement[C](c)(2) - (ac, bc, cc) - case _ => + Some(unwrap(options, coder)) + .collect { case c: StructuredCoder[_] => c } + .map(_.getComponents.asScala.toList) + .collect { case (c1: BCoder[A]) :: (c2: BCoder[B]) :: (c3: BCoder[C]) :: Nil => + val a = Coder.beam(unwrap(options, c1)) + val b = Coder.beam(unwrap(options, c2)) + val c = Coder.beam(unwrap(options, c3)) + (a, b, c) + } + .getOrElse { throw new IllegalArgumentException( s"Failed to extract tupled coders from Coder[(A, B, C)]: $coder" ) - } - (Coder.beam(unwrap(a)), Coder.beam(unwrap(b)), Coder.beam(unwrap(c))) + } } def getTuple4Coders[A, B, C, D]( coll: SCollection[(A, B, C, D)] ): (Coder[A], Coder[B], Coder[C], Coder[D]) = { + val options = CoderOptions(coll.context.options) val coder = coll.internal.getCoder - val (a, b, c, d) = unwrap(coder) match { - case c: scio.Tuple4Coder[A, B, C, D] => (c.ac, c.bc, c.cc, c.dc) - case c: RecordCoder[(A, B, C, D)] => - val ac = coderElement[A](c)(0) - val bc = coderElement[B](c)(1) - val cc = coderElement[C](c)(2) - val dc = coderElement[D](c)(3) - (ac, bc, cc, dc) - case _ => + Some(unwrap(options, coder)) + .collect { case c: StructuredCoder[_] => c } + .map(_.getComponents.asScala.toList) + .collect { + case (c1: BCoder[A]) :: (c2: BCoder[B]) :: (c3: BCoder[C]) :: (c4: BCoder[D]) :: Nil => + val a = Coder.beam(unwrap(options, c1)) + val b = Coder.beam(unwrap(options, c2)) + val c = Coder.beam(unwrap(options, c3)) + val d = Coder.beam(unwrap(options, c4)) + (a, b, c, d) + } + .getOrElse { throw new IllegalArgumentException( s"Failed to extract tupled coders from Coder[(A, B, C, D)]: $coder" ) - } - (Coder.beam(unwrap(a)), Coder.beam(unwrap(b)), Coder.beam(unwrap(c)), Coder.beam(unwrap(d))) - } - - private def getIterableV[V](coder: beam.Coder[Iterable[V]]): beam.Coder[V] = - unwrap(coder) match { - case c: scio.BaseSeqLikeCoder[Iterable, V] @unchecked => c.elemCoder - case _ => - throw new IllegalArgumentException( - s"Failed to extract value coder from Coder[Iterable[V]]: $coder" - ) - } - - /** Get key-value coders from a `SideInput[Map[K, Iterable[V]]]`. */ - def getMultiMapKV[K, V](si: SideInput[Map[K, Iterable[V]]]): (Coder[K], Coder[V]) = { - val coder = si.view.getPCollection.getCoder - val (k, v) = unwrap(coder) match { - // Beam's `View.asMultiMap` - case c: beam.KvCoder[K, V] @unchecked => (c.getKeyCoder, c.getValueCoder) - // `asMapSingletonSideInput` - case c: scio.MapCoder[K, Iterable[V]] @unchecked => (c.kc, getIterableV(c.vc)) - case _ => - throw new IllegalArgumentException( - s"Failed to extract key-value coders from Coder[Map[K, Iterable[V]]: $coder" - ) - } - (Coder.beam(unwrap(k)), Coder.beam(unwrap(v))) + } } } diff --git a/scio-core/src/main/scala/com/spotify/scio/estimators/ApproxDistinctCounter.scala b/scio-core/src/main/scala/com/spotify/scio/estimators/ApproxDistinctCounter.scala index 05a9cba57a..73192c4c68 100644 --- a/scio-core/src/main/scala/com/spotify/scio/estimators/ApproxDistinctCounter.scala +++ b/scio-core/src/main/scala/com/spotify/scio/estimators/ApproxDistinctCounter.scala @@ -60,7 +60,7 @@ case class ApproximateUniqueCounter[T](sampleSize: Int) extends ApproxDistinctCo .asInstanceOf[SCollection[Long]] override def estimateDistinctCountPerKey[K](in: SCollection[(K, T)]): SCollection[(K, Long)] = { - implicit val (keyCoder, _): (Coder[K], Coder[T]) = BeamCoders.getTupleCoders(in) + implicit val keyCoder: Coder[K] = BeamCoders.getTupleCoders(in)._1 in.toKV .applyTransform(beam.ApproximateUnique.perKey[K, T](sampleSize)) .map(klToTuple) @@ -83,7 +83,7 @@ case class ApproximateUniqueCounterByError[T](maximumEstimationError: Double = 0 .asInstanceOf[SCollection[Long]] override def estimateDistinctCountPerKey[K](in: SCollection[(K, T)]): SCollection[(K, Long)] = { - implicit val (keyCoder, _): (Coder[K], Coder[T]) = BeamCoders.getTupleCoders(in) + implicit val keyCoder: Coder[K] = BeamCoders.getTupleCoders(in)._1 in.toKV .applyTransform(beam.ApproximateUnique.perKey[K, T](maximumEstimationError)) .map(klToTuple) diff --git a/scio-core/src/main/scala/com/spotify/scio/values/PairHashSCollectionFunctions.scala b/scio-core/src/main/scala/com/spotify/scio/values/PairHashSCollectionFunctions.scala index b1138ec433..3fc75b9062 100644 --- a/scio-core/src/main/scala/com/spotify/scio/values/PairHashSCollectionFunctions.scala +++ b/scio-core/src/main/scala/com/spotify/scio/values/PairHashSCollectionFunctions.scala @@ -39,8 +39,10 @@ class PairHashSCollectionFunctions[K, V](val self: SCollection[(K, V)]) { */ def hashJoin[W]( rhs: SCollection[(K, W)] - ): SCollection[(K, (V, W))] = + ): SCollection[(K, (V, W))] = { + implicit val wCoder = BeamCoders.getTupleCoders(rhs)._2 hashJoin(rhs.asMultiMapSingletonSideInput) + } /** * Perform an inner join with a MultiMap `SideInput[Map[K, Iterable[V]]` @@ -56,20 +58,17 @@ class PairHashSCollectionFunctions[K, V](val self: SCollection[(K, V)]) { * * @group join */ - def hashJoin[W]( + def hashJoin[W: Coder]( sideInput: SideInput[Map[K, Iterable[W]]] - ): SCollection[(K, (V, W))] = { - implicit val wCoder = BeamCoders.getMultiMapKV(sideInput)._2 - self.transform { in => - in.withSideInputs(sideInput) - .flatMap[(K, (V, W))] { (kv, sideInputCtx) => - sideInputCtx(sideInput) - .getOrElse(kv._1, Iterable.empty[W]) - .iterator - .map(w => (kv._1, (kv._2, w))) - } - .toSCollection - } + ): SCollection[(K, (V, W))] = self.transform { in => + in.withSideInputs(sideInput) + .flatMap[(K, (V, W))] { (kv, sideInputCtx) => + sideInputCtx(sideInput) + .getOrElse(kv._1, Iterable.empty[W]) + .iterator + .map(w => (kv._1, (kv._2, w))) + } + .toSCollection } /** @@ -87,8 +86,10 @@ class PairHashSCollectionFunctions[K, V](val self: SCollection[(K, V)]) { */ def hashLeftOuterJoin[W]( rhs: SCollection[(K, W)] - ): SCollection[(K, (V, Option[W]))] = + ): SCollection[(K, (V, Option[W]))] = { + implicit val wCoder: Coder[W] = BeamCoders.getTupleCoders(rhs)._2 hashLeftOuterJoin(rhs.asMultiMapSingletonSideInput) + } /** * Perform a left outer join with a MultiMap `SideInput[Map[K, Iterable[V]]` @@ -101,19 +102,16 @@ class PairHashSCollectionFunctions[K, V](val self: SCollection[(K, V)]) { * }}} * @group join */ - def hashLeftOuterJoin[W]( + def hashLeftOuterJoin[W: Coder]( sideInput: SideInput[Map[K, Iterable[W]]] - ): SCollection[(K, (V, Option[W]))] = { - implicit val wCoder = BeamCoders.getMultiMapKV(sideInput)._2 - self.transform { in => - in.withSideInputs(sideInput) - .flatMap[(K, (V, Option[W]))] { case ((k, v), sideInputCtx) => - val rhsSideMap = sideInputCtx(sideInput) - if (rhsSideMap.contains(k)) rhsSideMap(k).iterator.map(w => (k, (v, Some(w)))) - else Iterator((k, (v, None))) - } - .toSCollection - } + ): SCollection[(K, (V, Option[W]))] = self.transform { in => + in.withSideInputs(sideInput) + .flatMap[(K, (V, Option[W]))] { case ((k, v), sideInputCtx) => + val rhsSideMap = sideInputCtx(sideInput) + if (rhsSideMap.contains(k)) rhsSideMap(k).iterator.map(w => (k, (v, Some(w)))) + else Iterator((k, (v, None))) + } + .toSCollection } /** @@ -124,8 +122,10 @@ class PairHashSCollectionFunctions[K, V](val self: SCollection[(K, V)]) { */ def hashFullOuterJoin[W]( rhs: SCollection[(K, W)] - ): SCollection[(K, (Option[V], Option[W]))] = + ): SCollection[(K, (Option[V], Option[W]))] = { + implicit val wCoder: Coder[W] = BeamCoders.getTupleCoders(rhs)._2 hashFullOuterJoin(rhs.asMultiMapSingletonSideInput) + } /** * Perform a full outer join with a `SideInput[Map[K, Iterable[W]]]`. @@ -139,38 +139,35 @@ class PairHashSCollectionFunctions[K, V](val self: SCollection[(K, V)]) { * * @group join */ - def hashFullOuterJoin[W]( + def hashFullOuterJoin[W: Coder]( sideInput: SideInput[Map[K, Iterable[W]]] - ): SCollection[(K, (Option[V], Option[W]))] = { - implicit val wCoder = BeamCoders.getMultiMapKV(sideInput)._2 - self.transform { in => - val leftHashed = in - .withSideInputs(sideInput) - .flatMap { case ((k, v), sideInputCtx) => - val rhsSideMap = sideInputCtx(sideInput) - if (rhsSideMap.contains(k)) { - rhsSideMap(k).iterator - .map[(K, (Option[V], Option[W]), Boolean)](w => (k, (Some(v), Some(w)), true)) - } else { - Iterator((k, (Some(v), None), false)) - } - } - .toSCollection - - val rightHashed = leftHashed - .filter(_._3) - .map(_._1) - .aggregate(Set.empty[K])(_ + _, _ ++ _) - .withSideInputs(sideInput) - .flatMap { (mk, sideInputCtx) => - val m = sideInputCtx(sideInput) - (m.keySet diff mk) - .flatMap(k => m(k).iterator.map[(K, (Option[V], Option[W]))](w => (k, (None, Some(w))))) + ): SCollection[(K, (Option[V], Option[W]))] = self.transform { in => + val leftHashed = in + .withSideInputs(sideInput) + .flatMap { case ((k, v), sideInputCtx) => + val rhsSideMap = sideInputCtx(sideInput) + if (rhsSideMap.contains(k)) { + rhsSideMap(k).iterator + .map[(K, (Option[V], Option[W]), Boolean)](w => (k, (Some(v), Some(w)), true)) + } else { + Iterator((k, (Some(v), None), false)) } - .toSCollection + } + .toSCollection + + val rightHashed = leftHashed + .filter(_._3) + .map(_._1) + .aggregate(Set.empty[K])(_ + _, _ ++ _) + .withSideInputs(sideInput) + .flatMap { (mk, sideInputCtx) => + val m = sideInputCtx(sideInput) + (m.keySet diff mk) + .flatMap(k => m(k).iterator.map[(K, (Option[V], Option[W]))](w => (k, (None, Some(w))))) + } + .toSCollection - leftHashed.map(x => (x._1, x._2)) ++ rightHashed - } + leftHashed.map(x => (x._1, x._2)) ++ rightHashed } /** diff --git a/scio-extra/src/main/scala/com/spotify/scio/extra/hll/sketching/SketchHllPlusPlus.scala b/scio-extra/src/main/scala/com/spotify/scio/extra/hll/sketching/SketchHllPlusPlus.scala index d98a5a2e76..848e2d4009 100644 --- a/scio-extra/src/main/scala/com/spotify/scio/extra/hll/sketching/SketchHllPlusPlus.scala +++ b/scio-extra/src/main/scala/com/spotify/scio/extra/hll/sketching/SketchHllPlusPlus.scala @@ -61,7 +61,7 @@ case class SketchHllPlusPlus[T](p: Int, sp: Int) extends ApproxDistinctCounter[T override def estimateDistinctCountPerKey[K]( in: SCollection[(K, T)] ): SCollection[(K, Long)] = { - implicit val (keyCoder, _): (Coder[K], Coder[T]) = BeamCoders.getTupleCoders(in) + implicit val keyCoder: Coder[K] = BeamCoders.getTupleCoders(in)._1 in.toKV .applyTransform( diff --git a/scio-test/src/main/scala/com/spotify/scio/testing/BigtableMatchers.scala b/scio-test/src/main/scala/com/spotify/scio/testing/BigtableMatchers.scala index 610f5fdfab..e057efef30 100644 --- a/scio-test/src/main/scala/com/spotify/scio/testing/BigtableMatchers.scala +++ b/scio-test/src/main/scala/com/spotify/scio/testing/BigtableMatchers.scala @@ -21,7 +21,6 @@ import com.google.bigtable.v2.Mutation import com.google.bigtable.v2.Mutation.MutationCase import com.google.protobuf.ByteString import com.spotify.scio.values.SCollection -import com.spotify.scio.coders.Coder import org.scalatest.matchers.{MatchResult, Matcher} @@ -34,10 +33,6 @@ trait BigtableMatchers extends SCollectionMatchers { type BTRow = (ByteString, Iterable[Mutation]) type BTCollection = SCollection[BTRow] - // Needed because scalac is an idiot - implicit def btCollCoder: Coder[BTRow] = - Coder.gen[(ByteString, Iterable[Mutation])] - /** Provide an implicit BT serializer for common cell value type String. */ implicit def stringBTSerializer(s: String): ByteString = ByteString.copyFromUtf8(s) diff --git a/scio-test/src/test/scala/com/spotify/scio/values/PairHashSCollectionFunctionsTest.scala b/scio-test/src/test/scala/com/spotify/scio/values/PairHashSCollectionFunctionsTest.scala index 53096f6b90..560266806c 100644 --- a/scio-test/src/test/scala/com/spotify/scio/values/PairHashSCollectionFunctionsTest.scala +++ b/scio-test/src/test/scala/com/spotify/scio/values/PairHashSCollectionFunctionsTest.scala @@ -77,9 +77,22 @@ class PairHashSCollectionFunctionsTest extends PipelineSpec { runWithContext { sc => val p1 = sc.parallelize(Seq(("a", 1), ("a", 2), ("b", 3))) val p2 = sc.parallelize[(String, Int)](Map.empty).asMultiMapSideInput + val p = p1.hashJoin(p2) + p should containInAnyOrder(Seq.empty[(String, (Int, Int))]) + } + } + + it should "support hashJoin() with transformed .asMultiMapSideInput" in { + runWithContext { sc => + val p1 = sc.parallelize(Seq(("a", 1), ("a", 2), ("b", 3))) + val p2 = sc + .parallelize(Seq(("a", "11"), ("b", "12"), ("b", "13"))) + .asMultiMapSideInput + .map(_.map { case (k, vs) => k -> vs.map(_.toInt) }) + val p = p1.hashJoin(p2) p should - containInAnyOrder(Seq.empty[(String, (Int, Int))]) + containInAnyOrder(Seq(("a", (1, 11)), ("a", (2, 11)), ("b", (3, 12)), ("b", (3, 13)))) } } diff --git a/scio-test/src/test/scala/com/spotify/scio/values/PairSCollectionFunctionsTest.scala b/scio-test/src/test/scala/com/spotify/scio/values/PairSCollectionFunctionsTest.scala index 33f3b66076..0e88697c51 100644 --- a/scio-test/src/test/scala/com/spotify/scio/values/PairSCollectionFunctionsTest.scala +++ b/scio-test/src/test/scala/com/spotify/scio/values/PairSCollectionFunctionsTest.scala @@ -17,29 +17,68 @@ package com.spotify.scio.values -import com.spotify.scio.coders.Beam +import com.spotify.scio.coders.{Beam, MaterializedCoder} import com.spotify.scio.testing.PipelineSpec import com.spotify.scio.util.random.RandomSamplerUtils import com.spotify.scio.hash._ import com.spotify.scio.options.ScioOptions import com.twitter.algebird.Aggregator import magnolify.guava.auto._ -import org.apache.beam.sdk.coders.{StringUtf8Coder, VarIntCoder} +import org.apache.beam.sdk.coders.{NullableCoder, StringUtf8Coder, StructuredCoder, VarIntCoder} import scala.collection.mutable class PairSCollectionFunctionsTest extends PipelineSpec { - "PairSCollection" should "propagate unwrapped coders" in { + "PairSCollection" should "propagates unwrapped coders" in { + runWithContext { sc => + val coll = sc.empty[(String, Int)]() + // internal is wrapped + val internalCoder = coll.internal.getCoder + internalCoder shouldBe a[MaterializedCoder[_]] + val materializedCoder = internalCoder.asInstanceOf[MaterializedCoder[_]] + materializedCoder.bcoder shouldBe a[StructuredCoder[_]] + val tupleCoder = materializedCoder.bcoder.asInstanceOf[StructuredCoder[_]] + val keyCoder = tupleCoder.getComponents.get(0) + keyCoder shouldBe StringUtf8Coder.of() + val valueCoder = tupleCoder.getComponents.get(1) + valueCoder shouldBe VarIntCoder.of() + // implicit SCollection key and value coder aren't + coll.keyCoder shouldBe a[Beam[_]] + val beamKeyCoder = coll.keyCoder.asInstanceOf[Beam[_]] + beamKeyCoder.beam shouldBe StringUtf8Coder.of() + + coll.valueCoder shouldBe a[Beam[_]] + val beamValueCoder = coll.valueCoder.asInstanceOf[Beam[_]] + beamValueCoder.beam shouldBe VarIntCoder.of() + } + } + + it should "propagate unwrapped nullable coders" in { runWithContext { sc => sc.optionsAs[ScioOptions].setNullableCoders(true) val coll = sc.empty[(String, Int)]() - coll.keyCoder shouldBe a[Beam[String]] - // No WrappedCoder nor NullableCoder - coll.keyCoder.asInstanceOf[Beam[String]].beam shouldBe StringUtf8Coder.of() - - coll.valueCoder shouldBe a[Beam[Int]] - coll.valueCoder.asInstanceOf[Beam[Int]].beam shouldBe VarIntCoder.of() + // internal is wrapped + val internalCoder = coll.internal.getCoder + internalCoder shouldBe a[MaterializedCoder[_]] + val materializedCoder = internalCoder.asInstanceOf[MaterializedCoder[_]] + materializedCoder.bcoder shouldBe a[NullableCoder[_]] + val nullableTupleCoder = materializedCoder.bcoder.asInstanceOf[NullableCoder[_]] + val tupleCoder = nullableTupleCoder.getValueCoder.asInstanceOf[StructuredCoder[_]] + val keyCoder = tupleCoder.getComponents.get(0) + keyCoder shouldBe a[NullableCoder[_]] + keyCoder.asInstanceOf[NullableCoder[_]].getValueCoder shouldBe StringUtf8Coder.of() + val valueCoder = tupleCoder.getComponents.get(1) + valueCoder shouldBe a[NullableCoder[_]] + valueCoder.asInstanceOf[NullableCoder[_]].getValueCoder shouldBe VarIntCoder.of() + // implicit SCollection key and value coder aren't + coll.keyCoder shouldBe a[Beam[_]] + val beamKeyCoder = coll.keyCoder.asInstanceOf[Beam[_]] + beamKeyCoder.beam shouldBe StringUtf8Coder.of() + + coll.valueCoder shouldBe a[Beam[_]] + val beamValueCoder = coll.valueCoder.asInstanceOf[Beam[_]] + beamValueCoder.beam shouldBe VarIntCoder.of() } } diff --git a/scio-test/src/test/scala/com/spotify/scio/values/SCollectionTest.scala b/scio-test/src/test/scala/com/spotify/scio/values/SCollectionTest.scala index cc22c44e7e..89cc4ca5c0 100644 --- a/scio-test/src/test/scala/com/spotify/scio/values/SCollectionTest.scala +++ b/scio-test/src/test/scala/com/spotify/scio/values/SCollectionTest.scala @@ -38,10 +38,10 @@ import org.joda.time.{DateTimeConstants, Duration, Instant} import scala.collection.mutable import scala.jdk.CollectionConverters._ -import com.spotify.scio.coders.{Beam, Coder} +import com.spotify.scio.coders.{Beam, Coder, MaterializedCoder} import com.spotify.scio.options.ScioOptions import com.spotify.scio.schemas.Schema -import org.apache.beam.sdk.coders.StringUtf8Coder +import org.apache.beam.sdk.coders.{NullableCoder, StringUtf8Coder} import java.nio.charset.StandardCharsets @@ -54,14 +54,39 @@ class SCollectionTest extends PipelineSpec { import SCollectionTest._ - "SCollection" should "propagate unwrapped coders" in { + "SCollection" should "propagates unwrapped coders" in { + runWithContext { sc => + val coll = sc.empty[String]() + // internal is wrapped + val internalCoder = coll.internal.getCoder + internalCoder shouldBe a[MaterializedCoder[_]] + val materializedCoder = internalCoder.asInstanceOf[MaterializedCoder[_]] + materializedCoder.bcoder shouldBe StringUtf8Coder.of() + // implicit SCollection coder is not + val scioCoder = coll.coder + scioCoder shouldBe a[Beam[_]] + val beamCoder = scioCoder.asInstanceOf[Beam[_]] + beamCoder.beam shouldBe StringUtf8Coder.of() + } + } + + it should "propagates unwrapped nullable coders" in { runWithContext { sc => sc.optionsAs[ScioOptions].setNullableCoders(true) val coll = sc.empty[String]() - coll.coder shouldBe a[Beam[String]] - // No WrappedCoder nor NullableCoder - coll.coder.asInstanceOf[Beam[String]].beam shouldBe StringUtf8Coder.of() + // internal is wrapped + val internalCoder = coll.internal.getCoder + internalCoder shouldBe a[MaterializedCoder[_]] + val materializedCoder = internalCoder.asInstanceOf[MaterializedCoder[_]] + materializedCoder.bcoder shouldBe a[NullableCoder[_]] + val nullableCoder = materializedCoder.bcoder.asInstanceOf[NullableCoder[_]] + nullableCoder.getValueCoder shouldBe StringUtf8Coder.of() + // implicit SCollection coder is not + val scioCoder = coll.coder + scioCoder shouldBe a[Beam[_]] + val beamCoder = scioCoder.asInstanceOf[Beam[_]] + beamCoder.beam shouldBe StringUtf8Coder.of() } }