Skip to content

Commit

Permalink
Merge pull request #1941 from fishtown-analytics/feature/lazy-load-co…
Browse files Browse the repository at this point in the history
…nnections

lazy-load connections (#1584)
  • Loading branch information
drewbanin committed Dec 9, 2019
2 parents 7814ea0 + a3d58f9 commit afd69ea
Show file tree
Hide file tree
Showing 10 changed files with 106 additions and 26 deletions.
16 changes: 11 additions & 5 deletions core/dbt/adapters/base/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import dbt.exceptions
import dbt.flags
from dbt.contracts.connection import (
Connection, Identifier, ConnectionState, AdapterRequiredConfig
Connection, Identifier, ConnectionState, AdapterRequiredConfig, LazyHandle
)
from dbt.adapters.base.query_headers import (
QueryStringSetter, MacroQueryStringSetter,
Expand Down Expand Up @@ -61,6 +61,14 @@ def get_thread_connection(self) -> Connection:
)
return self.thread_connections[key]

def set_thread_connection(self, conn):
key = self.get_thread_identifier()
if key in self.thread_connections:
raise dbt.exceptions.InternalException(
'In set_thread_connection, existing connection exists for {}'
)
self.thread_connections[key] = conn

def get_if_exists(self) -> Optional[Connection]:
key = self.get_thread_identifier()
with self.lock:
Expand Down Expand Up @@ -109,8 +117,6 @@ def set_connection_name(self, name: Optional[str] = None) -> Connection:
conn_name = name

conn = self.get_if_exists()
thread_id_key = self.get_thread_identifier()

if conn is None:
conn = Connection(
type=Identifier(self.TYPE),
Expand All @@ -120,7 +126,7 @@ def set_connection_name(self, name: Optional[str] = None) -> Connection:
handle=None,
credentials=self.profile.credentials
)
self.thread_connections[thread_id_key] = conn
self.set_thread_connection(conn)

if conn.name == conn_name and conn.state == 'open':
return conn
Expand All @@ -138,7 +144,7 @@ def set_connection_name(self, name: Optional[str] = None) -> Connection:
'Opening a new connection, currently in state {}'
.format(conn.state)
)
self.open(conn)
conn.handle = LazyHandle(type(self))

conn.name = conn_name
return conn
Expand Down
12 changes: 6 additions & 6 deletions core/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,21 +234,21 @@ def nice_connection_name(self):
@contextmanager
def connection_named(
self, name: str, node: Optional[CompileResultNode] = None
):
) -> Iterator[None]:
try:
self.connections.query_header.set(name, node)
conn = self.acquire_connection(name)
yield conn
self.acquire_connection(name)
yield
finally:
self.release_connection()
self.connections.query_header.reset()

@contextmanager
def connection_for(
self, node: CompileResultNode
) -> Iterator[Connection]:
with self.connection_named(node.unique_id, node) as conn:
yield conn
) -> Iterator[None]:
with self.connection_named(node.unique_id, node):
yield

@available.parse(lambda *a, **k: ('', empty_table()))
def execute(
Expand Down
29 changes: 28 additions & 1 deletion core/dbt/contracts/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import itertools
from dataclasses import dataclass, field
from typing import (
Any, ClassVar, Dict, Tuple, Iterable, Optional, NewType, List
Any, ClassVar, Dict, Tuple, Iterable, Optional, NewType, List, Type
)
from typing_extensions import Protocol

Expand All @@ -12,6 +12,7 @@
)

from dbt.contracts.util import Replaceable
from dbt.exceptions import InternalException
from dbt.utils import translate_aliases


Expand All @@ -26,6 +27,23 @@ class ConnectionState(StrEnum):
FAIL = 'fail'


class ConnectionOpenerProtocol(Protocol):
@classmethod
def open(cls, connection: 'Connection') -> Any:
raise NotImplementedError(f'open() not implemented for {cls.__name__}')


class LazyHandle:
"""Opener must be a callable that takes a Connection object and opens the
connection, updating the handle on the Connection.
"""
def __init__(self, opener: Type[ConnectionOpenerProtocol]):
self.opener = opener

def resolve(self, connection: 'Connection') -> Any:
return self.opener.open(connection)


@dataclass(init=False)
class Connection(ExtensibleJsonSchemaMixin, Replaceable):
type: Identifier
Expand Down Expand Up @@ -62,6 +80,15 @@ def credentials(self, value):

@property
def handle(self):
if isinstance(self._handle, LazyHandle):
try:
# this will actually change 'self._handle'.
self._handle.resolve(self)
except RecursionError as exc:
raise InternalException(
"A connection's open() method attempted to read the "
"handle value"
) from exc
return self._handle

@handle.setter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,12 @@ def project_config(self):
def run_select_and_check(self, rel, sql):
connection_name = '__test_{}'.format(id(threading.current_thread()))
try:
with self._secret_adapter.connection_named(connection_name) as conn:
with self._secret_adapter.connection_named(connection_name):
conn = self._secret_adapter.connections.get_thread_connection()
res = self.run_sql_common(self.transform_sql(sql), 'one', conn)

