Skip to content
This repository was archived by the owner on Mar 13, 2020. It is now read-only.

Commit dc8205f

Browse files
authored
Merge pull request #74 from pageuppeople-opensource/feature/SP-333_assume-role-support
[SP-333] Assume Role Support
2 parents 75b86dc + ba42aed commit dc8205f

File tree

1 file changed

+87
-14
lines changed

1 file changed

+87
-14
lines changed

rdl/data_sources/AWSLambdaDataSource.py

Lines changed: 87 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import json
44
import boto3
55
import time
6+
import datetime
67

78
from rdl.data_sources.ChangeTrackingInfo import ChangeTrackingInfo
89
from rdl.data_sources.SourceTableInfo import SourceTableInfo
@@ -11,15 +12,24 @@
1112

1213

1314
class AWSLambdaDataSource(object):
14-
# 'aws-lambda://tenant=543_dc2;function=123456789012:function:my-function;'
15+
# 'aws-lambda://tenant=543_dc2;function=123456789012:function:my-function;role=arn:aws:iam::123456789012:role/RoleName;'
1516
CONNECTION_STRING_PREFIX = "aws-lambda://"
1617
CONNECTION_STRING_GROUP_SEPARATOR = ";"
1718
CONNECTION_STRING_KEY_VALUE_SEPARATOR = "="
1819

20+
CONNECTION_DATA_ROLE_KEY = "role"
21+
CONNECTION_DATA_FUNCTION_KEY = "function"
22+
CONNECTION_DATA_TENANT_KEY = "tenant"
23+
24+
AWS_SERVICE_LAMBDA = "lambda"
25+
AWS_SERVICE_S3 = "s3"
26+
1927
def __init__(self, connection_string, logger=None):
2028
self.logger = logger or logging.getLogger(__name__)
29+
2130
if not AWSLambdaDataSource.can_handle_connection_string(connection_string):
2231
raise ValueError(connection_string)
32+
2333
self.connection_string = connection_string
2434
self.connection_data = dict(
2535
kv.split(AWSLambdaDataSource.CONNECTION_STRING_KEY_VALUE_SEPARATOR)
@@ -29,8 +39,19 @@ def __init__(self, connection_string, logger=None):
2939
.rstrip(AWSLambdaDataSource.CONNECTION_STRING_GROUP_SEPARATOR)
3040
.split(AWSLambdaDataSource.CONNECTION_STRING_GROUP_SEPARATOR)
3141
)
32-
self.aws_lambda_client = boto3.client("lambda")
33-
self.aws_s3_client = boto3.client("s3")
42+
43+
self.aws_sts_client = boto3.client("sts")
44+
role_credentials = self.__assume_role(
45+
self.connection_data[self.CONNECTION_DATA_ROLE_KEY],
46+
f"dwp_{self.connection_data[self.CONNECTION_DATA_TENANT_KEY]}",
47+
)
48+
49+
self.aws_lambda_client = self.__get_aws_client(
50+
self.AWS_SERVICE_LAMBDA, role_credentials
51+
)
52+
self.aws_s3_client = self.__get_aws_client(
53+
self.AWS_SERVICE_S3, role_credentials
54+
)
3455

