Skip to content

Commit

Permalink
DI-1113. Authentication: Enable user impersonation for Superset to Hi…
Browse files Browse the repository at this point in the history
…veServer2 using hive.server2.proxy.user (a.fernandez) (#3652)
  • Loading branch information
afernandez authored and mistercrunch committed Oct 17, 2017
1 parent 08f09b4 commit adef519
Show file tree
Hide file tree
Showing 7 changed files with 153 additions and 23 deletions.
54 changes: 54 additions & 0 deletions superset/db_engine_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from sqlalchemy.sql import text
from flask_babel import lazy_gettext as _

from sqlalchemy.engine.url import make_url

from superset.utils import SupersetTemplateException
from superset.utils import QueryStatus
from superset import conf, cache_util, utils
Expand Down Expand Up @@ -184,6 +186,28 @@ def select_star(cls, my_db, table_name, schema=None, limit=100,
sql = sqlparse.format(sql, reindent=True)
return sql

@classmethod
def modify_url_for_impersonation(cls, url, impersonate_user, username):
"""
Modify the SQL Alchemy URL object with the user to impersonate if applicable.
:param url: SQLAlchemy URL object
:param impersonate_user: Bool indicating if impersonation is enabled
:param username: Effective username
"""
if impersonate_user is not None and username is not None:
url.username = username

@classmethod
def get_uri_for_impersonation(cls, uri, impersonate_user, username):
"""
Return a new URI string that allows for user impersonation.
:param uri: URI string
:param impersonate_user: Bool indicating if impersonation is enabled
:param username: Effective username
:return: New URI string
"""
return uri


class PostgresEngineSpec(BaseEngineSpec):
engine = 'postgresql'
Expand Down Expand Up @@ -677,6 +701,7 @@ def patch(cls):
hive.constants = patched_constants
hive.ttypes = patched_ttypes
hive.Cursor.fetch_logs = patched_hive.fetch_logs
hive.Connection = patched_hive.ConnectionProxyUser

@classmethod
@cache_util.memoized_func(
Expand Down Expand Up @@ -830,6 +855,35 @@ def _partition_query(
cls, table_name, limit=0, order_by=None, filters=None):
return "SHOW PARTITIONS {table_name}".format(**locals())

@classmethod
def modify_url_for_impersonation(cls, url, impersonate_user, username):
"""
Modify the SQL Alchemy URL object with the user to impersonate if applicable.
:param url: SQLAlchemy URL object
:param impersonate_user: Bool indicating if impersonation is enabled
:param username: Effective username
"""
if impersonate_user is not None and "auth" in url.query.keys() and username is not None:
url.query["hive_server2_proxy_user"] = username

@classmethod
def get_uri_for_impersonation(cls, uri, impersonate_user, username):
"""
Return a new URI string that allows for user impersonation.
:param uri: URI string
:param impersonate_user: Bool indicating if impersonation is enabled
:param username: Effective username
:return: New URI string
"""
new_uri = uri
url = make_url(uri)
backend_name = url.get_backend_name()

# Must be Hive connection, enable impersonation, and set param auth=LDAP|KERBEROS
if backend_name == "hive" and "auth" in url.query.keys() and\
impersonate_user is True and username is not None:
new_uri += "&hive_server2_proxy_user={0}".format(username)
return new_uri

class MssqlEngineSpec(BaseEngineSpec):
engine = 'mssql'
Expand Down
22 changes: 22 additions & 0 deletions superset/db_engines/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,28 @@
from thrift import Thrift


old_Connection = hive.Connection

# TODO
# Monkey-patch of PyHive project's pyhive/hive.py which needed to change the constructor.
# Submitted a pull request on October 13, 2017 and waiting for it to be merged.
# https://github.com/dropbox/PyHive/pull/165
class ConnectionProxyUser(hive.Connection):

def __init__(self, host=None, port=None, username=None, database='default', auth=None,
configuration=None, kerberos_service_name=None, password=None,
thrift_transport=None, hive_server2_proxy_user=None):
configuration = configuration or {}
if auth is not None and auth in ('LDAP', 'KERBEROS'):
if hive_server2_proxy_user is not None:
configuration["hive.server2.proxy.user"] = hive_server2_proxy_user
# restore the old connection class, otherwise, will recurse on its own __init__ method
hive.Connection = old_Connection
hive.Connection.__init__(self, host=host, port=port, username=username, database=database, auth=auth,
configuration=configuration, kerberos_service_name=kerberos_service_name, password=password,
thrift_transport=thrift_transport)


# TODO: contribute back to pyhive.
def fetch_logs(self, max_rows=1024,
orientation=ttypes.TFetchOrientation.FETCH_NEXT):
Expand Down
58 changes: 45 additions & 13 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from future.standard_library import install_aliases
from copy import copy
from datetime import datetime, date
from copy import deepcopy

import pandas as pd
import sqlalchemy as sqla
Expand Down Expand Up @@ -47,6 +48,7 @@
stats_logger = config.get('STATS_LOGGER')
metadata = Model.metadata # pylint: disable=no-member

PASSWORD_MASK = "X" * 10

def set_related_perm(mapper, connection, target): # noqa
src_class = target.cls_model
Expand Down Expand Up @@ -581,30 +583,56 @@ def backend(self):
url = make_url(self.sqlalchemy_uri_decrypted)
return url.get_backend_name()

@classmethod
def get_password_masked_url_from_uri(cls, uri):
url = make_url(uri)
return cls.get_password_masked_url(url)

@classmethod
def get_password_masked_url(cls, url):
url_copy = deepcopy(url)
if url_copy.password is not None and url_copy.password != PASSWORD_MASK:
url_copy.password = PASSWORD_MASK
return url_copy

def set_sqlalchemy_uri(self, uri):
password_mask = "X" * 10
conn = sqla.engine.url.make_url(uri)
if conn.password != password_mask and not self.custom_password_store:
if conn.password != PASSWORD_MASK and not self.custom_password_store:
# do not over-write the password with the password mask
self.password = conn.password
conn.password = password_mask if conn.password else None
conn.password = PASSWORD_MASK if conn.password else None
self.sqlalchemy_uri = str(conn) # hides the password

def get_effective_user(self, url, user_name=None):
"""
Get the effective user, especially during impersonation.
:param url: SQL Alchemy URL object
:param user_name: Default username
:return: The effective username
"""
effective_username = None
if self.impersonate_user:
effective_username = url.username
if user_name:
effective_username = user_name
elif hasattr(g, 'user') and g.user.username:
effective_username = g.user.username
return effective_username

def get_sqla_engine(self, schema=None, nullpool=False, user_name=None):
extra = self.get_extra()
uri = make_url(self.sqlalchemy_uri_decrypted)
url = make_url(self.sqlalchemy_uri_decrypted)
params = extra.get('engine_params', {})
if nullpool:
params['poolclass'] = NullPool
uri = self.db_engine_spec.adjust_database_uri(uri, schema)
if self.impersonate_user:
eff_username = uri.username
if user_name:
eff_username = user_name
elif hasattr(g, 'user') and g.user.username:
eff_username = g.user.username
uri.username = eff_username
return create_engine(uri, **params)
url = self.db_engine_spec.adjust_database_uri(url, schema)
effective_username = self.get_effective_user(url, user_name)
self.db_engine_spec.modify_url_for_impersonation(url, self.impersonate_user, effective_username)

masked_url = self.get_password_masked_url(url)
logging.info("Database.get_sqla_engine(). Masked URL: {0}".format(masked_url))

return create_engine(url, **params)

def get_reserved_words(self):
return self.get_sqla_engine().dialect.preparer.reserved_words
Expand Down Expand Up @@ -688,6 +716,10 @@ def db_engine_spec(self):
return db_engine_specs.engines.get(
self.backend, db_engine_specs.BaseEngineSpec)

@classmethod
def get_db_engine_spec_for_backend(cls, backend):
return db_engine_specs.engines.get(backend, db_engine_specs.BaseEngineSpec)

def grains(self):
"""Defines time granularity database-specific expressions.
Expand Down
12 changes: 8 additions & 4 deletions superset/sql_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def handle_error(msg):
session.merge(query)
session.commit()
logging.info("Set query to 'running'")
conn = None
try:
engine = database.get_sqla_engine(
schema=query.schema, nullpool=not ctask.request.called_directly, user_name=user_name)
Expand All @@ -187,20 +188,23 @@ def handle_error(msg):
data = db_engine_spec.fetch_data(cursor, query.limit)
except SoftTimeLimitExceeded as e:
logging.exception(e)
conn.close()
if conn is not None:
conn.close()
return handle_error(
"SQL Lab timeout. This environment's policy is to kill queries "
"after {} seconds.".format(SQLLAB_TIMEOUT))
except Exception as e:
logging.exception(e)
conn.close()
if conn is not None:
conn.close()
return handle_error(db_engine_spec.extract_error_message(e))

logging.info("Fetching cursor description")
cursor_description = cursor.description

conn.commit()
conn.close()
if conn is not None:
conn.commit()
conn.close()

if query.status == utils.QueryStatus.STOPPED:
return json.dumps(
Expand Down
1 change: 1 addition & 0 deletions superset/templates/superset/models/database/macros.html
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
data = JSON.stringify({
uri: $("#sqlalchemy_uri").val(),
name: $('#database_name').val(),
impersonate_user: $('#impersonate_user').is(':checked'),
extras: JSON.parse($("#extra").val()),
})
} catch(parse_error){
Expand Down
20 changes: 17 additions & 3 deletions superset/views/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from flask_babel import lazy_gettext as _

from sqlalchemy import create_engine
from sqlalchemy.engine.url import make_url
from werkzeug.routing import BaseConverter

from superset import (
Expand Down Expand Up @@ -240,8 +241,10 @@ class DatabaseView(SupersetModelView, DeleteMixin): # noqa
"(http://docs.sqlalchemy.org/en/rel_1_0/core/metadata.html"
"#sqlalchemy.schema.MetaData) call. ", True),
'impersonate_user': _(
"All the queries in Sql Lab are going to be executed "
"on behalf of currently authorized user."),
"If Presto, all the queries in SQL Lab are going to be executed as the currently logged on user "
"who must have permission to run them.<br/>"
"If Hive and hive.server2.enable.doAs is enabled, will run the queries as service account, "
"but impersonate the currently logged on user via hive.server2.proxy.user property."),
}
label_columns = {
'expose_in_sqllab': _("Expose in SQL Lab"),
Expand All @@ -256,7 +259,7 @@ class DatabaseView(SupersetModelView, DeleteMixin): # noqa
'extra': _("Extra"),
'allow_run_sync': _("Allow Run Sync"),
'allow_run_async': _("Allow Run Async"),
'impersonate_user': _("Impersonate queries to the database"),
'impersonate_user': _("Impersonate the logged on user")
}

def pre_add(self, db):
Expand Down Expand Up @@ -1421,8 +1424,10 @@ def add_slices(self, dashboard_id):
def testconn(self):
"""Tests a sqla connection"""
try:
username = g.user.username if g.user is not None else None
uri = request.json.get('uri')
db_name = request.json.get('name')
impersonate_user = request.json.get('impersonate_user')
if db_name:
database = (
db.session
Expand All @@ -1434,6 +1439,15 @@ def testconn(self):
# the password-masked uri was passed
# use the URI associated with this database
uri = database.sqlalchemy_uri_decrypted

url = make_url(uri)
db_engine = models.Database.get_db_engine_spec_for_backend(url.get_backend_name())
db_engine.patch()
uri = db_engine.get_uri_for_impersonation(uri, impersonate_user, username)
masked_url = database.get_password_masked_url_from_uri(uri)

logging.info("Superset.testconn(). Masked URL: {0}".format(masked_url))

connect_args = (
request.json
.get('extras', {})
Expand Down
9 changes: 6 additions & 3 deletions tests/core_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,13 +276,15 @@ def test_misc(self):
assert self.get_resp('/health') == "OK"
assert self.get_resp('/ping') == "OK"

def test_testconn(self):
def test_testconn(self, username='admin'):
self.login(username=username)
database = self.get_main_database(db.session)

# validate that the endpoint works with the password-masked sqlalchemy uri
data = json.dumps({
'uri': database.safe_sqlalchemy_uri(),
'name': 'main'
'name': 'main',
'impersonate_user': False
})
response = self.client.post('/superset/testconn', data=data, content_type='application/json')
assert response.status_code == 200
Expand All @@ -291,7 +293,8 @@ def test_testconn(self):
# validate that the endpoint works with the decrypted sqlalchemy uri
data = json.dumps({
'uri': database.sqlalchemy_uri_decrypted,
'name': 'main'
'name': 'main',
'impersonate_user': False
})
response = self.client.post('/superset/testconn', data=data, content_type='application/json')
assert response.status_code == 200
Expand Down

0 comments on commit adef519

Please sign in to comment.