Skip to content

Commit

Permalink
update test
Browse files Browse the repository at this point in the history
  • Loading branch information
Lily Kuang committed Apr 9, 2021
1 parent 6ac4742 commit 009c43a
Show file tree
Hide file tree
Showing 9 changed files with 124 additions and 25 deletions.
2 changes: 1 addition & 1 deletion superset/reports/commands/alert.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def _execute_query(self) -> pd.DataFrame:
)
return df
except SoftTimeLimitExceeded as ex:
logger.error("A timeout occurred while executing the alert query: %s", ex)
logger.warning("A timeout occurred while executing the alert query: %s", ex)
raise AlertQueryTimeout()
except Exception as ex:
raise AlertQueryError(message=str(ex))
Expand Down
2 changes: 1 addition & 1 deletion superset/reports/commands/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def _get_screenshot(self) -> bytes:
try:
image_data = screenshot.get_screenshot(user=user)
except SoftTimeLimitExceeded:
logger.error("A timeout occurred while taking a screenshot.")
logger.warning("A timeout occurred while taking a screenshot.")
raise ReportScheduleScreenshotTimeout()
except Exception as ex:
raise ReportScheduleScreenshotFailedError(
Expand Down
2 changes: 1 addition & 1 deletion superset/sql_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def get_sql_results( # pylint: disable=too-many-arguments
log_params=log_params,
)
except SoftTimeLimitExceeded as ex:
logger.error("Query %d: Time limit exceeded", query_id)
logger.warning("Query %d: Time limit exceeded", query_id)
logger.debug("Query %d: %s", query_id, ex)
raise SqlLabTimeoutException(
_(
Expand Down
8 changes: 6 additions & 2 deletions superset/tasks/async_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ def load_chart_data_into_cache(
job_metadata, async_query_manager.STATUS_DONE, result_url=result_url,
)
except SoftTimeLimitExceeded as exc:
logger.error("A timeout occurred while loading chart data, error: %s", exc)
logger.warning(
"A timeout occurred while loading chart data, error: %s", exc
)
raise exc
except Exception as exc:
# TODO: QueryContext should support SIP-40 style errors
Expand Down Expand Up @@ -109,7 +111,9 @@ def load_explore_json_into_cache( # pylint: disable=too-many-locals
job_metadata, async_query_manager.STATUS_DONE, result_url=result_url,
)
except SoftTimeLimitExceeded as ex:
logger.error("A timeout occurred while loading explore json, error: %s", ex)
logger.warning(
"A timeout occurred while loading explore json, error: %s", ex
)
raise ex
except Exception as exc:
if isinstance(exc, SupersetVizException):
Expand Down
2 changes: 1 addition & 1 deletion superset/tasks/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,6 @@ def prune_log() -> None:
try:
AsyncPruneReportScheduleLogCommand().run()
except SoftTimeLimitExceeded as ex:
logger.error("A timeout occurred while pruning report schedule logs: %s", ex)
logger.warning("A timeout occurred while pruning report schedule logs: %s", ex)
except CommandException as ex:
logger.error("An exception occurred while pruning report schedule logs: %s", ex)
28 changes: 22 additions & 6 deletions tests/reports/commands_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,13 @@
ReportScheduleNotFoundError,
ReportScheduleNotificationError,
ReportSchedulePreviousWorkingError,
ReportSchedulePruneLogError,
ReportScheduleScreenshotFailedError,
ReportScheduleScreenshotTimeout,
ReportScheduleWorkingTimeoutError,
)
from superset.reports.commands.execute import AsyncExecuteReportScheduleCommand
from superset.reports.commands.log_prune import AsyncPruneReportScheduleLogCommand
from superset.utils.core import get_example_database
from tests.fixtures.birth_names_dashboard import load_birth_names_dashboard_with_slices
from tests.fixtures.world_bank_dashboard import (
Expand Down Expand Up @@ -186,7 +188,7 @@ def create_test_table_context(database: Database):
database.get_sqla_engine().execute("DROP TABLE test_table")


@pytest.yield_fixture()
@pytest.fixture()
def create_report_email_chart():
with app.app_context():
chart = db.session.query(Slice).first()
Expand All @@ -198,7 +200,7 @@ def create_report_email_chart():
cleanup_report_schedule(report_schedule)


@pytest.yield_fixture()
@pytest.fixture()
def create_report_email_dashboard():
with app.app_context():
dashboard = db.session.query(Dashboard).first()
Expand All @@ -210,7 +212,7 @@ def create_report_email_dashboard():
cleanup_report_schedule(report_schedule)


@pytest.yield_fixture()
@pytest.fixture()
def create_report_slack_chart():
with app.app_context():
chart = db.session.query(Slice).first()
Expand All @@ -222,7 +224,7 @@ def create_report_slack_chart():
cleanup_report_schedule(report_schedule)


@pytest.yield_fixture()
@pytest.fixture()
def create_report_slack_chart_working():
with app.app_context():
chart = db.session.query(Slice).first()
Expand All @@ -248,7 +250,7 @@ def create_report_slack_chart_working():
cleanup_report_schedule(report_schedule)


@pytest.yield_fixture()
@pytest.fixture()
def create_alert_slack_chart_success():
with app.app_context():
chart = db.session.query(Slice).first()
Expand All @@ -274,7 +276,7 @@ def create_alert_slack_chart_success():
cleanup_report_schedule(report_schedule)


@pytest.yield_fixture()
@pytest.fixture()
def create_alert_slack_chart_grace():
with app.app_context():
chart = db.session.query(Slice).first()
Expand Down Expand Up @@ -1109,3 +1111,17 @@ def test_grace_period_error_flap(
assert (
get_notification_error_sent_count(create_invalid_sql_alert_email_chart) == 2
)


@pytest.mark.usefixtures(
"load_birth_names_dashboard_with_slices", "create_report_email_dashboard"
)
@patch("superset.reports.dao.ReportScheduleDAO.bulk_delete_logs")
def test_prune_log_soft_time_out(bulk_delete_logs, create_report_email_dashboard):
from celery.exceptions import SoftTimeLimitExceeded
from datetime import datetime, timedelta

bulk_delete_logs.side_effect = SoftTimeLimitExceeded()
with pytest.raises(SoftTimeLimitExceeded) as excinfo:
AsyncPruneReportScheduleLogCommand().run()
assert str(excinfo.value) == "SoftTimeLimitExceeded()"
30 changes: 29 additions & 1 deletion tests/sqllab_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@
from superset.models.core import Database
from superset.models.sql_lab import Query, SavedQuery
from superset.result_set import SupersetResultSet
from superset.sql_lab import execute_sql_statements, SqlLabException
from superset.sql_lab import (
execute_sql_statements,
get_sql_results,
SqlLabException,
SqlLabTimeoutException,
)
from superset.sql_parse import CtasMethod
from superset.utils.core import (
datetime_to_epoch,
Expand Down Expand Up @@ -793,3 +798,26 @@ def test_execute_sql_statements_ctas(
"sure your query has only a SELECT statement. Then, "
"try running your query again."
)

@mock.patch("superset.sql_lab.get_query")
@mock.patch("superset.sql_lab.execute_sql_statement")
def test_get_sql_results_soft_time_limit(
self, mock_execute_sql_statement, mock_get_query
):
from celery.exceptions import SoftTimeLimitExceeded

sql = """
-- comment
SET @value = 42;
SELECT @value AS foo;
-- comment
"""
mock_get_query.side_effect = SoftTimeLimitExceeded()
with pytest.raises(SqlLabTimeoutException) as excinfo:
get_sql_results(
1, sql, return_results=True, store_results=False,
)
assert (
str(excinfo.value)
== "SQL Lab timeout. This environment's policy is to kill queries after 21600 seconds."
)
51 changes: 51 additions & 0 deletions tests/tasks/async_queries_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from uuid import uuid4

import pytest
from celery.exceptions import SoftTimeLimitExceeded

from superset import db
from superset.charts.commands.data import ChartDataCommand
Expand Down Expand Up @@ -94,6 +95,31 @@ def test_load_chart_data_into_cache_error(self, mock_update_job, mock_run_comman
errors = [{"message": "Error: foo"}]
mock_update_job.assert_called_once_with(job_metadata, "error", errors=errors)

@mock.patch.object(ChartDataCommand, "run")
@mock.patch.object(async_query_manager, "update_job")
def test_soft_timeout_load_chart_data_into_cache(
self, mock_update_job, mock_run_command
):
async_query_manager.init_app(app)
user = security_manager.find_user("gamma")
form_data = {}
job_metadata = {
"channel_id": str(uuid4()),
"job_id": str(uuid4()),
"user_id": user.id,
"status": "pending",
"errors": [],
}
errors = ["A timeout occurred while loading chart data"]

with pytest.raises(SoftTimeLimitExceeded):
with mock.patch.object(
async_queries, "ensure_user_is_set",
) as ensure_user_is_set:
ensure_user_is_set.side_effect = SoftTimeLimitExceeded()
load_chart_data_into_cache(job_metadata, form_data)
ensure_user_is_set.assert_called_once_with(user.id, "error", errors=errors)

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@mock.patch.object(async_query_manager, "update_job")
def test_load_explore_json_into_cache(self, mock_update_job):
Expand Down Expand Up @@ -151,3 +177,28 @@ def test_load_explore_json_into_cache_error(self, mock_update_job):

errors = ["The dataset associated with this chart no longer exists"]
mock_update_job.assert_called_once_with(job_metadata, "error", errors=errors)

@mock.patch.object(ChartDataCommand, "run")
@mock.patch.object(async_query_manager, "update_job")
def test_soft_timeout_load_explore_json_into_cache(
self, mock_update_job, mock_run_command
):
async_query_manager.init_app(app)
user = security_manager.find_user("gamma")
form_data = {}
job_metadata = {
"channel_id": str(uuid4()),
"job_id": str(uuid4()),
"user_id": user.id,
"status": "pending",
"errors": [],
}
errors = ["A timeout occurred while loading explore json, error"]

with pytest.raises(SoftTimeLimitExceeded):
with mock.patch.object(
async_queries, "ensure_user_is_set",
) as ensure_user_is_set:
ensure_user_is_set.side_effect = SoftTimeLimitExceeded()
load_explore_json_into_cache(job_metadata, form_data)
ensure_user_is_set.assert_called_once_with(user.id, "error", errors=errors)
24 changes: 12 additions & 12 deletions tests/thumbnails_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def url_open_auth(self, username: str, url: str):
@skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature")
def test_get_async_dashboard_screenshot(self):
"""
Thumbnails: Simple get async dashboard screenshot
Thumbnails: Simple get async dashboard screenshot
"""
dashboard = db.session.query(Dashboard).all()[0]
with patch("superset.dashboards.api.DashboardRestApi.get") as mock_get:
Expand All @@ -65,7 +65,7 @@ class TestThumbnails(SupersetTestCase):

def test_dashboard_thumbnail_disabled(self):
"""
Thumbnails: Dashboard thumbnail disabled
Thumbnails: Dashboard thumbnail disabled
"""
if is_feature_enabled("THUMBNAILS"):
return
Expand All @@ -77,7 +77,7 @@ def test_dashboard_thumbnail_disabled(self):

def test_chart_thumbnail_disabled(self):
"""
Thumbnails: Chart thumbnail disabled
Thumbnails: Chart thumbnail disabled
"""
if is_feature_enabled("THUMBNAILS"):
return
Expand All @@ -90,7 +90,7 @@ def test_chart_thumbnail_disabled(self):
@skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature")
def test_get_async_dashboard_screenshot(self):
"""
Thumbnails: Simple get async dashboard screenshot
Thumbnails: Simple get async dashboard screenshot
"""
dashboard = db.session.query(Dashboard).all()[0]
self.login(username="admin")
Expand All @@ -105,7 +105,7 @@ def test_get_async_dashboard_screenshot(self):
@skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature")
def test_get_async_dashboard_notfound(self):
"""
Thumbnails: Simple get async dashboard not found
Thumbnails: Simple get async dashboard not found
"""
max_id = db.session.query(func.max(Dashboard.id)).scalar()
self.login(username="admin")
Expand All @@ -116,7 +116,7 @@ def test_get_async_dashboard_notfound(self):
@skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature")
def test_get_async_dashboard_not_allowed(self):
"""
Thumbnails: Simple get async dashboard not allowed
Thumbnails: Simple get async dashboard not allowed
"""
dashboard = db.session.query(Dashboard).all()[0]
self.login(username="gamma")
Expand All @@ -127,7 +127,7 @@ def test_get_async_dashboard_not_allowed(self):
@skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature")
def test_get_async_chart_screenshot(self):
"""
Thumbnails: Simple get async chart screenshot
Thumbnails: Simple get async chart screenshot
"""
chart = db.session.query(Slice).all()[0]
self.login(username="admin")
Expand All @@ -142,7 +142,7 @@ def test_get_async_chart_screenshot(self):
@skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature")
def test_get_async_chart_notfound(self):
"""
Thumbnails: Simple get async chart not found
Thumbnails: Simple get async chart not found
"""
max_id = db.session.query(func.max(Slice.id)).scalar()
self.login(username="admin")
Expand All @@ -153,7 +153,7 @@ def test_get_async_chart_notfound(self):
@skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature")
def test_get_cached_chart_wrong_digest(self):
"""
Thumbnails: Simple get chart with wrong digest
Thumbnails: Simple get chart with wrong digest
"""
chart = db.session.query(Slice).all()[0]
chart_url = get_url_path("Superset.slice", slice_id=chart.id, standalone="true")
Expand All @@ -169,7 +169,7 @@ def test_get_cached_chart_wrong_digest(self):
@skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature")
def test_get_cached_dashboard_screenshot(self):
"""
Thumbnails: Simple get cached dashboard screenshot
Thumbnails: Simple get cached dashboard screenshot
"""
dashboard = db.session.query(Dashboard).all()[0]
dashboard_url = get_url_path("Superset.dashboard", dashboard_id=dashboard.id)
Expand All @@ -185,7 +185,7 @@ def test_get_cached_dashboard_screenshot(self):
@skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature")
def test_get_cached_chart_screenshot(self):
"""
Thumbnails: Simple get cached chart screenshot
Thumbnails: Simple get cached chart screenshot
"""
chart = db.session.query(Slice).all()[0]
chart_url = get_url_path("Superset.slice", slice_id=chart.id, standalone="true")
Expand All @@ -201,7 +201,7 @@ def test_get_cached_chart_screenshot(self):
@skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature")
def test_get_cached_dashboard_wrong_digest(self):
"""
Thumbnails: Simple get dashboard with wrong digest
Thumbnails: Simple get dashboard with wrong digest
"""
dashboard = db.session.query(Dashboard).all()[0]
dashboard_url = get_url_path("Superset.dashboard", dashboard_id=dashboard.id)
Expand Down

0 comments on commit 009c43a

Please sign in to comment.