Skip to content

Commit

Permalink
Merge pull request #1232 from alexyer/feature/snowflake-ssh-login
Browse files Browse the repository at this point in the history
Add support for Snowflake Key Pair Authentication
  • Loading branch information
drewbanin committed Jan 21, 2019
2 parents a34ab9a + 438b352 commit b7d9eec
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 5 deletions.
30 changes: 30 additions & 0 deletions plugins/snowflake/dbt/adapters/snowflake/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import dbt.compat
import dbt.exceptions
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from dbt.adapters.base import Credentials
from dbt.adapters.sql import SQLConnectionManager
from dbt.logger import GLOBAL_LOGGER as logger
Expand All @@ -29,6 +31,12 @@
'type': 'string',
'description': "Either 'externalbrowser', or a valid Okta url"
},
'private_key_path': {
'type': 'string',
},
'private_key_passphrase': {
'type': 'string',
},
'database': {
'type': 'string',
},
Expand Down Expand Up @@ -104,6 +112,11 @@ def open(cls, connection):
auth_args = {auth_key: credentials[auth_key]
for auth_key in ['user', 'password', 'authenticator']
if auth_key in credentials}

auth_args['private_key'] = cls._get_private_key(
credentials.get('private_key_path'),
credentials.get('private_key_passphrase'))

handle = snowflake.connector.connect(
account=credentials.account,
database=credentials.database,
Expand Down Expand Up @@ -163,6 +176,23 @@ def _split_queries(cls, sql):
split_query = snowflake.connector.util_text.split_statements(sql_buf)
return [part[0] for part in split_query]

@classmethod
def _get_private_key(cls, private_key_path, private_key_passphrase):
"""Get Snowflake private key by path or None."""
if private_key_path is None or private_key_passphrase is None:
return None

with open(private_key_path, 'rb') as key:
p_key = serialization.load_pem_private_key(
key.read(),
password=private_key_passphrase.encode(),
backend=default_backend())

return p_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption())

def add_query(self, sql, model_name=None, auto_begin=True,
bindings=None, abridge_sql_log=False):

Expand Down
30 changes: 25 additions & 5 deletions test/unit/test_snowflake_adapter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from mock import patch

import mock
import unittest

Expand Down Expand Up @@ -161,7 +163,7 @@ def test_client_session_keep_alive_false_by_default(self):
account='test_account', autocommit=False,
client_session_keep_alive=False, database='test_database',
role=None, schema='public', user='test_user',
warehouse='test_warehouse')
warehouse='test_warehouse', private_key=None)
])

def test_client_session_keep_alive_true(self):
Expand All @@ -175,7 +177,7 @@ def test_client_session_keep_alive_true(self):
account='test_account', autocommit=False,
client_session_keep_alive=True, database='test_database',
role=None, schema='public', user='test_user',
warehouse='test_warehouse')
warehouse='test_warehouse', private_key=None)
])

def test_user_pass_authentication(self):
Expand All @@ -189,7 +191,7 @@ def test_user_pass_authentication(self):
account='test_account', autocommit=False,
client_session_keep_alive=False, database='test_database',
password='test_password', role=None, schema='public',
user='test_user', warehouse='test_warehouse')
user='test_user', warehouse='test_warehouse', private_key=None)
])

def test_authenticator_user_pass_authentication(self):
Expand All @@ -204,7 +206,7 @@ def test_authenticator_user_pass_authentication(self):
client_session_keep_alive=False, database='test_database',
password='test_password', role=None, schema='public',
user='test_user', warehouse='test_warehouse',
authenticator='test_sso_url')
authenticator='test_sso_url', private_key=None)
])

def test_authenticator_externalbrowser_authentication(self):
Expand All @@ -218,5 +220,23 @@ def test_authenticator_externalbrowser_authentication(self):
account='test_account', autocommit=False,
client_session_keep_alive=False, database='test_database',
role=None, schema='public', user='test_user',
warehouse='test_warehouse', authenticator='externalbrowser')
warehouse='test_warehouse', authenticator='externalbrowser',
private_key=None)
])

@patch('dbt.adapters.snowflake.SnowflakeConnectionManager._get_private_key', return_value='test_key')
def test_authenticator_private_key_authentication(self, mock_get_private_key):
self.config.credentials = self.config.credentials.incorporate(
private_key_path='/tmp/test_key.p8',
private_key_passphrase='p@ssphr@se')

self.adapter = SnowflakeAdapter(self.config)
self.adapter.connections.get(name='new_connection_with_new_config')

self.snowflake.assert_has_calls([
mock.call(
account='test_account', autocommit=False,
client_session_keep_alive=False, database='test_databse',
role=None, schema='public', user='test_user',
warehouse='test_warehouse', private_key='test_key')
])

0 comments on commit b7d9eec

Please sign in to comment.