Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding app context wrapper to Celery tasks #8653

Merged
merged 3 commits into from
Nov 27, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions superset/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,22 @@ def post_init(self) -> None:
def configure_celery(self) -> None:
celery_app.config_from_object(self.config["CELERY_CONFIG"])
celery_app.set_default()
flask_app = self.flask_app

# Here, we want to ensure that every call into Celery task has an app context
# setup properly
task_base = celery_app.Task

class AppContextTask(task_base): # type: ignore
# pylint: disable=too-few-public-methods
abstract = True

# Grab each call into the task and set up an app context
def __call__(self, *args, **kwargs):
with flask_app.app_context():
return task_base.__call__(self, *args, **kwargs)

celery_app.Task = AppContextTask

@staticmethod
def init_views() -> None:
Expand Down
37 changes: 18 additions & 19 deletions superset/sql_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
SQLLAB_TIMEOUT = config["SQLLAB_ASYNC_TIME_LIMIT_SEC"]
SQLLAB_HARD_TIMEOUT = SQLLAB_TIMEOUT + 60
log_query = config["QUERY_LOGGER"]
logger = logging.getLogger(__name__)


class SqlLabException(Exception):
Expand Down Expand Up @@ -84,9 +85,9 @@ def handle_query_error(msg, query, session, payload=None):

def get_query_backoff_handler(details):
query_id = details["kwargs"]["query_id"]
logging.error(f"Query with id `{query_id}` could not be retrieved")
logger.error(f"Query with id `{query_id}` could not be retrieved")
stats_logger.incr("error_attempting_orm_query_{}".format(details["tries"] - 1))
logging.error(f"Query {query_id}: Sleeping for a sec before retrying...")
logger.error(f"Query {query_id}: Sleeping for a sec before retrying...")


def get_query_giveup_handler(details):
Expand Down Expand Up @@ -128,7 +129,7 @@ def session_scope(nullpool):
session.commit()
except Exception as e:
session.rollback()
logging.exception(e)
logger.exception(e)
raise
finally:
session.close()
Expand Down Expand Up @@ -166,7 +167,7 @@ def get_sql_results(
expand_data=expand_data,
)
except Exception as e:
logging.exception(f"Query {query_id}: {e}")
logger.exception(f"Query {query_id}: {e}")
stats_logger.incr("error_sqllab_unhandled")
query = get_query(query_id, session)
return handle_query_error(str(e), query, session)
Expand Down Expand Up @@ -224,38 +225,38 @@ def execute_sql_statement(sql_statement, query, user_name, session, cursor):
query.executed_sql = sql
session.commit()
with stats_timing("sqllab.query.time_executing_query", stats_logger):
logging.info(f"Query {query_id}: Running query: \n{sql}")
logger.info(f"Query {query_id}: Running query: \n{sql}")
db_engine_spec.execute(cursor, sql, async_=True)
logging.info(f"Query {query_id}: Handling cursor")
logger.info(f"Query {query_id}: Handling cursor")
db_engine_spec.handle_cursor(cursor, query, session)

with stats_timing("sqllab.query.time_fetching_results", stats_logger):
logging.debug(
logger.debug(
"Query {}: Fetching data for query object: {}".format(
query_id, query.to_dict()
)
)
data = db_engine_spec.fetch_data(cursor, query.limit)

except SoftTimeLimitExceeded as e:
logging.exception(f"Query {query_id}: {e}")
logger.exception(f"Query {query_id}: {e}")
raise SqlLabTimeoutException(
"SQL Lab timeout. This environment's policy is to kill queries "
"after {} seconds.".format(SQLLAB_TIMEOUT)
)
except Exception as e:
logging.exception(f"Query {query_id}: {e}")
logger.exception(f"Query {query_id}: {e}")
raise SqlLabException(db_engine_spec.extract_error_message(e))

logging.debug(f"Query {query_id}: Fetching cursor description")
logger.debug(f"Query {query_id}: Fetching cursor description")
cursor_description = cursor.description
return SupersetDataFrame(data, cursor_description, db_engine_spec)


def _serialize_payload(
payload: dict, use_msgpack: Optional[bool] = False
) -> Union[bytes, str]:
logging.debug(f"Serializing to msgpack: {use_msgpack}")
logger.debug(f"Serializing to msgpack: {use_msgpack}")
if use_msgpack:
return msgpack.dumps(payload, default=json_iso_dttm_ser, use_bin_type=True)
else:
Expand Down Expand Up @@ -324,9 +325,9 @@ def execute_sql_statements(
# Breaking down into multiple statements
parsed_query = ParsedQuery(rendered_query)
statements = parsed_query.get_statements()
logging.info(f"Query {query_id}: Executing {len(statements)} statement(s)")
logger.info(f"Query {query_id}: Executing {len(statements)} statement(s)")

logging.info(f"Query {query_id}: Set query to 'running'")
logger.info(f"Query {query_id}: Set query to 'running'")
query.status = QueryStatus.RUNNING
query.start_running_time = now_as_float()
session.commit()
Expand All @@ -350,7 +351,7 @@ def execute_sql_statements(

# Run statement
msg = f"Running statement {i+1} out of {statement_count}"
logging.info(f"Query {query_id}: {msg}")
logger.info(f"Query {query_id}: {msg}")
query.set_extra_json_key("progress", msg)
session.commit()
try:
Expand Down Expand Up @@ -396,9 +397,7 @@ def execute_sql_statements(

if store_results and results_backend:
key = str(uuid.uuid4())
logging.info(
f"Query {query_id}: Storing results in results backend, key: {key}"
)
logger.info(f"Query {query_id}: Storing results in results backend, key: {key}")
with stats_timing("sqllab.query.results_backend_write", stats_logger):
with stats_timing(
"sqllab.query.results_backend_write_serialization", stats_logger
Expand All @@ -411,10 +410,10 @@ def execute_sql_statements(
cache_timeout = config["CACHE_DEFAULT_TIMEOUT"]

compressed = zlib_compress(serialized_payload)
logging.debug(
logger.debug(
f"*** serialized payload size: {getsizeof(serialized_payload)}"
)
logging.debug(f"*** compressed payload size: {getsizeof(compressed)}")
logger.debug(f"*** compressed payload size: {getsizeof(compressed)}")
results_backend.set(key, compressed, cache_timeout)
query.results_key = key

Expand Down
23 changes: 22 additions & 1 deletion tests/celery_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,14 @@
import unittest
import unittest.mock as mock

from tests.test_app import app # isort:skip
import flask
from flask import current_app

from tests.test_app import app
from superset import db, sql_lab
from superset.dataframe import SupersetDataFrame
from superset.db_engine_specs.base import BaseEngineSpec
from superset.extensions import celery_app
from superset.models.helpers import QueryStatus
from superset.models.sql_lab import Query
from superset.sql_parse import ParsedQuery
Expand Down Expand Up @@ -69,6 +73,23 @@ def test_create_table_as(self):
)


class AppContextTests(SupersetTestCase):
def test_in_app_context(self):
@celery_app.task()
def my_task():
self.assertTrue(current_app)

# Make sure we can call tasks with an app already setup
my_task()

# Make sure the app gets pushed onto the stack properly
try:
popped_app = flask._app_ctx_stack.pop()
my_task()
finally:
flask._app_ctx_stack.push(popped_app)


class CeleryTestCase(SupersetTestCase):
def get_query_by_name(self, sql):
session = db.session
Expand Down