Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow String key type to transform SMB sources with CharSequence key #5297

Merged
merged 5 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,13 @@ public Map<Class<?>, Coder<?>> coderOverrides() {

@Override
int hashPrimaryKeyMetadata() {
return Objects.hash(keyField, getKeyClass());
return Objects.hash(keyField, AvroUtils.castToComparableStringClass(getKeyClass()));
}

@Override
int hashSecondaryKeyMetadata() {
return Objects.hash(keyFieldSecondary, getKeyClassSecondary());
return Objects.hash(
keyFieldSecondary, AvroUtils.castToComparableStringClass(getKeyClassSecondary()));
}

@Override
Expand Down Expand Up @@ -194,19 +195,15 @@ public void populateDisplayData(Builder builder) {

@Override
<OtherKeyType> boolean keyClassMatches(Class<OtherKeyType> requestedReadType) {
if (requestedReadType == String.class && getKeyClass() == CharSequence.class) {
return true;
} else {
return super.keyClassMatches(requestedReadType);
}
return super.keyClassMatches(requestedReadType)
|| AvroUtils.castToComparableStringClass(getKeyClass()) == requestedReadType
|| AvroUtils.castToComparableStringClass(requestedReadType) == getKeyClass();
}

@Override
<OtherKeyType> boolean keyClassSecondaryMatches(Class<OtherKeyType> requestedReadType) {
if (requestedReadType == String.class && getKeyClassSecondary() == CharSequence.class) {
return true;
} else {
return super.keyClassSecondaryMatches(requestedReadType);
}
return super.keyClassSecondaryMatches(requestedReadType)
|| AvroUtils.castToComparableStringClass(getKeyClassSecondary()) == requestedReadType
|| AvroUtils.castToComparableStringClass(requestedReadType) == getKeyClassSecondary();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,14 @@ public static Map<Class<?>, Coder<?>> coderOverrides() {
CharSequence.class, CharSequenceCoder.of());
}

static Class castToComparableStringClass(Class cls) {
if (cls == String.class) {
return CharSequence.class;
} else {
return cls;
}
}

private static class ByteBufferCoder extends AtomicCoder<ByteBuffer> {
private static final ByteBufferCoder INSTANCE = new ByteBufferCoder();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,11 +403,9 @@ private <MetadataT extends BucketMetadata> boolean isIntraPartitionCompatibleWit
final Class<? extends BucketMetadata> otherClass = other.getClass();
final Set<Class<? extends BucketMetadata>> compatibleTypes = compatibleMetadataTypes();

if (compatibleTypes.isEmpty() && other.getClass() != this.getClass()) {
return false;
} else if (this.getKeyClass() != other.getKeyClass()
Copy link
Contributor Author

@clairemcginty clairemcginty Mar 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this check was not correct; it should have been comparing the metadata class itself.

Added an assertion in AvroBucketMetadataTest for this

if ((other.getClass() != this.getClass())
&& !(compatibleTypes.contains(otherClass)
&& (other.compatibleMetadataTypes().contains(this.getClass())))) {
|| other.compatibleMetadataTypes().contains(this.getClass()))) {
return false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,13 @@ public void populateDisplayData(DisplayData.Builder builder) {

@Override
int hashPrimaryKeyMetadata() {
return Objects.hash(keyField, getKeyClass());
return Objects.hash(keyField, AvroUtils.castToComparableStringClass(getKeyClass()));
}

@Override
int hashSecondaryKeyMetadata() {
return Objects.hash(keyFieldSecondary, getKeyClassSecondary());
return Objects.hash(
keyFieldSecondary, AvroUtils.castToComparableStringClass(getKeyClassSecondary()));
}

@Override
Expand Down Expand Up @@ -230,20 +231,16 @@ static <K> K extractKey(Method[] keyGetters, Object value) {

@Override
<OtherKeyType> boolean keyClassMatches(Class<OtherKeyType> requestedReadType) {
if (requestedReadType == String.class && getKeyClass() == CharSequence.class) {
return true;
} else {
return super.keyClassMatches(requestedReadType);
}
return super.keyClassMatches(requestedReadType)
|| AvroUtils.castToComparableStringClass(getKeyClass()) == requestedReadType
|| AvroUtils.castToComparableStringClass(requestedReadType) == getKeyClass();
}

@Override
<OtherKeyType> boolean keyClassSecondaryMatches(Class<OtherKeyType> requestedReadType) {
if (requestedReadType == String.class && getKeyClassSecondary() == CharSequence.class) {
return true;
} else {
return super.keyClassSecondaryMatches(requestedReadType);
}
return super.keyClassSecondaryMatches(requestedReadType)
|| AvroUtils.castToComparableStringClass(getKeyClassSecondary()) == requestedReadType
|| AvroUtils.castToComparableStringClass(requestedReadType) == getKeyClassSecondary();
}

////////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ private Coder<K1> keyTypeCoder() {
Optional<Coder<K1>> c =
inputs.stream()
.flatMap(i -> i.getSourceMetadata().mapping.values().stream())
.filter(sm -> sm.metadata.getKeyClass() == keyClassPrimary)
.filter(sm -> sm.metadata.keyClassMatches(keyClassPrimary))
.findFirst()
.map(sm -> (Coder<K1>) sm.metadata.getKeyCoder());
if (!c.isPresent())
Expand Down Expand Up @@ -436,7 +436,7 @@ private Coder<K1> keyCoderPrimary() {
Optional<Coder<K1>> c =
inputs.stream()
.flatMap(i -> i.getSourceMetadata().mapping.values().stream())
.filter(sm -> sm.metadata.getKeyClass() == keyClassPrimary)
.filter(sm -> sm.metadata.keyClassMatches(keyClassPrimary))
.findFirst()
.map(sm -> (Coder<K1>) sm.metadata.getKeyCoder());
if (!c.isPresent())
Expand All @@ -454,7 +454,7 @@ private Coder<K2> keyCoderSecondary() {
.filter(
sm ->
sm.metadata.getKeyClassSecondary() != null
&& sm.metadata.getKeyClassSecondary() == keyClassSecondary
&& sm.metadata.keyClassSecondaryMatches(keyClassSecondary)
&& sm.metadata.getKeyCoderSecondary() != null)
.findFirst()
.map(sm -> (Coder<K2>) sm.metadata.getKeyCoderSecondary());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,51 @@ public void testSameSourceCompatibility() throws Exception {
Assert.assertFalse(metadata4.isPartitionCompatibleForPrimaryAndSecondaryKey(metadata5));
}

@Test
public void testParquetMetadataCompatibility() throws Exception {
final AvroBucketMetadata<String, Integer, GenericRecord> metadata1 =
new AvroBucketMetadata<>(
2,
1,
String.class,
"name",
Integer.class,
"favorite_number",
HashType.MURMUR3_32,
SortedBucketIO.DEFAULT_FILENAME_PREFIX,
AvroGeneratedUser.SCHEMA$);

final ParquetBucketMetadata<String, Integer, GenericRecord> metadata2 =
new ParquetBucketMetadata<>(
2,
1,
String.class,
"favorite_color",
Integer.class,
"favorite_number",
HashType.MURMUR3_32,
SortedBucketIO.DEFAULT_FILENAME_PREFIX,
AvroGeneratedUser.SCHEMA$);

final ParquetBucketMetadata<String, Integer, GenericRecord> metadata3 =
new ParquetBucketMetadata<>(
4,
1,
String.class,
"favorite_color",
Integer.class,
"favorite_number",
HashType.MURMUR3_32,
SortedBucketIO.DEFAULT_FILENAME_PREFIX,
AvroGeneratedUser.SCHEMA$);

Assert.assertFalse(metadata1.isPartitionCompatibleForPrimaryKey(metadata2));
Assert.assertFalse(metadata1.isPartitionCompatibleForPrimaryAndSecondaryKey(metadata2));

Assert.assertTrue(metadata2.isPartitionCompatibleForPrimaryKey(metadata3));
Assert.assertTrue(metadata2.isPartitionCompatibleForPrimaryAndSecondaryKey(metadata3));
}

@Test
public void testKeyTypeCheckingBytes()
throws CannotProvideCoderException, NonDeterministicException {
Expand Down
166 changes: 143 additions & 23 deletions scio-smb/src/test/scala/com/spotify/scio/smb/SmbVersionParityTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ import org.apache.beam.sdk.values.TupleTag
import java.nio.file.Files

class SmbVersionParityTest extends PipelineSpec {
private def testRoundtrip(
write: SortedBucketIO.Write[CharSequence, _, Account],
private def testReadRoundtrip(
writes: Seq[SortedBucketIO.Write[_ <: CharSequence, _, Account]],
read: SortedBucketIO.Read[Account]
): Unit = {
val accounts = (1 to 10).map { i =>
Expand All @@ -47,9 +47,10 @@ class SmbVersionParityTest extends PipelineSpec {

{
val sc = ScioContext()
sc.parallelize(accounts)
.saveAsSortedBucket(write)
val records = sc.parallelize(accounts)
writes.foreach(records.saveAsSortedBucket(_))
sc.run()
()
}

// Read data
Expand All @@ -61,38 +62,157 @@ class SmbVersionParityTest extends PipelineSpec {
.get(sc.run().waitUntilDone())
.flatMap(_._2)
.value
.toSeq should contain theSameElementsAs accounts
.toSeq should contain theSameElementsAs writes.flatMap(_ => accounts)
}

"SortedBucketSource" should "be able to read CharSequence-keyed Avro sources written before 0.14" in {
val output = Files.createTempDirectory("smb-version-test-avro").toFile
output.deleteOnExit()
private def testTransformRoundtrip(
writes: Seq[SortedBucketIO.Write[_ <: CharSequence, Void, Account]],
read: SortedBucketIO.Read[Account],
transform: SortedBucketIO.TransformOutput[String, Void, Account]
): Unit = {
val accounts = (1 to 10).map { i =>
Account
.newBuilder()
.setId(i)
.setName(i.toString)
.setAmount(i.toDouble)
.setType(s"type$i")
.setAccountStatus(AccountStatus.Active)
.build()
}

{
val sc = ScioContext()
val records = sc.parallelize(accounts)
writes.foreach(records.saveAsSortedBucket(_))
sc.run()
()
}

// Read data
val sc = ScioContext()
val tap = sc
.sortMergeTransform(classOf[String], read)
.to(transform)
.via { case (_, records, outputCollector) =>
records.foreach(outputCollector.accept(_))
}

tap
.get(sc.run().waitUntilDone())
.value
.toSeq should contain theSameElementsAs writes.flatMap(_ => accounts)
}

"SortedBucketSource" should "be able to read mixed CharSequence and String-keyed Avro sources" in {
val tmpDir = Files.createTempDirectory("smb-version-test-mixed-avro-read").toFile
tmpDir.deleteOnExit()

val partition1Output = tmpDir.toPath.resolve("partition1")
val partition2Output = tmpDir.toPath.resolve("partition2")

testRoundtrip(
testReadRoundtrip(
Seq(
AvroSortedBucketIO
.write(classOf[CharSequence], "name", classOf[Account])
.to(partition1Output.toString)
.withNumBuckets(1)
.withNumShards(1),
AvroSortedBucketIO
.write(classOf[String], "name", classOf[Account])
.to(partition2Output.toString)
.withNumBuckets(1)
.withNumShards(1)
),
AvroSortedBucketIO
.write(classOf[CharSequence], "name", classOf[Account])
.to(output.getAbsolutePath)
.withNumBuckets(1)
.withNumShards(1),
.read(new TupleTag[Account], classOf[Account])
.from(partition1Output.toString, partition2Output.toString)
)
}

it should "be able to transform mixed CharSequence- and String-keyed Avro sources written before 0.14" in {
val tmpDir = Files.createTempDirectory("smb-version-test-avro-tfx").toFile
tmpDir.deleteOnExit()

val partition1Output = tmpDir.toPath.resolve("partition1")
val partition2Output = tmpDir.toPath.resolve("partition2")
val tfxOutput = tmpDir.toPath.resolve("tfx")

testTransformRoundtrip(
Seq(
AvroSortedBucketIO
.write(classOf[CharSequence], "name", classOf[Account])
.to(partition1Output.toString)
.withNumBuckets(1)
.withNumShards(1),
AvroSortedBucketIO
.write(classOf[String], "name", classOf[Account])
.to(partition2Output.toString)
.withNumBuckets(1)
.withNumShards(1)
),
AvroSortedBucketIO
.read(new TupleTag[Account], classOf[Account])
.from(output.getAbsolutePath)
.from(partition1Output.toString, partition2Output.toString),
AvroSortedBucketIO
.transformOutput(classOf[String], "name", classOf[Account])
.to(tfxOutput.toString)
)
}

it should "be able to read CharSequence-keyed Parquet sources written before 0.14" in {
val output = Files.createTempDirectory("smb-version-test-parquet").toFile
output.deleteOnExit()
it should "be able to read mixed CharSequence- and String-keyed-keyed Parquet sources written before 0.14" in {
val tmpDir = Files.createTempDirectory("smb-version-test-mixed-parquet-read").toFile
tmpDir.deleteOnExit()

val partition1Output = tmpDir.toPath.resolve("partition1")
val partition2Output = tmpDir.toPath.resolve("partition2")

testRoundtrip(
testReadRoundtrip(
Seq(
ParquetAvroSortedBucketIO
.write(classOf[CharSequence], "name", classOf[Account])
.to(partition1Output.toString)
.withNumBuckets(1)
.withNumShards(1),
ParquetAvroSortedBucketIO
.write(classOf[String], "name", classOf[Account])
.to(partition2Output.toString)
.withNumBuckets(1)
.withNumShards(1)
),
ParquetAvroSortedBucketIO
.write(classOf[CharSequence], "name", classOf[Account])
.to(output.getAbsolutePath)
.withNumBuckets(1)
.withNumShards(1),
.read(new TupleTag[Account], classOf[Account])
.from(partition1Output.toString, partition2Output.toString)
)
}

it should "be able to transform mixed CharSequence- and String-keyed Parquet sources written before 0.14" in {
val tmpDir = Files.createTempDirectory("smb-version-test-parquet-tfx").toFile
tmpDir.deleteOnExit()

val partition1Output = tmpDir.toPath.resolve("partition1")
val partition2Output = tmpDir.toPath.resolve("partition2")
val tfxOutput = tmpDir.toPath.resolve("tfx")

testTransformRoundtrip(
Seq(
ParquetAvroSortedBucketIO
.write(classOf[CharSequence], "name", classOf[Account])
.to(partition1Output.toString)
.withNumBuckets(1)
.withNumShards(1),
ParquetAvroSortedBucketIO
.write(classOf[String], "name", classOf[Account])
.to(partition2Output.toString)
.withNumBuckets(1)
.withNumShards(1)
),
ParquetAvroSortedBucketIO
.read(new TupleTag[Account], classOf[Account])
.from(output.getAbsolutePath)
.from(partition1Output.toString, partition2Output.toString),
ParquetAvroSortedBucketIO
.transformOutput(classOf[String], "name", classOf[Account])
.to(tfxOutput.toString)
)
}
}