3
3
import json
4
4
import boto3
5
5
import time
6
+ import datetime
6
7
7
8
from rdl .data_sources .ChangeTrackingInfo import ChangeTrackingInfo
8
9
from rdl .data_sources .SourceTableInfo import SourceTableInfo
11
12
12
13
13
14
class AWSLambdaDataSource (object ):
14
- # 'aws-lambda://tenant=543_dc2;function=123456789012:function:my-function;'
15
+ # 'aws-lambda://tenant=543_dc2;function=123456789012:function:my-function;role=arn:aws:iam::123456789012:role/RoleName; '
15
16
CONNECTION_STRING_PREFIX = "aws-lambda://"
16
17
CONNECTION_STRING_GROUP_SEPARATOR = ";"
17
18
CONNECTION_STRING_KEY_VALUE_SEPARATOR = "="
18
19
20
+ CONNECTION_DATA_ROLE_KEY = "role"
21
+ CONNECTION_DATA_FUNCTION_KEY = "function"
22
+ CONNECTION_DATA_TENANT_KEY = "tenant"
23
+
24
+ AWS_SERVICE_LAMBDA = "lambda"
25
+ AWS_SERVICE_S3 = "s3"
26
+
19
27
def __init__ (self , connection_string , logger = None ):
20
28
self .logger = logger or logging .getLogger (__name__ )
29
+
21
30
if not AWSLambdaDataSource .can_handle_connection_string (connection_string ):
22
31
raise ValueError (connection_string )
32
+
23
33
self .connection_string = connection_string
24
34
self .connection_data = dict (
25
35
kv .split (AWSLambdaDataSource .CONNECTION_STRING_KEY_VALUE_SEPARATOR )
@@ -29,8 +39,19 @@ def __init__(self, connection_string, logger=None):
29
39
.rstrip (AWSLambdaDataSource .CONNECTION_STRING_GROUP_SEPARATOR )
30
40
.split (AWSLambdaDataSource .CONNECTION_STRING_GROUP_SEPARATOR )
31
41
)
32
- self .aws_lambda_client = boto3 .client ("lambda" )
33
- self .aws_s3_client = boto3 .client ("s3" )
42
+
43
+ self .aws_sts_client = boto3 .client ("sts" )
44
+ role_credentials = self .__assume_role (
45
+ self .connection_data [self .CONNECTION_DATA_ROLE_KEY ],
46
+ f"dwp_{ self .connection_data [self .CONNECTION_DATA_TENANT_KEY ]} " ,
47
+ )
48
+
49
+ self .aws_lambda_client = self .__get_aws_client (
50
+ self .AWS_SERVICE_LAMBDA , role_credentials
51
+ )
52
+ self .aws_s3_client = self .__get_aws_client (
53
+ self .AWS_SERVICE_S3 , role_credentials
54
+ )
34
55
35
56
@staticmethod
36
57
def can_handle_connection_string (connection_string ):
@@ -87,7 +108,7 @@ def get_table_data_frame(
87
108
def __get_table_info (self , table_config , last_known_sync_version ):
88
109
pay_load = {
89
110
"Command" : "GetTableInfo" ,
90
- "TenantId" : int (self .connection_data ["tenant" ]),
111
+ "TenantId" : int (self .connection_data [self . CONNECTION_DATA_TENANT_KEY ]),
91
112
"Table" : {"Schema" : table_config ["schema" ], "Name" : table_config ["name" ]},
92
113
"CommandPayload" : {"LastSyncVersion" : last_known_sync_version },
93
114
}
@@ -113,7 +134,7 @@ def __get_table_data(
113
134
):
114
135
pay_load = {
115
136
"Command" : "GetTableData" ,
116
- "TenantId" : int (self .connection_data ["tenant" ]),
137
+ "TenantId" : int (self .connection_data [self . CONNECTION_DATA_TENANT_KEY ]),
117
138
"Table" : {"Schema" : table_config ["schema" ], "Name" : table_config ["name" ]},
118
139
"CommandPayload" : {
119
140
"AuditColumnNameForChangeVersion" : Providers .AuditColumnsNames .CHANGE_VERSION ,
@@ -125,7 +146,7 @@ def __get_table_data(
125
146
{
126
147
"Name" : col ["source_name" ],
127
148
"DataType" : col ["destination" ]["type" ],
128
- "IsPrimaryKey" : col ["destination" ]["primary_key" ]
149
+ "IsPrimaryKey" : col ["destination" ]["primary_key" ],
129
150
}
130
151
for col in columns_config
131
152
],
@@ -148,41 +169,93 @@ def __get_table_data(
148
169
def __get_data_frame (self , data : [[]], column_names : []):
149
170
return pandas .DataFrame (data = data , columns = column_names )
150
171
172
+ def __assume_role (self , role_arn , session_name ):
173
+ self .logger .debug (f"\n Assuming role with ARN: { role_arn } " )
174
+
175
+ assume_role_response = self .aws_sts_client .assume_role (
176
+ RoleArn = role_arn , RoleSessionName = session_name
177
+ )
178
+
179
+ role_credentials = assume_role_response ["Credentials" ]
180
+
181
+ self .role_session_expiry = role_credentials ["Expiration" ]
182
+
183
+ return role_credentials
184
+
185
+ def __get_aws_client (self , service , credentials ):
186
+ return boto3 .client (
187
+ service_name = service ,
188
+ aws_access_key_id = credentials ["AccessKeyId" ],
189
+ aws_secret_access_key = credentials ["SecretAccessKey" ],
190
+ aws_session_token = credentials ["SessionToken" ],
191
+ )
192
+
193
+ def __refresh_aws_clients_if_expired (self ):
194
+ # this is due to AWS returning their expiry date in UTC
195
+ current_datetime = datetime .datetime .now (datetime .timezone .utc )
196
+
197
+ if (
198
+ current_datetime > self .role_session_expiry - datetime .timedelta (minutes = 5 )
199
+ and current_datetime < self .role_session_expiry
200
+ ):
201
+ role_credentials = self .__assume_role (
202
+ self .connection_data [self .CONNECTION_DATA_ROLE_KEY ],
203
+ f"dwp_{ self .connection_data [self .CONNECTION_DATA_TENANT_KEY ]} " ,
204
+ )
205
+
206
+ self .aws_lambda_client = self .__get_aws_client (
207
+ self .AWS_SERVICE_LAMBDA , role_credentials
208
+ )
209
+ self .aws_s3_client = self .__get_aws_client (
210
+ self .AWS_SERVICE_S3 , role_credentials
211
+ )
212
+
151
213
def __invoke_lambda (self , pay_load ):
152
214
max_attempts = Constants .MAX_AWS_LAMBDA_INVOKATION_ATTEMPTS
153
215
retry_delay = Constants .AWS_LAMBDA_RETRY_DELAY_SECONDS
154
216
response_payload = None
155
217
156
- for current_attempt in list (range (1 , max_attempts + 1 , 1 )):
218
+ for current_attempt in list (range (1 , max_attempts + 1 , 1 )):
219
+
220
+ self .__refresh_aws_clients_if_expired ()
221
+
157
222
if current_attempt > 1 :
158
- self .logger .debug (f"\n Delaying retry for { (current_attempt - 1 ) ^ retry_delay } seconds" )
223
+ self .logger .debug (
224
+ f"\n Delaying retry for { (current_attempt - 1 ) ^ retry_delay } seconds"
225
+ )
159
226
time .sleep ((current_attempt - 1 ) ^ retry_delay )
160
227
161
- self .logger .debug (f"\n Request being sent to Lambda, attempt { current_attempt } of { max_attempts } :" )
228
+ self .logger .debug (
229
+ f"\n Request being sent to Lambda, attempt { current_attempt } of { max_attempts } :"
230
+ )
162
231
self .logger .debug (pay_load )
163
232
164
233
lambda_response = self .aws_lambda_client .invoke (
165
- FunctionName = self .connection_data ["function" ],
234
+ FunctionName = self .connection_data [self . CONNECTION_DATA_FUNCTION_KEY ],
166
235
InvocationType = "RequestResponse" ,
167
236
LogType = "None" , # |'Tail', Set to Tail to include the execution log in the response
168
237
Payload = json .dumps (pay_load ).encode (),
169
238
)
170
239
171
240
response_status_code = int (lambda_response ["StatusCode" ])
172
241
response_function_error = lambda_response .get ("FunctionError" )
173
- self .logger .debug (f"\n Response received from Lambda, attempt { current_attempt } of { max_attempts } :" )
242
+ self .logger .debug (
243
+ f"\n Response received from Lambda, attempt { current_attempt } of { max_attempts } :"
244
+ )
174
245
self .logger .debug (f'Response - StatusCode = "{ response_status_code } "' )
175
246
self .logger .debug (f'Response - FunctionError = "{ response_function_error } "' )
176
247
177
248
response_payload = json .loads (lambda_response ["Payload" ].read ())
178
249
179
250
if response_status_code != 200 or response_function_error :
180
251
self .logger .error (
181
- f' Error in response from aws lambda \ '{ self .connection_data ["function" ] } \ ' , '
182
- f' attempt { current_attempt } of { max_attempts } '
252
+ f" Error in response from aws lambda '{ self .connection_data [self . CONNECTION_DATA_FUNCTION_KEY ] } ', "
253
+ f" attempt { current_attempt } of { max_attempts } "
183
254
)
184
255
self .logger .error (f"Response - Status Code = { response_status_code } " )
185
- self .logger .error (f"Response - Error Function = { response_function_error } " )
256
+ self .logger .error (
257
+ f"Response - Error Function = { response_function_error } "
258
+ )
186
259
self .logger .error (f"Response - Error Details:" )
187
260
# the below is risky as it may contain actual data if this line is reached in case of success
188
261
# however, the same Payload field is used to return actual error details in case of failure
0 commit comments