Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RPC server task management #1380

Merged
merged 9 commits into from
Apr 3, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion core/dbt/adapters/sql/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@ def cancel(self, connection):

def cancel_open(self):
names = []
this_connection = self.get_if_exists()
with self.lock:
for connection in self.thread_connections.values():
if connection.name == 'master':
if connection is this_connection:
continue

self.cancel(connection)
Expand Down
16 changes: 16 additions & 0 deletions core/dbt/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,22 @@ def data(self):
return result


class RPCKilledException(RuntimeException):
CODE = 10009
MESSAGE = 'RPC process killed'

def __init__(self, signum):
self.signum = signum
self.message = 'RPC process killed by signal {}'.format(self.signum)
super(RPCKilledException, self).__init__(self.message)

def data(self):
return {
'signum': self.signum,
'message': self.message,
}


class DatabaseException(RuntimeException):
CODE = 10003
MESSAGE = "Database Error"
Expand Down
328 changes: 327 additions & 1 deletion core/dbt/rpc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,23 @@
from jsonrpc.exceptions import JSONRPCDispatchException, JSONRPCInvalidParams
from jsonrpc.exceptions import JSONRPCDispatchException, \
JSONRPCInvalidParams, \
JSONRPCParseError, \
JSONRPCInvalidRequestException, \
JSONRPCInvalidRequest
from jsonrpc import JSONRPCResponseManager
from jsonrpc.jsonrpc import JSONRPCRequest
from jsonrpc.jsonrpc2 import JSONRPC20Response

import json
import uuid
import multiprocessing
import os
import signal
import time
from collections import namedtuple

from dbt.logger import RPC_LOGGER as logger
from dbt.logger import add_queue_handler
from dbt.compat import QueueEmpty
import dbt.exceptions


Expand All @@ -16,6 +34,12 @@ def __init__(self, code=None, message=None, data=None, logs=None):
data=data)
self.logs = logs

def __str__(self):
return (
'RPCException({0.code}, {0.message}, {0.data}, {1.logs})'
.format(self.error, self)
)

@property
def logs(self):
return self.error.data.get('logs')
Expand Down Expand Up @@ -66,3 +90,305 @@ def terminating(cls):
cls.Error,
cls.Result
]


def sigterm_handler(signum, frame):
raise dbt.exceptions.RPCKilledException(signum)


class RequestDispatcher(object):
"""A special dispatcher that knows about requests."""
def __init__(self, http_request, json_rpc_request, manager):
self.http_request = http_request
self.json_rpc_request = json_rpc_request
self.manager = manager
self.task = None

def rpc_factory(self, task):
request_handler = RequestTaskHandler(task,
self.http_request,
self.json_rpc_request)

def rpc_func(**kwargs):
try:
self.manager.add_request(request_handler)
return request_handler.handle(kwargs)
finally:
self.manager.mark_done(request_handler)

return rpc_func

def __getitem__(self, key):
# the dispatcher's keys are method names and its values are functions
# that implement the RPC calls
func = self.manager.rpc_builtin(key)
if func is not None:
return func

task = self.manager.rpc_task(key)
return self.rpc_factory(task)


class RequestTaskHandler(object):
def __init__(self, task, http_request, json_rpc_request):
self.task = task
self.http_request = http_request
self.json_rpc_request = json_rpc_request
self.queue = None
self.process = None
self.started = None
self.timeout = None
self.logs = []
self.task_id = uuid.uuid4()

@property
def request_source(self):
return self.http_request.remote_addr

@property
def request_id(self):
return self.json_rpc_request._id

@property
def method(self):
return self.task.METHOD_NAME

def _next_timeout(self):
if self.timeout is None:
return None
end = self.started + self.timeout
timeout = end - time.time()
if timeout < 0:
raise dbt.exceptions.RPCTimeoutException(self.timeout)
return timeout

def _wait_for_results(self):
"""Wait for results off the queue. If there is a timeout set, and it is
exceeded, raise an RPCTimeoutException.
"""
while True:
get_timeout = self._next_timeout()
try:
msgtype, value = self.queue.get(timeout=get_timeout)
except QueueEmpty:
raise dbt.exceptions.RPCTimeoutException(self.timeout)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this the right exception to raise here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. if the queue has not received any messages in the calculated timeout value, we have exceeded the set timeout, and should raise a timeout exception. This is actually old code, it just used to live in tasks/rpc_server.py


if msgtype == QueueMessageType.Log:
self.logs.append(value)
elif msgtype in QueueMessageType.terminating():
return msgtype, value
else:
raise dbt.exceptions.InternalException(
'Got invalid queue message type {}'.format(msgtype)
)

def _join_process(self):
try:
msgtype, result = self._wait_for_results()
except dbt.exceptions.RPCTimeoutException as exc:
self.process.terminate()
raise timeout_error(self.timeout)
except dbt.exceptions.Exception as exc:
raise dbt_error(exc)
except Exception as exc:
raise server_error(exc)
finally:
self.process.join()

if msgtype == QueueMessageType.Error:
raise RPCException.from_error(result)

return result

def get_result(self):
try:
result = self._join_process()
except RPCException as exc:
exc.logs = self.logs
raise

