Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement hashCode() and equals() for FileOperations subclasses #5265

Merged
merged 1 commit into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.nio.channels.Channels;
import java.nio.channels.ReadableByteChannel;
import java.util.Map;
import java.util.Objects;
import org.apache.avro.Schema;
import org.apache.avro.file.CodecFactory;
import org.apache.avro.file.DataFileStream;
Expand Down Expand Up @@ -110,6 +111,25 @@ Schema getSchema() {
return schemaSupplier.get();
}

@Override
public int hashCode() {
return Objects.hash(getSchema(), codec.getCodec(), metadata, datumFactory);
}

@SuppressWarnings("unchecked")
@Override
public boolean equals(Object obj) {
if (!(obj instanceof AvroFileOperations)) {
return false;
}
final AvroFileOperations<ValueT> that = (AvroFileOperations<ValueT>) obj;
return that.getSchema().equals(this.getSchema())
&& that.codec.getCodec().toString().equals(this.codec.getCodec().toString())
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Context: CodecFactory has a specialized override of toString that prints the codec type.

&& ((that.metadata == null && this.metadata == null)
|| (this.metadata.equals(that.metadata)))
&& that.datumFactory.equals(this.datumFactory);
}

private static class SerializableSchemaString implements Serializable {
private final String schema;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import java.nio.file.StandardOpenOption;
import java.util.Iterator;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
Expand Down Expand Up @@ -141,6 +142,18 @@ public void populateDisplayData(Builder builder) {
builder.add(DisplayData.item("mimeType", mimeType));
}

@Override
public int hashCode() {
return Objects.hash(getClass().getName(), compression, mimeType);
}

@Override
public boolean equals(Object obj) {
return obj.getClass() == getClass()
&& this.compression == ((FileOperations<?>) obj).compression
&& this.mimeType.equals(((FileOperations<?>) obj).mimeType);
}

/** Per-element file reader. */
public abstract static class Reader<V> implements Serializable {
private transient Supplier<?> cleanupFn = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
import java.io.IOException;
import java.nio.channels.ReadableByteChannel;
import java.nio.channels.WritableByteChannel;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
import org.apache.avro.Schema;
import org.apache.avro.generic.GenericRecord;
import org.apache.avro.generic.IndexedRecord;
Expand Down Expand Up @@ -125,6 +129,37 @@ Schema getSchema() {
return schemaSupplier.get();
}

@Override
public int hashCode() {
return Objects.hash(
getSchema(),
compression,
conf.get(),
projectionSupplier != null ? projectionSupplier.get() : null,
predicate);
}

@SuppressWarnings("unchecked")
@Override
public boolean equals(Object obj) {
if (!(obj instanceof ParquetAvroFileOperations)) {
return false;
}
final ParquetAvroFileOperations<ValueT> that = (ParquetAvroFileOperations<ValueT>) obj;

return that.getSchema().equals(this.getSchema())
&& that.compression.name().equals(this.compression.name())
&& ((that.projectionSupplier == null && this.projectionSupplier == null)
|| (this.projectionSupplier.get().equals(that.projectionSupplier.get())))
&& ((that.predicate == null && this.predicate == null)
|| (that.predicate.equals(this.predicate)))
&& StreamSupport.stream(that.conf.get().spliterator(), false)
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue))
.equals(
StreamSupport.stream(this.conf.get().spliterator(), false)
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)));
}

////////////////////////////////////////
// Reader
////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ import org.apache.parquet.filter2.predicate.FilterPredicate
import org.apache.parquet.hadoop.{ParquetReader, ParquetWriter}
import org.apache.parquet.hadoop.metadata.CompressionCodecName

import scala.jdk.CollectionConverters._

import java.nio.channels.{ReadableByteChannel, WritableByteChannel}
import java.util.Objects

