|
| 1 | +import com.amazonaws.services.glue.GlueContext |
| 2 | +import com.amazonaws.services.glue.util.GlueArgParser |
| 3 | +import com.amazonaws.services.glue.util.Job |
| 4 | +import org.apache.spark.SparkContext |
| 5 | +import org.apache.spark.SparkConf |
| 6 | +import org.apache.spark.sql.Dataset |
| 7 | +import org.apache.spark.sql.Row |
| 8 | +import org.apache.spark.sql.SaveMode |
| 9 | +import org.apache.spark.sql.SparkSession |
| 10 | +import org.apache.spark.sql.functions.from_json |
| 11 | +import org.apache.spark.sql.streaming.Trigger |
| 12 | +import scala.collection.JavaConverters._ |
| 13 | +import com.datastax.spark.connector._ |
| 14 | +import org.apache.spark.sql.cassandra._ |
| 15 | +import org.apache.spark.sql.SaveMode._ |
| 16 | +import com.datastax.spark.connector._ |
| 17 | +import com.datastax.spark.connector.cql._ |
| 18 | +import com.datastax.oss.driver.api.core._ |
| 19 | +import org.apache.spark.sql.functions.rand |
| 20 | +import com.amazonaws.services.glue.log.GlueLogger |
| 21 | +import java.time.ZonedDateTime |
| 22 | +import java.time.ZoneOffset |
| 23 | +import java.time.temporal.ChronoUnit |
| 24 | +import java.time.format.DateTimeFormatter |
| 25 | +import org.apache.spark.sql.functions._ |
| 26 | +import org.apache.spark.sql.Row |
| 27 | + |
| 28 | + |
| 29 | +object GlueApp { |
| 30 | + |
| 31 | + //currentTTL is the time left on the record |
| 32 | + //timeToAdd time the delta add or subtract. Use negative number for subtraction. |
| 33 | + def addTimeToExistingTTL(currentTTL: Int, timeToAdd: Int): Int = { |
| 34 | + |
| 35 | + val finalTTLValue = currentTTL + timeToAdd; |
| 36 | + |
| 37 | + // Scenario where the future ttl is less than the remaininng TTL. |
| 38 | + // Moving from 60 to 90 days. |
| 39 | + // TODO: May be more efficient to just delete, than modify/expire |
| 40 | + Math.max(1, finalTTLValue) |
| 41 | + } |
| 42 | + |
| 43 | + //update the row with the new ttl using LWT |
| 44 | + //to update the ttl we must overwrite using the same row values |
| 45 | + //Using LWT to check the value has not changed since reading the row for the current ttl. |
| 46 | + def updateRowWithLWT(row: Row, connector: CassandraConnector): Unit = { |
| 47 | + //open seach creates a session or updates a reference counter on shared session. |
| 48 | + val session = connector.openSession() |
| 49 | + |
| 50 | + val query = |
| 51 | + """UPDATE tlp_stress.keyvalue |
| 52 | + |USING TTL ? |
| 53 | + |SET value = ? |
| 54 | + |WHERE key = ? |
| 55 | + |IF value = ?""".stripMargin |
| 56 | + |
| 57 | + //prepared statmeents are cached by the driver, and not an issue if called multiple times. |
| 58 | + val prepared = session.prepare(query) |
| 59 | + |
| 60 | + val key = row.getAs[String]("key") |
| 61 | + val value = row.getAs[String]("value") |
| 62 | + val expectedValue = row.getAs[String]("value") |
| 63 | + val ttl = row.getAs[Int]("ttlCol") |
| 64 | + |
| 65 | + //bind the values to the prepared statement. |
| 66 | + val bound = prepared.bind( |
| 67 | + java.lang.Integer.valueOf(ttl), |
| 68 | + value, key, expectedValue) |
| 69 | + |
| 70 | + val result = session.execute(bound) |
| 71 | + |
| 72 | + // Optional: check whether LWT succeeded |
| 73 | + if (!result.wasApplied()) { |
| 74 | + println(s"Conditional update failed for id=$key") |
| 75 | + // Here you may want want to: |
| 76 | + //1. read the latest row and ttl |
| 77 | + //2. apply the correct ttl |
| 78 | + //3. use LWT to avoid conflicts |
| 79 | + } |
| 80 | + session.close() |
| 81 | + } |
| 82 | + |
| 83 | + def main(sysArgs: Array[String]) { |
| 84 | + |
| 85 | + val args = GlueArgParser.getResolvedOptions(sysArgs, Seq("JOB_NAME", "KEYSPACE_NAME", "TABLE_NAME", "DRIVER_CONF", "TTL_FIELD", "TTL_TIME_TO_ADD").toArray) |
| 86 | + |
| 87 | + val driverConfFileName = args("DRIVER_CONF") |
| 88 | + |
| 89 | + val conf = new SparkConf() |
| 90 | + .setAll( |
| 91 | + Seq( |
| 92 | + ("spark.task.maxFailures", "100"), |
| 93 | + |
| 94 | + ("spark.cassandra.connection.config.profile.path", driverConfFileName), |
| 95 | + ("spark.sql.extensions", "com.datastax.spark.connector.CassandraSparkExtensions"), |
| 96 | + ("directJoinSetting", "on"), |
| 97 | + |
| 98 | + ("spark.cassandra.output.consistency.level", "LOCAL_QUORUM"),//WRITES |
| 99 | + ("spark.cassandra.input.consistency.level", "LOCAL_ONE"),//READS |
| 100 | + |
| 101 | + ("spark.cassandra.sql.inClauseToJoinConversionThreshold", "0"), |
| 102 | + ("spark.cassandra.sql.inClauseToFullScanConversionThreshold", "0"), |
| 103 | + ("spark.cassandra.concurrent.reads", "50"), |
| 104 | + |
| 105 | + ("spark.cassandra.output.concurrent.writes", "5"), |
| 106 | + ("spark.cassandra.output.batch.grouping.key", "none"), |
| 107 | + ("spark.cassandra.output.batch.size.rows", "1"), |
| 108 | + ("spark.cassandra.output.batch.size.rows", "1"), |
| 109 | + ("spark.cassandra.output.ignoreNulls", "true") |
| 110 | + )) |
| 111 | + |
| 112 | + |
| 113 | + val spark: SparkContext = new SparkContext(conf) |
| 114 | + val glueContext: GlueContext = new GlueContext(spark) |
| 115 | + val sparkSession: SparkSession = glueContext.getSparkSession |
| 116 | + |
| 117 | + import sparkSession.implicits._ |
| 118 | + |
| 119 | + Job.init(args("JOB_NAME"), glueContext, args.asJava) |
| 120 | + |
| 121 | + val logger = new GlueLogger |
| 122 | + |
| 123 | + //validation steps for peers and partitioner |
| 124 | + val connector = CassandraConnector.apply(conf); |
| 125 | + val session = connector.openSession(); |
| 126 | + val peersCount = session.execute("SELECT * FROM system.peers").all().size() |
| 127 | + |
| 128 | + val partitioner = session.execute("SELECT partitioner from system.local").one().getString("partitioner") |
| 129 | + |
| 130 | + logger.info("Total number of seeds:" + peersCount) |
| 131 | + logger.info("Configured partitioner:" + partitioner) |
| 132 | + |
| 133 | + if(peersCount == 0){ |
| 134 | + throw new Exception("No system peers found. Check required permissions to read from the system.peers table. If using VPCE check permissions for describing VPCE endpoints. https://docs.aws.amazon.com/keyspaces/latest/devguide/vpc-endpoints.html") |
| 135 | + } |
| 136 | + |
| 137 | + if(partitioner.equals("com.amazonaws.cassandra.DefaultPartitioner")){ |
| 138 | + throw new Exception("Sark requires the use of RandomPartitioner or Murmur3Partitioner. See Working with partioners in Amazon Keyspaces documentation. https://docs.aws.amazon.com/keyspaces/latest/devguide/working-with-partitioners.html") |
| 139 | + } |
| 140 | + |
| 141 | + val tableName = args("TABLE_NAME") |
| 142 | + val keyspaceName = args("KEYSPACE_NAME") |
| 143 | + val backupS3 = args("S3_URI") |
| 144 | + val backupFormat = args("FORMAT") |
| 145 | + |
| 146 | + val tableDf = sparkSession.read |
| 147 | + .format("org.apache.spark.sql.cassandra") |
| 148 | + .options(Map( "table" -> tableName, |
| 149 | + "keyspace" -> keyspaceName, |
| 150 | + "pushdown" -> "false"))//set to true when executing against Apache Cassandra, false when working with Keyspaces |
| 151 | + .load() |
| 152 | + //.filter("my_column=='somevalue' AND my_othercolumn=='someothervalue'") |
| 153 | + |
| 154 | + // Register the UDF for calculating TTL |
| 155 | + val calculateTTLUDF = udf((currentTTL: Int, timeToAdd: Int) => addTimeToExistingTTL(currentTTL, timeToAdd)) |
| 156 | + |
| 157 | + val timeToAdd = args("TTL_TIME_TO_ADD").toInt |
| 158 | + val ttlField = args("TTL_FIELD") |
| 159 | + // val timeToAdd = 5 * 365 * 24 * 60 * 60 //add 5 years |
| 160 | + //val timeToAdd = -1 * 365 * 24 * 60 * 60 //subtract 1 year |
| 161 | + // Calculate TTL values |
| 162 | + val tableDfWithTTL = tableDf |
| 163 | + .withColumn("ttlCol", calculateTTLUDF(ttl(col(ttlField)), lit(timeToAdd))) |
| 164 | + |
| 165 | + tableDfWithTTL.foreachPartition { partition: Iterator[Row] => |
| 166 | + partition.foreach { row => updateRowWithLWT(row, connector) } |
| 167 | + } |
| 168 | + |
| 169 | + Job.commit() |
| 170 | + } |
| 171 | +} |
0 commit comments