Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated SigV4 signing library in Gremlin connection for Neptune connector #1698

Merged
merged 5 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion athena-neptune/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
<artifactId>athena-neptune</artifactId>
<version>2022.47.1</version>
<properties>
<gremlinDriverVersion>3.7.2</gremlinDriverVersion>
<!-- make sure gremlin driver version stays within the Neptune supported range -->
<gremlinDriverVersion>3.6.5</gremlinDriverVersion>
<neptune.sigv4.signer.version>2.4.0</neptune.sigv4.signer.version>
</properties>
<dependencies>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,23 @@

import com.amazonaws.athena.connectors.neptune.propertygraph.NeptuneGremlinConnection;
import com.amazonaws.athena.connectors.neptune.rdf.NeptuneSparqlConnection;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
import com.amazonaws.neptune.auth.NeptuneNettyHttpSigV4Signer;
import com.amazonaws.neptune.auth.NeptuneSigV4SignerException;
import org.apache.tinkerpop.gremlin.driver.Client;
import org.apache.tinkerpop.gremlin.driver.Cluster;
import org.apache.tinkerpop.gremlin.driver.SigV4WebSocketChannelizer;
import org.apache.tinkerpop.gremlin.driver.remote.DriverRemoteConnection;
import org.apache.tinkerpop.gremlin.process.traversal.AnonymousTraversalSource;
import org.apache.tinkerpop.gremlin.process.traversal.dsl.graph.GraphTraversalSource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class NeptuneConnection
{
private static Cluster cluster = null;

private static final Logger logger = LoggerFactory.getLogger(NeptuneConnection.class);

private String neptuneEndpoint;
private String neptunePort;
private boolean enabledIAM;
Expand All @@ -45,7 +51,22 @@ protected NeptuneConnection(String neptuneEndpoint, String neptunePort, boolean
.enableSsl(true);

if (enabledIAM) {
builder = builder.channelizer(SigV4WebSocketChannelizer.class);
logger.info("Connecting with IAM auth to https://" + neptuneEndpoint + ":" + neptunePort + " in " + region);
final AWSCredentialsProvider awsCredentialsProvider = new DefaultAWSCredentialsProviderChain();
builder.handshakeInterceptor(r ->
{
try {
NeptuneNettyHttpSigV4Signer sigV4Signer =
new NeptuneNettyHttpSigV4Signer(region, awsCredentialsProvider);
sigV4Signer.signRequest(r);
}
catch (NeptuneSigV4SignerException e) {
logger.error("SIGV4 exception", e);
throw new RuntimeException("Exception occurred while signing the request", e);
}
return r;
}
);
}

cluster = builder.create();
Expand Down Expand Up @@ -77,7 +98,7 @@ public static NeptuneConnection createConnection(java.util.Map<String, String> c
throw new IllegalArgumentException("Unsupported graphType: " + graphType);
}
}

public String getNeptuneEndpoint()
{
return this.neptuneEndpoint;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,21 @@
package com.amazonaws.athena.connectors.neptune.propertygraph;

import com.amazonaws.athena.connectors.neptune.NeptuneConnection;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
import com.amazonaws.neptune.auth.NeptuneNettyHttpSigV4Signer;
import com.amazonaws.neptune.auth.NeptuneSigV4SignerException;
import org.apache.tinkerpop.gremlin.driver.Client;
import org.apache.tinkerpop.gremlin.driver.Cluster;
import org.apache.tinkerpop.gremlin.driver.SigV4WebSocketChannelizer;
import org.apache.tinkerpop.gremlin.driver.remote.DriverRemoteConnection;
import org.apache.tinkerpop.gremlin.process.traversal.AnonymousTraversalSource;
import org.apache.tinkerpop.gremlin.process.traversal.dsl.graph.GraphTraversalSource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class NeptuneGremlinConnection extends NeptuneConnection
{
private static final Logger logger = LoggerFactory.getLogger(NeptuneGremlinConnection.class);
private static Cluster cluster = null;

public NeptuneGremlinConnection(String neptuneEndpoint, String neptunePort, boolean enabledIAM, String region)
Expand All @@ -40,7 +46,22 @@ public NeptuneGremlinConnection(String neptuneEndpoint, String neptunePort, bool
.enableSsl(true);

if (enabledIAM) {
builder = builder.channelizer(SigV4WebSocketChannelizer.class);
logger.info("Connecting with IAM auth to https://" + neptuneEndpoint + ":" + neptunePort + " in " + region);
final AWSCredentialsProvider awsCredentialsProvider = new DefaultAWSCredentialsProviderChain();
builder.handshakeInterceptor(r ->
{
try {
NeptuneNettyHttpSigV4Signer sigV4Signer =
new NeptuneNettyHttpSigV4Signer(region, awsCredentialsProvider);
sigV4Signer.signRequest(r);
}
catch (NeptuneSigV4SignerException e) {
logger.error("SIGV4 exception", e);
throw new RuntimeException("Exception occurred while signing the request", e);
}
return r;
}
);
}

cluster = builder.create();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.tinkerpop.gremlin.structure.T;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.Date;
Expand All @@ -50,19 +52,22 @@
* This class is a Utility class to create Extractors for each field type as per
* Schema
*/
public final class CustomSchemaRowWriter
public final class CustomSchemaRowWriter
{
private CustomSchemaRowWriter()
private static final Logger logger = LoggerFactory.getLogger(CustomSchemaRowWriter.class);
private CustomSchemaRowWriter()
{
// Empty private constructor
}

public static void writeRowTemplate(RowWriterBuilder rowWriterBuilder, Field field, java.util.Map<String, String> configOptions)
public static void writeRowTemplate(RowWriterBuilder rowWriterBuilder, Field field, java.util.Map<String, String> configOptions)
{
ArrowType arrowType = field.getType();
Types.MinorType minorType = Types.getMinorTypeForArrowType(arrowType);
logger.debug("writeRowTemplate*" + field.getName() + "*" + minorType + "*");
Boolean enableCaseinsensitivematch = (configOptions.get(Constants.SCHEMA_CASE_INSEN) == null) ? true : Boolean.parseBoolean(configOptions.get(Constants.SCHEMA_CASE_INSEN));

try {
switch (minorType) {
case BIT:
rowWriterBuilder.withExtractor(field.getName(),
Expand All @@ -72,19 +77,22 @@ public static void writeRowTemplate(RowWriterBuilder rowWriterBuilder, Field fie
value.isSet = 0;

Object fieldValue = obj.get(fieldName);
if (fieldValue.getClass().equals(Boolean.class)) {
logger.debug("writeRowTemplate BIT*" + field.getName() + "*" + minorType + "*"
+ (fieldValue == null ? "" : fieldValue.getClass()) + "*");

if (fieldValue.getClass().equals(Boolean.class)) {
Boolean booleanValue = Boolean.parseBoolean(fieldValue.toString());
value.value = booleanValue ? 1 : 0;
value.isSet = 1;
}
else {
else if (fieldValue instanceof ArrayList) {
ArrayList<Object> objValues = (ArrayList) obj.get(field.getName());
if (objValues != null && objValues.get(0) != null && !(objValues.get(0).toString().trim().isEmpty())) {
Boolean booleanValue = Boolean.parseBoolean(objValues.get(0).toString());
value.value = booleanValue ? 1 : 0;
value.isSet = 1;
}
}
}
});
break;

Expand All @@ -102,23 +110,29 @@ public static void writeRowTemplate(RowWriterBuilder rowWriterBuilder, Field fie
value.value = fieldValue.toString();
value.isSet = 1;
}
}
}
else {
Object fieldValue = obj.get(fieldName);
logger.debug("writeRowTemplate VARCHAR*" + field.getName() + "*" + minorType + "*"
+ (fieldValue == null ? "" : fieldValue.getClass()) + "*");

if (fieldValue != null) {
if (fieldValue.getClass().equals(String.class)) {
value.value = fieldValue.toString();
value.isSet = 1;
}
else {
else if (fieldValue instanceof ArrayList) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to apply this condition else if (fieldValue instanceof ArrayList) to rest of types? For example DATEMILLI

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I did miss these. I'll add them.

ArrayList<Object> objValues = (ArrayList) fieldValue;
if (objValues != null && objValues.get(0) != null) {
value.value = objValues.get(0).toString();
value.isSet = 1;
}
}
}
else {
value.value = "" + fieldValue;
value.isSet = 1;
}
}
}
});
break;
Expand All @@ -131,11 +145,13 @@ public static void writeRowTemplate(RowWriterBuilder rowWriterBuilder, Field fie
value.isSet = 0;

Object fieldValue = obj.get(fieldName);
if (fieldValue.getClass().equals(Date.class)) {
logger.debug("writeRowTemplate DATEMILLI*" + field.getName() + "*" + minorType + "*"
+ (fieldValue == null ? "" : fieldValue.getClass()) + "*");
if (fieldValue.getClass().equals(Date.class)) {
value.value = ((Date) fieldValue).getTime();
value.isSet = 1;
}
else {
else if (fieldValue instanceof ArrayList) {
ArrayList<Object> objValues = (ArrayList) fieldValue;
if (objValues != null && (objValues.get(0) != null) && !(objValues.get(0).toString().trim().isEmpty())) {
value.value = ((Date) objValues.get(0)).getTime();
Expand All @@ -153,11 +169,13 @@ public static void writeRowTemplate(RowWriterBuilder rowWriterBuilder, Field fie
value.isSet = 0;

Object fieldValue = obj.get(fieldName);
if (fieldValue.getClass().equals(Integer.class)) {
logger.debug("writeRowTemplate INT*" + field.getName() + "*" + minorType + "*"
+ (fieldValue == null ? "" : fieldValue.getClass()) + "*");
if (fieldValue.getClass().equals(Integer.class)) {
value.value = Integer.parseInt(fieldValue.toString());
value.isSet = 1;
}
else {
else if (fieldValue instanceof ArrayList) {
ArrayList<Object> objValues = (ArrayList) fieldValue;
if (objValues != null && objValues.get(0) != null && !(objValues.get(0).toString().trim().isEmpty())) {
value.value = Integer.parseInt(objValues.get(0).toString());
Expand All @@ -175,11 +193,13 @@ public static void writeRowTemplate(RowWriterBuilder rowWriterBuilder, Field fie
value.isSet = 0;

Object fieldValue = obj.get(fieldName);
if (fieldValue.getClass().equals(Long.class)) {
logger.debug("writeRowTemplate BIGINT*" + field.getName() + "*" + minorType + "*"
+ (fieldValue == null ? "" : fieldValue.getClass()) + "*");
if (fieldValue.getClass().equals(Long.class)) {
value.value = Long.parseLong(fieldValue.toString());
value.isSet = 1;
}
else {
else if (fieldValue instanceof ArrayList) {
ArrayList<Object> objValues = (ArrayList) fieldValue;
if (objValues != null && objValues.get(0) != null && !(objValues.get(0).toString().trim().isEmpty())) {
value.value = Long.parseLong(objValues.get(0).toString());
Expand All @@ -197,11 +217,13 @@ public static void writeRowTemplate(RowWriterBuilder rowWriterBuilder, Field fie
value.isSet = 0;

Object fieldValue = obj.get(fieldName);
if (fieldValue.getClass().equals(Float.class)) {
logger.debug("writeRowTemplate FLOAT4*" + field.getName() + "*" + minorType + "*"
+ (fieldValue == null ? "" : fieldValue.getClass()) + "*");
if (fieldValue.getClass().equals(Float.class)) {
value.value = Float.parseFloat(fieldValue.toString());
value.isSet = 1;
}
else {
else if (fieldValue instanceof ArrayList) {
ArrayList<Object> objValues = (ArrayList) fieldValue;
if (objValues != null && objValues.get(0) != null && !(objValues.get(0).toString().trim().isEmpty())) {
value.value = Float.parseFloat(objValues.get(0).toString());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we parse the list of value instead? Why are we parsing only the first element in list

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can only return zero or one value for a column. Custom query may return a single value wrapped inside a list/set, so we needed to access the element. In corner cases multiple values are returned, we are also only returning the first value based on this constraint.

Expand All @@ -218,12 +240,14 @@ public static void writeRowTemplate(RowWriterBuilder rowWriterBuilder, Field fie
Map<String, Object> obj = (Map<String, Object>) contextAsMap(context, enableCaseinsensitivematch);
value.isSet = 0;

Object fieldValue = obj.get(field.getName());
if (fieldValue.getClass().equals(Double.class)) {
Object fieldValue = obj.get(fieldName);
logger.debug("writeRowTemplate FLOAT8*" + field.getName() + "*" + minorType + "*"
+ (fieldValue == null ? "" : fieldValue.getClass()) + "*");
if (fieldValue.getClass().equals(Double.class)) {
value.value = Double.parseDouble(fieldValue.toString());
value.isSet = 1;
}
else {
else if (fieldValue instanceof ArrayList) {
ArrayList<Object> objValues = (ArrayList) fieldValue;
if (objValues != null && objValues.get(0) != null && !(objValues.get(0).toString().trim().isEmpty())) {
value.value = Double.parseDouble(objValues.get(0).toString());
Expand All @@ -234,9 +258,14 @@ public static void writeRowTemplate(RowWriterBuilder rowWriterBuilder, Field fie

break;
}
}
catch (Throwable e) {
logger.error("writeRowTemplate exception for *" + field.getName() + "*" + minorType + "*", e);
throw new RuntimeException(e);
}
}

private static Map<String, Object> contextAsMap(Object context, boolean caseInsensitive)
private static Map<String, Object> contextAsMap(Object context, boolean caseInsensitive)
{
Map<String, Object> contextAsMap = (Map<String, Object>) context;
Object fieldValueID = contextAsMap.get(T.id);
Expand Down