Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "lazy-load connections (#1584)" #1991

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 5 additions & 11 deletions core/dbt/adapters/base/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import dbt.exceptions
import dbt.flags
from dbt.contracts.connection import (
Connection, Identifier, ConnectionState, AdapterRequiredConfig, LazyHandle
Connection, Identifier, ConnectionState, AdapterRequiredConfig
)
from dbt.adapters.base.query_headers import (
QueryStringSetter, MacroQueryStringSetter,
Expand Down Expand Up @@ -61,14 +61,6 @@ def get_thread_connection(self) -> Connection:
)
return self.thread_connections[key]

def set_thread_connection(self, conn):
key = self.get_thread_identifier()
if key in self.thread_connections:
raise dbt.exceptions.InternalException(
'In set_thread_connection, existing connection exists for {}'
)
self.thread_connections[key] = conn

def get_if_exists(self) -> Optional[Connection]:
key = self.get_thread_identifier()
with self.lock:
Expand Down Expand Up @@ -117,6 +109,8 @@ def set_connection_name(self, name: Optional[str] = None) -> Connection:
conn_name = name

conn = self.get_if_exists()
thread_id_key = self.get_thread_identifier()

if conn is None:
conn = Connection(
type=Identifier(self.TYPE),
Expand All @@ -126,7 +120,7 @@ def set_connection_name(self, name: Optional[str] = None) -> Connection:
handle=None,
credentials=self.profile.credentials
)
self.set_thread_connection(conn)
self.thread_connections[thread_id_key] = conn

if conn.name == conn_name and conn.state == 'open':
return conn
Expand All @@ -144,7 +138,7 @@ def set_connection_name(self, name: Optional[str] = None) -> Connection:
'Opening a new connection, currently in state {}'
.format(conn.state)
)
conn.handle = LazyHandle(type(self))
self.open(conn)

conn.name = conn_name
return conn
Expand Down
12 changes: 6 additions & 6 deletions core/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,21 +234,21 @@ def nice_connection_name(self):
@contextmanager
def connection_named(
self, name: str, node: Optional[CompileResultNode] = None
) -> Iterator[None]:
):
try:
self.connections.query_header.set(name, node)
self.acquire_connection(name)
yield
conn = self.acquire_connection(name)
yield conn
finally:
self.release_connection()
self.connections.query_header.reset()

@contextmanager
def connection_for(
self, node: CompileResultNode
) -> Iterator[None]:
with self.connection_named(node.unique_id, node):
yield
) -> Iterator[Connection]:
with self.connection_named(node.unique_id, node) as conn:
yield conn

@available.parse(lambda *a, **k: ('', empty_table()))
def execute(
Expand Down
29 changes: 1 addition & 28 deletions core/dbt/contracts/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import itertools
from dataclasses import dataclass, field
from typing import (
Any, ClassVar, Dict, Tuple, Iterable, Optional, NewType, List, Type
Any, ClassVar, Dict, Tuple, Iterable, Optional, NewType, List
)
from typing_extensions import Protocol

Expand All @@ -12,7 +12,6 @@
)

from dbt.contracts.util import Replaceable
from dbt.exceptions import InternalException
from dbt.utils import translate_aliases


Expand All @@ -27,23 +26,6 @@ class ConnectionState(StrEnum):
FAIL = 'fail'


class ConnectionOpenerProtocol(Protocol):
@classmethod
def open(cls, connection: 'Connection') -> Any:
raise NotImplementedError(f'open() not implemented for {cls.__name__}')


class LazyHandle:
"""Opener must be a callable that takes a Connection object and opens the
connection, updating the handle on the Connection.
"""
def __init__(self, opener: Type[ConnectionOpenerProtocol]):
self.opener = opener

def resolve(self, connection: 'Connection') -> Any:
return self.opener.open(connection)


@dataclass(init=False)
class Connection(ExtensibleJsonSchemaMixin, Replaceable):
type: Identifier
Expand Down Expand Up @@ -80,15 +62,6 @@ def credentials(self, value):

@property
def handle(self):
if isinstance(self._handle, LazyHandle):
try:
# this will actually change 'self._handle'.
self._handle.resolve(self)
except RecursionError as exc:
raise InternalException(
"A connection's open() method attempted to read the "
"handle value"
) from exc
return self._handle

@handle.setter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,11 @@ def project_config(self):
def run_select_and_check(self, rel, sql):
connection_name = '__test_{}'.format(id(threading.current_thread()))
try:
with self._secret_adapter.connection_named(connection_name):
conn = self._secret_adapter.connections.get_thread_connection()
with self._secret_adapter.connection_named(connection_name) as conn:
res = self.run_sql_common(self.transform_sql(sql), 'one', conn)

