Skip to content

Commit

Permalink
Add more unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida committed Mar 31, 2022
1 parent 2006811 commit b83a64f
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 9 deletions.
4 changes: 2 additions & 2 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,7 +990,7 @@ def is_alias_used_in_orderby(col: ColumnElement) -> bool:
if is_alias_used_in_orderby(col):
col.name = f"{col.name}__"

def _get_sqla_row_level_filters(
def get_sqla_row_level_filters(
self, template_processor: BaseTemplateProcessor
) -> List[TextClause]:
"""
Expand Down Expand Up @@ -1394,7 +1394,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
_("Invalid filter operation type: %(op)s", op=op)
)
if is_feature_enabled("ROW_LEVEL_SECURITY"):
where_clause_and += self._get_sqla_row_level_filters(template_processor)
where_clause_and += self.get_sqla_row_level_filters(template_processor)
if extras:
where = extras.get("where")
if where:
Expand Down
8 changes: 4 additions & 4 deletions superset/sql_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def get_statements(self) -> List[str]:
return statements

@staticmethod
def _get_table(tlist: TokenList) -> Optional[Table]:
def get_table(tlist: TokenList) -> Optional[Table]:
"""
Return the table if valid, i.e., conforms to the [[catalog.]schema.]table
construct.
Expand Down Expand Up @@ -325,7 +325,7 @@ def _process_tokenlist(self, token_list: TokenList) -> None:
"""
# exclude subselects
if "(" not in str(token_list):
table = self._get_table(token_list)
table = self.get_table(token_list)
if table and not table.table.startswith(CTE_PREFIX):
self._tables.add(table)
return
Expand Down Expand Up @@ -555,7 +555,7 @@ def get_rls_for_table(
if not isinstance(candidate, Identifier):
candidate = Identifier([Token(Name, candidate.value)])

table = ParsedQuery._get_table(candidate) # pylint: disable=protected-access
table = ParsedQuery.get_table(candidate)
if not table:
return None

Expand All @@ -577,7 +577,7 @@ def get_rls_for_table(
# pylint: disable=protected-access
predicate = " AND ".join(
str(filter_)
for filter_ in dataset._get_sqla_row_level_filters(template_processor)
for filter_ in dataset.get_sqla_row_level_filters(template_processor)
)
if not predicate:
return None
Expand Down
36 changes: 33 additions & 3 deletions tests/unit_tests/sql_parse_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,22 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# pylint: disable=invalid-name, too-many-lines
# pylint: disable=invalid-name, redefined-outer-name, unused-argument, protected-access, too-many-lines

import unittest
from typing import Optional, Set

import pytest
import sqlparse
from pytest_mock import MockerFixture
from sqlparse.sql import Token, TokenList
from sqlalchemy import text
from sqlparse.sql import Identifier, Token, TokenList
from sqlparse.tokens import Name

from superset.exceptions import QueryClauseValidationException
from superset.sql_parse import (
add_table_name,
get_rls_for_table,
has_table_query,
insert_rls,
ParsedQuery,
Expand Down Expand Up @@ -1438,3 +1440,31 @@ def test_add_table_name(rls: str, table: str, expected: str) -> None:
condition = sqlparse.parse(rls)[0]
add_table_name(condition, table)
assert str(condition) == expected


def test_get_rls_for_table(mocker: MockerFixture, app_context: None) -> None:
"""
Tests for ``get_rls_for_table``.
"""
candidate = Identifier([Token(Name, "some_table")])
db = mocker.patch("superset.db")
dataset = db.session.query().filter().one_or_none()
dataset.__str__.return_value = "some_table"

dataset.get_sqla_row_level_filters.return_value = [text("organization_id = 1")]
assert (
str(get_rls_for_table(candidate, 1, "public"))
== "some_table.organization_id = 1"
)

dataset.get_sqla_row_level_filters.return_value = [
text("organization_id = 1"),
text("foo = 'bar'"),
]
assert (
str(get_rls_for_table(candidate, 1, "public"))
== "some_table.organization_id = 1 AND some_table.foo = 'bar'"
)

dataset.get_sqla_row_level_filters.return_value = []
assert get_rls_for_table(candidate, 1, "public") is None

0 comments on commit b83a64f

Please sign in to comment.