# The result is the output of f_sleep(), which is True
if res[0] == True:
if res[0]:
self.query_state[rel] = 'good'
else:
self.query_state[rel] = 'bad'
Expand Down
3 changes: 2 additions & 1 deletion test/integration/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,8 @@ def get_connection(self, name=None):
if name is None:
name = '__test'
with patch.object(common, 'get_adapter', return_value=self.adapter):
with self.adapter.connection_named(name) as conn:
with self.adapter.connection_named(name):
conn = self.adapter.connections.get_thread_connection()
yield conn

def get_relation_columns(self, relation):
Expand Down
3 changes: 2 additions & 1 deletion test/rpc/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,8 @@ def __init__(self, profiles_dir, which='run-operation', kwargs={}):


def execute(adapter, sql):
with adapter.connection_named('rpc-tests') as conn:
with adapter.connection_named('rpc-tests'):
conn = adapter.connections.get_thread_connection()
with conn.handle.cursor() as cursor:
try:
cursor.execute(sql)
Expand Down
14 changes: 10 additions & 4 deletions test/unit/test_bigquery_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ def test_acquire_connection_oauth_validations(self, mock_open_connection):
except BaseException as e:
raise

mock_open_connection.assert_not_called()
connection.handle
mock_open_connection.assert_called_once()

