Skip to content

Commit

Permalink
Simplify BucketedInput serialization (#5270)
Browse files Browse the repository at this point in the history
  • Loading branch information
clairemcginty authored Feb 23, 2024
1 parent 5510e07 commit 3323eda
Show file tree
Hide file tree
Showing 10 changed files with 46 additions and 249 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import java.nio.channels.Channels;
import java.nio.channels.ReadableByteChannel;
import java.util.Map;
import java.util.Objects;
import org.apache.avro.Schema;
import org.apache.avro.file.CodecFactory;
import org.apache.avro.file.DataFileStream;
Expand Down Expand Up @@ -111,25 +110,6 @@ Schema getSchema() {
return schemaSupplier.get();
}

@Override
public int hashCode() {
return Objects.hash(getSchema(), codec.getCodec(), metadata, datumFactory);
}

@SuppressWarnings("unchecked")
@Override
public boolean equals(Object obj) {
if (!(obj instanceof AvroFileOperations)) {
return false;
}
final AvroFileOperations<ValueT> that = (AvroFileOperations<ValueT>) obj;
return that.getSchema().equals(this.getSchema())
&& that.codec.getCodec().toString().equals(this.codec.getCodec().toString())
&& ((that.metadata == null && this.metadata == null)
|| (this.metadata.equals(that.metadata)))
&& that.datumFactory.equals(this.datumFactory);
}

private static class SerializableSchemaString implements Serializable {
private final String schema;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import java.nio.file.StandardOpenOption;
import java.util.Iterator;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
Expand Down Expand Up @@ -142,18 +141,6 @@ public void populateDisplayData(Builder builder) {
builder.add(DisplayData.item("mimeType", mimeType));
}

@Override
public int hashCode() {
return Objects.hash(getClass().getName(), compression, mimeType);
}

@Override
public boolean equals(Object obj) {
return obj.getClass() == getClass()
&& this.compression == ((FileOperations<?>) obj).compression
&& this.mimeType.equals(((FileOperations<?>) obj).mimeType);
}

/** Per-element file reader. */
public abstract static class Reader<V> implements Serializable {
private transient Supplier<?> cleanupFn = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,7 @@
import java.io.IOException;
import java.nio.channels.ReadableByteChannel;
import java.nio.channels.WritableByteChannel;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
import org.apache.avro.Schema;
import org.apache.avro.generic.GenericRecord;
import org.apache.avro.generic.IndexedRecord;
Expand Down Expand Up @@ -129,37 +125,6 @@ Schema getSchema() {
return schemaSupplier.get();
}

@Override
public int hashCode() {
return Objects.hash(
getSchema(),
compression,
conf.get(),
projectionSupplier != null ? projectionSupplier.get() : null,
predicate);
}

@SuppressWarnings("unchecked")
@Override
public boolean equals(Object obj) {
if (!(obj instanceof ParquetAvroFileOperations)) {
return false;
}
final ParquetAvroFileOperations<ValueT> that = (ParquetAvroFileOperations<ValueT>) obj;

return that.getSchema().equals(this.getSchema())
&& that.compression.name().equals(this.compression.name())
&& ((that.projectionSupplier == null && this.projectionSupplier == null)
|| (this.projectionSupplier.get().equals(that.projectionSupplier.get())))
&& ((that.predicate == null && this.predicate == null)
|| (that.predicate.equals(this.predicate)))
&& StreamSupport.stream(that.conf.get().spliterator(), false)
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue))
.equals(
StreamSupport.stream(this.conf.get().spliterator(), false)
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)));
}

