From cfb9bbcf630165f8f9c0d8754730f954c06a31ac Mon Sep 17 00:00:00 2001 From: Claire McGinty Date: Tue, 12 Mar 2024 11:16:15 -0400 Subject: [PATCH 1/5] Allow String key type to transform SMB sources with CharSequence key --- .../sdk/extensions/smb/SortedBucketIO.java | 4 +- .../scio/smb/SmbVersionParityTest.scala | 88 ++++++++++++++++++- 2 files changed, 87 insertions(+), 5 deletions(-) diff --git a/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/SortedBucketIO.java b/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/SortedBucketIO.java index 4daa45b3d4..1324fd9dd5 100644 --- a/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/SortedBucketIO.java +++ b/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/SortedBucketIO.java @@ -388,7 +388,7 @@ private Coder keyTypeCoder() { Optional> c = inputs.stream() .flatMap(i -> i.getSourceMetadata().mapping.values().stream()) - .filter(sm -> sm.metadata.getKeyClass() == keyClassPrimary) + .filter(sm -> sm.metadata.keyClassMatches(keyClassPrimary)) .findFirst() .map(sm -> (Coder) sm.metadata.getKeyCoder()); if (!c.isPresent()) @@ -454,7 +454,7 @@ private Coder keyCoderSecondary() { .filter( sm -> sm.metadata.getKeyClassSecondary() != null - && sm.metadata.getKeyClassSecondary() == keyClassSecondary + && sm.metadata.keyClassSecondaryMatches(keyClassSecondary) && sm.metadata.getKeyCoderSecondary() != null) .findFirst() .map(sm -> (Coder) sm.metadata.getKeyCoderSecondary()); diff --git a/scio-smb/src/test/scala/com/spotify/scio/smb/SmbVersionParityTest.scala b/scio-smb/src/test/scala/com/spotify/scio/smb/SmbVersionParityTest.scala index 71c46bd0d6..f213c07c26 100644 --- a/scio-smb/src/test/scala/com/spotify/scio/smb/SmbVersionParityTest.scala +++ b/scio-smb/src/test/scala/com/spotify/scio/smb/SmbVersionParityTest.scala @@ -30,7 +30,7 @@ import org.apache.beam.sdk.values.TupleTag import java.nio.file.Files class SmbVersionParityTest extends PipelineSpec { - private def testRoundtrip( + private def testReadRoundtrip( write: SortedBucketIO.Write[CharSequence, _, Account], read: SortedBucketIO.Read[Account] ): Unit = { @@ -64,11 +64,49 @@ class SmbVersionParityTest extends PipelineSpec { .toSeq should contain theSameElementsAs accounts } + private def testTransformRoundtrip( + write: SortedBucketIO.Write[CharSequence, Void, Account], + read: SortedBucketIO.Read[Account], + transform: SortedBucketIO.TransformOutput[String, Void, Account] + ): Unit = { + val accounts = (1 to 10).map { i => + Account + .newBuilder() + .setId(i) + .setName(i.toString) + .setAmount(i.toDouble) + .setType(s"type$i") + .setAccountStatus(AccountStatus.Active) + .build() + } + + { + val sc = ScioContext() + sc.parallelize(accounts) + .saveAsSortedBucket(write) + sc.run() + } + + // Read data + val sc = ScioContext() + val tap = sc + .sortMergeTransform(classOf[String], read) + .to(transform) + .via { case (_, records, outputCollector) => + records.foreach(outputCollector.accept(_)) + } + + tap + .get(sc.run().waitUntilDone()) + .value + .toSeq should contain theSameElementsAs accounts + } + "SortedBucketSource" should "be able to read CharSequence-keyed Avro sources written before 0.14" in { val output = Files.createTempDirectory("smb-version-test-avro").toFile output.deleteOnExit() - testRoundtrip( + testReadRoundtrip( AvroSortedBucketIO .write(classOf[CharSequence], "name", classOf[Account]) .to(output.getAbsolutePath) @@ -80,11 +118,33 @@ class SmbVersionParityTest extends PipelineSpec { ) } + it should "be able to transform CharSequence-keyed Avro sources written before 0.14" in { + val tmpDir = Files.createTempDirectory("smb-version-test-avro-tfx").toFile + tmpDir.deleteOnExit() + + val writeOutput = tmpDir.toPath.resolve("write") + val tfxOutput = tmpDir.toPath.resolve("tfx") + + testTransformRoundtrip( + AvroSortedBucketIO + .write(classOf[CharSequence], "name", classOf[Account]) + .to(writeOutput.toString) + .withNumBuckets(1) + .withNumShards(1), + AvroSortedBucketIO + .read(new TupleTag[Account], classOf[Account]) + .from(writeOutput.toString), + AvroSortedBucketIO + .transformOutput(classOf[String], "name", classOf[Account]) + .to(tfxOutput.toString) + ) + } + it should "be able to read CharSequence-keyed Parquet sources written before 0.14" in { val output = Files.createTempDirectory("smb-version-test-parquet").toFile output.deleteOnExit() - testRoundtrip( + testReadRoundtrip( ParquetAvroSortedBucketIO .write(classOf[CharSequence], "name", classOf[Account]) .to(output.getAbsolutePath) @@ -95,4 +155,26 @@ class SmbVersionParityTest extends PipelineSpec { .from(output.getAbsolutePath) ) } + + it should "be able to transform CharSequence-keyed Parquet sources written before 0.14" in { + val tmpDir = Files.createTempDirectory("smb-version-test-parquet-tfx").toFile + tmpDir.deleteOnExit() + + val writeOutput = tmpDir.toPath.resolve("write") + val tfxOutput = tmpDir.toPath.resolve("tfx") + + testTransformRoundtrip( + ParquetAvroSortedBucketIO + .write(classOf[CharSequence], "name", classOf[Account]) + .to(writeOutput.toString) + .withNumBuckets(1) + .withNumShards(1), + ParquetAvroSortedBucketIO + .read(new TupleTag[Account], classOf[Account]) + .from(writeOutput.toString), + ParquetAvroSortedBucketIO + .transformOutput(classOf[String], "name", classOf[Account]) + .to(tfxOutput.toString) + ) + } } From 67ac03fc39fe7e9c5787e4fba31185eab6b33573 Mon Sep 17 00:00:00 2001 From: Claire McGinty Date: Tue, 12 Mar 2024 12:21:17 -0400 Subject: [PATCH 2/5] Support mixed key partitions --- .../extensions/smb/AvroBucketMetadata.java | 21 ++- .../beam/sdk/extensions/smb/AvroUtils.java | 8 ++ .../sdk/extensions/smb/BucketMetadata.java | 2 +- .../extensions/smb/ParquetBucketMetadata.java | 21 ++- .../sdk/extensions/smb/SortedBucketIO.java | 2 +- .../scio/smb/SmbVersionParityTest.scala | 122 ++++++++++++------ 6 files changed, 108 insertions(+), 68 deletions(-) diff --git a/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/AvroBucketMetadata.java b/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/AvroBucketMetadata.java index 2be69b9cd7..5c20e1ec33 100644 --- a/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/AvroBucketMetadata.java +++ b/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/AvroBucketMetadata.java @@ -134,12 +134,13 @@ public Map, Coder> coderOverrides() { @Override int hashPrimaryKeyMetadata() { - return Objects.hash(keyField, getKeyClass()); + return Objects.hash(keyField, AvroUtils.castToComparableStringClass(getKeyClass())); } @Override int hashSecondaryKeyMetadata() { - return Objects.hash(keyFieldSecondary, getKeyClassSecondary()); + return Objects.hash( + keyFieldSecondary, AvroUtils.castToComparableStringClass(getKeyClassSecondary())); } @Override @@ -194,19 +195,15 @@ public void populateDisplayData(Builder builder) { @Override boolean keyClassMatches(Class requestedReadType) { - if (requestedReadType == String.class && getKeyClass() == CharSequence.class) { - return true; - } else { - return super.keyClassMatches(requestedReadType); - } + return super.keyClassMatches(requestedReadType) + || AvroUtils.castToComparableStringClass(getKeyClass()) == requestedReadType + || AvroUtils.castToComparableStringClass(requestedReadType) == getKeyClass(); } @Override boolean keyClassSecondaryMatches(Class requestedReadType) { - if (requestedReadType == String.class && getKeyClassSecondary() == CharSequence.class) { - return true; - } else { - return super.keyClassSecondaryMatches(requestedReadType); - } + return super.keyClassSecondaryMatches(requestedReadType) + || AvroUtils.castToComparableStringClass(getKeyClassSecondary()) == requestedReadType + || AvroUtils.castToComparableStringClass(requestedReadType) == getKeyClassSecondary(); } } diff --git a/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/AvroUtils.java b/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/AvroUtils.java index eee1fd850e..4bfcda12f4 100644 --- a/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/AvroUtils.java +++ b/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/AvroUtils.java @@ -166,6 +166,14 @@ public static Map, Coder> coderOverrides() { CharSequence.class, CharSequenceCoder.of()); } + static Class castToComparableStringClass(Class cls) { + if (cls == String.class) { + return CharSequence.class; + } else { + return cls; + } + } + private static class ByteBufferCoder extends AtomicCoder { private static final ByteBufferCoder INSTANCE = new ByteBufferCoder(); diff --git a/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/BucketMetadata.java b/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/BucketMetadata.java index 74f619b066..7b31833700 100644 --- a/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/BucketMetadata.java +++ b/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/BucketMetadata.java @@ -405,7 +405,7 @@ private boolean isIntraPartitionCompatibleWit if (compatibleTypes.isEmpty() && other.getClass() != this.getClass()) { return false; - } else if (this.getKeyClass() != other.getKeyClass() + } else if (!this.keyClassMatches(other.getKeyClass()) && !(compatibleTypes.contains(otherClass) && (other.compatibleMetadataTypes().contains(this.getClass())))) { return false; diff --git a/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/ParquetBucketMetadata.java b/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/ParquetBucketMetadata.java index 2016c1258c..b23f495223 100644 --- a/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/ParquetBucketMetadata.java +++ b/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/ParquetBucketMetadata.java @@ -149,12 +149,13 @@ public void populateDisplayData(DisplayData.Builder builder) { @Override int hashPrimaryKeyMetadata() { - return Objects.hash(keyField, getKeyClass()); + return Objects.hash(keyField, AvroUtils.castToComparableStringClass(getKeyClass())); } @Override int hashSecondaryKeyMetadata() { - return Objects.hash(keyFieldSecondary, getKeyClassSecondary()); + return Objects.hash( + keyFieldSecondary, AvroUtils.castToComparableStringClass(getKeyClassSecondary())); } @Override @@ -230,20 +231,16 @@ static K extractKey(Method[] keyGetters, Object value) { @Override boolean keyClassMatches(Class requestedReadType) { - if (requestedReadType == String.class && getKeyClass() == CharSequence.class) { - return true; - } else { - return super.keyClassMatches(requestedReadType); - } + return super.keyClassMatches(requestedReadType) + || AvroUtils.castToComparableStringClass(getKeyClass()) == requestedReadType + || AvroUtils.castToComparableStringClass(requestedReadType) == getKeyClass(); } @Override boolean keyClassSecondaryMatches(Class requestedReadType) { - if (requestedReadType == String.class && getKeyClassSecondary() == CharSequence.class) { - return true; - } else { - return super.keyClassSecondaryMatches(requestedReadType); - } + return super.keyClassSecondaryMatches(requestedReadType) + || AvroUtils.castToComparableStringClass(getKeyClassSecondary()) == requestedReadType + || AvroUtils.castToComparableStringClass(requestedReadType) == getKeyClassSecondary(); } //////////////////////////////////////////////////////////////////////////////// diff --git a/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/SortedBucketIO.java b/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/SortedBucketIO.java index 1324fd9dd5..5ab2e6a0c1 100644 --- a/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/SortedBucketIO.java +++ b/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/SortedBucketIO.java @@ -436,7 +436,7 @@ private Coder keyCoderPrimary() { Optional> c = inputs.stream() .flatMap(i -> i.getSourceMetadata().mapping.values().stream()) - .filter(sm -> sm.metadata.getKeyClass() == keyClassPrimary) + .filter(sm -> sm.metadata.keyClassMatches(keyClassPrimary)) .findFirst() .map(sm -> (Coder) sm.metadata.getKeyCoder()); if (!c.isPresent()) diff --git a/scio-smb/src/test/scala/com/spotify/scio/smb/SmbVersionParityTest.scala b/scio-smb/src/test/scala/com/spotify/scio/smb/SmbVersionParityTest.scala index f213c07c26..f36c21b876 100644 --- a/scio-smb/src/test/scala/com/spotify/scio/smb/SmbVersionParityTest.scala +++ b/scio-smb/src/test/scala/com/spotify/scio/smb/SmbVersionParityTest.scala @@ -31,7 +31,7 @@ import java.nio.file.Files class SmbVersionParityTest extends PipelineSpec { private def testReadRoundtrip( - write: SortedBucketIO.Write[CharSequence, _, Account], + writes: Seq[SortedBucketIO.Write[_ <: CharSequence, _, Account]], read: SortedBucketIO.Read[Account] ): Unit = { val accounts = (1 to 10).map { i => @@ -47,9 +47,10 @@ class SmbVersionParityTest extends PipelineSpec { { val sc = ScioContext() - sc.parallelize(accounts) - .saveAsSortedBucket(write) + val records = sc.parallelize(accounts) + writes.foreach(records.saveAsSortedBucket(_)) sc.run() + () } // Read data @@ -61,11 +62,11 @@ class SmbVersionParityTest extends PipelineSpec { .get(sc.run().waitUntilDone()) .flatMap(_._2) .value - .toSeq should contain theSameElementsAs accounts + .toSeq should contain theSameElementsAs writes.flatMap(_ => accounts) } private def testTransformRoundtrip( - write: SortedBucketIO.Write[CharSequence, Void, Account], + writes: Seq[SortedBucketIO.Write[_ <: CharSequence, Void, Account]], read: SortedBucketIO.Read[Account], transform: SortedBucketIO.TransformOutput[String, Void, Account] ): Unit = { @@ -82,9 +83,10 @@ class SmbVersionParityTest extends PipelineSpec { { val sc = ScioContext() - sc.parallelize(accounts) - .saveAsSortedBucket(write) + val records = sc.parallelize(accounts) + writes.foreach(records.saveAsSortedBucket(_)) sc.run() + () } // Read data @@ -99,79 +101,115 @@ class SmbVersionParityTest extends PipelineSpec { tap .get(sc.run().waitUntilDone()) .value - .toSeq should contain theSameElementsAs accounts + .toSeq should contain theSameElementsAs writes.flatMap(_ => accounts) } - "SortedBucketSource" should "be able to read CharSequence-keyed Avro sources written before 0.14" in { - val output = Files.createTempDirectory("smb-version-test-avro").toFile - output.deleteOnExit() + "SortedBucketSource" should "be able to read mixed CharSequence and String-keyed Avro sources" in { + val tmpDir = Files.createTempDirectory("smb-version-test-mixed-avro-read").toFile + tmpDir.deleteOnExit() + + val partition1Output = tmpDir.toPath.resolve("partition1") + val partition2Output = tmpDir.toPath.resolve("partition2") testReadRoundtrip( - AvroSortedBucketIO - .write(classOf[CharSequence], "name", classOf[Account]) - .to(output.getAbsolutePath) - .withNumBuckets(1) - .withNumShards(1), + Seq( + AvroSortedBucketIO + .write(classOf[CharSequence], "name", classOf[Account]) + .to(partition1Output.toString) + .withNumBuckets(1) + .withNumShards(1), + AvroSortedBucketIO + .write(classOf[String], "name", classOf[Account]) + .to(partition2Output.toString) + .withNumBuckets(1) + .withNumShards(1) + ), AvroSortedBucketIO .read(new TupleTag[Account], classOf[Account]) - .from(output.getAbsolutePath) + .from(partition1Output.toString, partition2Output.toString) ) } - it should "be able to transform CharSequence-keyed Avro sources written before 0.14" in { + it should "be able to transform mixed CharSequence- and String-keyed Avro sources written before 0.14" in { val tmpDir = Files.createTempDirectory("smb-version-test-avro-tfx").toFile tmpDir.deleteOnExit() - val writeOutput = tmpDir.toPath.resolve("write") + val partition1Output = tmpDir.toPath.resolve("partition1") + val partition2Output = tmpDir.toPath.resolve("partition2") val tfxOutput = tmpDir.toPath.resolve("tfx") testTransformRoundtrip( - AvroSortedBucketIO - .write(classOf[CharSequence], "name", classOf[Account]) - .to(writeOutput.toString) - .withNumBuckets(1) - .withNumShards(1), + Seq( + AvroSortedBucketIO + .write(classOf[CharSequence], "name", classOf[Account]) + .to(partition1Output.toString) + .withNumBuckets(1) + .withNumShards(1), + AvroSortedBucketIO + .write(classOf[String], "name", classOf[Account]) + .to(partition2Output.toString) + .withNumBuckets(1) + .withNumShards(1) + ), AvroSortedBucketIO .read(new TupleTag[Account], classOf[Account]) - .from(writeOutput.toString), + .from(partition1Output.toString, partition2Output.toString), AvroSortedBucketIO .transformOutput(classOf[String], "name", classOf[Account]) .to(tfxOutput.toString) ) } - it should "be able to read CharSequence-keyed Parquet sources written before 0.14" in { - val output = Files.createTempDirectory("smb-version-test-parquet").toFile - output.deleteOnExit() + it should "be able to read mixed CharSequence- and String-keyed-keyed Parquet sources written before 0.14" in { + val tmpDir = Files.createTempDirectory("smb-version-test-mixed-parquet-read").toFile + tmpDir.deleteOnExit() + + val partition1Output = tmpDir.toPath.resolve("partition1") + val partition2Output = tmpDir.toPath.resolve("partition2") testReadRoundtrip( - ParquetAvroSortedBucketIO - .write(classOf[CharSequence], "name", classOf[Account]) - .to(output.getAbsolutePath) - .withNumBuckets(1) - .withNumShards(1), + Seq( + ParquetAvroSortedBucketIO + .write(classOf[CharSequence], "name", classOf[Account]) + .to(partition1Output.toString) + .withNumBuckets(1) + .withNumShards(1), + ParquetAvroSortedBucketIO + .write(classOf[String], "name", classOf[Account]) + .to(partition2Output.toString) + .withNumBuckets(1) + .withNumShards(1) + ), ParquetAvroSortedBucketIO .read(new TupleTag[Account], classOf[Account]) - .from(output.getAbsolutePath) + .from(partition1Output.toString, partition2Output.toString) ) } - it should "be able to transform CharSequence-keyed Parquet sources written before 0.14" in { + it should "be able to transform mixed CharSequence- and String-keyed Parquet sources written before 0.14" in { val tmpDir = Files.createTempDirectory("smb-version-test-parquet-tfx").toFile tmpDir.deleteOnExit() - val writeOutput = tmpDir.toPath.resolve("write") + val partition1Output = tmpDir.toPath.resolve("partition1") + val partition2Output = tmpDir.toPath.resolve("partition2") val tfxOutput = tmpDir.toPath.resolve("tfx") testTransformRoundtrip( - ParquetAvroSortedBucketIO - .write(classOf[CharSequence], "name", classOf[Account]) - .to(writeOutput.toString) - .withNumBuckets(1) - .withNumShards(1), + Seq( + ParquetAvroSortedBucketIO + .write(classOf[CharSequence], "name", classOf[Account]) + .to(partition1Output.toString) + .withNumBuckets(1) + .withNumShards(1), + ParquetAvroSortedBucketIO + .write(classOf[String], "name", classOf[Account]) + .to(partition2Output.toString) + .withNumBuckets(1) + .withNumShards(1) + ), ParquetAvroSortedBucketIO .read(new TupleTag[Account], classOf[Account]) - .from(writeOutput.toString), + .from(partition1Output.toString, partition2Output.toString), ParquetAvroSortedBucketIO .transformOutput(classOf[String], "name", classOf[Account]) .to(tfxOutput.toString) From 7b69d9d40aa8d9f847214bc8832ab1483076aac5 Mon Sep 17 00:00:00 2001 From: Claire McGinty Date: Tue, 12 Mar 2024 12:28:26 -0400 Subject: [PATCH 3/5] cleanup assertion --- .../apache/beam/sdk/extensions/smb/BucketMetadata.java | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/BucketMetadata.java b/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/BucketMetadata.java index 7b31833700..446c4c8d13 100644 --- a/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/BucketMetadata.java +++ b/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/BucketMetadata.java @@ -403,11 +403,9 @@ private boolean isIntraPartitionCompatibleWit final Class otherClass = other.getClass(); final Set> compatibleTypes = compatibleMetadataTypes(); - if (compatibleTypes.isEmpty() && other.getClass() != this.getClass()) { - return false; - } else if (!this.keyClassMatches(other.getKeyClass()) - && !(compatibleTypes.contains(otherClass) - && (other.compatibleMetadataTypes().contains(this.getClass())))) { + if ((other.getClass() != this.getClass()) + && (!compatibleTypes.contains(otherClass) + || !other.compatibleMetadataTypes().contains(this.getClass()))) { return false; } From 2bd06ece221c85e7e0b1976cf1fc9f91fc3743b8 Mon Sep 17 00:00:00 2001 From: Claire McGinty Date: Tue, 12 Mar 2024 12:30:04 -0400 Subject: [PATCH 4/5] cleanup assertion --- .../org/apache/beam/sdk/extensions/smb/BucketMetadata.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/BucketMetadata.java b/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/BucketMetadata.java index 446c4c8d13..5c634f94fd 100644 --- a/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/BucketMetadata.java +++ b/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/BucketMetadata.java @@ -404,8 +404,8 @@ private boolean isIntraPartitionCompatibleWit final Set> compatibleTypes = compatibleMetadataTypes(); if ((other.getClass() != this.getClass()) - && (!compatibleTypes.contains(otherClass) - || !other.compatibleMetadataTypes().contains(this.getClass()))) { + && !(compatibleTypes.contains(otherClass) + || other.compatibleMetadataTypes().contains(this.getClass()))) { return false; } From 7b5abad5580def239769a49f7859d4eaaeb5585f Mon Sep 17 00:00:00 2001 From: Claire McGinty Date: Tue, 12 Mar 2024 12:33:40 -0400 Subject: [PATCH 5/5] Add test for Avro/Parquet compat --- .../smb/AvroBucketMetadataTest.java | 45 +++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/scio-smb/src/test/java/org/apache/beam/sdk/extensions/smb/AvroBucketMetadataTest.java b/scio-smb/src/test/java/org/apache/beam/sdk/extensions/smb/AvroBucketMetadataTest.java index 5813537657..2f2492d076 100644 --- a/scio-smb/src/test/java/org/apache/beam/sdk/extensions/smb/AvroBucketMetadataTest.java +++ b/scio-smb/src/test/java/org/apache/beam/sdk/extensions/smb/AvroBucketMetadataTest.java @@ -363,6 +363,51 @@ public void testSameSourceCompatibility() throws Exception { Assert.assertFalse(metadata4.isPartitionCompatibleForPrimaryAndSecondaryKey(metadata5)); } + @Test + public void testParquetMetadataCompatibility() throws Exception { + final AvroBucketMetadata metadata1 = + new AvroBucketMetadata<>( + 2, + 1, + String.class, + "name", + Integer.class, + "favorite_number", + HashType.MURMUR3_32, + SortedBucketIO.DEFAULT_FILENAME_PREFIX, + AvroGeneratedUser.SCHEMA$); + + final ParquetBucketMetadata metadata2 = + new ParquetBucketMetadata<>( + 2, + 1, + String.class, + "favorite_color", + Integer.class, + "favorite_number", + HashType.MURMUR3_32, + SortedBucketIO.DEFAULT_FILENAME_PREFIX, + AvroGeneratedUser.SCHEMA$); + + final ParquetBucketMetadata metadata3 = + new ParquetBucketMetadata<>( + 4, + 1, + String.class, + "favorite_color", + Integer.class, + "favorite_number", + HashType.MURMUR3_32, + SortedBucketIO.DEFAULT_FILENAME_PREFIX, + AvroGeneratedUser.SCHEMA$); + + Assert.assertFalse(metadata1.isPartitionCompatibleForPrimaryKey(metadata2)); + Assert.assertFalse(metadata1.isPartitionCompatibleForPrimaryAndSecondaryKey(metadata2)); + + Assert.assertTrue(metadata2.isPartitionCompatibleForPrimaryKey(metadata3)); + Assert.assertTrue(metadata2.isPartitionCompatibleForPrimaryAndSecondaryKey(metadata3)); + } + @Test public void testKeyTypeCheckingBytes() throws CannotProvideCoderException, NonDeterministicException {