# The result is the output of f_sleep(), which is True
if res[0]:
if res[0] == True:
self.query_state[rel] = 'good'
else:
self.query_state[rel] = 'bad'
Expand Down
3 changes: 1 addition & 2 deletions test/integration/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,8 +753,7 @@ def get_connection(self, name=None):
if name is None:
name = '__test'
with patch.object(common, 'get_adapter', return_value=self.adapter):
with self.adapter.connection_named(name):
conn = self.adapter.connections.get_thread_connection()
with self.adapter.connection_named(name) as conn:
yield conn

def get_relation_columns(self, relation):
Expand Down
3 changes: 1 addition & 2 deletions test/rpc/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,8 +519,7 @@ def __init__(self, profiles_dir, which='run-operation', kwargs={}):


def execute(adapter, sql):
with adapter.connection_named('rpc-tests'):
conn = adapter.connections.get_thread_connection()
with adapter.connection_named('rpc-tests') as conn:
with conn.handle.cursor() as cursor:
try:
cursor.execute(sql)
Expand Down
14 changes: 4 additions & 10 deletions test/unit/test_bigquery_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,6 @@ def test_acquire_connection_oauth_validations(self, mock_open_connection):
except BaseException as e:
raise

mock_open_connection.assert_not_called()
connection.handle
mock_open_connection.assert_called_once()

