Skip to content

Commit

Permalink
KafkaToBigQueryFlex: Add support for UDFs. (#1857)
Browse files Browse the repository at this point in the history
This also includes a fix for the FailsafeElementCoder in StringMessageToTableRow to allow null keys inside KafkaRecord.
  • Loading branch information
an2x committed Sep 19, 2024
1 parent fe87109 commit 4f45d13
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.google.cloud.teleport.v2.kafka.dlq.BigQueryDeadLetterQueueOptions;
import com.google.cloud.teleport.v2.kafka.options.KafkaReadOptions;
import com.google.cloud.teleport.v2.kafka.options.SchemaRegistryOptions;
import com.google.cloud.teleport.v2.transforms.JavascriptTextTransformer.JavascriptTextTransformerOptions;
import org.apache.beam.runners.dataflow.options.DataflowPipelineOptions;
import org.apache.beam.sdk.options.Default;

Expand All @@ -31,7 +32,8 @@ public interface KafkaToBigQueryFlexOptions
KafkaReadOptions,
BigQueryStorageApiStreamingOptions,
SchemaRegistryOptions,
BigQueryDeadLetterQueueOptions {
BigQueryDeadLetterQueueOptions,
JavascriptTextTransformerOptions {
// This is a duplicate option that already exist in KafkaReadOptions but keeping it here
// so the KafkaTopic appears above the authentication enum on the Templates UI.
@TemplateParameter.KafkaReadTopic(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import com.google.cloud.teleport.v2.utils.MetadataValidator;
import com.google.cloud.teleport.v2.utils.SchemaUtils;
import com.google.cloud.teleport.v2.values.FailsafeElement;
import com.google.common.base.Strings;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
Expand Down Expand Up @@ -459,6 +460,12 @@ public static Pipeline runAvroPipeline(
throw new IllegalArgumentException(
"Schema Registry Connection URL or Avro schema is needed in order to read confluent wire format messages.");
}
if (!Strings.isNullOrEmpty(options.getJavascriptTextTransformGcsPath())
&& !Strings.isNullOrEmpty(options.getJavascriptTextTransformFunctionName())) {
LOG.warn(
"JavaScript UDF parameters are set while using Avro message format. "
+ "UDFs are supported for JSON format only. No UDF transformation will be applied.");
}

PCollection<KafkaRecord<byte[], byte[]>> kafkaRecords;

Expand All @@ -474,8 +481,7 @@ public static Pipeline runAvroPipeline(
.setCoder(
KafkaRecordCoder.of(NullableCoder.of(ByteArrayCoder.of()), ByteArrayCoder.of()));

WriteResult writeResult = null;
writeResult = processKafkaRecords(kafkaRecords, options);
WriteResult writeResult = processKafkaRecords(kafkaRecords, options);
return pipeline;
}

Expand Down Expand Up @@ -514,7 +520,16 @@ public static Pipeline runJsonPipeline(
/*
* Step #2: Transform the Kafka Messages into TableRows
*/
.apply("ConvertMessageToTableRow", new StringMessageToTableRow());
.apply(
"ConvertMessageToTableRow",
StringMessageToTableRow.newBuilder()
.setFileSystemPath(options.getJavascriptTextTransformGcsPath())
.setFunctionName(options.getJavascriptTextTransformFunctionName())
.setReloadIntervalMinutes(
options.getJavascriptTextTransformReloadIntervalMinutes())
.setSuccessTag(TRANSFORM_OUT)
.setFailureTag(TRANSFORM_DEADLETTER_OUT)
.build());
/*
* Step #3: Write the successful records out to BigQuery
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,74 +16,148 @@
package com.google.cloud.teleport.v2.transforms;

import com.google.api.services.bigquery.model.TableRow;
import com.google.auto.value.AutoValue;
import com.google.cloud.teleport.v2.coders.FailsafeElementCoder;
import com.google.cloud.teleport.v2.templates.KafkaToBigQueryFlex;
import com.google.cloud.teleport.v2.transforms.BigQueryConverters.FailsafeJsonToTableRow;
import com.google.cloud.teleport.v2.transforms.JavascriptTextTransformer.FailsafeJavascriptUdf;
import com.google.cloud.teleport.v2.values.FailsafeElement;
import com.google.common.base.Strings;
import javax.annotation.Nullable;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.NullableCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.io.kafka.KafkaRecord;
import org.apache.beam.sdk.io.kafka.KafkaRecordCoder;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.Flatten;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionList;
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.TupleTag;

/**
* The {@link StringMessageToTableRow} class is a {@link PTransform} which transforms incoming Kafka
* Message objects into {@link TableRow} objects for insertion into BigQuery while applying a UDF to
* the input. The executions of the UDF and transformation to {@link TableRow} objects is done in a
* fail-safe way by wrapping the element with it's original payload inside the {@link
* fail-safe way by wrapping the element with its original payload inside the {@link
* FailsafeElement} class. The {@link StringMessageToTableRow} transform will output a {@link
* PCollectionTuple} which contains all output and dead-letter {@link PCollection}.
*
* <p>The {@link PCollectionTuple} output will contain the following {@link PCollection}:
*
* <ul>
* <li>{@link KafkaToBigQuery#TRANSFORM_OUT} - Contains all records successfully converted from
* JSON to {@link TableRow} objects.
* <li>{@link KafkaToBigQuery#TRANSFORM_DEADLETTER_OUT} - Contains all {@link FailsafeElement}
* records which couldn't be converted to table rows.
* <li>{@link #successTag()} - Contains all records successfully converted from JSON to {@link
* TableRow} objects.
* <li>{@link #failureTag()} - Contains all {@link FailsafeElement} records which couldn't be
* converted to table rows.
* </ul>
*/
public class StringMessageToTableRow
@AutoValue
public abstract class StringMessageToTableRow
extends PTransform<PCollection<KafkaRecord<String, String>>, PCollectionTuple> {

public abstract @Nullable String fileSystemPath();

public abstract @Nullable String functionName();

public abstract @Nullable Integer reloadIntervalMinutes();

public abstract TupleTag<TableRow> successTag();

public abstract TupleTag<FailsafeElement<KafkaRecord<String, String>, String>> failureTag();

public static Builder newBuilder() {
return new AutoValue_StringMessageToTableRow.Builder();
}

@AutoValue.Builder
public abstract static class Builder {
public abstract Builder setFileSystemPath(@Nullable String fileSystemPath);

public abstract Builder setFunctionName(@Nullable String functionName);

public abstract Builder setReloadIntervalMinutes(@Nullable Integer value);

public abstract Builder setSuccessTag(TupleTag<TableRow> successTag);

public abstract Builder setFailureTag(
TupleTag<FailsafeElement<KafkaRecord<String, String>, String>> failureTag);

public abstract StringMessageToTableRow build();
}

private static final Coder<String> NULLABLE_STRING_CODER = NullableCoder.of(StringUtf8Coder.of());
private static final Coder<KafkaRecord<String, String>> NULLABLE_KAFKA_RECORD_CODER =
NullableCoder.of(KafkaRecordCoder.of(NULLABLE_STRING_CODER, NULLABLE_STRING_CODER));
private static final Coder<FailsafeElement<KafkaRecord<String, String>, String>> FAILSAFE_CODER =
FailsafeElementCoder.of(NULLABLE_KAFKA_RECORD_CODER, NULLABLE_STRING_CODER);

private static final TupleTag<FailsafeElement<KafkaRecord<String, String>, String>> UDF_OUT =
new TupleTag<FailsafeElement<KafkaRecord<String, String>, String>>() {};

private static final TupleTag<FailsafeElement<KafkaRecord<String, String>, String>>
UDF_DEADLETTER_OUT = new TupleTag<FailsafeElement<KafkaRecord<String, String>, String>>() {};

private static final TupleTag<TableRow> TABLE_ROW_OUT = new TupleTag<TableRow>() {};

/** The tag for the dead-letter output of the json to table row transform. */
private static final TupleTag<FailsafeElement<KafkaRecord<String, String>, String>>
TABLE_ROW_DEADLETTER_OUT =
new TupleTag<FailsafeElement<KafkaRecord<String, String>, String>>() {};

@Override
public PCollectionTuple expand(PCollection<KafkaRecord<String, String>> input) {

PCollectionTuple jsonToTableRowOut =
PCollection<FailsafeElement<KafkaRecord<String, String>, String>> failsafeElements =
input
// Map the incoming messages into FailsafeElements so we can recover from failures
// across multiple transforms.
.apply("MapToRecord", ParDo.of(new StringMessageToFailsafeElementFn()))
.setCoder(
FailsafeElementCoder.of(
NullableCoder.of(
KafkaRecordCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of())),
NullableCoder.of(StringUtf8Coder.of())))
.apply(
"JsonToTableRow",
FailsafeJsonToTableRow.<KafkaRecord<String, String>>newBuilder()
.setSuccessTag(KafkaToBigQueryFlex.TRANSFORM_OUT)
.setFailureTag(KafkaToBigQueryFlex.TRANSFORM_DEADLETTER_OUT)
.build());
.setCoder(FAILSAFE_CODER);

PCollectionTuple udfOut = null;

if (!Strings.isNullOrEmpty(fileSystemPath()) && !Strings.isNullOrEmpty(functionName())) {
// Apply UDF transform only if UDF options are enabled, otherwise skip it completely.

udfOut =
failsafeElements.apply(
"InvokeUDF",
FailsafeJavascriptUdf.<KafkaRecord<String, String>>newBuilder()
.setFileSystemPath(fileSystemPath())
.setFunctionName(functionName())
.setReloadIntervalMinutes(reloadIntervalMinutes())
.setSuccessTag(UDF_OUT)
.setFailureTag(UDF_DEADLETTER_OUT)
.build());
failsafeElements = udfOut.get(UDF_OUT).setCoder(FAILSAFE_CODER);
}

PCollectionTuple tableRowOut =
failsafeElements.apply(
"JsonToTableRow",
FailsafeJsonToTableRow.<KafkaRecord<String, String>>newBuilder()
.setSuccessTag(TABLE_ROW_OUT)
.setFailureTag(TABLE_ROW_DEADLETTER_OUT)
.build());

PCollection<FailsafeElement<KafkaRecord<String, String>, String>> badRecords =
jsonToTableRowOut
.get(KafkaToBigQueryFlex.TRANSFORM_DEADLETTER_OUT)
.setCoder(
FailsafeElementCoder.of(
NullableCoder.of(
KafkaRecordCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of())),
NullableCoder.of(StringUtf8Coder.of())));
tableRowOut.get(TABLE_ROW_DEADLETTER_OUT).setCoder(FAILSAFE_CODER);

if (udfOut != null) {
// If UDF is enabled, combine TableRow transform DLQ output with UDF DLQ output.

PCollection<FailsafeElement<KafkaRecord<String, String>, String>> udfBadRecords =
udfOut.get(UDF_DEADLETTER_OUT).setCoder(FAILSAFE_CODER);

badRecords = PCollectionList.of(badRecords).and(udfBadRecords).apply(Flatten.pCollections());
}

// Re-wrap the PCollections so we can return a single PCollectionTuple
return PCollectionTuple.of(
KafkaToBigQueryFlex.TRANSFORM_OUT,
jsonToTableRowOut.get(KafkaToBigQueryFlex.TRANSFORM_OUT))
.and(KafkaToBigQueryFlex.TRANSFORM_DEADLETTER_OUT, badRecords);
return PCollectionTuple.of(successTag(), tableRowOut.get(TABLE_ROW_OUT))
.and(failureTag(), badRecords);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import com.google.cloud.bigquery.TableId;
import com.google.cloud.bigquery.TableResult;
import com.google.cloud.teleport.metadata.TemplateIntegrationTest;
import io.confluent.kafka.schemaregistry.client.rest.exceptions.RestClientException;
import java.io.IOException;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -89,6 +90,28 @@ public void testKafkaToBigQuery() throws IOException {
.addParameter("kafkaReadAuthenticationMode", "NONE"));
}

@Test
public void testKafkaToBigQueryWithUdfFunction() throws RestClientException, IOException {
String udfFileName = "input/transform.js";
gcsClient.createArtifact(
udfFileName,
"function transform(value) {\n"
+ " const data = JSON.parse(value);\n"
+ " data.name = data.name.toUpperCase();\n"
+ " return JSON.stringify(data);\n"
+ "}");

baseKafkaToBigQuery(
b ->
b.addParameter("messageFormat", "JSON")
.addParameter("writeMode", "SINGLE_TABLE_NAME")
.addParameter("useBigQueryDLQ", "false")
.addParameter("kafkaReadAuthenticationMode", "NONE")
.addParameter("javascriptTextTransformGcsPath", getGcsPath(udfFileName))
.addParameter("javascriptTextTransformFunctionName", "transform"),
s -> s == null ? null : s.toUpperCase());
}

@Test
@TemplateIntegrationTest(value = KafkaToBigQueryFlex.class, template = "Kafka_to_BigQuery_Flex")
public void testKafkaToBigQueryWithExistingDLQ() throws IOException {
Expand Down Expand Up @@ -138,6 +161,13 @@ private Schema getDeadletterSchema() {

public void baseKafkaToBigQuery(Function<LaunchConfig.Builder, LaunchConfig.Builder> paramsAdder)
throws IOException {
baseKafkaToBigQuery(paramsAdder, Function.identity());
}

public void baseKafkaToBigQuery(
Function<LaunchConfig.Builder, LaunchConfig.Builder> paramsAdder,
Function<String, String> namePostProcessor)
throws IOException {
// Arrange
String topicName = kafkaResourceManager.createTopic(testName, 5);

Expand Down Expand Up @@ -198,7 +228,9 @@ public void baseKafkaToBigQuery(Function<LaunchConfig.Builder, LaunchConfig.Buil

assertThatBigQueryRecords(tableRows)
.hasRecordsUnordered(
List.of(Map.of("id", 11, "name", "Dataflow"), Map.of("id", 12, "name", "Pub/Sub")));
List.of(
Map.of("id", 11, "name", namePostProcessor.apply("Dataflow")),
Map.of("id", 12, "name", namePostProcessor.apply("Pub/Sub"))));
assertThatBigQueryRecords(dlqRows).hasRecordsWithStrings(List.of("bad json string"));
}

Expand Down

0 comments on commit 4f45d13

Please sign in to comment.