Skip to content

Commit

Permalink
[Feature] enhanced memoized on get_sqla_engine and other functions (a…
Browse files Browse the repository at this point in the history
…pache#3530)

* added watch to memoized

* added unit tests for memoized

* code style changes
  • Loading branch information
Mogball authored and michellethomas committed May 23, 2018
1 parent b888e5d commit dd54812
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 12 deletions.
7 changes: 5 additions & 2 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,8 @@ def get_effective_user(self, url, user_name=None):
effective_username = g.user.username
return effective_username

@utils.memoized(
watch=('impersonate_user', 'sqlalchemy_uri_decrypted', 'extra'))
def get_sqla_engine(self, schema=None, nullpool=False, user_name=None):
extra = self.get_extra()
url = make_url(self.sqlalchemy_uri_decrypted)
Expand Down Expand Up @@ -662,10 +664,10 @@ def get_sqla_engine(self, schema=None, nullpool=False, user_name=None):
return create_engine(url, **params)

def get_reserved_words(self):
return self.get_sqla_engine().dialect.preparer.reserved_words
return self.get_dialect().preparer.reserved_words

def get_quoter(self):
return self.get_sqla_engine().dialect.identifier_preparer.quote
return self.get_dialect().identifier_preparer.quote

def get_df(self, sql, schema):
sql = sql.strip().strip(';')
Expand Down Expand Up @@ -813,6 +815,7 @@ def has_table(self, table):
return engine.has_table(
table.table_name, table.schema or None)

@utils.memoized
def get_dialect(self):
sqla_url = url.make_url(self.sqlalchemy_uri_decrypted)
return sqla_url.get_dialect()()
Expand Down
38 changes: 29 additions & 9 deletions superset/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,38 +91,58 @@ def flasher(msg, severity=None):
logging.info(msg)


class memoized(object): # noqa
class _memoized(object): # noqa
"""Decorator that caches a function's return value each time it is called
If called later with the same arguments, the cached value is returned, and
not re-evaluated.
Define ``watch`` as a tuple of attribute names if this Decorator
should account for instance variable changes.
"""

def __init__(self, func):
def __init__(self, func, watch=()):
self.func = func
self.cache = {}

def __call__(self, *args):
self.is_method = False
self.watch = watch

def __call__(self, *args, **kwargs):
key = [args, frozenset(kwargs.items())]
if self.is_method:
key.append(tuple([getattr(args[0], v, None) for v in self.watch]))
key = tuple(key)
if key in self.cache:
return self.cache[key]
try:
return self.cache[args]
except KeyError:
value = self.func(*args)
self.cache[args] = value
value = self.func(*args, **kwargs)
self.cache[key] = value
return value
except TypeError:
# uncachable -- for instance, passing a list as an argument.
# Better to not cache than to blow up entirely.
return self.func(*args)
return self.func(*args, **kwargs)

def __repr__(self):
"""Return the function's docstring."""
return self.func.__doc__

def __get__(self, obj, objtype):
if not self.is_method:
self.is_method = True
"""Support instance methods."""
return functools.partial(self.__call__, obj)


def memoized(func=None, watch=None):
if func:
return _memoized(func)
else:
def wrapper(f):
return _memoized(f, watch)
return wrapper


def js_string_to_python(item):
return None if item in ('null', 'undefined') else item

Expand Down
76 changes: 75 additions & 1 deletion tests/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from superset.utils import (
base_json_conv, datetime_f, json_int_dttm_ser, json_iso_dttm_ser,
JSONEncodedDict, merge_extra_filters, parse_human_timedelta,
JSONEncodedDict, memoized, merge_extra_filters, parse_human_timedelta,
SupersetException, validate_json, zlib_compress, zlib_decompress_to_string,
)

Expand Down Expand Up @@ -219,3 +219,77 @@ def test_validate_json(self):
invalid = '{"a": 5, "b": [1, 5, ["g", "h]]}'
with self.assertRaises(SupersetException):
validate_json(invalid)

def test_memoized_on_functions(self):
watcher = {'val': 0}

@memoized
def test_function(a, b, c):
watcher['val'] += 1
return a * b * c
result1 = test_function(1, 2, 3)
result2 = test_function(1, 2, 3)
self.assertEquals(result1, result2)
self.assertEquals(watcher['val'], 1)

def test_memoized_on_methods(self):

class test_class:
def __init__(self, num):
self.num = num
self.watcher = 0

@memoized
def test_method(self, a, b, c):
self.watcher += 1
return a * b * c * self.num

instance = test_class(5)
result1 = instance.test_method(1, 2, 3)
result2 = instance.test_method(1, 2, 3)
self.assertEquals(result1, result2)
self.assertEquals(instance.watcher, 1)
instance.num = 10
self.assertEquals(result2, instance.test_method(1, 2, 3))

def test_memoized_on_methods_with_watches(self):

class test_class:
def __init__(self, x, y):
self.x = x
self.y = y
self.watcher = 0

@memoized(watch=('x', 'y'))
def test_method(self, a, b, c):
self.watcher += 1
return a * b * c * self.x * self.y

instance = test_class(3, 12)
result1 = instance.test_method(1, 2, 3)
result2 = instance.test_method(1, 2, 3)
self.assertEquals(result1, result2)
self.assertEquals(instance.watcher, 1)
result3 = instance.test_method(2, 3, 4)
self.assertEquals(instance.watcher, 2)
result4 = instance.test_method(2, 3, 4)
self.assertEquals(instance.watcher, 2)
self.assertEquals(result3, result4)
self.assertNotEqual(result3, result1)
instance.x = 1
result5 = instance.test_method(2, 3, 4)
self.assertEqual(instance.watcher, 3)
self.assertNotEqual(result5, result4)
result6 = instance.test_method(2, 3, 4)
self.assertEqual(instance.watcher, 3)
self.assertEqual(result6, result5)
instance.x = 10
instance.y = 10
result7 = instance.test_method(2, 3, 4)
self.assertEqual(instance.watcher, 4)
self.assertNotEqual(result7, result6)
instance.x = 3
instance.y = 12
result8 = instance.test_method(1, 2, 3)
self.assertEqual(instance.watcher, 4)
self.assertEqual(result1, result8)

0 comments on commit dd54812

Please sign in to comment.