Skip to content

Commit

Permalink
Use Taps in SortMergeBucketExample
Browse files Browse the repository at this point in the history
  • Loading branch information
clairemcginty committed Jan 8, 2024
1 parent 95e0005 commit eaf197c
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ package com.spotify.scio.examples.extra
import com.spotify.scio.{Args, ContextAndArgs, ScioContext}
import com.spotify.scio.avro._
import com.spotify.scio.coders.Coder
import com.spotify.scio.io.ClosedTap
import com.spotify.scio.values.SCollection
import org.apache.avro.Schema
import org.apache.avro.file.CodecFactory
Expand Down Expand Up @@ -70,10 +71,9 @@ object SortMergeBucketWriteExample {
implicit val coder: Coder[GenericRecord] =
avroGenericRecordCoder(SortMergeBucketExample.UserDataSchema)

def pipeline(cmdLineArgs: Array[String]): ScioContext = {
val (sc, args) = ContextAndArgs(cmdLineArgs)

sc.parallelize(0 until 500)
def pipeline(sc: ScioContext, args: Args): (ClosedTap[GenericRecord], ClosedTap[Account]) = {
val userTap = sc
.parallelize(0 until 500)
.map(i => SortMergeBucketExample.user(i.toString, i % 100))
.saveAsSortedBucket(
AvroSortedBucketIO
Expand All @@ -88,7 +88,8 @@ object SortMergeBucketWriteExample {
)

// #SortMergeBucketExample_sink
sc.parallelize(250 until 750)
val accountTap = sc
.parallelize(250 until 750)
.map { i =>
Account
.newBuilder()
Expand All @@ -111,7 +112,7 @@ object SortMergeBucketWriteExample {
.withNumShards(1)
)
// #SortMergeBucketExample_sink
sc
(userTap, accountTap)
}

def secondaryKeyExample(
Expand All @@ -137,8 +138,10 @@ object SortMergeBucketWriteExample {
}

def main(cmdLineArgs: Array[String]): Unit = {
val sc = pipeline(cmdLineArgs)
val (sc, args) = ContextAndArgs(cmdLineArgs)
pipeline(sc, args)
sc.run().waitUntilDone()
()
}
}

Expand All @@ -152,9 +155,7 @@ object SortMergeBucketJoinExample {
override def toString: String = s"$userId\t$age\t$balance"
}

def pipeline(cmdLineArgs: Array[String]): ScioContext = {
val (sc, args) = ContextAndArgs(cmdLineArgs)

def pipeline(sc: ScioContext, args: Args): ClosedTap[String] = {
val mapFn: ((String, (GenericRecord, Account))) => UserAccountData = {
case (userId, (userData, account)) =>
UserAccountData(userId, userData.get("age").toString.toInt, account.getAmount)
Expand All @@ -176,12 +177,11 @@ object SortMergeBucketJoinExample {
).map(mapFn) // Apply mapping function
.saveAsTextFile(args("output"))
// #SortMergeBucketExample_join

sc
}

def main(cmdLineArgs: Array[String]): Unit = {
val sc = pipeline(cmdLineArgs)
val (sc, args) = ContextAndArgs(cmdLineArgs)
pipeline(sc, args)
sc.run().waitUntilDone()
()
}
Expand All @@ -190,9 +190,7 @@ object SortMergeBucketJoinExample {
object SortMergeBucketTransformExample {
import com.spotify.scio.smb._

def pipeline(cmdLineArgs: Array[String]): ScioContext = {
val (sc, args) = ContextAndArgs(cmdLineArgs)

def pipeline(sc: ScioContext, args: Args): ClosedTap[Account] = {
// #SortMergeBucketExample_transform
val (readLhs, readRhs) = (
AvroSortedBucketIO
Expand Down Expand Up @@ -226,12 +224,9 @@ object SortMergeBucketTransformExample {
}
}
// #SortMergeBucketExample_transform
sc
}

def secondaryReadExample(cmdLineArgs: Array[String]): Unit = {
val (sc, args) = ContextAndArgs(cmdLineArgs)

def secondaryReadExample(sc: ScioContext, args: Args): Unit = {
// #SortMergeBucketExample_secondary_read
sc.sortMergeGroupByKey(
classOf[String], // primary key class
Expand All @@ -246,7 +241,8 @@ object SortMergeBucketTransformExample {
}

def main(cmdLineArgs: Array[String]): Unit = {
val sc = pipeline(cmdLineArgs)
val (sc, args) = ContextAndArgs(cmdLineArgs)
pipeline(sc, args)
sc.run().waitUntilDone()
()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

package com.spotify.scio.examples.extra

import com.spotify.scio.ContextAndArgs

import java.io.File
import java.nio.file.Files
import com.spotify.scio.avro.{Account, AvroIO, GenericRecordTap, SpecificRecordTap}
import com.spotify.scio.io.{TextIO, TextTap}
import org.scalatest.matchers.should.Matchers
import org.scalatest.flatspec.AnyFlatSpec

Expand All @@ -38,59 +38,63 @@ class SortMergeBucketExampleTest extends AnyFlatSpec with Matchers {

"SortMergeBucketExample" should "join user and account data" in withTempFolders {
(userDir, accountDir, joinOutputDir) =>
SortMergeBucketWriteExample.main(
// Write data
val (writeContext, writeArgs) = ContextAndArgs(
Array(
s"--users=$userDir",
s"--accounts=$accountDir"
)
)

GenericRecordTap(
path = userDir.getAbsolutePath,
schema = SortMergeBucketExample.UserDataSchema,
params = AvroIO.ReadParam(".avro")
).value.size shouldBe 500
val (userTap, accountTap) = SortMergeBucketWriteExample.pipeline(writeContext, writeArgs)
val writeResult = writeContext.run().waitUntilDone()

SpecificRecordTap[Account](
path = accountDir.getAbsolutePath,
params = AvroIO.ReadParam(".avro")
).value.size shouldBe 500
userTap.get(writeResult).value.size shouldBe 500
accountTap.get(writeResult).value.size shouldBe 500

SortMergeBucketJoinExample.main(
// Read SMB data
val (readContext, readArgs) = ContextAndArgs(
Array(
s"--users=$userDir",
s"--accounts=$accountDir",
s"--output=$joinOutputDir"
)
)

TextTap(
path = joinOutputDir.getAbsolutePath,
params = TextIO.ReadParam(suffix = ".txt")
).value.size shouldBe 100
val joinTap = SortMergeBucketJoinExample.pipeline(readContext, readArgs)
val readResult = readContext.run().waitUntilDone()

joinTap.get(readResult).value.size shouldBe 100
}

it should "transform user and account data" in withTempFolders {
(userDir, accountDir, joinOutputDir) =>
SortMergeBucketWriteExample.main(
// Write data
val (writeContext, writeArgs) = ContextAndArgs(
Array(
s"--users=$userDir",
s"--accounts=$accountDir"
)
)

SortMergeBucketTransformExample.main(
SortMergeBucketWriteExample.pipeline(writeContext, writeArgs)
writeContext.run().waitUntilDone()

// Transform Data
val (transformContext, transformArgs) = ContextAndArgs(
Array(
s"--users=$userDir",
s"--accounts=$accountDir",
s"--output=$joinOutputDir"
)
)

SpecificRecordTap[Account](
joinOutputDir.getAbsolutePath,
AvroIO.ReadParam(".avro")
).value
val transformTap = SortMergeBucketTransformExample.pipeline(transformContext, transformArgs)
val transformResult = transformContext.run().waitUntilDone()

transformTap
.get(transformResult)
.value
.map(account => (account.getId, account.getType.toString))
.toList should contain theSameElementsAs (0 until 500).map((_, "combinedAmount"))
()
Expand Down

0 comments on commit eaf197c

Please sign in to comment.