diff --git a/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/AvroBucketMetadata.java b/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/AvroBucketMetadata.java index 2be69b9cd7..5c20e1ec33 100644 --- a/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/AvroBucketMetadata.java +++ b/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/AvroBucketMetadata.java @@ -134,12 +134,13 @@ public Map, Coder> coderOverrides() { @Override int hashPrimaryKeyMetadata() { - return Objects.hash(keyField, getKeyClass()); + return Objects.hash(keyField, AvroUtils.castToComparableStringClass(getKeyClass())); } @Override int hashSecondaryKeyMetadata() { - return Objects.hash(keyFieldSecondary, getKeyClassSecondary()); + return Objects.hash( + keyFieldSecondary, AvroUtils.castToComparableStringClass(getKeyClassSecondary())); } @Override @@ -194,19 +195,15 @@ public void populateDisplayData(Builder builder) { @Override boolean keyClassMatches(Class requestedReadType) { - if (requestedReadType == String.class && getKeyClass() == CharSequence.class) { - return true; - } else { - return super.keyClassMatches(requestedReadType); - } + return super.keyClassMatches(requestedReadType) + || AvroUtils.castToComparableStringClass(getKeyClass()) == requestedReadType + || AvroUtils.castToComparableStringClass(requestedReadType) == getKeyClass(); } @Override boolean keyClassSecondaryMatches(Class requestedReadType) { - if (requestedReadType == String.class && getKeyClassSecondary() == CharSequence.class) { - return true; - } else { - return super.keyClassSecondaryMatches(requestedReadType); - } + return super.keyClassSecondaryMatches(requestedReadType) + || AvroUtils.castToComparableStringClass(getKeyClassSecondary()) == requestedReadType + || AvroUtils.castToComparableStringClass(requestedReadType) == getKeyClassSecondary(); } } diff --git a/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/AvroUtils.java b/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/AvroUtils.java index eee1fd850e..4bfcda12f4 100644 --- a/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/AvroUtils.java +++ b/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/AvroUtils.java @@ -166,6 +166,14 @@ public static Map, Coder> coderOverrides() { CharSequence.class, CharSequenceCoder.of()); } + static Class castToComparableStringClass(Class cls) { + if (cls == String.class) { + return CharSequence.class; + } else { + return cls; + } + } + private static class ByteBufferCoder extends AtomicCoder { private static final ByteBufferCoder INSTANCE = new ByteBufferCoder(); diff --git a/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/BucketMetadata.java b/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/BucketMetadata.java index 74f619b066..5c634f94fd 100644 --- a/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/BucketMetadata.java +++ b/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/BucketMetadata.java @@ -403,11 +403,9 @@ private boolean isIntraPartitionCompatibleWit final Class otherClass = other.getClass(); final Set> compatibleTypes = compatibleMetadataTypes(); - if (compatibleTypes.isEmpty() && other.getClass() != this.getClass()) { - return false; - } else if (this.getKeyClass() != other.getKeyClass() + if ((other.getClass() != this.getClass()) && !(compatibleTypes.contains(otherClass) - && (other.compatibleMetadataTypes().contains(this.getClass())))) { + || other.compatibleMetadataTypes().contains(this.getClass()))) { return false; } diff --git a/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/ParquetBucketMetadata.java b/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/ParquetBucketMetadata.java index 2016c1258c..b23f495223 100644 --- a/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/ParquetBucketMetadata.java +++ b/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/ParquetBucketMetadata.java @@ -149,12 +149,13 @@ public void populateDisplayData(DisplayData.Builder builder) { @Override int hashPrimaryKeyMetadata() { - return Objects.hash(keyField, getKeyClass()); + return Objects.hash(keyField, AvroUtils.castToComparableStringClass(getKeyClass())); } @Override int hashSecondaryKeyMetadata() { - return Objects.hash(keyFieldSecondary, getKeyClassSecondary()); + return Objects.hash( + keyFieldSecondary, AvroUtils.castToComparableStringClass(getKeyClassSecondary())); } @Override @@ -230,20 +231,16 @@ static K extractKey(Method[] keyGetters, Object value) { @Override boolean keyClassMatches(Class requestedReadType) { - if (requestedReadType == String.class && getKeyClass() == CharSequence.class) { - return true; - } else { - return super.keyClassMatches(requestedReadType); - } + return super.keyClassMatches(requestedReadType) + || AvroUtils.castToComparableStringClass(getKeyClass()) == requestedReadType + || AvroUtils.castToComparableStringClass(requestedReadType) == getKeyClass(); } @Override boolean keyClassSecondaryMatches(Class requestedReadType) { - if (requestedReadType == String.class && getKeyClassSecondary() == CharSequence.class) { - return true; - } else { - return super.keyClassSecondaryMatches(requestedReadType); - } + return super.keyClassSecondaryMatches(requestedReadType) + || AvroUtils.castToComparableStringClass(getKeyClassSecondary()) == requestedReadType + || AvroUtils.castToComparableStringClass(requestedReadType) == getKeyClassSecondary(); } //////////////////////////////////////////////////////////////////////////////// diff --git a/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/SortedBucketIO.java b/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/SortedBucketIO.java index 4daa45b3d4..5ab2e6a0c1 100644 --- a/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/SortedBucketIO.java +++ b/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/SortedBucketIO.java @@ -388,7 +388,7 @@ private Coder keyTypeCoder() { Optional> c = inputs.stream() .flatMap(i -> i.getSourceMetadata().mapping.values().stream()) - .filter(sm -> sm.metadata.getKeyClass() == keyClassPrimary) + .filter(sm -> sm.metadata.keyClassMatches(keyClassPrimary)) .findFirst() .map(sm -> (Coder) sm.metadata.getKeyCoder()); if (!c.isPresent()) @@ -436,7 +436,7 @@ private Coder keyCoderPrimary() { Optional> c = inputs.stream() .flatMap(i -> i.getSourceMetadata().mapping.values().stream()) - .filter(sm -> sm.metadata.getKeyClass() == keyClassPrimary) + .filter(sm -> sm.metadata.keyClassMatches(keyClassPrimary)) .findFirst() .map(sm -> (Coder) sm.metadata.getKeyCoder()); if (!c.isPresent()) @@ -454,7 +454,7 @@ private Coder keyCoderSecondary() { .filter( sm -> sm.metadata.getKeyClassSecondary() != null - && sm.metadata.getKeyClassSecondary() == keyClassSecondary + && sm.metadata.keyClassSecondaryMatches(keyClassSecondary) && sm.metadata.getKeyCoderSecondary() != null) .findFirst() .map(sm -> (Coder) sm.metadata.getKeyCoderSecondary()); diff --git a/scio-smb/src/test/java/org/apache/beam/sdk/extensions/smb/AvroBucketMetadataTest.java b/scio-smb/src/test/java/org/apache/beam/sdk/extensions/smb/AvroBucketMetadataTest.java index 5813537657..2f2492d076 100644 --- a/scio-smb/src/test/java/org/apache/beam/sdk/extensions/smb/AvroBucketMetadataTest.java +++ b/scio-smb/src/test/java/org/apache/beam/sdk/extensions/smb/AvroBucketMetadataTest.java @@ -363,6 +363,51 @@ public void testSameSourceCompatibility() throws Exception { Assert.assertFalse(metadata4.isPartitionCompatibleForPrimaryAndSecondaryKey(metadata5)); } + @Test + public void testParquetMetadataCompatibility() throws Exception { + final AvroBucketMetadata metadata1 = + new AvroBucketMetadata<>( + 2, + 1, + String.class, + "name", + Integer.class, + "favorite_number", + HashType.MURMUR3_32, + SortedBucketIO.DEFAULT_FILENAME_PREFIX, + AvroGeneratedUser.SCHEMA$); + + final ParquetBucketMetadata metadata2 = + new ParquetBucketMetadata<>( + 2, + 1, + String.class, + "favorite_color", + Integer.class, + "favorite_number", + HashType.MURMUR3_32, + SortedBucketIO.DEFAULT_FILENAME_PREFIX, + AvroGeneratedUser.SCHEMA$); + + final ParquetBucketMetadata metadata3 = + new ParquetBucketMetadata<>( + 4, + 1, + String.class, + "favorite_color", + Integer.class, + "favorite_number", + HashType.MURMUR3_32, + SortedBucketIO.DEFAULT_FILENAME_PREFIX, + AvroGeneratedUser.SCHEMA$); + + Assert.assertFalse(metadata1.isPartitionCompatibleForPrimaryKey(metadata2)); + Assert.assertFalse(metadata1.isPartitionCompatibleForPrimaryAndSecondaryKey(metadata2)); + + Assert.assertTrue(metadata2.isPartitionCompatibleForPrimaryKey(metadata3)); + Assert.assertTrue(metadata2.isPartitionCompatibleForPrimaryAndSecondaryKey(metadata3)); + } + @Test public void testKeyTypeCheckingBytes() throws CannotProvideCoderException, NonDeterministicException { diff --git a/scio-smb/src/test/scala/com/spotify/scio/smb/SmbVersionParityTest.scala b/scio-smb/src/test/scala/com/spotify/scio/smb/SmbVersionParityTest.scala index 71c46bd0d6..f36c21b876 100644 --- a/scio-smb/src/test/scala/com/spotify/scio/smb/SmbVersionParityTest.scala +++ b/scio-smb/src/test/scala/com/spotify/scio/smb/SmbVersionParityTest.scala @@ -30,8 +30,8 @@ import org.apache.beam.sdk.values.TupleTag import java.nio.file.Files class SmbVersionParityTest extends PipelineSpec { - private def testRoundtrip( - write: SortedBucketIO.Write[CharSequence, _, Account], + private def testReadRoundtrip( + writes: Seq[SortedBucketIO.Write[_ <: CharSequence, _, Account]], read: SortedBucketIO.Read[Account] ): Unit = { val accounts = (1 to 10).map { i => @@ -47,9 +47,10 @@ class SmbVersionParityTest extends PipelineSpec { { val sc = ScioContext() - sc.parallelize(accounts) - .saveAsSortedBucket(write) + val records = sc.parallelize(accounts) + writes.foreach(records.saveAsSortedBucket(_)) sc.run() + () } // Read data @@ -61,38 +62,157 @@ class SmbVersionParityTest extends PipelineSpec { .get(sc.run().waitUntilDone()) .flatMap(_._2) .value - .toSeq should contain theSameElementsAs accounts + .toSeq should contain theSameElementsAs writes.flatMap(_ => accounts) } - "SortedBucketSource" should "be able to read CharSequence-keyed Avro sources written before 0.14" in { - val output = Files.createTempDirectory("smb-version-test-avro").toFile - output.deleteOnExit() + private def testTransformRoundtrip( + writes: Seq[SortedBucketIO.Write[_ <: CharSequence, Void, Account]], + read: SortedBucketIO.Read[Account], + transform: SortedBucketIO.TransformOutput[String, Void, Account] + ): Unit = { + val accounts = (1 to 10).map { i => + Account + .newBuilder() + .setId(i) + .setName(i.toString) + .setAmount(i.toDouble) + .setType(s"type$i") + .setAccountStatus(AccountStatus.Active) + .build() + } + + { + val sc = ScioContext() + val records = sc.parallelize(accounts) + writes.foreach(records.saveAsSortedBucket(_)) + sc.run() + () + } + + // Read data + val sc = ScioContext() + val tap = sc + .sortMergeTransform(classOf[String], read) + .to(transform) + .via { case (_, records, outputCollector) => + records.foreach(outputCollector.accept(_)) + } + + tap + .get(sc.run().waitUntilDone()) + .value + .toSeq should contain theSameElementsAs writes.flatMap(_ => accounts) + } + + "SortedBucketSource" should "be able to read mixed CharSequence and String-keyed Avro sources" in { + val tmpDir = Files.createTempDirectory("smb-version-test-mixed-avro-read").toFile + tmpDir.deleteOnExit() + + val partition1Output = tmpDir.toPath.resolve("partition1") + val partition2Output = tmpDir.toPath.resolve("partition2") - testRoundtrip( + testReadRoundtrip( + Seq( + AvroSortedBucketIO + .write(classOf[CharSequence], "name", classOf[Account]) + .to(partition1Output.toString) + .withNumBuckets(1) + .withNumShards(1), + AvroSortedBucketIO + .write(classOf[String], "name", classOf[Account]) + .to(partition2Output.toString) + .withNumBuckets(1) + .withNumShards(1) + ), AvroSortedBucketIO - .write(classOf[CharSequence], "name", classOf[Account]) - .to(output.getAbsolutePath) - .withNumBuckets(1) - .withNumShards(1), + .read(new TupleTag[Account], classOf[Account]) + .from(partition1Output.toString, partition2Output.toString) + ) + } + + it should "be able to transform mixed CharSequence- and String-keyed Avro sources written before 0.14" in { + val tmpDir = Files.createTempDirectory("smb-version-test-avro-tfx").toFile + tmpDir.deleteOnExit() + + val partition1Output = tmpDir.toPath.resolve("partition1") + val partition2Output = tmpDir.toPath.resolve("partition2") + val tfxOutput = tmpDir.toPath.resolve("tfx") + + testTransformRoundtrip( + Seq( + AvroSortedBucketIO + .write(classOf[CharSequence], "name", classOf[Account]) + .to(partition1Output.toString) + .withNumBuckets(1) + .withNumShards(1), + AvroSortedBucketIO + .write(classOf[String], "name", classOf[Account]) + .to(partition2Output.toString) + .withNumBuckets(1) + .withNumShards(1) + ), AvroSortedBucketIO .read(new TupleTag[Account], classOf[Account]) - .from(output.getAbsolutePath) + .from(partition1Output.toString, partition2Output.toString), + AvroSortedBucketIO + .transformOutput(classOf[String], "name", classOf[Account]) + .to(tfxOutput.toString) ) } - it should "be able to read CharSequence-keyed Parquet sources written before 0.14" in { - val output = Files.createTempDirectory("smb-version-test-parquet").toFile - output.deleteOnExit() + it should "be able to read mixed CharSequence- and String-keyed-keyed Parquet sources written before 0.14" in { + val tmpDir = Files.createTempDirectory("smb-version-test-mixed-parquet-read").toFile + tmpDir.deleteOnExit() + + val partition1Output = tmpDir.toPath.resolve("partition1") + val partition2Output = tmpDir.toPath.resolve("partition2") - testRoundtrip( + testReadRoundtrip( + Seq( + ParquetAvroSortedBucketIO + .write(classOf[CharSequence], "name", classOf[Account]) + .to(partition1Output.toString) + .withNumBuckets(1) + .withNumShards(1), + ParquetAvroSortedBucketIO + .write(classOf[String], "name", classOf[Account]) + .to(partition2Output.toString) + .withNumBuckets(1) + .withNumShards(1) + ), ParquetAvroSortedBucketIO - .write(classOf[CharSequence], "name", classOf[Account]) - .to(output.getAbsolutePath) - .withNumBuckets(1) - .withNumShards(1), + .read(new TupleTag[Account], classOf[Account]) + .from(partition1Output.toString, partition2Output.toString) + ) + } + + it should "be able to transform mixed CharSequence- and String-keyed Parquet sources written before 0.14" in { + val tmpDir = Files.createTempDirectory("smb-version-test-parquet-tfx").toFile + tmpDir.deleteOnExit() + + val partition1Output = tmpDir.toPath.resolve("partition1") + val partition2Output = tmpDir.toPath.resolve("partition2") + val tfxOutput = tmpDir.toPath.resolve("tfx") + + testTransformRoundtrip( + Seq( + ParquetAvroSortedBucketIO + .write(classOf[CharSequence], "name", classOf[Account]) + .to(partition1Output.toString) + .withNumBuckets(1) + .withNumShards(1), + ParquetAvroSortedBucketIO + .write(classOf[String], "name", classOf[Account]) + .to(partition2Output.toString) + .withNumBuckets(1) + .withNumShards(1) + ), ParquetAvroSortedBucketIO .read(new TupleTag[Account], classOf[Account]) - .from(output.getAbsolutePath) + .from(partition1Output.toString, partition2Output.toString), + ParquetAvroSortedBucketIO + .transformOutput(classOf[String], "name", classOf[Account]) + .to(tfxOutput.toString) ) } }