@patch('dbt.adapters.bigquery.BigQueryConnectionManager.open', return_value=_bq_conn())
Expand All @@ -115,6 +117,8 @@ def test_acquire_connection_service_account_validations(self, mock_open_connecti
except BaseException as e:
raise

mock_open_connection.assert_not_called()
connection.handle
mock_open_connection.assert_called_once()

@patch('dbt.adapters.bigquery.BigQueryConnectionManager.open', return_value=_bq_conn())
Expand All @@ -128,9 +132,8 @@ def test_acquire_connection_priority(self, mock_open_connection):
except dbt.exceptions.ValidationException as e:
self.fail('got ValidationException: {}'.format(str(e)))

except BaseException as e:
raise

mock_open_connection.assert_not_called()
connection.handle
mock_open_connection.assert_called_once()

def test_cancel_open_connections_empty(self):
Expand Down Expand Up @@ -158,8 +161,11 @@ def test_location_value(self, mock_bq, mock_auth_default):
mock_auth_default.return_value = (creds, MagicMock())
adapter = self.get_adapter('loc')

adapter.acquire_connection('dummy')
connection = adapter.acquire_connection('dummy')
mock_client = mock_bq.Client

mock_client.assert_not_called()
connection.handle
mock_client.assert_called_once_with('dbt-unit-000000', creds,
location='Luna Station')

Expand Down
15 changes: 15 additions & 0 deletions test/unit/test_postgres_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,17 @@ def test_acquire_connection_validations(self, psycopg2):
self.fail('acquiring connection failed with unknown exception: {}'
.format(str(e)))
self.assertEqual(connection.type, 'postgres')

psycopg2.connect.assert_not_called()
connection.handle
psycopg2.connect.assert_called_once()

@mock.patch('dbt.adapters.postgres.connections.psycopg2')
def test_acquire_connection(self, psycopg2):
connection = self.adapter.acquire_connection('dummy')

psycopg2.connect.assert_not_called()
connection.handle
self.assertEqual(connection.state, 'open')
self.assertNotEqual(connection.handle, None)
psycopg2.connect.assert_called_once()
Expand Down Expand Up @@ -101,6 +106,8 @@ def test_cancel_open_connections_single(self):
def test_default_keepalive(self, psycopg2):
connection = self.adapter.acquire_connection('dummy')

psycopg2.connect.assert_not_called()
connection.handle
psycopg2.connect.assert_called_once_with(
dbname='postgres',
user='root',
Expand All @@ -114,6 +121,8 @@ def test_changed_keepalive(self, psycopg2):
self.config.credentials = self.config.credentials.replace(keepalives_idle=256)
connection = self.adapter.acquire_connection('dummy')

psycopg2.connect.assert_not_called()
connection.handle
psycopg2.connect.assert_called_once_with(
dbname='postgres',
user='root',
Expand All @@ -128,6 +137,8 @@ def test_search_path(self, psycopg2):
self.config.credentials = self.config.credentials.replace(search_path="test")
connection = self.adapter.acquire_connection('dummy')

psycopg2.connect.assert_not_called()
connection.handle
psycopg2.connect.assert_called_once_with(
dbname='postgres',
user='root',
Expand All @@ -142,6 +153,8 @@ def test_schema_with_space(self, psycopg2):
self.config.credentials = self.config.credentials.replace(search_path="test test")
connection = self.adapter.acquire_connection('dummy')

psycopg2.connect.assert_not_called()
connection.handle
psycopg2.connect.assert_called_once_with(
dbname='postgres',
user='root',
Expand All @@ -156,6 +169,8 @@ def test_set_zero_keepalive(self, psycopg2):
self.config.credentials = self.config.credentials.replace(keepalives_idle=0)
connection = self.adapter.acquire_connection('dummy')

psycopg2.connect.assert_not_called()
connection.handle
psycopg2.connect.assert_called_once_with(
dbname='postgres',
user='root',
Expand Down
10 changes: 10 additions & 0 deletions test/unit/test_redshift_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ def test_cancel_open_connections_single(self):
def test_default_keepalive(self, psycopg2):
connection = self.adapter.acquire_connection('dummy')

psycopg2.connect.assert_not_called()
connection.handle
psycopg2.connect.assert_called_once_with(
dbname='redshift',
user='root',
Expand All @@ -168,6 +170,8 @@ def test_changed_keepalive(self, psycopg2):
self.config.credentials = self.config.credentials.replace(keepalives_idle=256)
connection = self.adapter.acquire_connection('dummy')

psycopg2.connect.assert_not_called()
connection.handle
psycopg2.connect.assert_called_once_with(
dbname='redshift',
user='root',
Expand All @@ -182,6 +186,8 @@ def test_search_path(self, psycopg2):
self.config.credentials = self.config.credentials.replace(search_path="test")
connection = self.adapter.acquire_connection('dummy')

psycopg2.connect.assert_not_called()
connection.handle
psycopg2.connect.assert_called_once_with(
dbname='redshift',
user='root',
Expand All @@ -197,6 +203,8 @@ def test_search_path_with_space(self, psycopg2):
self.config.credentials = self.config.credentials.replace(search_path="test test")
connection = self.adapter.acquire_connection('dummy')

psycopg2.connect.assert_not_called()
connection.handle
psycopg2.connect.assert_called_once_with(
dbname='redshift',
user='root',
Expand All @@ -212,6 +220,8 @@ def test_set_zero_keepalive(self, psycopg2):
self.config.credentials = self.config.credentials.replace(keepalives_idle=0)
connection = self.adapter.acquire_connection('dummy')

psycopg2.connect.assert_not_called()
connection.handle
psycopg2.connect.assert_called_once_with(
dbname='redshift',
user='root',
Expand Down
25 changes: 19 additions & 6 deletions test/unit/test_snowflake_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,10 @@ def test_cancel_open_connections_single(self):
add_query.assert_called_once_with('select system$abort_session(42)')

def test_client_session_keep_alive_false_by_default(self):
self.adapter.connections.set_connection_name(name='new_connection_with_new_config')
conn = self.adapter.connections.set_connection_name(name='new_connection_with_new_config')

self.snowflake.assert_not_called()
conn.handle
self.snowflake.assert_has_calls([
mock.call(
account='test_account', autocommit=False,
Expand All @@ -243,8 +246,10 @@ def test_client_session_keep_alive_true(self):
self.config.credentials = self.config.credentials.replace(
client_session_keep_alive=True)
self.adapter = SnowflakeAdapter(self.config)
self.adapter.connections.set_connection_name(name='new_connection_with_new_config')
conn = self.adapter.connections.set_connection_name(name='new_connection_with_new_config')

self.snowflake.assert_not_called()
conn.handle
self.snowflake.assert_has_calls([
mock.call(
account='test_account', autocommit=False,
Expand All @@ -258,8 +263,10 @@ def test_user_pass_authentication(self):
password='test_password',
)
self.adapter = SnowflakeAdapter(self.config)
self.adapter.connections.set_connection_name(name='new_connection_with_new_config')
conn = self.adapter.connections.set_connection_name(name='new_connection_with_new_config')

self.snowflake.assert_not_called()
conn.handle
self.snowflake.assert_has_calls([
mock.call(
account='test_account', autocommit=False,
Expand All @@ -275,8 +282,10 @@ def test_authenticator_user_pass_authentication(self):
authenticator='test_sso_url',
)
self.adapter = SnowflakeAdapter(self.config)
self.adapter.connections.set_connection_name(name='new_connection_with_new_config')
conn = self.adapter.connections.set_connection_name(name='new_connection_with_new_config')

self.snowflake.assert_not_called()
conn.handle
self.snowflake.assert_has_calls([
mock.call(
account='test_account', autocommit=False,
Expand All @@ -292,8 +301,10 @@ def test_authenticator_externalbrowser_authentication(self):
authenticator='externalbrowser'
)
self.adapter = SnowflakeAdapter(self.config)
self.adapter.connections.set_connection_name(name='new_connection_with_new_config')
conn = self.adapter.connections.set_connection_name(name='new_connection_with_new_config')

self.snowflake.assert_not_called()
conn.handle
self.snowflake.assert_has_calls([
mock.call(
account='test_account', autocommit=False,
Expand All @@ -311,8 +322,10 @@ def test_authenticator_private_key_authentication(self, mock_get_private_key):
)

self.adapter = SnowflakeAdapter(self.config)
self.adapter.connections.set_connection_name(name='new_connection_with_new_config')
conn = self.adapter.connections.set_connection_name(name='new_connection_with_new_config')

self.snowflake.assert_not_called()
conn.handle
self.snowflake.assert_has_calls([
mock.call(
account='test_account', autocommit=False,
Expand Down

0 comments on commit afd69ea

Please sign in to comment.