Skip to content

Commit

Permalink
Add the last query id from presto when doing presto query onto cursor (
Browse files Browse the repository at this point in the history
…#313)

* Add the last query id from presto when doing presto query onto cursor

* Add some assertions to the integration tests for last_query_id

* Fix the testcase so it's actually doing the correct test
  • Loading branch information
mb-m authored Feb 27, 2020
1 parent 50c5ef7 commit 19fd140
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 4 deletions.
3 changes: 3 additions & 0 deletions pyhive/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ class will use the default requests behavior of making a new session per HTTP re
self._poll_interval = poll_interval
self._source = source
self._session_props = session_props if session_props is not None else {}
self.last_query_id = None

if protocol not in ('http', 'https'):
raise ValueError("Protocol must be http/https, was {!r}".format(protocol))
Expand Down Expand Up @@ -319,6 +320,8 @@ def _process_response(self, response):
assert self._state == self._STATE_RUNNING, "Should be running if processing response"
self._nextUri = response_json.get('nextUri')
self._columns = response_json.get('columns')
if 'id' in response_json:
self.last_query_id = response_json['id']
if 'X-Presto-Clear-Session' in response.headers:
propname = response.headers['X-Presto-Clear-Session']
self._session_props.pop(propname, None)
Expand Down
28 changes: 24 additions & 4 deletions pyhive/tests/test_presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def test_bad_protocol(self):
def test_description(self, cursor):
cursor.execute('SELECT 1 AS foobar FROM one_row')
self.assertEqual(cursor.description, [('foobar', 'integer', None, None, None, None, True)])
self.assertIsNotNone(cursor.last_query_id)

@with_cursor
def test_complex(self, cursor):
Expand Down Expand Up @@ -99,6 +100,7 @@ def test_cancel(self, cursor):
self.assertIn(cursor.poll()['stats']['state'], (
'STARTING', 'PLANNING', 'RUNNING', 'WAITING_FOR_RESOURCES', 'QUEUED'))
cursor.cancel()
self.assertIsNotNone(cursor.last_query_id)
self.assertIsNone(cursor.poll())

def test_noops(self):
Expand All @@ -110,6 +112,7 @@ def test_noops(self):
self.assertEqual(cursor.rowcount, -1)
cursor.setinputsizes([])
cursor.setoutputsize(1, 'blah')
self.assertIsNone(cursor.last_query_id)
connection.commit()

@mock.patch('requests.post')
Expand Down Expand Up @@ -137,23 +140,40 @@ def fail(*args, **kwargs):

@with_cursor
def test_set_session(self, cursor):
id = None
self.assertIsNone(cursor.last_query_id)
cursor.execute("SET SESSION query_max_run_time = '1234m'")
self.assertIsNotNone(cursor.last_query_id)
id = cursor.last_query_id
cursor.fetchall()
self.assertEqual(id, cursor.last_query_id)

cursor.execute('SHOW SESSION')
self.assertIsNotNone(cursor.last_query_id)
self.assertNotEqual(id, cursor.last_query_id)
id = cursor.last_query_id
rows = [r for r in cursor.fetchall() if r[0] == 'query_max_run_time']
assert len(rows) == 1
self.assertEqual(len(rows), 1)
session_prop = rows[0]
assert session_prop[1] == '1234m'
self.assertEqual(session_prop[1], '1234m')
self.assertEqual(id, cursor.last_query_id)

cursor.execute('RESET SESSION query_max_run_time')
self.assertIsNotNone(cursor.last_query_id)
self.assertNotEqual(id, cursor.last_query_id)
id = cursor.last_query_id
cursor.fetchall()
self.assertEqual(id, cursor.last_query_id)

cursor.execute('SHOW SESSION')
self.assertIsNotNone(cursor.last_query_id)
self.assertNotEqual(id, cursor.last_query_id)
id = cursor.last_query_id
rows = [r for r in cursor.fetchall() if r[0] == 'query_max_run_time']
assert len(rows) == 1
self.assertEqual(len(rows), 1)
session_prop = rows[0]
assert session_prop[1] != '1234m'
self.assertNotEqual(session_prop[1], '1234m')
self.assertEqual(id, cursor.last_query_id)

def test_set_session_in_constructor(self):
conn = presto.connect(
Expand Down

0 comments on commit 19fd140

Please sign in to comment.