Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Tap for SMB writes (addresses #5080) #5144

Merged
merged 9 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions scio-core/src/main/scala/com/spotify/scio/io/Tap.scala
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,18 @@ trait Tap[T] extends Serializable { self =>
/** Open data set as an [[com.spotify.scio.values.SCollection SCollection]]. */
override def open(sc: ScioContext): SCollection[U] = self.open(sc).map(f)
}

def flatMap[U: Coder](f: T => TraversableOnce[U]): Tap[U] = new Tap[U] {

/** Parent of this Tap before [[flatMap]]. */
override val parent: Option[Tap[_]] = Option(self)

/** Read data set into memory. */
override def value: Iterator[U] = self.value.flatMap(f)

/** Open data set as an [[com.spotify.scio.values.SCollection SCollection]]. */
override def open(sc: ScioContext): SCollection[U] = self.open(sc).flatMap(f)
}
}

case object EmptyTap extends Tap[Nothing] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ public Write<K1, K2, T> withTempDirectory(String tempDirectory) {

@SuppressWarnings("unchecked")
@Override
FileOperations<T> getFileOperations() {
public FileOperations<T> getFileOperations() {
return getRecordClass() == null
? (AvroFileOperations<T>) AvroFileOperations.of(getSchema(), getCodec())
: (AvroFileOperations<T>)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ public Write<K1, K2> withCompression(Compression compression) {
}

@Override
FileOperations<TableRow> getFileOperations() {
public FileOperations<TableRow> getFileOperations() {
return JsonFileOperations.of(getCompression());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ public Write<K1, K2, T> withTempDirectory(String tempDirectory) {
}

@Override
FileOperations<T> getFileOperations() {
public FileOperations<T> getFileOperations() {
return getRecordClass() == null
? ParquetAvroFileOperations.of(getSchema(), getCompression(), getConfiguration())
: ParquetAvroFileOperations.of(getRecordClass(), getCompression(), getConfiguration());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,10 @@ ResourceId getTempDirectoryOrDefault(Pipeline pipeline) {
return FileSystems.matchNewResource(tempLocationOpt, true);
}

public FileOperations<V> getFileOperations() {
return fileOperations;
}

@Override
public WriteResult expand(PBegin input) {
Preconditions.checkNotNull(outputDirectory, "outputDirectory is not set");
Expand Down Expand Up @@ -600,7 +604,7 @@ public abstract static class Write<K1, K2, V> extends PTransform<PCollection<V>,

abstract int getSorterMemoryMb();

abstract FileOperations<V> getFileOperations();
public abstract FileOperations<V> getFileOperations();
Copy link
Contributor Author

@clairemcginty clairemcginty Jan 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could get around the signature change by adding a new method to SortedBucketIO.Write like:

public Iterator<T> tapBucketFile(ResourceId resourceId) {
   return getFileOperations().open(resourceId);
}

but this seems overtly hacky to me - plus, we re-create a FileOperations instance for every bucket/shard combo

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tbh it might be simplest to just make all these getters public 🤷‍♀️


abstract BucketMetadata<K1, K2, V> getBucketMetadata();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ public Write<K1, K2> withCompression(Compression compression) {
}

@Override
FileOperations<Example> getFileOperations() {
public FileOperations<Example> getFileOperations() {
return TensorFlowFileOperations.of(getCompression());
}

Expand Down
26 changes: 25 additions & 1 deletion scio-smb/src/main/scala/com/spotify/scio/smb/SmbIO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,16 @@

package com.spotify.scio.smb

import com.spotify.scio.ScioContext
import com.spotify.scio.coders.Coder
import com.spotify.scio.io.{KeyedIO, TapOf, TapT, TestIO}
import com.spotify.scio.io.{ClosedTap, KeyedIO, Tap, TapOf, TapT, TestIO}
import com.spotify.scio.util.ScioUtil
import org.apache.beam.sdk.extensions.smb.{BucketShardId, FileOperations}
import org.apache.beam.sdk.extensions.smb.SortedBucketSink.WriteResult
import org.apache.beam.sdk.io.fs.ResourceId
import org.apache.beam.sdk.values.{KV, PCollection, TupleTag}

import scala.jdk.CollectionConverters._

final class SmbIO[K, T](path: String, override val keyBy: T => K)(implicit
override val keyCoder: Coder[K]
Expand All @@ -37,4 +44,21 @@ object SmbIO {
val normalizedPaths = paths.map(p => ScioUtil.strippedPath(p) + "/").mkString(",")
s"SortedBucketIO($normalizedPaths)"
}

private[scio] def tap[T: Coder](
fileOperations: FileOperations[T],
writeResult: WriteResult
): ScioContext => Tap[T] =
(sc: ScioContext) => {
val bucketFiles = sc
.wrap(
writeResult
.expand()
.get(new TupleTag("WrittenFiles"))
.asInstanceOf[PCollection[KV[BucketShardId, ResourceId]]]
)
.materialize

bucketFiles.underlying.flatMap(kv => fileOperations.iterator(kv.getValue).asScala)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package com.spotify.scio.smb

import com.spotify.scio.ScioContext
import com.spotify.scio.coders.Coder
import com.spotify.scio.io.{ClosedTap, EmptyTap}
import com.spotify.scio.io.{ClosedTap, EmptyTap, TapOf}
import com.spotify.scio.testing.TestDataManager
import com.spotify.scio.values.{SCollection, SideInput, SideInputContext}
import org.apache.beam.sdk.extensions.smb.{SortedBucketIOUtil, SortedBucketTransform}
Expand Down Expand Up @@ -82,10 +82,10 @@ object SortMergeTransform {
*/
def via(
transformFn: (KeyType, R, SortedBucketTransform.SerializableConsumer[W]) => Unit
): ClosedTap[Nothing]
): ClosedTap[W]
}

private[smb] class WriteBuilderImpl[KeyType, R, W](
private[smb] class WriteBuilderImpl[KeyType, R, W: Coder](
@transient private val sc: ScioContext,
transform: AbsCoGbkTransform[KeyType, W],
fromResult: CoGbkResult => R
Expand All @@ -97,7 +97,7 @@ object SortMergeTransform {

override def via(
transformFn: (KeyType, R, SortedBucketTransform.SerializableConsumer[W]) => Unit
): ClosedTap[Nothing] = {
): ClosedTap[W] = {
val fn = new SortedBucketTransform.TransformFn[KeyType, W]() {
override def writeTransform(
keyGroup: KV[KeyType, CoGbkResult],
Expand All @@ -112,8 +112,9 @@ object SortMergeTransform {

val t = transform.via(fn)
val tfName = sc.tfName(Some("sortMergeTransform"))
sc.applyInternal(tfName, t)
ClosedTap[Nothing](EmptyTap)
val writeResult = sc.applyInternal(tfName, t)

ClosedTap(SmbIO.tap(t.getFileOperations, writeResult).apply(sc))
}
}

Expand All @@ -130,11 +131,10 @@ object SortMergeTransform {

override def via(
transformFn: (KeyType, R, SortedBucketTransform.SerializableConsumer[W]) => Unit
): ClosedTap[Nothing] = {
): ClosedTap[W] = {
val data = read.parDo(new ViaTransform(transformFn))
val testOutput = TestDataManager.getOutput(sc.testId.get)
testOutput(SortedBucketIOUtil.testId(output))(data)
ClosedTap[Nothing](EmptyTap)
TestDataManager.getOutput(sc.testId.get)(SortedBucketIOUtil.testId(output))
ClosedTap(TapOf[W].saveForTest(data))
}
}

Expand All @@ -159,10 +159,10 @@ object SortMergeTransform {
SideInputContext[_],
SortedBucketTransform.SerializableConsumer[W]
) => Unit
): ClosedTap[Nothing]
): ClosedTap[W]
}

private[smb] class WithSideInputsWriteBuilderImpl[KeyType, R, W](
private[smb] class WithSideInputsWriteBuilderImpl[KeyType, R, W: Coder](
@transient private val sc: ScioContext,
transform: AbsCoGbkTransform[KeyType, W],
toR: CoGbkResult => R,
Expand All @@ -175,7 +175,7 @@ object SortMergeTransform {
SideInputContext[_],
SortedBucketTransform.SerializableConsumer[W]
) => Unit
): ClosedTap[Nothing] = {
): ClosedTap[W] = {
val sideViews: java.lang.Iterable[PCollectionView[_]] = sides.map(_.view).asJava

val fn = new SortedBucketTransform.TransformFnWithSideInputContext[KeyType, W]() {
Expand All @@ -197,7 +197,9 @@ object SortMergeTransform {
}
val t = transform.via(fn, sideViews)
sc.applyInternal(t)
ClosedTap[Nothing](EmptyTap)

val writeResult = sc.applyInternal(t)
Comment on lines 199 to +201
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Introduction of duplicated transform

ClosedTap(SmbIO.tap(t.getFileOperations, writeResult).apply(sc))
}
}

Expand All @@ -213,11 +215,10 @@ object SortMergeTransform {
SideInputContext[_],
SortedBucketTransform.SerializableConsumer[W]
) => Unit
): ClosedTap[Nothing] = {
): ClosedTap[W] = {
val data = read.parDo(new ViaTransformWithSideOutput(transformFn))
val testOutput = TestDataManager.getOutput(sc.testId.get)
testOutput(SortedBucketIOUtil.testId(output))(data)
ClosedTap[Nothing](EmptyTap)
TestDataManager.getOutput(sc.testId.get)(SortedBucketIOUtil.testId(output))
ClosedTap(TapOf[W].saveForTest(data))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,16 @@ package com.spotify.scio.smb.syntax

import com.spotify.scio.annotations.experimental
import com.spotify.scio.coders.Coder
import com.spotify.scio.io.{ClosedTap, EmptyTap}
import com.spotify.scio.io.{ClosedTap, EmptyTap, TapOf}
import com.spotify.scio.smb.SmbIO
import com.spotify.scio.testing.TestDataManager
import com.spotify.scio.values._
import org.apache.beam.sdk.coders.KvCoder
import org.apache.beam.sdk.extensions.smb.{SortedBucketIO, SortedBucketIOUtil}
import org.apache.beam.sdk.values.KV
import org.apache.beam.sdk.extensions.smb.{BucketShardId, SortedBucketIO, SortedBucketIOUtil}
import org.apache.beam.sdk.io.fs.ResourceId
import org.apache.beam.sdk.values.{KV, PCollection, TupleTag}

import scala.jdk.CollectionConverters._

trait SortMergeBucketSCollectionSyntax {
implicit def toSortMergeBucketKeyedSCollection[K, V](
Expand All @@ -49,16 +53,20 @@ final class SortedBucketSCollection[T](private val self: SCollection[T]) {
* contains information about key function, bucket and shard size, etc.
*/
@experimental
def saveAsSortedBucket(write: SortedBucketIO.Write[_, _, T]): ClosedTap[Nothing] = {
def saveAsSortedBucket(
write: SortedBucketIO.Write[_, _, T]
): ClosedTap[T] = {
val beamValueCoder = self.internal.getCoder
implicit val valueCoder: Coder[T] = Coder.beam(beamValueCoder)
clairemcginty marked this conversation as resolved.
Show resolved Hide resolved

if (self.context.isTest) {
val testOutput = TestDataManager.getOutput(self.context.testId.get)
testOutput[T](SortedBucketIOUtil.testId(write))(self)
TestDataManager.getOutput(self.context.testId.get)(SortedBucketIOUtil.testId(write))
ClosedTap(TapOf[T].saveForTest(self))
} else {
self.applyInternal(write)
}
val writeResult = self.applyInternal(write)

// @Todo: Implement taps for metadata/bucket elements
ClosedTap[Nothing](EmptyTap)
ClosedTap(SmbIO.tap(write.getFileOperations, writeResult).apply(self.context))
RustedBones marked this conversation as resolved.
Show resolved Hide resolved
}
}
}

Expand All @@ -82,17 +90,18 @@ final class SortedBucketPairSCollection[K, V](private val self: SCollection[KV[K
def saveAsPreKeyedSortedBucket(
write: SortedBucketIO.Write[K, Void, V],
verifyKeyExtraction: Boolean = true
): ClosedTap[Nothing] = {
): ClosedTap[V] = {
val beamValueCoder = self.internal.getCoder.asInstanceOf[KvCoder[K, V]].getValueCoder
implicit val valueCoder: Coder[V] = Coder.beam(beamValueCoder)

if (self.context.isTest) {
implicit val valueCoder: Coder[V] = Coder.beam(beamValueCoder)
val testOutput = TestDataManager.getOutput(self.context.testId.get)
testOutput(SortedBucketIOUtil.testId(write))(self.map(_.getValue))
TestDataManager.getOutput(self.context.testId.get)(SortedBucketIOUtil.testId(write))
ClosedTap(TapOf[V].saveForTest(self.map(_.getValue)))
} else {
self.applyInternal(write.onKeyedCollection(beamValueCoder, verifyKeyExtraction))
}
val writeResult =
self.applyInternal(write.onKeyedCollection(beamValueCoder, verifyKeyExtraction))

// @Todo: Implement taps for metadata/bucket elements
ClosedTap[Nothing](EmptyTap)
ClosedTap(SmbIO.tap(write.getFileOperations, writeResult).apply(self.context))
RustedBones marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
Loading
Loading