////////////////////////////////////////
// Reader
////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
import com.google.common.base.Preconditions;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.text.DecimalFormat;
import java.util.ArrayList;
Expand All @@ -45,17 +43,12 @@
import java.util.stream.IntStream;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.MapCoder;
import org.apache.beam.sdk.coders.SerializableCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.coders.VarIntCoder;
import org.apache.beam.sdk.extensions.smb.BucketMetadataUtil.SourceMetadata;
import org.apache.beam.sdk.io.BoundedSource;
import org.apache.beam.sdk.io.FileSystems;
import org.apache.beam.sdk.io.fs.MatchResult.Metadata;
import org.apache.beam.sdk.io.fs.ResolveOptions.StandardResolveOptions;
import org.apache.beam.sdk.io.fs.ResourceId;
import org.apache.beam.sdk.io.fs.ResourceIdCoder;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.display.DisplayData;
Expand Down Expand Up @@ -373,7 +366,7 @@ public PrimaryKeyedBucketedInput(

public SourceMetadata<V> getSourceMetadata() {
if (sourceMetadata == null)
sourceMetadata = BucketMetadataUtil.get().getPrimaryKeyedSourceMetadata(inputs);
sourceMetadata = BucketMetadataUtil.get().getPrimaryKeyedSourceMetadata(getInputs());
return sourceMetadata;
}
}
Expand Down Expand Up @@ -403,7 +396,8 @@ public PrimaryAndSecondaryKeyedBucktedInput(

public SourceMetadata<V> getSourceMetadata() {
if (sourceMetadata == null)
sourceMetadata = BucketMetadataUtil.get().getPrimaryAndSecondaryKeyedSourceMetadata(inputs);
sourceMetadata =
BucketMetadataUtil.get().getPrimaryAndSecondaryKeyedSourceMetadata(getInputs());
return sourceMetadata;
}
}
Expand All @@ -416,22 +410,15 @@ public SourceMetadata<V> getSourceMetadata() {
*/
public abstract static class BucketedInput<V> implements Serializable {
private static final Pattern BUCKET_PATTERN = Pattern.compile("(\\d+)-of-(\\d+)");

protected TupleTag<V> tupleTag;
protected Map<ResourceId, KV<String, FileOperations<V>>> inputs;
protected Predicate<V> predicate;
protected Keying keying;
// lazy, internal checks depend on what kind of iteration is requested
protected transient SourceMetadata<V> sourceMetadata = null; // lazy

// Used to efficiently serialize BucketedInput instances
private static Coder<Map<ResourceId, Integer>> directoriesEncodingCoder =
MapCoder.of(ResourceIdCoder.of(), VarIntCoder.of());

private static Coder<Map<Integer, KV<String, FileOperations>>> fileOperationsEncodingCoder =
MapCoder.of(
VarIntCoder.of(),
KvCoder.of(StringUtf8Coder.of(), SerializableCoder.of(FileOperations.class)));
private transient Map<ResourceId, KV<String, FileOperations<V>>> inputs;
private final Map<Integer, KV<String, FileOperations<V>>> fileOperationsEncoding;
private final Map<ResourceId, Integer> directoriesEncoding;

public static <V> BucketedInput<V> of(
Keying keying,
Expand Down Expand Up @@ -486,6 +473,26 @@ public BucketedInput(
.collect(
Collectors.toMap(
e -> FileSystems.matchNewResource(e.getKey(), true), Map.Entry::getValue));

// Map distinct FileOperations/FileSuffixes to indices in a map, for efficient encoding of
// large BucketedInputs
final Map<KV<String, String>, Integer> fileOperationsMetadata = new HashMap<>();
fileOperationsEncoding = new HashMap<>();
directoriesEncoding = new HashMap<>();

int i = 0;
for (Map.Entry<ResourceId, KV<String, FileOperations<V>>> entry : inputs.entrySet()) {
final KV<String, FileOperations<V>> fileOps = entry.getValue();
final KV<String, String> metadataKey =
KV.of(fileOps.getKey(), fileOps.getValue().getClass().getName());
if (!fileOperationsMetadata.containsKey(metadataKey)) {
fileOperationsMetadata.put(metadataKey, i);
fileOperationsEncoding.put(i, KV.of(fileOps.getKey(), fileOps.getValue()));
i++;
}
directoriesEncoding.put(entry.getKey(), fileOperationsMetadata.get(metadataKey));
}

this.predicate = predicate;
}

Expand All @@ -499,13 +506,28 @@ public Predicate<V> getPredicate() {
return predicate;
}

@SuppressWarnings("unchecked")
public Map<ResourceId, KV<String, FileOperations<V>>> getInputs() {
if (inputs == null) {
this.inputs =
directoriesEncoding.entrySet().stream()
.collect(
Collectors.toMap(
Map.Entry::getKey,
dirAndIndex -> {
final String dir =
fileOperationsEncoding.get(dirAndIndex.getValue()).getKey();
final FileOperations<V> fileOps =
fileOperationsEncoding.get(dirAndIndex.getValue()).getValue();
return KV.of(dir, fileOps);
}));
}
return inputs;
}

public Coder<V> getCoder() {
final KV<String, FileOperations<V>> sampledSource =
inputs.entrySet().iterator().next().getValue();
getInputs().entrySet().iterator().next().getValue();
return sampledSource.getValue().getCoder();
}

Expand All @@ -527,7 +549,7 @@ private static List<Metadata> sampleDirectory(ResourceId directory, String filep
}

long getOrSampleByteSize() {
return inputs
return getInputs()
.entrySet()
.parallelStream()
.mapToLong(
Expand Down Expand Up @@ -604,7 +626,8 @@ public KeyGroupIterator<V> createIterator(
try {
Iterator<KV<SortedBucketIO.ComparableKeyBytes, V>> iterator =
Iterators.transform(
inputs.get(dir).getValue().iterator(file), v -> KV.of(keyFn.apply(v), v));
getInputs().get(dir).getValue().iterator(file),
v -> KV.of(keyFn.apply(v), v));
Iterator<KV<SortedBucketIO.ComparableKeyBytes, V>> out =
(bufferSize > 0) ? new BufferedIterator<>(iterator, bufferSize) : iterator;
iterators.add(out);
Expand All @@ -619,7 +642,7 @@ public KeyGroupIterator<V> createIterator(

@Override
public String toString() {
List<ResourceId> inputDirectories = new ArrayList<>(inputs.keySet());
List<ResourceId> inputDirectories = new ArrayList<>(getInputs().keySet());
return String.format(
"BucketedInput[tupleTag=%s, inputDirectories=[%s]]",
tupleTag.getId(),
Expand All @@ -629,62 +652,6 @@ public String toString() {
+ inputDirectories.get(inputDirectories.size() - 1)
: inputDirectories);
}

// Not all instance members can be natively serialized, so override writeObject/readObject
// using Coders for each type
@SuppressWarnings("unchecked")
private void writeObject(ObjectOutputStream outStream) throws IOException {
SerializableCoder.of(TupleTag.class).encode(tupleTag, outStream);
outStream.writeObject(predicate);
outStream.writeObject(keying);

// Map distinct FileOperations/FileSuffixes to indices in a map, for efficient encoding of
// large BucketedInputs
final Map<KV<String, String>, Integer> fileOperationsMetadata = new HashMap<>();
final Map<Integer, KV<String, FileOperations>> fileOperationsEncoding = new HashMap<>();
final Map<ResourceId, Integer> directoriesEncoding = new HashMap<>();
int i = 0;

for (Map.Entry<ResourceId, KV<String, FileOperations<V>>> entry : inputs.entrySet()) {
final KV<String, FileOperations<V>> fileOps = entry.getValue();
final KV<String, String> metadataKey =
KV.of(fileOps.getKey(), fileOps.getValue().getClass().getName());
if (!fileOperationsMetadata.containsKey(metadataKey)) {
fileOperationsMetadata.put(metadataKey, i);
fileOperationsEncoding.put(i, KV.of(fileOps.getKey(), fileOps.getValue()));
i++;
}
directoriesEncoding.put(entry.getKey(), fileOperationsMetadata.get(metadataKey));
}

fileOperationsEncodingCoder.encode(fileOperationsEncoding, outStream);
directoriesEncodingCoder.encode(directoriesEncoding, outStream);
}

@SuppressWarnings("unchecked")
private void readObject(ObjectInputStream inStream) throws ClassNotFoundException, IOException {
this.tupleTag = SerializableCoder.of(TupleTag.class).decode(inStream);
this.predicate = (Predicate<V>) inStream.readObject();
this.keying = (Keying) inStream.readObject();

final Map<Integer, KV<String, FileOperations>> fileOperationsEncoding =
fileOperationsEncodingCoder.decode(inStream);
final Map<ResourceId, Integer> directoriesEncoding =
directoriesEncodingCoder.decode(inStream);

this.inputs =
directoriesEncoding.entrySet().stream()
.collect(
Collectors.toMap(
Map.Entry::getKey,
dirAndIndex -> {
final String dir =
fileOperationsEncoding.get(dirAndIndex.getValue()).getKey();
final FileOperations<V> fileOps =
fileOperationsEncoding.get(dirAndIndex.getValue()).getValue();
return KV.of(dir, fileOps);
}));
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,7 @@ import org.apache.parquet.filter2.predicate.FilterPredicate
import org.apache.parquet.hadoop.{ParquetReader, ParquetWriter}
import org.apache.parquet.hadoop.metadata.CompressionCodecName

import scala.jdk.CollectionConverters._

import java.nio.channels.{ReadableByteChannel, WritableByteChannel}
import java.util.Objects

object ParquetTypeFileOperations {

Expand Down Expand Up @@ -94,23 +91,6 @@ case class ParquetTypeFileOperations[T](
override protected def createSink(): FileIO.Sink[T] = ParquetTypeSink(compression, conf)

override def getCoder: BCoder[T] = CoderMaterializer.beamWithDefault(coder)

override def hashCode(): Int = Objects.hash(compression.name(), conf.get(), predicate)

override def equals(obj: Any): Boolean = obj match {
case ParquetTypeFileOperations(compressionThat, confThat, predicateThat) =>
this.compression.name() == compressionThat.name() && this.predicate == predicateThat &&
conf
.get()
.iterator()
.asScala
.map(e => (e.getKey, e.getValue))
.toMap
.equals(
confThat.get().iterator().asScala.map(e => (e.getKey, e.getValue)).toMap
)
case _ => false
}
}

private case class ParquetTypeReader[T](
Expand Down
Loading

0 comments on commit 3323eda

Please sign in to comment.