result['logs'] = self.logs
return result

def task_bootstrap(self, kwargs):
signal.signal(signal.SIGTERM, sigterm_handler)
# the first thing we do in a new process: start logging
add_queue_handler(self.queue)

error = None
result = None
try:
result = self.task.handle_request(**kwargs)
except RPCException as exc:
error = exc
except dbt.exceptions.RPCKilledException as exc:
# do NOT log anything here, you risk triggering a deadlock on the
# queue handler we inserted above
error = dbt_error(exc)
except dbt.exceptions.Exception as exc:
logger.debug('dbt runtime exception', exc_info=True)
error = dbt_error(exc)
except Exception as exc:
logger.debug('uncaught python exception', exc_info=True)
error = server_error(exc)

# put whatever result we got onto the queue as well.
if error is not None:
self.queue.put([QueueMessageType.Error, error.error])
else:
self.queue.put([QueueMessageType.Result, result])

def handle(self, kwargs):
self.started = time.time()
self.timeout = kwargs.pop('timeout', None)
self.queue = multiprocessing.Queue()
self.process = multiprocessing.Process(
target=self.task_bootstrap,
args=(kwargs,)
)
self.process.start()
return self.get_result()

@property
def state(self):
if self.started is None:
return 'not started'
elif self.process is None:
return 'initializing'
elif self.process.is_alive():
return 'running'
else:
return 'finished'


TaskRow = namedtuple(
'TaskRow',
'task_id request_id request_source method state start elapsed timeout'
)


class TaskManager(object):
def __init__(self):
self.tasks = {}
self.completed = {}
self._rpc_task_map = {}
self._rpc_function_map = {}
self._lock = multiprocessing.Lock()

def add_request(self, request_handler):
self.tasks[request_handler.task_id] = request_handler

def add_task_handler(self, task):
self._rpc_task_map[task.METHOD_NAME] = task

def rpc_task(self, method_name):
return self._rpc_task_map[method_name]

def process_listing(self, active=True, completed=False):
included_tasks = {}
with self._lock:
if completed:
included_tasks.update(self.completed)
if active:
included_tasks.update(self.tasks)

table = []
now = time.time()
for task_handler in included_tasks.values():
start = task_handler.started
if start is not None:
elapsed = now - start

table.append(TaskRow(
str(task_handler.task_id), task_handler.request_id,
task_handler.request_source, task_handler.method,
task_handler.state, start, elapsed, task_handler.timeout
))
table.sort(key=lambda r: (r.state, r.start))
result = {
'rows': [dict(r._asdict()) for r in table],
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you make this a list of objects instead? e.g.

[{
  "task_id": "1f7054d4-8dae-4560-9cc1-662127de4184",
  "request_id": "5bf9dec2-431b-44e2-9d76-1239928571ad",
  "request_source": "10.1.31.102",
  "method": "run",
  "state":  "running",
  "start": 1554303525.1364384,
  "elapsed": 0.014189481735229492,
  "timeout": 900
}]

In the frontend, I have to iterate over the list of tasks and match the request id I sent to the request id in the table. if the indexing of the table changes, my loop will break.... but this way I can just look the task up by key.

if data has to be a dict, just put this list under another key, up to you what that is called

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm pretty sure it does have to be a dict, so I named the key rows.

return result

def process_kill(self, task_id):
# TODO: this result design is terrible
result = {
'found': False,
'started': False,
'finished': False,
'killed': False
}
task_id = uuid.UUID(task_id)
try:
task = self.tasks[task_id]
except KeyError:
# nothing to do!
return result

result['found'] = True

if task.process is None:
return result
pid = task.process.pid
if pid is None:
return result

result['started'] = True

if task.process.is_alive():
os.kill(pid, signal.SIGINT)
result['killed'] = True
return result

result['finished'] = True
return result

def rpc_builtin(self, method_name):
if method_name == 'ps':
return self.process_listing
if method_name == 'kill':
return self.process_kill
return None

def mark_done(self, request_handler):
task_id = request_handler.task_id
with self._lock:
if task_id not in self.tasks:
# lost a task! Maybe it was killed before it started.
return
self.completed[task_id] = self.tasks.pop(task_id)

def methods(self):
rpc_builtin_methods = ['ps', 'kill']
return list(self._rpc_task_map) + rpc_builtin_methods


class ResponseManager(JSONRPCResponseManager):
"""Override the default response manager to handle request metadata and
track in-flight tasks.
"""
@classmethod
def handle(cls, http_request, task_manager):
# pretty much just copy+pasted from the original, with slight tweaks to
# preserve the request
request_str = http_request.data
if isinstance(request_str, bytes):
request_str = request_str.decode("utf-8")

try:
data = json.loads(request_str)
except (TypeError, ValueError):
return JSONRPC20Response(error=JSONRPCParseError()._data)

try:
request = JSONRPCRequest.from_data(data)
except JSONRPCInvalidRequestException:
return JSONRPC20Response(error=JSONRPCInvalidRequest()._data)

dispatcher = RequestDispatcher(
http_request,
request,
task_manager
)

return cls.handle_request(request, dispatcher)
Loading