Skip to content

Commit

Permalink
Merge pull request taoyds#3 from ElementAI/dima-improve-parsing
Browse files Browse the repository at this point in the history
support "distinct(bla)" and "FROM foo, bar"
  • Loading branch information
rizar authored Jan 4, 2021
2 parents 4f9c20f + 2623570 commit a9b7211
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 28 deletions.
99 changes: 71 additions & 28 deletions process_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,20 +41,21 @@
)
JOIN_KEYWORDS = ("join", "on", "as")

WHERE_OPS = (
"not",
"between",
"=",
">",
"<",
">=",
"<=",
"!=",
"in",
"like",
"is",
"exists",
)
WHERE_OPS = {
"not": 0,
"between": 1,
"=": 2,
">": 3,
"<": 4,
">=": 5,
"<=": 6,
"!=": 7,
"in": 8,
"like": 9,
"is": 10,
"exists": 11,
"<>": 7
}
UNIT_OPS = ("none", "-", "+", "*", "/")
AGG_OPS = ("none", "max", "min", "count", "sum", "avg")
TABLE_TYPE = {
Expand All @@ -67,6 +68,19 @@
ORDER_OPS = ("desc", "asc")


class DerivedFieldAliasError(ValueError):
pass

class DerivedTableAliasError(ValueError):
pass

class ParenthesesInConditionError(ValueError):
pass

class ValueListError(ValueError):
pass


class Schema:
"""
Simple schema which maps table&column to a unique identifier
Expand Down Expand Up @@ -162,14 +176,14 @@ def tokenize(string):
if toks[i] in vals:
toks[i] = vals[toks[i]]

# find if there exists !=, >=, <=
eq_idxs = [idx for idx, tok in enumerate(toks) if tok == "="]
# find if there exists !=, >=, <=, <>`
eq_idxs = [idx for idx, tok in enumerate(toks) if tok in ("=", ">")]
eq_idxs.reverse()
prefix = ("!", ">", "<")
for eq_idx in eq_idxs:
pre_tok = toks[eq_idx - 1]
if pre_tok in prefix:
toks = toks[: eq_idx - 1] + [pre_tok + "="] + toks[eq_idx + 1 :]
toks = toks[: eq_idx - 1] + [pre_tok + toks[eq_idx]] + toks[eq_idx + 1 :]

return toks

Expand All @@ -195,14 +209,23 @@ def parse_col(toks, start_idx, tables_with_alias, schema, default_tables=None):
"""
:returns next idx, column id
"""
tok = toks[start_idx]
if tok == "*":
return start_idx + 1, schema.idMap[tok]
idx = start_idx
tok = toks[idx]
in_parentheses = False
col_id = None

if tok == '(':
in_parentheses = True
idx += 1
tok = toks[idx]

if tok in ['1', '*']:
col_id = schema.idMap['*']

if "." in tok: # if token is a composite
alias, col = tok.split(".")
key = tables_with_alias[alias] + "." + col
return start_idx + 1, schema.idMap[key]
col_id = schema.idMap[key]

assert (
default_tables is not None and len(default_tables) > 0
Expand All @@ -212,9 +235,20 @@ def parse_col(toks, start_idx, tables_with_alias, schema, default_tables=None):
table = tables_with_alias[alias]
if tok in schema.schema[table]:
key = table + "." + tok
return start_idx + 1, schema.idMap[key]
col_id = schema.idMap[key]

if col_id is None:
if tok == 'as':
raise DerivedFieldAliasError(toks[idx + 1])
else:
assert False, "Error col: {}".format(tok)

if in_parentheses:
assert toks[idx + 1] == ')'
return idx + 2, col_id
else:
return idx + 1, col_id

assert False, "Error col: {}".format(tok)


def parse_col_unit(toks, start_idx, tables_with_alias, schema, default_tables=None):
Expand Down Expand Up @@ -336,6 +370,8 @@ def parse_value(toks, start_idx, tables_with_alias, schema, default_tables=None)
idx = end_idx

if isBlock:
if toks[idx] == ',':
raise ValueListError()
assert toks[idx] == ")"
idx += 1

Expand All @@ -348,6 +384,9 @@ def parse_condition(toks, start_idx, tables_with_alias, schema, default_tables=N
conds = []

while idx < len_:
if toks[idx] == '(':
raise ParenthesesInConditionError()

idx, val_unit = parse_val_unit(
toks, idx, tables_with_alias, schema, default_tables
)
Expand All @@ -359,12 +398,11 @@ def parse_condition(toks, start_idx, tables_with_alias, schema, default_tables=N
assert (
idx < len_ and toks[idx] in WHERE_OPS
), "Error condition: idx: {}, tok: {}".format(idx, toks[idx])
op_id = WHERE_OPS.index(toks[idx])
op_id = WHERE_OPS[toks[idx]]
idx += 1
val1 = val2 = None
if op_id == WHERE_OPS.index(
"between"
): # between..and... special case: dual values
if op_id == WHERE_OPS['between']:
# between..and... special case: dual values
idx, val1 = parse_value(
toks, idx, tables_with_alias, schema, default_tables
)
Expand Down Expand Up @@ -440,12 +478,17 @@ def parse_from(toks, start_idx, tables_with_alias, schema):
isBlock = True
idx += 1

if toks[idx] == 'as':
raise DerivedTableAliasError()

if toks[idx] == "select":
idx, sql = parse_sql(toks, idx, tables_with_alias, schema)
table_units.append((TABLE_TYPE["sql"], sql))
else:
if idx < len_ and toks[idx] == "join":
if idx < len_ and toks[idx] in [",", "join"]:
idx += 1 # skip join
elif idx + 1 < len_ and toks[idx:idx + 2] == ["inner", "join"]:
idx += 2 # skip join
idx, table_unit, table_name = parse_table_unit(
toks, idx, tables_with_alias, schema
)
Expand Down
Binary file not shown.
Binary file added test/db.sqlite
Binary file not shown.
38 changes: 38 additions & 0 deletions test/test_parsing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from process_sql import get_schema, Schema, get_sql


def test_schema():
return Schema(get_schema('test/db.sqlite'))


def test_parse_col():
ground_truth = (False, [(3, (0, (0, '__papers.id__', True), None))])
assert get_sql(test_schema(),
'SELECT COUNT(DISTINCT(papers.id)) FROM papers')['select'] == ground_truth
assert get_sql(test_schema(),
'SELECT COUNT(DISTINCT papers.id) FROM papers')['select'] == ground_truth

ground_truth = (True, [(0, (0, (0, '__papers.id__', False), None))])
assert get_sql(test_schema(),
'SELECT DISTINCT(papers.id) FROM papers')['select'] == ground_truth
assert get_sql(test_schema(),
'SELECT DISTINCT papers.id FROM papers')['select'] == ground_truth


def test_joins():
ground_truth = {'conds': [],
'table_units': [('table_unit', '__papers__'), ('table_unit', '__coauthored__')]}
assert get_sql(test_schema(),
'SELECT * FROM papers JOIN coauthored')['from'] == ground_truth
assert get_sql(test_schema(),
'SELECT * FROM papers INNER JOIN coauthored')['from'] == ground_truth
assert get_sql(test_schema(),
'SELECT * FROM papers, coauthored')['from'] == ground_truth


def test_different_not_equal_operators():
ground_truth = [(False, 7, (0, (0, '__papers.title__', False), None), '"bar"', None)]
assert get_sql(test_schema(),
'SELECT * FROM papers WHERE papers.title <> "bar"')['where'] == ground_truth
assert get_sql(test_schema(),
'SELECT * FROM papers WHERE papers.title != "bar"')['where'] == ground_truth

0 comments on commit a9b7211

Please sign in to comment.