diff --git a/dask_drmaa/adaptive.py b/dask_drmaa/adaptive.py index fbc124c..d5ddfca 100644 --- a/dask_drmaa/adaptive.py +++ b/dask_drmaa/adaptive.py @@ -32,7 +32,7 @@ class Adaptive(object): ... def scale_down(self, workers): ... """ Remove worker addresses from cluster """ ''' - def __init__(self, scheduler=None, cluster=None, interval=1000, startup_cost=1): + def __init__(self, cluster=None, scheduler=None, interval=1000, startup_cost=1): self.cluster = cluster if scheduler is None: scheduler = cluster.scheduler @@ -85,7 +85,7 @@ def _adapt(self): memory.append(m) if memory: - workers = self.cluster.start_workers(1, memory=max(memory) * 2) + workers = self.cluster.start_workers(1, memory=max(memory) * 4) else: workers = self.cluster.start_workers(1) logger.info("Starting workers due to resource constraints: %s", workers) diff --git a/dask_drmaa/core.py b/dask_drmaa/core.py index 806e64d..e5b7d79 100644 --- a/dask_drmaa/core.py +++ b/dask_drmaa/core.py @@ -4,6 +4,7 @@ import sys import drmaa +from toolz import merge from tornado.ioloop import PeriodicCallback from distributed import LocalCluster @@ -21,18 +22,19 @@ def get_session(): return _global_session[0] +default_template = { + 'remoteCommand': os.path.join(sys.exec_prefix, 'bin', 'dask-worker'), + 'jobName': 'dask-worker', + 'outputPath': ':%s/out' % os.getcwd(), + 'errorPath': ':%s/err' % os.getcwd(), + 'workingDirectory': os.getcwd(), + 'nativeSpecification': '', + 'args': [] + } + + class DRMAACluster(object): - def __init__(self, - jobName='dask-worker', - remoteCommand=os.path.join(sys.exec_prefix, 'bin', 'dask-worker'), - args=(), - outputPath=':%s/out' % os.getcwd(), - errorPath=':%s/err' % os.getcwd(), - workingDirectory = os.getcwd(), - nativeSpecification='', - max_runtime='1:00:00', #1 hour - cleanup_interval=1000, - **kwargs): + def __init__(self, template=None, cleanup_interval=1000, hostname=None, **kwargs): """ Dask workers launched by a DRMAA-compatible cluster @@ -50,8 +52,6 @@ def __init__(self, Where dask-worker runs, defaults to current directory nativeSpecification: string Options native to the job scheduler - max_runtime: string - Maximum runtime of worker jobs in format ``"HH:MM:SS"`` Examples -------- @@ -66,18 +66,11 @@ def __init__(self, >>> future.result() # doctest: +SKIP 11 """ - logger.info("Start local scheduler") + self.hostname = hostname or socket.gethostname() + logger.info("Start local scheduler at %s", self.hostname) self.local_cluster = LocalCluster(n_workers=0, **kwargs) - logger.info("Initialize connection to job scheduler") - self.jobName = jobName - self.remoteCommand = remoteCommand - self.args = ['%s:%d' % (socket.gethostname(), - self.local_cluster.scheduler.port)] + list(args) - self.outputPath = outputPath - self.errorPath = errorPath - self.nativeSpecification = nativeSpecification - self.max_runtime = max_runtime + self.template = merge(default_template, template or {}) self._cleanup_callback = PeriodicCallback(callback=self.cleanup_closed_workers, callback_time=cleanup_interval, @@ -92,39 +85,42 @@ def scheduler(self): @property def scheduler_address(self): - return self.scheduler.address - - def createJobTemplate(self, nativeSpecification=''): - wt = get_session().createJobTemplate() - wt.jobName = self.jobName - wt.remoteCommand = self.remoteCommand - wt.args = self.args - wt.outputPath = self.outputPath - wt.errorPath = self.errorPath - wt.nativeSpecification = self.nativeSpecification + ' ' + nativeSpecification - return wt + return '%s:%d' % (self.hostname, self.scheduler.port) + + def create_job_template(self, **kwargs): + template = self.template.copy() + if kwargs: + template.update(kwargs) + template['args'] = [self.scheduler_address] + template['args'] + + jt = get_session().createJobTemplate() + valid_attributes = dir(jt) + + for key, value in template.items(): + if key not in valid_attributes: + raise ValueError("Invalid job template attribute %s" % key) + setattr(jt, key, value) + + return jt def start_workers(self, n=1, **kwargs): with log_errors(): - wt = self.createJobTemplate(**kwargs) - - ids = get_session().runBulkJobs(wt, 1, n, 1) - logger.info("Start %d workers. Job ID: %s", len(ids), ids[0].split('.')[0]) - self.workers.update({jid: kwargs for jid in ids}) - global_workers.update(ids) + with self.create_job_template(**kwargs) as jt: + ids = get_session().runBulkJobs(jt, 1, n, 1) + logger.info("Start %d workers. Job ID: %s", len(ids), ids[0].split('.')[0]) + self.workers.update({jid: kwargs for jid in ids}) def stop_workers(self, worker_ids, sync=False): - worker_ids = list(worker_ids) - for wid in worker_ids: + if isinstance(worker_ids, str): + worker_ids = [worker_ids] + + for wid in list(worker_ids): try: get_session().control(wid, drmaa.JobControlAction.TERMINATE) except drmaa.errors.InvalidJobException: pass self.workers.pop(wid) - with ignoring(KeyError): - global_workers.remove(wid) - logger.info("Stop workers %s", worker_ids) if sync: get_session().synchronize(worker_ids, dispose=True) @@ -159,18 +155,10 @@ def __str__(self): __repr__ = __str__ -global_workers = set() - def remove_workers(): - if not get_session(): - return - - for wid in global_workers: - try: - get_session().control(wid, drmaa.JobControlAction.TERMINATE) - except drmaa.errors.InvalidJobException: - pass + get_session().control(drmaa.Session.JOB_IDS_SESSION_ALL, + drmaa.JobControlAction.TERMINATE) import atexit diff --git a/dask_drmaa/sge.py b/dask_drmaa/sge.py index 0ec11da..feec9ed 100644 --- a/dask_drmaa/sge.py +++ b/dask_drmaa/sge.py @@ -3,32 +3,36 @@ class SGECluster(DRMAACluster): default_memory = None - default_memory_fraction = 0.6 - def createJobTemplate(self, nativeSpecification='', cpus=1, memory=None, - memory_fraction=None): + + def create_job_template(self, nativeSpecification='', cpus=1, memory=None, + memory_fraction=0.5): memory = memory or self.default_memory - memory_fraction = memory_fraction or self.default_memory_fraction + template = self.template.copy() + + ns = template['nativeSpecification'] + args = template['args'] + + args = [self.scheduler_address] + template['args'] - args = self.args - ns = self.nativeSpecification if nativeSpecification: ns = ns + nativeSpecification if memory: - args = args + ['--memory-limit', str(memory * memory_fraction)] - args = args + ['--resources', 'memory=%f' % (memory * 0.8)] + args = args + ['--memory-limit', str(memory * (1 - memory_fraction))] + args = args + ['--resources', 'memory=%f' % (memory * memory_fraction)] ns += ' -l h_vmem=%dG' % int(memory / 1e9) # / cpus if cpus: args = args + ['--nprocs', '1', '--nthreads', str(cpus)] # ns += ' -l TODO=%d' % (cpu + 1) - ns += ' -l h_rt={}'.format(self.max_runtime) + template['nativeSpecification'] = ns + template['args'] = args + + jt = get_session().createJobTemplate() + valid_attributes = dir(jt) - wt = get_session().createJobTemplate() - wt.jobName = self.jobName - wt.remoteCommand = self.remoteCommand - wt.args = args - wt.outputPath = self.outputPath - wt.errorPath = self.errorPath - wt.nativeSpecification = ns + for key, value in template.items(): + if key not in valid_attributes: + raise ValueError("Invalid job template attribute %s" % key) + setattr(jt, key, value) - return wt + return jt diff --git a/dask_drmaa/tests/test_adaptive.py b/dask_drmaa/tests/test_adaptive.py index 5e096aa..8454df1 100644 --- a/dask_drmaa/tests/test_adaptive.py +++ b/dask_drmaa/tests/test_adaptive.py @@ -11,7 +11,7 @@ def test_adaptive_memory(loop): with SGECluster(scheduler_port=0, cleanup_interval=100) as cluster: - adapt = Adaptive(cluster=cluster) + adapt = Adaptive(cluster) with Client(cluster, loop=loop) as client: future = client.submit(inc, 1, resources={'memory': 1e9}) assert future.result() == 2 @@ -34,7 +34,7 @@ def test_adaptive_memory(loop): def test_adaptive_normal_tasks(loop): with SGECluster(scheduler_port=0) as cluster: - adapt = Adaptive(cluster=cluster) + adapt = Adaptive(cluster) with Client(cluster, loop=loop) as client: future = client.submit(inc, 1) assert future.result() == 2 @@ -43,7 +43,7 @@ def test_adaptive_normal_tasks(loop): @pytest.mark.parametrize('interval', [50, 1000]) def test_dont_over_request(loop, interval): with SGECluster(scheduler_port=0) as cluster: - adapt = Adaptive(cluster=cluster) + adapt = Adaptive(cluster) with Client(cluster, loop=loop) as client: future = client.submit(inc, 1) assert future.result() == 2 @@ -56,7 +56,7 @@ def test_dont_over_request(loop, interval): def test_request_more_than_one(loop): with SGECluster(scheduler_port=0) as cluster: - adapt = Adaptive(cluster=cluster) + adapt = Adaptive(cluster) with Client(cluster, loop=loop) as client: futures = client.map(slowinc, range(1000), delay=0.2) while len(cluster.scheduler.workers) < 3: @@ -71,7 +71,7 @@ def test_dont_request_if_idle(loop): sleep(0.1) futures = client.map(slowinc, range(1000), delay=0.2, workers=first(cluster.scheduler.workers)) - adapt = Adaptive(cluster=cluster, interval=2000) + adapt = Adaptive(cluster, interval=2000) for i in range(60): sleep(0.1) @@ -80,7 +80,7 @@ def test_dont_request_if_idle(loop): def test_dont_request_if_not_enough_tasks(loop): with SGECluster(scheduler_port=0) as cluster: - adapt = Adaptive(cluster=cluster) + adapt = Adaptive(cluster) with Client(cluster, loop=loop) as client: cluster.scheduler.task_duration['slowinc'] = 1000 future = client.submit(slowinc, 1, delay=1000) @@ -92,7 +92,7 @@ def test_dont_request_if_not_enough_tasks(loop): def test_dont_request_on_many_short_tasks(loop): with SGECluster(scheduler_port=0) as cluster: - adapt = Adaptive(cluster=cluster, interval=50, startup_cost=10) + adapt = Adaptive(cluster, interval=50, startup_cost=10) with Client(cluster, loop=loop) as client: cluster.scheduler.task_duration['slowinc'] = 0.001 futures = client.map(slowinc, range(1000), delay=0.001) diff --git a/dask_drmaa/tests/test_core.py b/dask_drmaa/tests/test_core.py index 56c784f..f731122 100644 --- a/dask_drmaa/tests/test_core.py +++ b/dask_drmaa/tests/test_core.py @@ -58,3 +58,37 @@ def test_multiple_overlapping_clusters(loop): assert future_1.result() == 2 assert future_2.result() == 3 + + +def test_stop_single_worker(loop): + with DRMAACluster(scheduler_port=0) as cluster: + with Client(cluster, loop=loop) as client: + cluster.start_workers(2) + future = client.submit(lambda x: x + 1, 1) + assert future.result() == 2 + + a, b = cluster.workers + cluster.stop_workers(a) + + start = time() + while len(client.ncores()) != 1: + sleep(0.2) + assert time() < start + 60 + + +@pytest.mark.xfail(reason="Need mapping from worker addresses to job ids") +def test_stop_workers_politely(loop): + with DRMAACluster(scheduler_port=0) as cluster: + with Client(cluster, loop=loop) as client: + cluster.start_workers(2) + + while len(client.ncores()) < 2: + sleep(0.1) + + futures = client.scatter(list(range(10))) + + a, b = cluster.workers + cluster.stop_workers(a) + + data = client.gather(futures) + assert data == list(range(10))