object ParquetTypeFileOperations {

Expand Down Expand Up @@ -91,6 +94,23 @@ case class ParquetTypeFileOperations[T](
override protected def createSink(): FileIO.Sink[T] = ParquetTypeSink(compression, conf)

override def getCoder: BCoder[T] = CoderMaterializer.beamWithDefault(coder)

override def hashCode(): Int = Objects.hash(compression.name(), conf.get(), predicate)

override def equals(obj: Any): Boolean = obj match {
case ParquetTypeFileOperations(compressionThat, confThat, predicateThat) =>
this.compression.name() == compressionThat.name() && this.predicate == predicateThat &&
conf
.get()
.iterator()
.asScala
.map(e => (e.getKey, e.getValue))
.toMap
.equals(
confThat.get().iterator().asScala.map(e => (e.getKey, e.getValue)).toMap
)
case _ => false
}
}

private case class ParquetTypeReader[T](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import org.apache.beam.sdk.transforms.display.DisplayData;
import org.apache.beam.sdk.util.CoderUtils;
import org.apache.beam.sdk.util.MimeTypes;
import org.apache.beam.sdk.util.SerializableUtils;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import org.hamcrest.MatcherAssert;
import org.junit.Assert;
Expand Down Expand Up @@ -72,6 +73,9 @@ public void testGenericRecord() throws Exception {
AvroFileOperations.of(GenericRecordDatumFactory$.INSTANCE, USER_SCHEMA)
.withCodec(CodecFactory.snappyCodec())
.withMetadata(TEST_METADATA);

Assert.assertEquals(fileOperations, SerializableUtils.ensureSerializable(fileOperations));

final ResourceId file =
fromFolder(output).resolve("file.avro", StandardResolveOptions.RESOLVE_FILE);

Expand All @@ -85,6 +89,7 @@ public void testGenericRecord() throws Exception {
.build())
.collect(Collectors.toList());
final FileOperations.Writer<GenericRecord> writer = fileOperations.createWriter(file);

for (GenericRecord record : records) {
writer.write(record);
}
Expand All @@ -105,6 +110,9 @@ public void testSpecificRecord() throws Exception {
AvroFileOperations.of(new SpecificRecordDatumFactory<>(AvroGeneratedUser.class), schema)
.withCodec(CodecFactory.snappyCodec())
.withMetadata(TEST_METADATA);

Assert.assertEquals(fileOperations, SerializableUtils.ensureSerializable(fileOperations));

final ResourceId file =
fromFolder(output).resolve("file.avro", StandardResolveOptions.RESOLVE_FILE);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.apache.beam.sdk.io.fs.ResourceId;
import org.apache.beam.sdk.transforms.display.DisplayData;
import org.apache.beam.sdk.util.MimeTypes;
import org.apache.beam.sdk.util.SerializableUtils;
import org.hamcrest.MatcherAssert;
import org.junit.Assert;
import org.junit.Rule;
Expand All @@ -52,6 +53,8 @@ public void testCompression() throws Exception {

private void test(Compression compression) throws Exception {
final JsonFileOperations fileOperations = JsonFileOperations.of(compression);
Assert.assertEquals(fileOperations, SerializableUtils.ensureSerializable(fileOperations));

final ResourceId file =
fromFolder(output).resolve("file.json", ResolveOptions.StandardResolveOptions.RESOLVE_FILE);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.apache.beam.sdk.io.fs.ResourceId;
import org.apache.beam.sdk.transforms.display.DisplayData;
import org.apache.beam.sdk.util.MimeTypes;
import org.apache.beam.sdk.util.SerializableUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.parquet.filter2.predicate.FilterApi;
import org.apache.parquet.filter2.predicate.FilterPredicate;
Expand Down Expand Up @@ -93,6 +94,8 @@ public void testGenericRecord() throws Exception {
final ParquetAvroFileOperations<GenericRecord> fileOperations =
ParquetAvroFileOperations.of(USER_SCHEMA);

Assert.assertEquals(fileOperations, SerializableUtils.ensureSerializable(fileOperations));

final List<GenericRecord> actual = new ArrayList<>();
fileOperations.iterator(file).forEachRemaining(actual::add);

Expand All @@ -103,6 +106,9 @@ public void testGenericRecord() throws Exception {
public void testSpecificRecord() throws Exception {
final ParquetAvroFileOperations<AvroGeneratedUser> fileOperations =
ParquetAvroFileOperations.of(AvroGeneratedUser.class);

Assert.assertEquals(fileOperations, SerializableUtils.ensureSerializable(fileOperations));

final ResourceId file =
fromFolder(output)
.resolve("file.parquet", ResolveOptions.StandardResolveOptions.RESOLVE_FILE);
Expand Down Expand Up @@ -135,6 +141,7 @@ public void testLogicalTypes() throws Exception {

final ParquetAvroFileOperations<TestLogicalTypes> fileOperations =
ParquetAvroFileOperations.of(TestLogicalTypes.class).withConfiguration(conf);

final ResourceId file =
fromFolder(output)
.resolve("file.parquet", ResolveOptions.StandardResolveOptions.RESOLVE_FILE);
Expand Down Expand Up @@ -179,6 +186,8 @@ public void testGenericProjection() throws Exception {
.withCompression(CompressionCodecName.ZSTD)
.withProjection(projection);

Assert.assertEquals(fileOperations, SerializableUtils.ensureSerializable(fileOperations));

final List<GenericRecord> expected =
USER_RECORDS.stream()
.map(r -> new GenericRecordBuilder(USER_SCHEMA).set("name", r.get("name")).build())
Expand All @@ -201,6 +210,8 @@ public void testSpecificRecordWithProjection() throws Exception {
final ParquetAvroFileOperations<AvroGeneratedUser> fileOperations =
ParquetAvroFileOperations.of(AvroGeneratedUser.class).withProjection(projection);

Assert.assertEquals(fileOperations, SerializableUtils.ensureSerializable(fileOperations));

final ResourceId file =
fromFolder(output)
.resolve("file.parquet", ResolveOptions.StandardResolveOptions.RESOLVE_FILE);
Expand Down Expand Up @@ -290,6 +301,8 @@ public void testPredicate() throws Exception {
final ParquetAvroFileOperations<GenericRecord> fileOperations =
ParquetAvroFileOperations.of(USER_SCHEMA).withFilterPredicate(predicate);

Assert.assertEquals(fileOperations, SerializableUtils.ensureSerializable(fileOperations));

final List<GenericRecord> expected =
USER_RECORDS.stream().filter(r -> (int) r.get("age") <= 5).collect(Collectors.toList());
final List<GenericRecord> actual = new ArrayList<>();
Expand All @@ -314,6 +327,35 @@ public void testDisplayData() {
MatcherAssert.assertThat(displayData, hasDisplayItem("schema", USER_SCHEMA.getFullName()));
}

@Test
public void testConfigurationEquality() {
final Configuration configuration1 = new Configuration();
configuration1.set("foo", "bar");

final ParquetAvroFileOperations<GenericRecord> fileOperations1 =
ParquetAvroFileOperations.of(USER_SCHEMA).withConfiguration(configuration1);

// Copy of configuration with same keys
final Configuration configuration2 = new Configuration();
configuration2.set("foo", "bar");

final ParquetAvroFileOperations<GenericRecord> fileOperations2 =
ParquetAvroFileOperations.of(USER_SCHEMA).withConfiguration(configuration2);

// Assert that configuration equality check fails
Assert.assertEquals(fileOperations1, SerializableUtils.ensureSerializable(fileOperations2));

// Copy of configuration with different keys
final Configuration configuration3 = new Configuration();
configuration3.set("bar", "baz");

final ParquetAvroFileOperations<GenericRecord> fileOperations3 =
ParquetAvroFileOperations.of(USER_SCHEMA).withConfiguration(configuration3);

// Assert that configuration equality check fails
Assert.assertNotEquals(fileOperations1, SerializableUtils.ensureSerializable(fileOperations3));
}

private void writeFile(ResourceId file) throws IOException {
final ParquetAvroFileOperations<GenericRecord> fileOperations =
ParquetAvroFileOperations.of(USER_SCHEMA).withCompression(CompressionCodecName.ZSTD);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.apache.beam.sdk.io.fs.ResourceId;
import org.apache.beam.sdk.transforms.display.DisplayData;
import org.apache.beam.sdk.util.MimeTypes;
import org.apache.beam.sdk.util.SerializableUtils;
import org.hamcrest.MatcherAssert;
import org.junit.Assert;
import org.junit.Rule;
Expand Down Expand Up @@ -58,6 +59,8 @@ public void testCompressed() throws Exception {

private void test(Compression compression) throws Exception {
final TensorFlowFileOperations fileOperations = TensorFlowFileOperations.of(compression);
Assert.assertEquals(fileOperations, SerializableUtils.ensureSerializable(fileOperations));

final ResourceId file =
fromFolder(output)
.resolve("file.tfrecord", ResolveOptions.StandardResolveOptions.RESOLVE_FILE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import com.spotify.scio.CoreSysProps
import org.apache.beam.sdk.io.LocalResources
import org.apache.beam.sdk.io.fs.ResolveOptions.StandardResolveOptions
import org.apache.beam.sdk.io.fs.ResourceId
import org.apache.beam.sdk.util.SerializableUtils
import org.apache.hadoop.conf.Configuration
import org.apache.parquet.filter2.predicate.FilterApi
import org.apache.parquet.hadoop.metadata.CompressionCodecName
import org.scalatest.flatspec.AnyFlatSpec
Expand Down Expand Up @@ -50,6 +52,8 @@ class ParquetTypeFileOperationsTest extends AnyFlatSpec with Matchers {
writeFile(file)

val fileOps = ParquetTypeFileOperations[User]()
SerializableUtils.ensureSerializable(fileOps) should equal(fileOps)

val actual = fileOps.iterator(file).asScala.toSeq

actual shouldBe users
Expand All @@ -64,6 +68,7 @@ class ParquetTypeFileOperationsTest extends AnyFlatSpec with Matchers {
writeFile(file)

val fileOps = ParquetTypeFileOperations[Username]()
SerializableUtils.ensureSerializable(fileOps) should equal(fileOps)
val actual = fileOps.iterator(file).asScala.toSeq

actual shouldBe users.map(u => Username(u.name))
Expand All @@ -79,12 +84,33 @@ class ParquetTypeFileOperationsTest extends AnyFlatSpec with Matchers {

val predicate = FilterApi.ltEq(FilterApi.intColumn("age"), java.lang.Integer.valueOf(5))
val fileOps = ParquetTypeFileOperations[User](predicate)
SerializableUtils.ensureSerializable(fileOps) should equal(fileOps)
val actual = fileOps.iterator(file).asScala.toSeq

actual shouldBe users.filter(_.age <= 5)
tmpDir.delete()
}

it should "compare Configuration values in equals() check" in {
val conf1 = new Configuration()
conf1.set("foo", "bar")
val fileOps1 = ParquetTypeFileOperations[User](CompressionCodecName.UNCOMPRESSED, conf1)

val conf2 = new Configuration()
conf2.set("foo", "bar")
val fileOps2 = ParquetTypeFileOperations[User](CompressionCodecName.UNCOMPRESSED, conf2)

// FileOperations with equal Configuration maps should be equal
SerializableUtils.ensureSerializable(fileOps2) should equal(fileOps1)

val conf3 = new Configuration()
conf3.set("bar", "baz")
val fileOps3 = ParquetTypeFileOperations[User](CompressionCodecName.UNCOMPRESSED, conf3)

// FileOperations with different Configuration maps should not be equal
SerializableUtils.ensureSerializable(fileOps3) shouldNot equal(fileOps1)
}

private def writeFile(file: ResourceId): Unit = {
val fileOps = ParquetTypeFileOperations[User](CompressionCodecName.GZIP)
val writer = fileOps.createWriter(file);
Expand Down