Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] enhanced memoized on get_sqla_engine and other functions #3530

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,8 @@ def get_effective_user(self, url, user_name=None):
effective_username = g.user.username
return effective_username

@utils.memoized(
watch=('impersonate_user', 'sqlalchemy_uri_decrypted', 'extra'))
def get_sqla_engine(self, schema=None, nullpool=False, user_name=None):
extra = self.get_extra()
url = make_url(self.sqlalchemy_uri_decrypted)
Expand Down Expand Up @@ -662,10 +664,10 @@ def get_sqla_engine(self, schema=None, nullpool=False, user_name=None):
return create_engine(url, **params)

def get_reserved_words(self):
return self.get_sqla_engine().dialect.preparer.reserved_words
return self.get_dialect().preparer.reserved_words

def get_quoter(self):
return self.get_sqla_engine().dialect.identifier_preparer.quote
return self.get_dialect().identifier_preparer.quote

def get_df(self, sql, schema):
sql = sql.strip().strip(';')
Expand Down Expand Up @@ -813,6 +815,7 @@ def has_table(self, table):
return engine.has_table(
table.table_name, table.schema or None)

@utils.memoized
def get_dialect(self):
sqla_url = url.make_url(self.sqlalchemy_uri_decrypted)
return sqla_url.get_dialect()()
Expand Down
38 changes: 29 additions & 9 deletions superset/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,38 +91,58 @@ def flasher(msg, severity=None):
logging.info(msg)


class memoized(object): # noqa
class _memoized(object): # noqa
"""Decorator that caches a function's return value each time it is called

If called later with the same arguments, the cached value is returned, and
not re-evaluated.

Define ``watch`` as a tuple of attribute names if this Decorator
should account for instance variable changes.
"""

def __init__(self, func):
def __init__(self, func, watch=()):
self.func = func
self.cache = {}

def __call__(self, *args):
self.is_method = False
self.watch = watch

def __call__(self, *args, **kwargs):
key = [args, frozenset(kwargs.items())]
if self.is_method:
key.append(tuple([getattr(args[0], v, None) for v in self.watch]))
key = tuple(key)
if key in self.cache:
return self.cache[key]
try:
return self.cache[args]
except KeyError:
value = self.func(*args)
self.cache[args] = value
value = self.func(*args, **kwargs)
self.cache[key] = value
return value
except TypeError:
# uncachable -- for instance, passing a list as an argument.
# Better to not cache than to blow up entirely.
return self.func(*args)
return self.func(*args, **kwargs)

def __repr__(self):
"""Return the function's docstring."""
return self.func.__doc__

def __get__(self, obj, objtype):
if not self.is_method:
self.is_method = True
"""Support instance methods."""
return functools.partial(self.__call__, obj)


def memoized(func=None, watch=None):
if func:
return _memoized(func)
else:
def wrapper(f):
return _memoized(f, watch)
return wrapper


def js_string_to_python(item):
return None if item in ('null', 'undefined') else item

Expand Down
76 changes: 75 additions & 1 deletion tests/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from superset.utils import (
base_json_conv, datetime_f, json_int_dttm_ser, json_iso_dttm_ser,
JSONEncodedDict, merge_extra_filters, parse_human_timedelta,
JSONEncodedDict, memoized, merge_extra_filters, parse_human_timedelta,
SupersetException, validate_json, zlib_compress, zlib_decompress_to_string,
)

Expand Down Expand Up @@ -219,3 +219,77 @@ def test_validate_json(self):
invalid = '{"a": 5, "b": [1, 5, ["g", "h]]}'
with self.assertRaises(SupersetException):
validate_json(invalid)

def test_memoized_on_functions(self):
watcher = {'val': 0}

@memoized
def test_function(a, b, c):
watcher['val'] += 1
return a * b * c
result1 = test_function(1, 2, 3)
result2 = test_function(1, 2, 3)
self.assertEquals(result1, result2)
self.assertEquals(watcher['val'], 1)

def test_memoized_on_methods(self):

class test_class:
def __init__(self, num):
self.num = num
self.watcher = 0

@memoized
def test_method(self, a, b, c):
self.watcher += 1
return a * b * c * self.num

instance = test_class(5)
result1 = instance.test_method(1, 2, 3)
result2 = instance.test_method(1, 2, 3)
self.assertEquals(result1, result2)
self.assertEquals(instance.watcher, 1)
instance.num = 10
self.assertEquals(result2, instance.test_method(1, 2, 3))

def test_memoized_on_methods_with_watches(self):

class test_class:
def __init__(self, x, y):
self.x = x
self.y = y
self.watcher = 0

@memoized(watch=('x', 'y'))
def test_method(self, a, b, c):
self.watcher += 1
return a * b * c * self.x * self.y

instance = test_class(3, 12)
result1 = instance.test_method(1, 2, 3)
result2 = instance.test_method(1, 2, 3)
self.assertEquals(result1, result2)
self.assertEquals(instance.watcher, 1)
result3 = instance.test_method(2, 3, 4)
self.assertEquals(instance.watcher, 2)
result4 = instance.test_method(2, 3, 4)
self.assertEquals(instance.watcher, 2)
self.assertEquals(result3, result4)
self.assertNotEqual(result3, result1)
instance.x = 1
result5 = instance.test_method(2, 3, 4)
self.assertEqual(instance.watcher, 3)
self.assertNotEqual(result5, result4)
result6 = instance.test_method(2, 3, 4)
self.assertEqual(instance.watcher, 3)
self.assertEqual(result6, result5)
instance.x = 10
instance.y = 10
result7 = instance.test_method(2, 3, 4)
self.assertEqual(instance.watcher, 4)
self.assertNotEqual(result7, result6)
instance.x = 3
instance.y = 12
result8 = instance.test_method(1, 2, 3)
self.assertEqual(instance.watcher, 4)
self.assertEqual(result1, result8)