From b78253b8aa1edd8ba1e821d6a0c74aa278ec087a Mon Sep 17 00:00:00 2001 From: Bogdan Kyryliuk Date: Wed, 10 Aug 2016 09:36:25 -0700 Subject: [PATCH] Refactor the query runner to enable async mode. --- bogdan.todo | 4 + caravel/config.py | 16 - caravel/extract_table_names.py | 60 ++++ .../versions/ad82a75afd82_add_query_model.py | 15 +- caravel/models.py | 28 +- caravel/tasks.py | 302 ++++++++++-------- caravel/utils.py | 9 +- caravel/views.py | 123 +++++-- setup.py | 3 + tests/celery_tests.py | 218 ++++++------- tests/core_tests.py | 10 +- 11 files changed, 496 insertions(+), 292 deletions(-) create mode 100644 bogdan.todo create mode 100644 caravel/extract_table_names.py diff --git a/bogdan.todo b/bogdan.todo new file mode 100644 index 0000000000000..552a259279fad --- /dev/null +++ b/bogdan.todo @@ -0,0 +1,4 @@ +1. [] implement the polling of the query results +2. [] implement the retrieving of the CTA results +3. [] implement parsing of the query to retrieve the table names + diff --git a/caravel/config.py b/caravel/config.py index f922b9db02aaa..79c87d0a9215a 100644 --- a/caravel/config.py +++ b/caravel/config.py @@ -179,22 +179,7 @@ # Set this API key to enable Mapbox visualizations MAPBOX_API_KEY = "" -# Maximum number of rows returned in the SQL editor -SQL_MAX_ROW = 1000 -# Default celery config is to use SQLA as a broker, in a production setting -# you'll want to use a proper broker as specified here: -# http://docs.celeryproject.org/en/latest/getting-started/brokers/index.html -""" -# Example: -class CeleryConfig(object): - BROKER_URL = 'sqla+sqlite:///celerydb.sqlite' - CELERY_IMPORTS = ('caravel.tasks', ) - CELERY_RESULT_BACKEND = 'db+sqlite:///celery_results.sqlite' - CELERY_ANNOTATIONS = {'tasks.add': {'rate_limit': '10/s'}} -CELERY_CONFIG = CeleryConfig -""" -CELERY_CONFIG = None try: from caravel_config import * # noqa @@ -203,4 +188,3 @@ class CeleryConfig(object): if not CACHE_DEFAULT_TIMEOUT: CACHE_DEFAULT_TIMEOUT = CACHE_CONFIG.get('CACHE_DEFAULT_TIMEOUT') - diff --git a/caravel/extract_table_names.py b/caravel/extract_table_names.py new file mode 100644 index 0000000000000..4bc57074290a0 --- /dev/null +++ b/caravel/extract_table_names.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (C) 2016 Andi Albrecht, albrecht.andi@gmail.com +# +# This example is part of python-sqlparse and is released under +# the BSD License: http://www.opensource.org/licenses/bsd-license.php +# +# This example illustrates how to extract table names from nested +# SELECT statements. +# +# See: +# http://groups.google.com/group/sqlparse/browse_thread/thread/b0bd9a022e9d4895 + +import sqlparse +from sqlparse.sql import IdentifierList, Identifier +from sqlparse.tokens import Keyword, DML + + +def is_subselect(parsed): + if not parsed.is_group(): + return False + for item in parsed.tokens: + if item.ttype is DML and item.value.upper() == 'SELECT': + return True + return False + + +def extract_from_part(parsed): + from_seen = False + for item in parsed.tokens: + if from_seen: + if is_subselect(item): + for x in extract_from_part(item): + yield x + elif item.ttype is Keyword: + raise StopIteration + else: + yield item + elif item.ttype is Keyword and item.value.upper() == 'FROM': + from_seen = True + + +def extract_table_identifiers(token_stream): + for item in token_stream: + if isinstance(item, IdentifierList): + for identifier in item.get_identifiers(): + yield identifier.get_name() + elif isinstance(item, Identifier): + yield item.get_name() + # It's a bug to check for Keyword here, but in the example + # above some tables names are identified as keywords... + elif item.ttype is Keyword: + yield item.value + + +# TODO(bkyryliuk): add logic to support joins and unions. +def extract_tables(sql): + stream = extract_from_part(sqlparse.parse(sql)[0]) + return list(extract_table_identifiers(stream)) diff --git a/caravel/migrations/versions/ad82a75afd82_add_query_model.py b/caravel/migrations/versions/ad82a75afd82_add_query_model.py index 4794f416de07f..4a53c4309e077 100644 --- a/caravel/migrations/versions/ad82a75afd82_add_query_model.py +++ b/caravel/migrations/versions/ad82a75afd82_add_query_model.py @@ -13,17 +13,26 @@ from alembic import op import sqlalchemy as sa + def upgrade(): op.create_table('query', sa.Column('id', sa.Integer(), nullable=False), sa.Column('database_id', sa.Integer(), nullable=False), - sa.Column('tmp_table_name', sa.String(length=64), nullable=True), + sa.Column('tmp_table_name', sa.String(length=256), nullable=True), + sa.Column('tab_name', sa.String(length=256),nullable=True), sa.Column('user_id', sa.Integer(), nullable=True), sa.Column('status', sa.String(length=16), nullable=True), - sa.Column('name', sa.String(length=64), nullable=True), - sa.Column('sql', sa.Text, nullable=True), + sa.Column('name', sa.String(length=256), nullable=True), + sa.Column('schema', sa.String(length=256), nullable=True), + sa.Column('sql', sa.Text(), nullable=True), + sa.Column('select_sql', sa.Text(), nullable=True), + sa.Column('executed_sql', sa.Text(), nullable=True), sa.Column('limit', sa.Integer(), nullable=True), + sa.Column('limit_used', sa.Boolean(), nullable=True), + sa.Column('select_as_cta', sa.Boolean(), nullable=True), + sa.Column('select_as_cta_used', sa.Boolean(), nullable=True), sa.Column('progress', sa.Integer(), nullable=True), + sa.Column('error_message', sa.Text(), nullable=True), sa.Column('start_time', sa.DateTime(), nullable=True), sa.Column('end_time', sa.DateTime(), nullable=True), sa.ForeignKeyConstraint(['database_id'], [u'dbs.id'], ), diff --git a/caravel/models.py b/caravel/models.py index 083f0b3c1ea07..feb2f6d13f011 100644 --- a/caravel/models.py +++ b/caravel/models.py @@ -379,7 +379,7 @@ class Database(Model, AuditMixinNullable): sqlalchemy_uri = Column(String(1024)) password = Column(EncryptedType(String(1024), config.get('SECRET_KEY'))) cache_timeout = Column(Integer) - select_as_create_table_as = Column(Boolean, default=True) + select_as_create_table_as = Column(Boolean, default=False) extra = Column(Text, default=textwrap.dedent("""\ { "metadata_params": {}, @@ -1711,6 +1711,16 @@ class FavStar(Model): class QueryStatus: + def from_presto_states(self, presto_status): + if presto_status.lower() == 'running': + return QueryStatus.IN_PROGRESS + if presto_status.lower() == 'running': + return QueryStatus.IN_PROGRESS + if presto_status.lower() == 'running': + return QueryStatus.IN_PROGRESS + if presto_status.lower() == 'running': + return QueryStatus.IN_PROGRESS + SCHEDULED = 'SCHEDULED' CANCELLED = 'CANCELLED' IN_PROGRESS = 'IN_PROGRESS' @@ -1729,18 +1739,28 @@ class Query(Model): database_id = Column(Integer, ForeignKey('dbs.id'), nullable=False) # Store the tmp table into the DB only if the user asks for it. - tmp_table_name = Column(String(64)) + tmp_table_name = Column(String(256)) user_id = Column(Integer, ForeignKey('ab_user.id'), nullable=True) # models.QueryStatus status = Column(String(16)) - name = Column(String(64)) + name = Column(String(256)) + tab_name = Column(String(256)) + schema = Column(String(256)) sql = Column(Text) - # Could be configured in the caravel config + # Query to retrieve the results, + # used only in case of select_as_cta_used is true. + select_sql = Column(Text) + executed_sql = Column(Text) + # Could be configured in the caravel config. limit = Column(Integer) + limit_used = Column(Boolean) + select_as_cta = Column(Boolean) + select_as_cta_used = Column(Boolean) # 1..100 progress = Column(Integer) + error_message = Column(Text) start_time = Column(DateTime) end_time = Column(DateTime) diff --git a/caravel/tasks.py b/caravel/tasks.py index c48e66997456a..38f4b0edd8a72 100644 --- a/caravel/tasks.py +++ b/caravel/tasks.py @@ -1,7 +1,7 @@ import celery from caravel import models, app, utils from datetime import datetime -import logging + from sqlalchemy import create_engine, select, text from sqlalchemy.orm import scoped_session, sessionmaker from sqlalchemy.sql.expression import TextAsFrom @@ -11,6 +11,173 @@ celery_app = celery.Celery(config_source=app.config.get('CELERY_CONFIG')) +@celery_app.task +def get_sql_results(query_id): + """Executes the sql query returns the results.""" + # Create a separate session, reusing the db.session leads to the + # concurrency issues. + session = get_session() + query = session.query(models.Query).filter_by(id=query_id).first() + result = None + try: + db_to_query = ( + session.query(models.Database).filter_by(id=query.database_id) + .first() + ) + except Exception as e: + result = fail_query(query, utils.error_msg_from_exception(e)) + + if not db_to_query: + result = fail_query(query, "Database with id {0} is missing.".format( + query.database_id)) + + if not result: + result = get_sql_results_as_dict(db_to_query, query, session) + query.end_time = datetime.now() + session.flush() + return result + + +# TODO(bkyryliuk): dump results somewhere for the webserver. +def get_sql_results_as_dict(db_to_query, query, orm_session): + """Get the SQL query results from the give session and db connection.""" + engine = db_to_query.get_sqla_engine(schema=query.schema) + query.executed_sql = query.sql.strip().strip(';') + + # Limit enforced only for retrieving the data, not for the CTA queries. + query.select_as_cta_used = False + query.limit_used = False + if is_query_select(query.executed_sql): + if query.select_as_cta: + if not query.tmp_table_name: + query.tmp_table_name = 'tmp_{}_table_{}'.format( + query.user_id, + query.start_time.strftime('%Y_%m_%d_%H_%M_%S')) + query.executed_sql = create_table_as( + query.executed_sql, query.tmp_table_name) + query.select_as_cta_used = True + elif query.limit: + query.executed_sql = add_limit_to_the_sql( + query.executed_sql, query.limit, engine) + query.limit_used = True + + # TODO(bkyryliuk): ensure that tmp table was created. + # Do not set tmp table name if table wasn't created. + if not query.select_as_cta_used: + query.tmp_table_name = None + + backend = engine.url.get_backend_name() + if backend in ('presto', 'hive'): + result = get_sql_results_async(engine, query, orm_session) + else: + result = get_sql_results_sync(engine, query) + + orm_session.flush() + return result + + +def get_sql_results_async(engine, query, orm_session): + try: + result_proxy = engine.execute(query.executed_sql, schema=query.schema) + except Exception as e: + return fail_query(query, utils.error_msg_from_exception(e)) + + cursor = result_proxy.cursor + query_stats = cursor.poll() + query.status = models.QueryStatus.IN_PROGRESS + orm_session.flush() + # poll returns dict -- JSON status information or ``None`` + # if the query is done + # https://github.com/dropbox/PyHive/blob/ + # b34bdbf51378b3979eaf5eca9e956f06ddc36ca0/pyhive/presto.py#L178 + while query_stats: + # Update the object and wait for the kill signal. + orm_session.refresh(query) + completed_splits = int(query_stats['stats']['completedSplits']) + total_splits = int(query_stats['stats']['totalSplits']) + progress = 100 * completed_splits / total_splits + if progress > query.progress: + query.progress = progress + + orm_session.flush() + query_stats = cursor.poll() + # TODO(b.kyryliuk): check for the kill signal. + + if query.select_as_cta_used: + select_star = ( + select('*').select_from(query.tmp_table_name). + limit(query.limit) + ) + # SQL code to preview the results + query.select_sql = str(select_star.compile( + engine, compile_kwargs={"literal_binds": True})) + try: + # override cursor value to reuse the data extraction down below. + result_proxy = engine.execute( + query.select_sql, schema=query.schema) + cursor = result_proxy.cursor + while cursor.poll(): + # TODO: wait till the data is fetched + pass + except Exception as e: + return fail_query(query, utils.error_msg_from_exception(e)) + + response = fetch_response_from_cursor(result_proxy, query) + query.status = models.QueryStatus.FINISHED + orm_session.flush() + return response + + +def get_sql_results_sync(engine, query): + # TODO(bkyryliuk): rewrite into eng.execute as queries different from + # select should be permitted too. + query.select_sql = query.sql + if query.select_as_cta_used: + try: + engine.execute(query.executed_sql, schema=query.schema) + except Exception as e: + return fail_query(query, utils.error_msg_from_exception(e)) + select_star = ( + select('*').select_from(query.tmp_table_name). + limit(query.limit) + ) + query.select_sql = str(select_star.compile( + engine, compile_kwargs={"literal_binds": True})) + try: + result_proxy = engine.execute( + query.select_sql, schema=query.schema) + except Exception as e: + return fail_query(query, utils.error_msg_from_exception(e)) + response = fetch_response_from_cursor(result_proxy, query) + query.status = models.QueryStatus.FINISHED + return response + + +def fail_query(query, message): + query.error_message = message + query.status = models.QueryStatus.FAILED + return { + 'error': query.error_message, + 'status': query.status, + } + + +# TODO(b.kyryliuk): find better way to pass the data. +def fetch_response_from_cursor(result_proxy, query): + cols = [col[0] for col in result_proxy.cursor.description] + data = result_proxy.fetchall() + print("DELETEME") + print(data) + df = pd.DataFrame(data, columns=cols) + df = df.fillna(0) + return { + 'query_id': query.id, + 'columns': [c for c in df.columns], + 'data': df.to_dict(orient='records'), + 'status': models.QueryStatus.FINISHED, + } + + def is_query_select(sql): try: return sqlparse.parse(sql)[0].get_type() == 'SELECT' @@ -35,13 +202,12 @@ def get_tables(): pass -def add_limit_to_the_query(sql, limit, eng): +def add_limit_to_the_sql(sql, limit, eng): # Treat as single sql statement in case of failure. - sql_statements = [sql] try: sql_statements = [s for s in sqlparse.split(sql) if s] except Exception as e: - logging.info( + app.logger.info( "Statement " + sql + "failed to be transformed to have the limit " "with the exception" + e.message) return sql @@ -56,6 +222,8 @@ def add_limit_to_the_query(sql, limit, eng): # create table works only for the single statement. +# TODO(bkyryliuk): enforce that all the columns have names. Presto requires it +# for the CTA operation. def create_table_as(sql, table_name, override=False): """Reformats the query into the create table as query. @@ -69,12 +237,11 @@ def create_table_as(sql, table_name, override=False): # TODO(bkyryliuk): drop table if allowed, check the namespace and # the permissions. # Treat as single sql statement in case of failure. - sql_statements = [sql] try: # Filter out empty statements. sql_statements = [s for s in sqlparse.split(sql) if s] except Exception as e: - logging.info( + app.logger.info( "Statement " + sql + "failed to be transformed as create table as " "with the exception" + e.message) return sql @@ -95,125 +262,4 @@ def get_session(): engine = create_engine( app.config.get('SQLALCHEMY_DATABASE_URI'), convert_unicode=True) return scoped_session(sessionmaker( - autocommit=False, autoflush=False, bind=engine)) - - -@celery_app.task -def get_sql_results(database_id, sql, user_id, tmp_table_name="", schema=None): - """Executes the sql query returns the results. - - :param database_id: integer - :param sql: string, query that will be executed - :param user_id: integer - :param tmp_table_name: name of the table for CTA - :param schema: string, name of the schema (used in presto) - :return: dataframe, query result - """ - # Create a separate session, reusing the db.session leads to the - # concurrency issues. - session = get_session() - try: - db_to_query = ( - session.query(models.Database).filter_by(id=database_id).first() - ) - except Exception as e: - return { - 'error': utils.error_msg_from_exception(e), - 'success': False, - } - if not db_to_query: - return { - 'error': "Database with id {0} is missing.".format(database_id), - 'success': False, - } - - # TODO(bkyryliuk): provide a way for the user to name the query. - # TODO(bkyryliuk): run explain query to derive the tables and fill in the - # table_ids - # TODO(bkyryliuk): check the user permissions - # TODO(bkyryliuk): store the tab name in the query model - limit = app.config.get('SQL_MAX_ROW', None) - start_time = datetime.now() - if not tmp_table_name: - tmp_table_name = 'tmp.{}_table_{}'.format(user_id, start_time) - query = models.Query( - user_id=user_id, - database_id=database_id, - limit=limit, - name='{}'.format(start_time), - sql=sql, - start_time=start_time, - tmp_table_name=tmp_table_name, - status=models.QueryStatus.IN_PROGRESS, - ) - session.add(query) - session.commit() - query_result = get_sql_results_as_dict( - db_to_query, sql, query.tmp_table_name, schema=schema) - query.end_time = datetime.now() - if query_result['success']: - query.status = models.QueryStatus.FINISHED - else: - query.status = models.QueryStatus.FAILED - session.commit() - # TODO(bkyryliuk): return the tmp table / query_id - return query_result - - -# TODO(bkyryliuk): merge the changes made in the carapal first -# before merging this PR. -def get_sql_results_as_dict(db_to_query, sql, tmp_table_name, schema=None): - """Get the SQL query results from the give session and db connection. - - :param sql: string, query that will be executed - :param db_to_query: models.Database to query, cannot be None - :param tmp_table_name: name of the table for CTA - :param schema: string, name of the schema (used in presto) - :return: (dataframe, boolean), results and the status - """ - eng = db_to_query.get_sqla_engine(schema=schema) - sql = sql.strip().strip(';') - # TODO(bkyryliuk): fix this case for multiple statements - if app.config.get('SQL_MAX_ROW'): - sql = add_limit_to_the_query( - sql, app.config.get("SQL_MAX_ROW"), eng) - - cta_used = False - if (app.config.get('SQL_SELECT_AS_CTA') and - db_to_query.select_as_create_table_as and is_query_select(sql)): - # TODO(bkyryliuk): figure out if the query is select query. - sql = create_table_as(sql, tmp_table_name) - cta_used = True - - if cta_used: - try: - eng.execute(sql) - return { - 'tmp_table': tmp_table_name, - 'success': True, - } - except Exception as e: - return { - 'error': utils.error_msg_from_exception(e), - 'success': False, - } - - # otherwise run regular SQL query. - # TODO(bkyryliuk): rewrite into eng.execute as queries different from - # select should be permitted too. - try: - df = db_to_query.get_df(sql, schema) - df = df.fillna(0) - return { - 'columns': [c for c in df.columns], - 'data': df.to_dict(orient='records'), - 'success': True, - } - - except Exception as e: - return { - 'error': utils.error_msg_from_exception(e), - 'success': False, - } - - + autocommit=True, autoflush=False, bind=engine)) diff --git a/caravel/utils.py b/caravel/utils.py index 9b784517c573f..668c80f493674 100644 --- a/caravel/utils.py +++ b/caravel/utils.py @@ -339,10 +339,17 @@ def error_msg_from_exception(e): Database have different ways to handle exception. This function attempts to make sense of the exception object and construct a human readable sentence. + + TODO(bkyryliuk): parse the Presto error message from the connection + created via create_engine. + engine = create_engine('presto://localhost:3506/silver') - + gives an e.message as the str(dict) + presto.connect("localhost", port=3506, catalog='silver') - as a dict. + The latter version is parsed correctly by this function. """ msg = '' if hasattr(e, 'message'): - if (type(e.message) is dict): + if type(e.message) is dict: msg = e.message.get('message') elif e.message: msg = "{}".format(e.message) diff --git a/caravel/views.py b/caravel/views.py index 66e3697f48502..3cfbf338bd594 100755 --- a/caravel/views.py +++ b/caravel/views.py @@ -430,11 +430,13 @@ class DatabaseAsync(DatabaseView): appbuilder.add_view_no_menu(DatabaseAsync) + class DatabaseTablesAsync(DatabaseView): list_columns = ['id', 'all_table_names', 'all_schema_names'] appbuilder.add_view_no_menu(DatabaseTablesAsync) + class TableModelView(CaravelModelView, DeleteMixin): # noqa datamodel = SQLAInterface(models.SqlaTable) list_columns = [ @@ -592,7 +594,8 @@ def add(self): url = "/druiddatasourcemodelview/list/" msg = _( "Click on a datasource link to create a Slice, " - "or click on a table link here " + "or click on a table link " + "here " "to create a Slice for a table" ) else: @@ -866,7 +869,8 @@ def explore(self, datasource_type, datasource_id): datasource_access = self.can_access( 'datasource_access', datasource.perm) if not (all_datasource_access or datasource_access): - flash(__("You don't seem to have access to this datasource"), "danger") + flash(__("You don't seem to have access to this datasource"), + "danger") return redirect(error_redirect) action = request.args.get('action') @@ -943,7 +947,8 @@ def save_or_overwrite_slice( del d['action'] del d['previous_viz_type'] - as_list = ('metrics', 'groupby', 'columns', 'all_columns', 'mapbox_label', 'order_by_cols') + as_list = ('metrics', 'groupby', 'columns', 'all_columns', + 'mapbox_label', 'order_by_cols') for k in d: v = d.get(k) if k in as_list and not isinstance(v, list): @@ -1054,7 +1059,8 @@ def activity_per_day(self): .group_by(Log.dt) .all() ) - payload = {str(time.mktime(dt.timetuple())): ccount for dt, ccount in qry if dt} + payload = {str(time.mktime(dt.timetuple())): + ccount for dt, ccount in qry if dt} return Response(json.dumps(payload), mimetype="application/json") @api @@ -1110,9 +1116,11 @@ def add_slices(self, dashboard_id): data = json.loads(request.form.get('data')) session = db.session() Slice = models.Slice # noqa - dash = session.query(models.Dashboard).filter_by(id=dashboard_id).first() + dash = ( + session.query(models.Dashboard).filter_by(id=dashboard_id).first()) check_ownership(dash, raise_if_false=True) - new_slices = session.query(Slice).filter(Slice.id.in_(data['slice_ids'])) + new_slices = session.query(Slice).filter( + Slice.id.in_(data['slice_ids'])) dash.slices += new_slices session.merge(dash) session.commit() @@ -1146,13 +1154,18 @@ def favstar(self, class_name, obj_id, action): FavStar = models.FavStar # noqa count = 0 favs = session.query(FavStar).filter_by( - class_name=class_name, obj_id=obj_id, user_id=g.user.get_id()).all() + class_name=class_name, obj_id=obj_id, + user_id=g.user.get_id()).all() if action == 'select': if not favs: session.add( FavStar( - class_name=class_name, obj_id=obj_id, user_id=g.user.get_id(), - dttm=datetime.now())) + class_name=class_name, + obj_id=obj_id, + user_id=g.user.get_id(), + dttm=datetime.now() + ) + ) count = 1 elif action == 'unselect': for fav in favs: @@ -1358,9 +1371,22 @@ def sql_json(self): sql = request.form.get('sql') database_id = request.form.get('database_id') schema = request.form.get('schema') + tab_name = request.form.get('tab_name') + tmp_table_name = request.form.get('tmp_table_name') + select_as_cta = request.form.get('select_as_cta') == 'True' + session = db.session() mydb = session.query(models.Database).filter_by(id=database_id).first() + if not mydb: + return Response( + json.dumps({ + 'error': 'Database with id 0 is missing.', + 'status': models.QueryStatus.FAILED, + }), + status=500, + mimetype="application/json") + if not (self.can_access( 'all_datasource_access', 'all_datasource_access') or self.can_access('database_access', mydb.perm)): @@ -1368,19 +1394,78 @@ def sql_json(self): "SQL Lab requires the `all_datasource_access` or " "specific DB permission")) - data = tasks.get_sql_results(database_id, sql, g.user.get_id(), - schema=schema) - if 'error' in data: + # DB select_as_create_table_as forces all queries to be + # select_as_cta. + if select_as_cta or mydb.select_as_create_table_as: + select_as_cta = True + start_time = datetime.now() + query_name = '{}_{}_{}'.format( + g.user.get_id(), tab_name, start_time.strftime('%M:%S:%f')) + + query = models.Query( + database_id=database_id, + limit=app.config.get('SQL_MAX_ROW', None), + name=query_name, + sql=sql, + schema=schema, + # TODO(bkyryliuk): consider it being DB property. + select_as_cta=select_as_cta, + start_time=start_time, + status=models.QueryStatus.SCHEDULED, + tab_name=tab_name, + tmp_table_name=tmp_table_name, + user_id=g.user.get_id(), + ) + session.add(query) + session.commit() + + data = tasks.get_sql_results(query.id) + if data['status'] == models.QueryStatus.FAILED: return Response( - json.dumps(data), + json.dumps( + data, default=utils.json_int_dttm_ser, allow_nan=False), status=500, mimetype="application/json") - if 'tmp_table' in data: - # TODO(bkyryliuk): add query id to the response and implement the - # endpoint to poll the status and results. - return None - return json.dumps( - data, default=utils.json_int_dttm_ser, allow_nan=False) + print("DELETEME") + print(data) + return Response( + json.dumps( + data, default=utils.json_int_dttm_ser, allow_nan=False), + status=200, + mimetype="application/json") + + @has_access + @expose("/query_progress/", methods=['GET']) + @log_this + def query_progress(self): + """Runs arbitrary sql and returns and json""" + query_id = request.form.get('query_id') + s = db.session() + query = s.query(models.Query).filter_by(id=query_id).first() + mydb = s.query(models.Database).filter_by(id=query.database_id).first() + + if not (self.can_access( + 'all_datasource_access', 'all_datasource_access') or + self.can_access('database_access', mydb.perm)): + raise utils.CaravelSecurityException(_( + "SQL Lab requires the `all_datasource_access` or " + "specific DB permission")) + + if query: + return Response( + json.dumps({ + 'status': query.status, + 'progress': query.progress + }), + status=200, + mimetype="application/json") + + return Response( + json.dumps({ + 'error': "Query with id {} wasn't found".format(query_id), + }), + status=404, + mimetype="application/json") @has_access @expose("/refresh_datasources/") diff --git a/setup.py b/setup.py index ceb266d9fef18..07ae7b1750173 100644 --- a/setup.py +++ b/setup.py @@ -30,6 +30,7 @@ 'pandas==0.18.1', 'parsedatetime==2.0.0', 'pydruid==0.3.0', + 'PyHive>=0.2.1', 'python-dateutil==2.5.3', 'requests==2.10.0', 'simplejson==3.8.2', @@ -37,6 +38,8 @@ 'sqlalchemy==1.0.13', 'sqlalchemy-utils==0.32.7', 'sqlparse==0.1.19', + 'thrift>=0.9.3', + 'thrift-sasl>=0.2.1', 'werkzeug==0.11.10', ], extras_require={ diff --git a/tests/celery_tests.py b/tests/celery_tests.py index e88ae0fca1c5b..89208bb458bc3 100644 --- a/tests/celery_tests.py +++ b/tests/celery_tests.py @@ -1,10 +1,9 @@ """Unit tests for Caravel Celery worker""" -import datetime import imp +import json import subprocess import os import pandas as pd -import time import unittest import caravel @@ -116,36 +115,39 @@ class CeleryTestCase(unittest.TestCase): def __init__(self, *args, **kwargs): super(CeleryTestCase, self).__init__(*args, **kwargs) self.client = app.test_client() - utils.init(caravel) - admin = appbuilder.sm.find_user('admin') - if not admin: - appbuilder.sm.add_user( - 'admin', 'admin', ' user', 'admin@fab.org', - appbuilder.sm.find_role('Admin'), - password='general') - utils.init(caravel) @classmethod def setUpClass(cls): try: os.remove(app.config.get('SQL_CELERY_DB_FILE_PATH')) - except OSError: - pass + except OSError as e: + app.logger.warn(str(e)) try: os.remove(app.config.get('SQL_CELERY_RESULTS_DB_FILE_PATH')) - except OSError: - pass + except OSError as e: + app.logger.warn(str(e)) + + utils.init(caravel) + admin = appbuilder.sm.find_user('admin') + if not admin: + appbuilder.sm.add_user( + 'admin', 'admin', ' user', 'admin@fab.org', + appbuilder.sm.find_role('Admin'), + password='general') + cli.load_examples(load_test_data=True) worker_command = BASE_DIR + '/bin/caravel worker' subprocess.Popen( worker_command, shell=True, stdout=subprocess.PIPE) - cli.load_examples(load_test_data=True) @classmethod def tearDownClass(cls): + main_db = db.session.query(models.Database).filter_by( + database_name="main").first() + main_db.get_sqla_engine().execute("DELETE FROM query;") + subprocess.call( - "ps auxww | grep 'celeryd' | awk '{print $2}' | " - "xargs kill -9", + "ps auxww | grep 'celeryd' | awk '{print $2}' | xargs kill -9", shell=True ) subprocess.call( @@ -160,6 +162,30 @@ def setUp(self): def tearDown(self): pass + def login(self, username='admin', password='general'): + resp = self.client.post( + '/login/', + data=dict(username=username, password=password), + follow_redirects=True) + assert 'Welcome' in resp.data.decode('utf-8') + + def logout(self): + self.client.get('/logout/', follow_redirects=True) + + def run_sql(self, dbid, sql, select_as_cta='False', tmp_table_name='tmp'): + self.login() + resp = self.client.post( + '/caravel/sql_json/', + data=dict( + database_id=dbid, + sql=sql, + select_as_cta=select_as_cta, + tmp_table_name=tmp_table_name, + ), + ) + self.logout() + return json.loads(resp.data.decode('utf-8')) + def test_add_limit_to_the_query(self): query_session = tasks.get_session() db_to_query = query_session.query(models.Database).filter_by( @@ -167,7 +193,7 @@ def test_add_limit_to_the_query(self): eng = db_to_query.get_sqla_engine() select_query = "SELECT * FROM outer_space;" - updated_select_query = tasks.add_limit_to_the_query( + updated_select_query = tasks.add_limit_to_the_sql( select_query, 100, eng) # Different DB engines have their own spacing while compiling # the queries, that's why ' '.join(query.split()) is used. @@ -178,7 +204,7 @@ def test_add_limit_to_the_query(self): ) select_query_no_semicolon = "SELECT * FROM outer_space" - updated_select_query_no_semicolon = tasks.add_limit_to_the_query( + updated_select_query_no_semicolon = tasks.add_limit_to_the_sql( select_query_no_semicolon, 100, eng) self.assertTrue( "SELECT * FROM (SELECT * FROM outer_space) AS inner_qry " @@ -187,19 +213,19 @@ def test_add_limit_to_the_query(self): ) incorrect_query = "SMTH WRONG SELECT * FROM outer_space" - updated_incorrect_query = tasks.add_limit_to_the_query( + updated_incorrect_query = tasks.add_limit_to_the_sql( incorrect_query, 100, eng) self.assertEqual(incorrect_query, updated_incorrect_query) insert_query = "INSERT INTO stomach VALUES (beer, chips);" - updated_insert_query = tasks.add_limit_to_the_query( + updated_insert_query = tasks.add_limit_to_the_sql( insert_query, 100, eng) self.assertEqual(insert_query, updated_insert_query) multi_line_query = ( "SELECT * FROM planets WHERE\n Luke_Father = 'Darth Vader';" ) - updated_multi_line_query = tasks.add_limit_to_the_query( + updated_multi_line_query = tasks.add_limit_to_the_sql( multi_line_query, 100, eng) self.assertTrue( "SELECT * FROM (SELECT * FROM planets WHERE " @@ -208,13 +234,13 @@ def test_add_limit_to_the_query(self): ) delete_query = "DELETE FROM planet WHERE name = 'Earth'" - updated_delete_query = tasks.add_limit_to_the_query( + updated_delete_query = tasks.add_limit_to_the_sql( delete_query, 100, eng) self.assertEqual(delete_query, updated_delete_query) create_table_as = ( "CREATE TABLE pleasure AS SELECT chocolate FROM lindt_store;\n") - updated_create_table_as = tasks.add_limit_to_the_query( + updated_create_table_as = tasks.add_limit_to_the_sql( create_table_as, 100, eng) self.assertEqual(create_table_as, updated_create_table_as) @@ -231,7 +257,7 @@ def test_add_limit_to_the_query(self): "(B.TECH ,BE ,Degree ,MCA ,MiBA)\n " "AND Having Brothers= Null AND Sisters = Null" ) - updated_sql_procedure = tasks.add_limit_to_the_query( + updated_sql_procedure = tasks.add_limit_to_the_sql( sql_procedure, 100, eng) self.assertEqual(sql_procedure, updated_sql_procedure) @@ -242,11 +268,15 @@ def test_run_async_query_delay_get(self): # Case 1. # DB #0 doesn't exist. - result1 = tasks.get_sql_results.delay( - 0, 'SELECT * FROM dontexist', 1, tmp_table_name='tmp_1_1').get() + result1 = self.run_sql( + 0, + 'SELECT * FROM dontexist', + tmp_table_name='tmp_table_1_a', + select_as_cta='True', + ) expected_result1 = { 'error': 'Database with id 0 is missing.', - 'success': False + 'status': models.QueryStatus.FAILED, } self.assertEqual( sorted(expected_result1.items()), @@ -255,18 +285,17 @@ def test_run_async_query_delay_get(self): session1 = db.create_scoped_session() query1 = session1.query(models.Query).filter_by( sql='SELECT * FROM dontexist').first() - session1.close() self.assertIsNone(query1) + session1.close() # Case 2. - session2 = db.create_scoped_session() - query2 = session2.query(models.Query).filter_by( - sql='SELECT * FROM dontexist1').first() - self.assertEqual(models.QueryStatus.FAILED, query2.status) - session2.close() - - result2 = tasks.get_sql_results.delay( - 1, 'SELECT * FROM dontexist1', 1, tmp_table_name='tmp_2_1').get() + # Table doesn't exist. + result2 = self.run_sql( + 1, + 'SELECT * FROM dontexist1', + tmp_table_name='tmp_table_2_a', + select_as_cta='True', + ) self.assertTrue('error' in result2) session2 = db.create_scoped_session() query2 = session2.query(models.Query).filter_by( @@ -275,13 +304,20 @@ def test_run_async_query_delay_get(self): session2.close() # Case 3. + # Table and DB exists, CTA call to the backend. where_query = ( "SELECT name FROM ab_permission WHERE name='can_select_star'") - result3 = tasks.get_sql_results.delay( - 1, where_query, 1, tmp_table_name='tmp_3_1').get() + result3 = self.run_sql( + 1, + where_query, + tmp_table_name='tmp_table_3_a', + select_as_cta='True', + ) expected_result3 = { - 'tmp_table': 'tmp_3_1', - 'success': True + u'query_id': 2, + u'status': models.QueryStatus.FINISHED, + u'columns': [u'name'], + u'data': [{u'name': u'can_select_star'}], } self.assertEqual( sorted(expected_result3.items()), @@ -291,18 +327,24 @@ def test_run_async_query_delay_get(self): query3 = session3.query(models.Query).filter_by( sql=where_query).first() session3.close() - df3 = pd.read_sql_query(sql="SELECT * FROM tmp_3_1", con=eng) + df3 = pd.read_sql_query(sql="SELECT * FROM tmp_table_3_a", con=eng) data3 = df3.to_dict(orient='records') self.assertEqual(models.QueryStatus.FINISHED, query3.status) self.assertEqual([{'name': 'can_select_star'}], data3) # Case 4. - result4 = tasks.get_sql_results.delay( - 1, 'SELECT * FROM ab_permission WHERE id=666', 1, - tmp_table_name='tmp_4_1').get() + # Table and DB exists, CTA call to the backend, no data. + result4 = self.run_sql( + 1, + 'SELECT * FROM ab_permission WHERE id=666', + tmp_table_name='tmp_table_4_a', + select_as_cta='True', + ) expected_result4 = { - 'tmp_table': 'tmp_4_1', - 'success': True + u'query_id': 3, + u'status': models.QueryStatus.FINISHED, + u'columns': [u'id', u'name'], + u'data': [], } self.assertEqual( sorted(expected_result4.items()), @@ -312,88 +354,30 @@ def test_run_async_query_delay_get(self): query4 = session4.query(models.Query).filter_by( sql='SELECT * FROM ab_permission WHERE id=666').first() session4.close() - df4 = pd.read_sql_query(sql="SELECT * FROM tmp_4_1", con=eng) + df4 = pd.read_sql_query(sql="SELECT * FROM tmp_table_4_a", con=eng) data4 = df4.to_dict(orient='records') self.assertEqual(models.QueryStatus.FINISHED, query4.status) self.assertEqual([], data4) # Case 5. - # Return the data directly if DB select_as_create_table_as is False. - main_db.select_as_create_table_as = False - db.session.commit() - result5 = tasks.get_sql_results.delay( - 1, where_query, 1, tmp_table_name='tmp_5_1').get() + # Table and DB exists, select without CTA. + result5 = self.run_sql( + 1, + where_query, + tmp_table_name='tmp_table_5_a', + select_as_cta='False', + ) expected_result5 = { - 'columns': ['name'], - 'data': [{'name': 'can_select_star'}], - 'success': True + u'query_id': 4, + u'columns': [u'name'], + u'data': [{u'name': u'can_select_star'}], + u'status': models.QueryStatus.FINISHED, } self.assertEqual( sorted(expected_result5.items()), sorted(result5.items()) ) - def test_run_async_query_delay(self): - celery_task1 = tasks.get_sql_results.delay( - 0, 'SELECT * FROM dontexist', 1, tmp_table_name='tmp_1_2') - celery_task2 = tasks.get_sql_results.delay( - 1, 'SELECT * FROM dontexist1', 1, tmp_table_name='tmp_2_2') - where_query = ( - "SELECT name FROM ab_permission WHERE name='can_select_star'") - celery_task3 = tasks.get_sql_results.delay( - 1, where_query, 1, tmp_table_name='tmp_3_2') - celery_task4 = tasks.get_sql_results.delay( - 1, 'SELECT * FROM ab_permission WHERE id=666', 1, - tmp_table_name='tmp_4_2') - - time.sleep(1) - - # DB #0 doesn't exist. - expected_result1 = { - 'error': 'Database with id 0 is missing.', - 'success': False - } - self.assertEqual( - sorted(expected_result1.items()), - sorted(celery_task1.get().items()) - ) - session2 = db.create_scoped_session() - query2 = session2.query(models.Query).filter_by( - sql='SELECT * FROM dontexist1').first() - self.assertEqual(models.QueryStatus.FAILED, query2.status) - self.assertTrue('error' in celery_task2.get()) - expected_result3 = { - 'tmp_table': 'tmp_3_2', - 'success': True - } - self.assertEqual( - sorted(expected_result3.items()), - sorted(celery_task3.get().items()) - ) - expected_result4 = { - 'tmp_table': 'tmp_4_2', - 'success': True - } - self.assertEqual( - sorted(expected_result4.items()), - sorted(celery_task4.get().items()) - ) - - session = db.create_scoped_session() - query1 = session.query(models.Query).filter_by( - sql='SELECT * FROM dontexist').first() - self.assertIsNone(query1) - query2 = session.query(models.Query).filter_by( - sql='SELECT * FROM dontexist1').first() - self.assertEqual(models.QueryStatus.FAILED, query2.status) - query3 = session.query(models.Query).filter_by( - sql=where_query).first() - self.assertEqual(models.QueryStatus.FINISHED, query3.status) - query4 = session.query(models.Query).filter_by( - sql='SELECT * FROM ab_permission WHERE id=666').first() - self.assertEqual(models.QueryStatus.FINISHED, query4.status) - session.close() - if __name__ == '__main__': unittest.main() diff --git a/tests/core_tests.py b/tests/core_tests.py index 48d26c16e9606..87bce398b9e7c 100644 --- a/tests/core_tests.py +++ b/tests/core_tests.py @@ -321,7 +321,7 @@ def run_sql(self, sql, user_name): ) resp = self.client.post( '/caravel/sql_json/', - data=dict(database_id=dbid, sql=sql), + data=dict(database_id=dbid, sql=sql, select_as_create_as=False), ) self.logout() return json.loads(resp.data.decode('utf-8')) @@ -340,9 +340,9 @@ def test_sql_json_has_access(self): db.session.commit() main_db_permission_view = ( db.session.query(ab_models.PermissionView) - .join(ab_models.ViewMenu) - .filter(ab_models.ViewMenu.name == '[main].(id:1)') - .first() + .join(ab_models.ViewMenu) + .filter(ab_models.ViewMenu.name == '[main].(id:1)') + .first() ) astronaut = sm.add_role("Astronaut") sm.add_permission_role(astronaut, main_db_permission_view) @@ -361,6 +361,8 @@ def test_sql_json_has_access(self): def test_sql_json(self): data = self.run_sql("SELECT * FROM ab_user", 'admin') + print("self.run_sql") + print(str(data)) assert len(data['data']) > 0 data = self.run_sql("SELECT * FROM unexistant_table", 'admin')