From bc432f95844c986cba2713dc3af08c51f1a4900b Mon Sep 17 00:00:00 2001 From: Daniel Chalef Date: Thu, 10 May 2018 13:00:52 -0700 Subject: [PATCH 1/7] added iam authentication method for redshift adapter --- dbt/adapters/redshift/impl.py | 146 +++++++++++++++++++++++++--------- dbt/contracts/connection.py | 15 +++- requirements.txt | 1 + 3 files changed, 123 insertions(+), 39 deletions(-) diff --git a/dbt/adapters/redshift/impl.py b/dbt/adapters/redshift/impl.py index 311e09d3923..1c2f8b8a463 100644 --- a/dbt/adapters/redshift/impl.py +++ b/dbt/adapters/redshift/impl.py @@ -2,7 +2,9 @@ from dbt.adapters.postgres import PostgresAdapter from dbt.logger import GLOBAL_LOGGER as logger # noqa - +import dbt.exceptions +import boto3 +import psycopg2 drop_lock = multiprocessing.Lock() @@ -17,6 +19,76 @@ def type(cls): def date_function(cls): return 'getdate()' + @classmethod + def get_redshift_credentials(cls, config): + result = config.copy() + + method = result.get('method') + + if method == 'database': + return (result) + + elif method == 'iam': + cluster_id = result.get('cluster_id') + if not cluster_id: + error = '`cluster_id` must be set in profile if IAM authentication method selected' + raise dbt.exceptions.FailedToConnectException(error) + + client = boto3.client('redshift') + + # replace username and password with temporary redshift credentials + try: + cluster_creds = client.get_cluster_credentials(DbUser=result.get('user'), + DbName=result.get('dbname'), + ClusterIdentifier=result.get('cluster_id'), + AutoCreate=False) + result['user_tmp'] = cluster_creds.get('DbUser') + result['pass_tmp'] = cluster_creds.get('DbPassword') + except client.exceptions.ClientError as e: + error = ('Unable to get temporary Redshift cluster credentials: "{}"'.format(str(e))) + raise dbt.exceptions.FailedToConnectException(error) + + return result + + else: + error = ('Invalid `method` in profile: "{}"'.format(method)) + raise dbt.exceptions.FailedToConnectException(error) + + @classmethod + def open_connection(cls, connection): + if connection.get('state') == 'open': + logger.debug('Connection is already open, skipping open.') + return connection + + result = connection.copy() + + try: + credentials = cls.get_redshift_credentials(connection.get('credentials', {})) + user = credentials.get('user_tmp') if credentials.get('user_tmp') else credentials.get('user') + password = credentials.get('pass_tmp') if credentials.get('pass_tmp') else credentials.get('pass') + + handle = psycopg2.connect( + dbname=credentials.get('dbname'), + user=user, + host=credentials.get('host'), + password=password, + port=credentials.get('port'), + connect_timeout=10) + + result['handle'] = handle + result['state'] = 'open' + except psycopg2.Error as e: + logger.debug("Got an error when attempting to open a postgres " + "connection: '{}'" + .format(e)) + + result['handle'] = None + result['state'] = 'fail' + + raise dbt.exceptions.FailedToConnectException(str(e)) + + return result + @classmethod def _get_columns_in_table_sql(cls, schema_name, table_name, database): # Redshift doesn't support cross-database queries, @@ -27,65 +99,65 @@ def _get_columns_in_table_sql(cls, schema_name, table_name, database): table_schema_filter = '1=1' else: table_schema_filter = "table_schema = '{schema_name}'".format( - schema_name=schema_name) + schema_name=schema_name) sql = """ - with bound_views as ( - select + WITH bound_views AS ( + SELECT ordinal_position, table_schema, column_name, data_type, character_maximum_length, - numeric_precision || ',' || numeric_scale as numeric_size + numeric_precision || ',' || numeric_scale AS numeric_size - from information_schema.columns - where table_name = '{table_name}' + FROM information_schema.columns + WHERE table_name = '{table_name}' ), - unbound_views as ( - select + unbound_views AS ( + SELECT ordinal_position, view_schema, col_name, - case - when col_type ilike 'character varying%' then + CASE + WHEN col_type ILIKE 'character varying%' THEN 'character varying' - when col_type ilike 'numeric%' then 'numeric' - else col_type - end as col_type, - case - when col_type like 'character%' - then nullif(REGEXP_SUBSTR(col_type, '[0-9]+'), '')::int - else null - end as character_maximum_length, - case - when col_type like 'numeric%' - then nullif(REGEXP_SUBSTR(col_type, '[0-9,]+'), '') - else null - end as numeric_size - - from pg_get_late_binding_view_cols() - cols(view_schema name, view_name name, col_name name, - col_type varchar, ordinal_position int) - where view_name = '{table_name}' + WHEN col_type ILIKE 'numeric%' THEN 'numeric' + ELSE col_type + END AS col_type, + CASE + WHEN col_type LIKE 'character%' + THEN nullif(REGEXP_SUBSTR(col_type, '[0-9]+'), '')::INT + ELSE NULL + END AS character_maximum_length, + CASE + WHEN col_type LIKE 'numeric%' + THEN nullif(REGEXP_SUBSTR(col_type, '[0-9,]+'), '') + ELSE NULL + END AS numeric_size + + FROM pg_get_late_binding_view_cols() + cols(view_schema NAME, view_name NAME, col_name NAME, + col_type VARCHAR, ordinal_position INT) + WHERE view_name = '{table_name}' ), - unioned as ( - select * from bound_views - union all - select * from unbound_views + unioned AS ( + SELECT * FROM bound_views + UNION ALL + SELECT * FROM unbound_views ) - select + SELECT column_name, data_type, character_maximum_length, numeric_size - from unioned - where {table_schema_filter} - order by ordinal_position + FROM unioned + WHERE {table_schema_filter} + ORDER BY ordinal_position """.format(table_name=table_name, table_schema_filter=table_schema_filter).strip() return sql diff --git a/dbt/contracts/connection.py b/dbt/contracts/connection.py index 2524b8e371e..7c478aab9e5 100644 --- a/dbt/contracts/connection.py +++ b/dbt/contracts/connection.py @@ -4,7 +4,6 @@ from dbt.contracts.common import validate_with from dbt.logger import GLOBAL_LOGGER as logger # noqa - adapter_types = ['postgres', 'redshift', 'snowflake', 'bigquery'] connection_contract = Schema({ Required('type'): Any(*adapter_types), @@ -24,6 +23,18 @@ Required('schema'): basestring, }) +redshift_auth_methods = ['database', 'iam'] +redshift_credentials_contract = Schema({ + Required('method'): Any(*redshift_auth_methods), + Required('dbname'): basestring, + Required('host'): basestring, + Required('user'): basestring, + Optional('pass'): basestring, + Required('port'): Any(All(int, Range(min=0, max=65535)), basestring), + Required('schema'): basestring, + Optional('cluster_id'): basestring, # TODO: require if 'iam' method selected +}) + snowflake_credentials_contract = Schema({ Required('account'): basestring, Required('user'): basestring, @@ -46,7 +57,7 @@ credentials_mapping = { 'postgres': postgres_credentials_contract, - 'redshift': postgres_credentials_contract, + 'redshift': redshift_credentials_contract, 'snowflake': snowflake_credentials_contract, 'bigquery': bigquery_credentials_contract, } diff --git a/requirements.txt b/requirements.txt index 55c785bff56..e1a8e982c4e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,3 +14,4 @@ google-cloud-bigquery==0.29.0 requests>=2.18.0 agate>=1.6,<2 jsonschema==2.6.0 +boto3>=1.6.23 From 688fa467b22a5669f04ae18ae9e86b30ed107d38 Mon Sep 17 00:00:00 2001 From: Daniel Chalef Date: Thu, 10 May 2018 14:24:02 -0700 Subject: [PATCH 2/7] back out SQL uppercasing --- dbt/adapters/redshift/impl.py | 70 +++++++++++++++++------------------ 1 file changed, 35 insertions(+), 35 deletions(-) diff --git a/dbt/adapters/redshift/impl.py b/dbt/adapters/redshift/impl.py index 1c2f8b8a463..2a0b1848acd 100644 --- a/dbt/adapters/redshift/impl.py +++ b/dbt/adapters/redshift/impl.py @@ -102,62 +102,62 @@ def _get_columns_in_table_sql(cls, schema_name, table_name, database): schema_name=schema_name) sql = """ - WITH bound_views AS ( - SELECT + with bound_views as ( + select ordinal_position, table_schema, column_name, data_type, character_maximum_length, - numeric_precision || ',' || numeric_scale AS numeric_size + numeric_precision || ',' || numeric_scale as numeric_size - FROM information_schema.columns - WHERE table_name = '{table_name}' + from information_schema.columns + where table_name = '{table_name}' ), - unbound_views AS ( - SELECT + unbound_views as ( + select ordinal_position, view_schema, col_name, - CASE - WHEN col_type ILIKE 'character varying%' THEN + case + when col_type ilike 'character varying%' then 'character varying' - WHEN col_type ILIKE 'numeric%' THEN 'numeric' - ELSE col_type - END AS col_type, - CASE - WHEN col_type LIKE 'character%' - THEN nullif(REGEXP_SUBSTR(col_type, '[0-9]+'), '')::INT - ELSE NULL - END AS character_maximum_length, - CASE - WHEN col_type LIKE 'numeric%' - THEN nullif(REGEXP_SUBSTR(col_type, '[0-9,]+'), '') - ELSE NULL - END AS numeric_size - - FROM pg_get_late_binding_view_cols() - cols(view_schema NAME, view_name NAME, col_name NAME, - col_type VARCHAR, ordinal_position INT) - WHERE view_name = '{table_name}' + when col_type ilike 'numeric%' then 'numeric' + else col_type + end as col_type, + case + when col_type like 'character%' + then nullif(REGEXP_SUBSTR(col_type, '[0-9]+'), '')::int + else null + end as character_maximum_length, + case + when col_type like 'numeric%' + then nullif(REGEXP_SUBSTR(col_type, '[0-9,]+'), '') + else null + end as numeric_size + + from pg_get_late_binding_view_cols() + cols(view_schema name, view_name name, col_name name, + col_type varchar, ordinal_position int) + where view_name = '{table_name}' ), - unioned AS ( - SELECT * FROM bound_views - UNION ALL - SELECT * FROM unbound_views + unioned as ( + select * from bound_views + union all + select * from unbound_views ) - SELECT + select column_name, data_type, character_maximum_length, numeric_size - FROM unioned - WHERE {table_schema_filter} - ORDER BY ordinal_position + from unioned + where {table_schema_filter} + order by ordinal_position """.format(table_name=table_name, table_schema_filter=table_schema_filter).strip() return sql From f65b3d677a647160d1d5dc609c32ef2f41207857 Mon Sep 17 00:00:00 2001 From: Daniel Chalef Date: Thu, 10 May 2018 15:32:44 -0700 Subject: [PATCH 3/7] move get_cluster_credentials into separate fn. Make method optional --- dbt/adapters/redshift/impl.py | 61 +++++++++++++++++++---------------- dbt/contracts/connection.py | 4 +-- 2 files changed, 35 insertions(+), 30 deletions(-) diff --git a/dbt/adapters/redshift/impl.py b/dbt/adapters/redshift/impl.py index 2a0b1848acd..e01c49594fe 100644 --- a/dbt/adapters/redshift/impl.py +++ b/dbt/adapters/redshift/impl.py @@ -19,36 +19,43 @@ def type(cls): def date_function(cls): return 'getdate()' + @classmethod + def get_tmp_cluster_credentials(cls, config): + creds = config.copy() + + cluster_id = creds.get('cluster_id') + if not cluster_id: + error = '`cluster_id` must be set in profile if IAM authentication method selected' + raise dbt.exceptions.FailedToConnectException(error) + + client = boto3.client('redshift') + + # replace username and password with temporary redshift credentials + try: + cluster_creds = client.get_cluster_credentials(DbUser=creds.get('user'), + DbName=creds.get('dbname'), + ClusterIdentifier=creds.get('cluster_id'), + AutoCreate=False) + creds['user'] = cluster_creds.get('DbUser') + creds['pass'] = cluster_creds.get('DbPassword') + + return creds + + except client.exceptions.ClientError as e: + error = ('Unable to get temporary Redshift cluster credentials: "{}"'.format(str(e))) + raise dbt.exceptions.FailedToConnectException(error) + @classmethod def get_redshift_credentials(cls, config): - result = config.copy() + creds = config.copy() - method = result.get('method') + method = creds.get('method') - if method == 'database': - return (result) + if method == 'database' or method is None: # Support missing method for backwards compatibility + return creds elif method == 'iam': - cluster_id = result.get('cluster_id') - if not cluster_id: - error = '`cluster_id` must be set in profile if IAM authentication method selected' - raise dbt.exceptions.FailedToConnectException(error) - - client = boto3.client('redshift') - - # replace username and password with temporary redshift credentials - try: - cluster_creds = client.get_cluster_credentials(DbUser=result.get('user'), - DbName=result.get('dbname'), - ClusterIdentifier=result.get('cluster_id'), - AutoCreate=False) - result['user_tmp'] = cluster_creds.get('DbUser') - result['pass_tmp'] = cluster_creds.get('DbPassword') - except client.exceptions.ClientError as e: - error = ('Unable to get temporary Redshift cluster credentials: "{}"'.format(str(e))) - raise dbt.exceptions.FailedToConnectException(error) - - return result + return cls.get_tmp_cluster_credentials(creds) else: error = ('Invalid `method` in profile: "{}"'.format(method)) @@ -64,14 +71,12 @@ def open_connection(cls, connection): try: credentials = cls.get_redshift_credentials(connection.get('credentials', {})) - user = credentials.get('user_tmp') if credentials.get('user_tmp') else credentials.get('user') - password = credentials.get('pass_tmp') if credentials.get('pass_tmp') else credentials.get('pass') handle = psycopg2.connect( dbname=credentials.get('dbname'), - user=user, + user=credentials.get('user'), host=credentials.get('host'), - password=password, + password=credentials.get('pass'), port=credentials.get('port'), connect_timeout=10) diff --git a/dbt/contracts/connection.py b/dbt/contracts/connection.py index 7c478aab9e5..a5ef7735fb9 100644 --- a/dbt/contracts/connection.py +++ b/dbt/contracts/connection.py @@ -25,11 +25,11 @@ redshift_auth_methods = ['database', 'iam'] redshift_credentials_contract = Schema({ - Required('method'): Any(*redshift_auth_methods), + Optional('method'): Any(*redshift_auth_methods), Required('dbname'): basestring, Required('host'): basestring, Required('user'): basestring, - Optional('pass'): basestring, + Optional('pass'): basestring, # TODO: require if 'database' method selected Required('port'): Any(All(int, Range(min=0, max=65535)), basestring), Required('schema'): basestring, Optional('cluster_id'): basestring, # TODO: require if 'iam' method selected From 58c184a1f464f78c6c9c3e2039134a1ab06463d7 Mon Sep 17 00:00:00 2001 From: Drew Banin Date: Mon, 2 Jul 2018 16:58:40 -0400 Subject: [PATCH 4/7] cleanup; update json schemas --- dbt/adapters/postgres/impl.py | 8 ++- dbt/adapters/redshift/impl.py | 97 +++++++++++++---------------------- dbt/contracts/connection.py | 23 ++++----- 3 files changed, 54 insertions(+), 74 deletions(-) diff --git a/dbt/adapters/postgres/impl.py b/dbt/adapters/postgres/impl.py index 50c05b0f199..f9a69ad386c 100644 --- a/dbt/adapters/postgres/impl.py +++ b/dbt/adapters/postgres/impl.py @@ -53,6 +53,10 @@ def date_function(cls): def get_status(cls, cursor): return cursor.statusmessage + @classmethod + def get_credentials(cls, credentials): + return credentials + @classmethod def open_connection(cls, connection): if connection.get('state') == 'open': @@ -61,8 +65,10 @@ def open_connection(cls, connection): result = connection.copy() + base_credentials = connection.get('credentials', {}) + credentials = cls.get_credentials(base_credentials.copy()) + try: - credentials = connection.get('credentials', {}) handle = psycopg2.connect( dbname=credentials.get('dbname'), user=credentials.get('user'), diff --git a/dbt/adapters/redshift/impl.py b/dbt/adapters/redshift/impl.py index e01c49594fe..d0a51324392 100644 --- a/dbt/adapters/redshift/impl.py +++ b/dbt/adapters/redshift/impl.py @@ -4,13 +4,11 @@ from dbt.logger import GLOBAL_LOGGER as logger # noqa import dbt.exceptions import boto3 -import psycopg2 drop_lock = multiprocessing.Lock() class RedshiftAdapter(PostgresAdapter): - @classmethod def type(cls): return 'redshift' @@ -20,79 +18,58 @@ def date_function(cls): return 'getdate()' @classmethod - def get_tmp_cluster_credentials(cls, config): - creds = config.copy() + def get_tmp_iam_cluster_credentials(cls, credentials): + cluster_id = credentials.get('cluster_id') + + # default via: + # boto3.readthedocs.io/en/latest/reference/services/redshift.html + iam_duration_s = credentials.get('iam_duration_seconds', 900) - cluster_id = creds.get('cluster_id') if not cluster_id: - error = '`cluster_id` must be set in profile if IAM authentication method selected' - raise dbt.exceptions.FailedToConnectException(error) + raise dbt.exceptions.FailedToConnectException( + "'cluster_id' must be provided in profile if IAM " + "authentication method selected") - client = boto3.client('redshift') + boto_client = boto3.client('redshift') # replace username and password with temporary redshift credentials + to_update = {} try: - cluster_creds = client.get_cluster_credentials(DbUser=creds.get('user'), - DbName=creds.get('dbname'), - ClusterIdentifier=creds.get('cluster_id'), - AutoCreate=False) - creds['user'] = cluster_creds.get('DbUser') - creds['pass'] = cluster_creds.get('DbPassword') + cluster_creds = boto_client.get_cluster_credentials( + DbUser=credentials.get('user'), + DbName=credentials.get('dbname'), + ClusterIdentifier=credentials.get('cluster_id'), + DurationSeconds=iam_duration_s, + AutoCreate=False) - return creds + to_update = { + 'user': cluster_creds.get('DbUser'), + 'pass': cluster_creds.get('DbPassword') + } - except client.exceptions.ClientError as e: - error = ('Unable to get temporary Redshift cluster credentials: "{}"'.format(str(e))) - raise dbt.exceptions.FailedToConnectException(error) + except boto_client.exceptions.ClientError as e: + raise dbt.exceptions.FailedToConnectException( + "Unable to get temporary Redshift cluster credentials: " + "{}".format(e)) - @classmethod - def get_redshift_credentials(cls, config): - creds = config.copy() + return dbt.utils.merge(credentials, to_update) - method = creds.get('method') + @classmethod + def get_credentials(cls, credentials): + method = credentials.get('method') - if method == 'database' or method is None: # Support missing method for backwards compatibility - return creds + # Support missing 'method' for backwards compatibility + if method == 'database' or method is None: + logger.debug("Connecting to Redshift using 'database' credentials") + return credentials elif method == 'iam': - return cls.get_tmp_cluster_credentials(creds) + logger.debug("Connecting to Redshift using 'IAM' credentials") + return cls.get_tmp_iam_cluster_credentials(credentials) else: - error = ('Invalid `method` in profile: "{}"'.format(method)) - raise dbt.exceptions.FailedToConnectException(error) - - @classmethod - def open_connection(cls, connection): - if connection.get('state') == 'open': - logger.debug('Connection is already open, skipping open.') - return connection - - result = connection.copy() - - try: - credentials = cls.get_redshift_credentials(connection.get('credentials', {})) - - handle = psycopg2.connect( - dbname=credentials.get('dbname'), - user=credentials.get('user'), - host=credentials.get('host'), - password=credentials.get('pass'), - port=credentials.get('port'), - connect_timeout=10) - - result['handle'] = handle - result['state'] = 'open' - except psycopg2.Error as e: - logger.debug("Got an error when attempting to open a postgres " - "connection: '{}'" - .format(e)) - - result['handle'] = None - result['state'] = 'fail' - - raise dbt.exceptions.FailedToConnectException(str(e)) - - return result + raise dbt.exceptions.FailedToConnectException( + 'Invalid `method` in profile: "{}"'.format(method)) @classmethod def _get_columns_in_table_sql(cls, schema_name, table_name, database): diff --git a/dbt/contracts/connection.py b/dbt/contracts/connection.py index 8adf7ec97be..98432399a25 100644 --- a/dbt/contracts/connection.py +++ b/dbt/contracts/connection.py @@ -41,8 +41,7 @@ 'additionalProperties': False, 'properties': { 'method': { - 'enum': ['credentials', 'iam'], - 'default': 'credentials' # TODO : Can I do this? + 'enum': ['database', 'iam'] }, 'dbname': { 'type': 'string', @@ -74,15 +73,13 @@ 'cluster_id': { 'type': 'string' }, - }, - # Confirm this works as intended. cluster_id only required if method == iam - "anyOf": [{ - "properties": { - "method": { "enum": ["iam"] } - }, - "required": ["cluster_id"] - }], - 'required': ['dbname', 'host', 'user', 'pass', 'port', 'schema', 'method'], + 'iam_duration_seconds': { + 'type': ['null', 'integer'], + 'minimum': 900, + 'maximum': 3600 + }, + 'required': ['dbname', 'host', 'user', 'port', 'schema'] + } } SNOWFLAKE_CREDENTIALS_CONTRACT = { @@ -162,11 +159,11 @@ }, 'credentials': { 'description': ( - 'The credentials object here should match the connection ' - 'type. Redshift uses the Postgres connection model.' + 'The credentials object here should match the connection type.' ), 'oneOf': [ POSTGRES_CREDENTIALS_CONTRACT, + REDSHIFT_CREDENTIALS_CONTRACT, SNOWFLAKE_CREDENTIALS_CONTRACT, BIGQUERY_CREDENTIALS_CONTRACT, ], From a894ca9e65fb9da015b896eb20515499c48f2040 Mon Sep 17 00:00:00 2001 From: Drew Banin Date: Mon, 2 Jul 2018 17:03:47 -0400 Subject: [PATCH 5/7] add boto3 dep to setup.py --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 04c4c443fe1..b6e69f6749e 100644 --- a/setup.py +++ b/setup.py @@ -52,5 +52,6 @@ def read(fname): 'google-cloud-bigquery==0.29.0', 'agate>=1.6,<2', 'jsonschema==2.6.0', + 'boto3>=1.6.23' ] ) From bf7608550daee5e6f0e612c4fcc825f456006677 Mon Sep 17 00:00:00 2001 From: Drew Banin Date: Mon, 2 Jul 2018 22:23:12 -0400 Subject: [PATCH 6/7] add unit tests, refactor --- dbt/adapters/redshift/impl.py | 52 +++++++++------- dbt/contracts/connection.py | 19 ++++-- test/unit/test_redshift_adapter.py | 98 ++++++++++++++++++++++++++++++ 3 files changed, 142 insertions(+), 27 deletions(-) create mode 100644 test/unit/test_redshift_adapter.py diff --git a/dbt/adapters/redshift/impl.py b/dbt/adapters/redshift/impl.py index d0a51324392..7f6f13195ab 100644 --- a/dbt/adapters/redshift/impl.py +++ b/dbt/adapters/redshift/impl.py @@ -17,6 +17,25 @@ def type(cls): def date_function(cls): return 'getdate()' + @classmethod + def fetch_cluster_credentials(cls, db_user, db_name, cluster_id, duration_s): + """Fetches temporary login credentials from AWS. The specified user + must already exist in the database, or else an error will occur""" + boto_client = boto3.client('redshift') + + try: + return boto_client.get_cluster_credentials( + DbUser=db_user, + DbName=db_name, + ClusterIdentifier=cluster_id, + DurationSeconds=duration_s, + AutoCreate=False) + + except boto_client.exceptions.ClientError as e: + raise dbt.exceptions.FailedToConnectException( + "Unable to get temporary Redshift cluster credentials: " + "{}".format(e)) + @classmethod def get_tmp_iam_cluster_credentials(cls, credentials): cluster_id = credentials.get('cluster_id') @@ -30,29 +49,18 @@ def get_tmp_iam_cluster_credentials(cls, credentials): "'cluster_id' must be provided in profile if IAM " "authentication method selected") - boto_client = boto3.client('redshift') + cluster_creds = cls.fetch_cluster_credentials( + credentials.get('user'), + credentials.get('dbname'), + credentials.get('cluster_id'), + iam_duration_s, + ) # replace username and password with temporary redshift credentials - to_update = {} - try: - cluster_creds = boto_client.get_cluster_credentials( - DbUser=credentials.get('user'), - DbName=credentials.get('dbname'), - ClusterIdentifier=credentials.get('cluster_id'), - DurationSeconds=iam_duration_s, - AutoCreate=False) - - to_update = { - 'user': cluster_creds.get('DbUser'), - 'pass': cluster_creds.get('DbPassword') - } - - except boto_client.exceptions.ClientError as e: - raise dbt.exceptions.FailedToConnectException( - "Unable to get temporary Redshift cluster credentials: " - "{}".format(e)) - - return dbt.utils.merge(credentials, to_update) + return dbt.utils.merge(credentials, { + 'user': cluster_creds.get('DbUser'), + 'pass': cluster_creds.get('DbPassword') + }) @classmethod def get_credentials(cls, credentials): @@ -69,7 +77,7 @@ def get_credentials(cls, credentials): else: raise dbt.exceptions.FailedToConnectException( - 'Invalid `method` in profile: "{}"'.format(method)) + "Invalid 'method' in profile: '{}'".format(method)) @classmethod def _get_columns_in_table_sql(cls, schema_name, table_name, database): diff --git a/dbt/contracts/connection.py b/dbt/contracts/connection.py index 98432399a25..2183c5d891b 100644 --- a/dbt/contracts/connection.py +++ b/dbt/contracts/connection.py @@ -41,7 +41,10 @@ 'additionalProperties': False, 'properties': { 'method': { - 'enum': ['database', 'iam'] + 'enum': ['database', 'iam'], + 'description': ( + 'database: use user/pass creds; iam: use temporary creds' + ), }, 'dbname': { 'type': 'string', @@ -71,12 +74,18 @@ 'type': 'string', }, 'cluster_id': { - 'type': 'string' + 'type': 'string', + 'description': ( + 'If using IAM auth, the name of the cluster' + ) }, 'iam_duration_seconds': { - 'type': ['null', 'integer'], + 'type': 'integer', 'minimum': 900, - 'maximum': 3600 + 'maximum': 3600, + 'description': ( + 'If using IAM auth, the ttl for the temporary credentials' + ) }, 'required': ['dbname', 'host', 'user', 'port', 'schema'] } @@ -161,7 +170,7 @@ 'description': ( 'The credentials object here should match the connection type.' ), - 'oneOf': [ + 'anyOf': [ POSTGRES_CREDENTIALS_CONTRACT, REDSHIFT_CREDENTIALS_CONTRACT, SNOWFLAKE_CREDENTIALS_CONTRACT, diff --git a/test/unit/test_redshift_adapter.py b/test/unit/test_redshift_adapter.py new file mode 100644 index 00000000000..505ae4b536d --- /dev/null +++ b/test/unit/test_redshift_adapter.py @@ -0,0 +1,98 @@ +import unittest +import mock + +import dbt.flags as flags +import dbt.utils + +from dbt.adapters.redshift import RedshiftAdapter +from dbt.exceptions import ValidationException, FailedToConnectException +from dbt.logger import GLOBAL_LOGGER as logger # noqa + +@classmethod +def fetch_cluster_credentials(*args, **kwargs): + return { + 'DbUser': 'root', + 'DbPassword': 'tmp_password' + } + +class TestRedshiftAdapter(unittest.TestCase): + + def setUp(self): + flags.STRICT_MODE = True + + def test_implicit_database_conn(self): + implicit_database_profile = { + 'dbname': 'redshift', + 'user': 'root', + 'host': 'database', + 'pass': 'password', + 'port': 5439, + 'schema': 'public' + } + + creds = RedshiftAdapter.get_credentials(implicit_database_profile) + self.assertEquals(creds, implicit_database_profile) + + def test_explicit_database_conn(self): + explicit_database_profile = { + 'method': 'database', + 'dbname': 'redshift', + 'user': 'root', + 'host': 'database', + 'pass': 'password', + 'port': 5439, + 'schema': 'public' + } + + creds = RedshiftAdapter.get_credentials(explicit_database_profile) + self.assertEquals(creds, explicit_database_profile) + + def test_explicit_iam_conn(self): + explicit_iam_profile = { + 'method': 'iam', + 'cluster_id': 'my_redshift', + 'iam_duration_s': 1200, + 'dbname': 'redshift', + 'user': 'root', + 'host': 'database', + 'port': 5439, + 'schema': 'public', + } + + with mock.patch.object(RedshiftAdapter, 'fetch_cluster_credentials', new=fetch_cluster_credentials): + creds = RedshiftAdapter.get_credentials(explicit_iam_profile) + + expected_creds = dbt.utils.merge(explicit_iam_profile, {'pass': 'tmp_password'}) + self.assertEquals(creds, expected_creds) + + def test_invalid_auth_method(self): + invalid_profile = { + 'method': 'badmethod', + 'dbname': 'redshift', + 'user': 'root', + 'host': 'database', + 'pass': 'password', + 'port': 5439, + 'schema': 'public' + } + + with self.assertRaises(dbt.exceptions.FailedToConnectException) as context: + with mock.patch.object(RedshiftAdapter, 'fetch_cluster_credentials', new=fetch_cluster_credentials): + RedshiftAdapter.get_credentials(invalid_profile) + + self.assertTrue('badmethod' in context.exception.msg) + + def test_invalid_iam_no_cluster_id(self): + invalid_profile = { + 'method': 'iam', + 'dbname': 'redshift', + 'user': 'root', + 'host': 'database', + 'port': 5439, + 'schema': 'public' + } + with self.assertRaises(dbt.exceptions.FailedToConnectException) as context: + with mock.patch.object(RedshiftAdapter, 'fetch_cluster_credentials', new=fetch_cluster_credentials): + RedshiftAdapter.get_credentials(invalid_profile) + + self.assertTrue("'cluster_id' must be provided" in context.exception.msg) From e7abe27bfa9fc4aa64abf55ef63c73716bd259a0 Mon Sep 17 00:00:00 2001 From: Drew Banin Date: Tue, 3 Jul 2018 14:02:26 -0400 Subject: [PATCH 7/7] pep8 --- dbt/adapters/redshift/impl.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dbt/adapters/redshift/impl.py b/dbt/adapters/redshift/impl.py index 7f6f13195ab..319edf2bf45 100644 --- a/dbt/adapters/redshift/impl.py +++ b/dbt/adapters/redshift/impl.py @@ -18,7 +18,8 @@ def date_function(cls): return 'getdate()' @classmethod - def fetch_cluster_credentials(cls, db_user, db_name, cluster_id, duration_s): + def fetch_cluster_credentials(cls, db_user, db_name, cluster_id, + duration_s): """Fetches temporary login credentials from AWS. The specified user must already exist in the database, or else an error will occur""" boto_client = boto3.client('redshift')