3556
@staticmethod
3657
def can_handle_connection_string(connection_string):
@@ -87,7 +108,7 @@ def get_table_data_frame(
87108
def __get_table_info(self, table_config, last_known_sync_version):
88109
pay_load = {
89110
"Command": "GetTableInfo",
90-
"TenantId": int(self.connection_data["tenant"]),
111+
"TenantId": int(self.connection_data[self.CONNECTION_DATA_TENANT_KEY]),
91112
"Table": {"Schema": table_config["schema"], "Name": table_config["name"]},
92113
"CommandPayload": {"LastSyncVersion": last_known_sync_version},
93114
}
@@ -113,7 +134,7 @@ def __get_table_data(
113134
):
114135
pay_load = {
115136
"Command": "GetTableData",
116-
"TenantId": int(self.connection_data["tenant"]),
137+
"TenantId": int(self.connection_data[self.CONNECTION_DATA_TENANT_KEY]),
117138
"Table": {"Schema": table_config["schema"], "Name": table_config["name"]},
118139
"CommandPayload": {
119140
"AuditColumnNameForChangeVersion": Providers.AuditColumnsNames.CHANGE_VERSION,
@@ -125,7 +146,7 @@ def __get_table_data(
125146
{
126147
"Name": col["source_name"],
127148
"DataType": col["destination"]["type"],
128-
"IsPrimaryKey": col["destination"]["primary_key"]
149+
"IsPrimaryKey": col["destination"]["primary_key"],
129150
}
130151
for col in columns_config
131152
],
@@ -148,41 +169,93 @@ def __get_table_data(
148169
def __get_data_frame(self, data: [[]], column_names: []):
149170
return pandas.DataFrame(data=data, columns=column_names)
150171

172+
def __assume_role(self, role_arn, session_name):
173+
self.logger.debug(f"\nAssuming role with ARN: {role_arn}")
174+
175+
assume_role_response = self.aws_sts_client.assume_role(
176+
RoleArn=role_arn, RoleSessionName=session_name
177+
)
178+
179+
role_credentials = assume_role_response["Credentials"]
180+
181+
self.role_session_expiry = role_credentials["Expiration"]
182+
183+
return role_credentials
184+
185+
def __get_aws_client(self, service, credentials):
186+
return boto3.client(
187+
service_name=service,
188+
aws_access_key_id=credentials["AccessKeyId"],
189+
aws_secret_access_key=credentials["SecretAccessKey"],
190+
aws_session_token=credentials["SessionToken"],
191+
)
192+
193+
def __refresh_aws_clients_if_expired(self):
194+
# this is due to AWS returning their expiry date in UTC
195+
current_datetime = datetime.datetime.now(datetime.timezone.utc)
196+
197+
if (
198+
current_datetime > self.role_session_expiry - datetime.timedelta(minutes=5)
199+
and current_datetime < self.role_session_expiry
200+
):
201+
role_credentials = self.__assume_role(
202+
self.connection_data[self.CONNECTION_DATA_ROLE_KEY],
203+
f"dwp_{self.connection_data[self.CONNECTION_DATA_TENANT_KEY]}",
204+
)
205+
206+
self.aws_lambda_client = self.__get_aws_client(
207+
self.AWS_SERVICE_LAMBDA, role_credentials
208+
)
209+
self.aws_s3_client = self.__get_aws_client(
210+
self.AWS_SERVICE_S3, role_credentials
211+
)
212+
151213
def __invoke_lambda(self, pay_load):
152214
max_attempts = Constants.MAX_AWS_LAMBDA_INVOKATION_ATTEMPTS
153215
retry_delay = Constants.AWS_LAMBDA_RETRY_DELAY_SECONDS
154216
response_payload = None
155217

156-
for current_attempt in list(range(1, max_attempts+1, 1)):
218+
for current_attempt in list(range(1, max_attempts + 1, 1)):
219+
220+
self.__refresh_aws_clients_if_expired()
221+
157222
if current_attempt > 1:
158-
self.logger.debug(f"\nDelaying retry for {(current_attempt - 1) ^ retry_delay} seconds")
223+
self.logger.debug(
224+
f"\nDelaying retry for {(current_attempt - 1) ^ retry_delay} seconds"
225+
)
159226
time.sleep((current_attempt - 1) ^ retry_delay)
160227

161-
self.logger.debug(f"\nRequest being sent to Lambda, attempt {current_attempt} of {max_attempts}:")
228+
self.logger.debug(
229+
f"\nRequest being sent to Lambda, attempt {current_attempt} of {max_attempts}:"
230+
)
162231
self.logger.debug(pay_load)
163232

164233
lambda_response = self.aws_lambda_client.invoke(
165-
FunctionName=self.connection_data["function"],
234+
FunctionName=self.connection_data[self.CONNECTION_DATA_FUNCTION_KEY],
166235
InvocationType="RequestResponse",
167236
LogType="None", # |'Tail', Set to Tail to include the execution log in the response
168237
Payload=json.dumps(pay_load).encode(),
169238
)
170239

171240
response_status_code = int(lambda_response["StatusCode"])
172241
response_function_error = lambda_response.get("FunctionError")
173-
self.logger.debug(f"\nResponse received from Lambda, attempt {current_attempt} of {max_attempts}:")
242+
self.logger.debug(
243+
f"\nResponse received from Lambda, attempt {current_attempt} of {max_attempts}:"
244+
)
174245
self.logger.debug(f'Response - StatusCode = "{response_status_code}"')
175246
self.logger.debug(f'Response - FunctionError = "{response_function_error}"')
176247

177248
response_payload = json.loads(lambda_response["Payload"].read())
178249

179250
if response_status_code != 200 or response_function_error:
180251
self.logger.error(
181-
f'Error in response from aws lambda \'{self.connection_data["function"]}\', '
182-
f'attempt {current_attempt} of {max_attempts}'
252+
f"Error in response from aws lambda '{self.connection_data[self.CONNECTION_DATA_FUNCTION_KEY]}', "
253+
f"attempt {current_attempt} of {max_attempts}"
183254
)
184255
self.logger.error(f"Response - Status Code = {response_status_code}")
185-
self.logger.error(f"Response - Error Function = {response_function_error}")
256+
self.logger.error(
257+
f"Response - Error Function = {response_function_error}"
258+
)
186259
self.logger.error(f"Response - Error Details:")
187260
# the below is risky as it may contain actual data if this line is reached in case of success
188261
# however, the same Payload field is used to return actual error details in case of failure

0 commit comments

Comments
 (0)