@patch('dbt.adapters.bigquery.BigQueryConnectionManager.open', return_value=_bq_conn())
Expand All @@ -117,8 +115,6 @@ def test_acquire_connection_service_account_validations(self, mock_open_connecti
except BaseException as e:
raise

mock_open_connection.assert_not_called()
connection.handle
mock_open_connection.assert_called_once()

@patch('dbt.adapters.bigquery.BigQueryConnectionManager.open', return_value=_bq_conn())
Expand All @@ -132,8 +128,9 @@ def test_acquire_connection_priority(self, mock_open_connection):
except dbt.exceptions.ValidationException as e:
self.fail('got ValidationException: {}'.format(str(e)))

mock_open_connection.assert_not_called()
connection.handle
except BaseException as e:
raise

mock_open_connection.assert_called_once()

def test_cancel_open_connections_empty(self):
Expand Down Expand Up @@ -161,11 +158,8 @@ def test_location_value(self, mock_bq, mock_auth_default):
mock_auth_default.return_value = (creds, MagicMock())
adapter = self.get_adapter('loc')

connection = adapter.acquire_connection('dummy')
adapter.acquire_connection('dummy')
mock_client = mock_bq.Client

mock_client.assert_not_called()
connection.handle
mock_client.assert_called_once_with('dbt-unit-000000', creds,
location='Luna Station')

Expand Down
15 changes: 0 additions & 15 deletions test/unit/test_postgres_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,12 @@ def test_acquire_connection_validations(self, psycopg2):
self.fail('acquiring connection failed with unknown exception: {}'
.format(str(e)))
self.assertEqual(connection.type, 'postgres')

psycopg2.connect.assert_not_called()
connection.handle
psycopg2.connect.assert_called_once()

@mock.patch('dbt.adapters.postgres.connections.psycopg2')
def test_acquire_connection(self, psycopg2):
connection = self.adapter.acquire_connection('dummy')

psycopg2.connect.assert_not_called()
connection.handle
self.assertEqual(connection.state, 'open')
self.assertNotEqual(connection.handle, None)
psycopg2.connect.assert_called_once()
Expand Down Expand Up @@ -106,8 +101,6 @@ def test_cancel_open_connections_single(self):
def test_default_keepalive(self, psycopg2):
connection = self.adapter.acquire_connection('dummy')

psycopg2.connect.assert_not_called()
connection.handle
psycopg2.connect.assert_called_once_with(
dbname='postgres',
user='root',
Expand All @@ -121,8 +114,6 @@ def test_changed_keepalive(self, psycopg2):
self.config.credentials = self.config.credentials.replace(keepalives_idle=256)
connection = self.adapter.acquire_connection('dummy')

psycopg2.connect.assert_not_called()
connection.handle
psycopg2.connect.assert_called_once_with(
dbname='postgres',
user='root',
Expand All @@ -137,8 +128,6 @@ def test_search_path(self, psycopg2):
self.config.credentials = self.config.credentials.replace(search_path="test")
connection = self.adapter.acquire_connection('dummy')

psycopg2.connect.assert_not_called()
connection.handle
psycopg2.connect.assert_called_once_with(
dbname='postgres',
user='root',
Expand All @@ -153,8 +142,6 @@ def test_schema_with_space(self, psycopg2):
self.config.credentials = self.config.credentials.replace(search_path="test test")
connection = self.adapter.acquire_connection('dummy')

psycopg2.connect.assert_not_called()
connection.handle
psycopg2.connect.assert_called_once_with(
dbname='postgres',
user='root',
Expand All @@ -169,8 +156,6 @@ def test_set_zero_keepalive(self, psycopg2):
self.config.credentials = self.config.credentials.replace(keepalives_idle=0)
connection = self.adapter.acquire_connection('dummy')

psycopg2.connect.assert_not_called()
connection.handle
psycopg2.connect.assert_called_once_with(
dbname='postgres',
user='root',
Expand Down
10 changes: 0 additions & 10 deletions test/unit/test_redshift_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,6 @@ def test_cancel_open_connections_single(self):
def test_default_keepalive(self, psycopg2):
connection = self.adapter.acquire_connection('dummy')

psycopg2.connect.assert_not_called()
connection.handle
psycopg2.connect.assert_called_once_with(
dbname='redshift',
user='root',
Expand All @@ -170,8 +168,6 @@ def test_changed_keepalive(self, psycopg2):
self.config.credentials = self.config.credentials.replace(keepalives_idle=256)
connection = self.adapter.acquire_connection('dummy')

psycopg2.connect.assert_not_called()
connection.handle
psycopg2.connect.assert_called_once_with(
dbname='redshift',
user='root',
Expand All @@ -186,8 +182,6 @@ def test_search_path(self, psycopg2):
self.config.credentials = self.config.credentials.replace(search_path="test")
connection = self.adapter.acquire_connection('dummy')

psycopg2.connect.assert_not_called()
connection.handle
psycopg2.connect.assert_called_once_with(
dbname='redshift',
user='root',
Expand All @@ -203,8 +197,6 @@ def test_search_path_with_space(self, psycopg2):
self.config.credentials = self.config.credentials.replace(search_path="test test")
connection = self.adapter.acquire_connection('dummy')

psycopg2.connect.assert_not_called()
connection.handle
psycopg2.connect.assert_called_once_with(
dbname='redshift',
user='root',
Expand All @@ -220,8 +212,6 @@ def test_set_zero_keepalive(self, psycopg2):
self.config.credentials = self.config.credentials.replace(keepalives_idle=0)
connection = self.adapter.acquire_connection('dummy')

psycopg2.connect.assert_not_called()
connection.handle
psycopg2.connect.assert_called_once_with(
dbname='redshift',
user='root',
Expand Down
25 changes: 6 additions & 19 deletions test/unit/test_snowflake_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,10 +230,7 @@ def test_cancel_open_connections_single(self):
add_query.assert_called_once_with('select system$abort_session(42)')

def test_client_session_keep_alive_false_by_default(self):
conn = self.adapter.connections.set_connection_name(name='new_connection_with_new_config')

self.snowflake.assert_not_called()
conn.handle
self.adapter.connections.set_connection_name(name='new_connection_with_new_config')
self.snowflake.assert_has_calls([
mock.call(
account='test_account', autocommit=False,
Expand All @@ -246,10 +243,8 @@ def test_client_session_keep_alive_true(self):
self.config.credentials = self.config.credentials.replace(
client_session_keep_alive=True)
self.adapter = SnowflakeAdapter(self.config)
conn = self.adapter.connections.set_connection_name(name='new_connection_with_new_config')
self.adapter.connections.set_connection_name(name='new_connection_with_new_config')

self.snowflake.assert_not_called()
conn.handle
self.snowflake.assert_has_calls([
mock.call(
account='test_account', autocommit=False,
Expand All @@ -263,10 +258,8 @@ def test_user_pass_authentication(self):
password='test_password',
)
self.adapter = SnowflakeAdapter(self.config)
conn = self.adapter.connections.set_connection_name(name='new_connection_with_new_config')
self.adapter.connections.set_connection_name(name='new_connection_with_new_config')

self.snowflake.assert_not_called()
conn.handle
self.snowflake.assert_has_calls([
mock.call(
account='test_account', autocommit=False,
Expand All @@ -282,10 +275,8 @@ def test_authenticator_user_pass_authentication(self):
authenticator='test_sso_url',
)
self.adapter = SnowflakeAdapter(self.config)
conn = self.adapter.connections.set_connection_name(name='new_connection_with_new_config')
self.adapter.connections.set_connection_name(name='new_connection_with_new_config')

self.snowflake.assert_not_called()
conn.handle
self.snowflake.assert_has_calls([
mock.call(
account='test_account', autocommit=False,
Expand All @@ -301,10 +292,8 @@ def test_authenticator_externalbrowser_authentication(self):
authenticator='externalbrowser'
)
self.adapter = SnowflakeAdapter(self.config)
conn = self.adapter.connections.set_connection_name(name='new_connection_with_new_config')
self.adapter.connections.set_connection_name(name='new_connection_with_new_config')

self.snowflake.assert_not_called()
conn.handle
self.snowflake.assert_has_calls([
mock.call(
account='test_account', autocommit=False,
Expand All @@ -322,10 +311,8 @@ def test_authenticator_private_key_authentication(self, mock_get_private_key):
)

self.adapter = SnowflakeAdapter(self.config)
conn = self.adapter.connections.set_connection_name(name='new_connection_with_new_config')
self.adapter.connections.set_connection_name(name='new_connection_with_new_config')

self.snowflake.assert_not_called()
conn.handle
self.snowflake.assert_has_calls([
mock.call(
account='test_account', autocommit=False,
Expand Down