Skip to content
This repository has been archived by the owner on Feb 10, 2021. It is now read-only.

Various changes #19

Merged
merged 7 commits into from
Jan 23, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions dask_drmaa/adaptive.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class Adaptive(object):
... def scale_down(self, workers):
... """ Remove worker addresses from cluster """
'''
def __init__(self, scheduler=None, cluster=None, interval=1000, startup_cost=1):
def __init__(self, cluster=None, scheduler=None, interval=1000, startup_cost=1):
self.cluster = cluster
if scheduler is None:
scheduler = cluster.scheduler
Expand Down Expand Up @@ -85,7 +85,7 @@ def _adapt(self):
memory.append(m)

if memory:
workers = self.cluster.start_workers(1, memory=max(memory) * 2)
workers = self.cluster.start_workers(1, memory=max(memory) * 4)
else:
workers = self.cluster.start_workers(1)
logger.info("Starting workers due to resource constraints: %s", workers)
Expand Down
98 changes: 43 additions & 55 deletions dask_drmaa/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import sys

import drmaa
from toolz import merge
from tornado.ioloop import PeriodicCallback

from distributed import LocalCluster
Expand All @@ -21,18 +22,19 @@ def get_session():
return _global_session[0]


default_template = {
'remoteCommand': os.path.join(sys.exec_prefix, 'bin', 'dask-worker'),
'jobName': 'dask-worker',
'outputPath': ':%s/out' % os.getcwd(),
'errorPath': ':%s/err' % os.getcwd(),
'workingDirectory': os.getcwd(),
'nativeSpecification': '',
'args': []
}


class DRMAACluster(object):
def __init__(self,
jobName='dask-worker',
remoteCommand=os.path.join(sys.exec_prefix, 'bin', 'dask-worker'),
args=(),
outputPath=':%s/out' % os.getcwd(),
errorPath=':%s/err' % os.getcwd(),
workingDirectory = os.getcwd(),
nativeSpecification='',
max_runtime='1:00:00', #1 hour
cleanup_interval=1000,
**kwargs):
def __init__(self, template=None, cleanup_interval=1000, hostname=None, **kwargs):
"""
Dask workers launched by a DRMAA-compatible cluster

Expand All @@ -50,8 +52,6 @@ def __init__(self,
Where dask-worker runs, defaults to current directory
nativeSpecification: string
Options native to the job scheduler
max_runtime: string
Maximum runtime of worker jobs in format ``"HH:MM:SS"``

Examples
--------
Expand All @@ -66,18 +66,11 @@ def __init__(self,
>>> future.result() # doctest: +SKIP
11
"""
logger.info("Start local scheduler")
self.hostname = hostname or socket.gethostname()
logger.info("Start local scheduler at %s", self.hostname)
self.local_cluster = LocalCluster(n_workers=0, **kwargs)
logger.info("Initialize connection to job scheduler")

self.jobName = jobName
self.remoteCommand = remoteCommand
self.args = ['%s:%d' % (socket.gethostname(),
self.local_cluster.scheduler.port)] + list(args)
self.outputPath = outputPath
self.errorPath = errorPath
self.nativeSpecification = nativeSpecification
self.max_runtime = max_runtime
self.template = merge(default_template, template or {})

self._cleanup_callback = PeriodicCallback(callback=self.cleanup_closed_workers,
callback_time=cleanup_interval,
Expand All @@ -92,39 +85,42 @@ def scheduler(self):

@property
def scheduler_address(self):
return self.scheduler.address

def createJobTemplate(self, nativeSpecification=''):
wt = get_session().createJobTemplate()
wt.jobName = self.jobName
wt.remoteCommand = self.remoteCommand
wt.args = self.args
wt.outputPath = self.outputPath
wt.errorPath = self.errorPath
wt.nativeSpecification = self.nativeSpecification + ' ' + nativeSpecification
return wt
return '%s:%d' % (self.hostname, self.scheduler.port)

def create_job_template(self, **kwargs):
template = self.template.copy()
if kwargs:
template.update(kwargs)
template['args'] = [self.scheduler_address] + template['args']

jt = get_session().createJobTemplate()
valid_attributes = dir(jt)

for key, value in template.items():
if key not in valid_attributes:
raise ValueError("Invalid job template attribute %s" % key)
setattr(jt, key, value)

return jt

def start_workers(self, n=1, **kwargs):
with log_errors():
wt = self.createJobTemplate(**kwargs)

ids = get_session().runBulkJobs(wt, 1, n, 1)
logger.info("Start %d workers. Job ID: %s", len(ids), ids[0].split('.')[0])
self.workers.update({jid: kwargs for jid in ids})
global_workers.update(ids)
with self.create_job_template(**kwargs) as jt:
ids = get_session().runBulkJobs(jt, 1, n, 1)
logger.info("Start %d workers. Job ID: %s", len(ids), ids[0].split('.')[0])
self.workers.update({jid: kwargs for jid in ids})

def stop_workers(self, worker_ids, sync=False):
worker_ids = list(worker_ids)
for wid in worker_ids:
if isinstance(worker_ids, str):
worker_ids = [worker_ids]

for wid in list(worker_ids):
try:
get_session().control(wid, drmaa.JobControlAction.TERMINATE)
except drmaa.errors.InvalidJobException:
pass
self.workers.pop(wid)

with ignoring(KeyError):
global_workers.remove(wid)

logger.info("Stop workers %s", worker_ids)
if sync:
get_session().synchronize(worker_ids, dispose=True)
Expand Down Expand Up @@ -159,18 +155,10 @@ def __str__(self):
__repr__ = __str__


global_workers = set()


def remove_workers():
if not get_session():
return

for wid in global_workers:
try:
get_session().control(wid, drmaa.JobControlAction.TERMINATE)
except drmaa.errors.InvalidJobException:
pass
get_session().control(drmaa.Session.JOB_IDS_SESSION_ALL,
drmaa.JobControlAction.TERMINATE)


import atexit
Expand Down
38 changes: 21 additions & 17 deletions dask_drmaa/sge.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,36 @@

class SGECluster(DRMAACluster):
default_memory = None
default_memory_fraction = 0.6
def createJobTemplate(self, nativeSpecification='', cpus=1, memory=None,
memory_fraction=None):

def create_job_template(self, nativeSpecification='', cpus=1, memory=None,
memory_fraction=0.5):
memory = memory or self.default_memory
memory_fraction = memory_fraction or self.default_memory_fraction
template = self.template.copy()

ns = template['nativeSpecification']
args = template['args']

args = [self.scheduler_address] + template['args']

args = self.args
ns = self.nativeSpecification
if nativeSpecification:
ns = ns + nativeSpecification
if memory:
args = args + ['--memory-limit', str(memory * memory_fraction)]
args = args + ['--resources', 'memory=%f' % (memory * 0.8)]
args = args + ['--memory-limit', str(memory * (1 - memory_fraction))]
args = args + ['--resources', 'memory=%f' % (memory * memory_fraction)]
ns += ' -l h_vmem=%dG' % int(memory / 1e9) # / cpus
if cpus:
args = args + ['--nprocs', '1', '--nthreads', str(cpus)]
# ns += ' -l TODO=%d' % (cpu + 1)

ns += ' -l h_rt={}'.format(self.max_runtime)
template['nativeSpecification'] = ns
template['args'] = args

jt = get_session().createJobTemplate()
valid_attributes = dir(jt)

wt = get_session().createJobTemplate()
wt.jobName = self.jobName
wt.remoteCommand = self.remoteCommand
wt.args = args
wt.outputPath = self.outputPath
wt.errorPath = self.errorPath
wt.nativeSpecification = ns
for key, value in template.items():
if key not in valid_attributes:
raise ValueError("Invalid job template attribute %s" % key)
setattr(jt, key, value)

return wt
return jt
14 changes: 7 additions & 7 deletions dask_drmaa/tests/test_adaptive.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

def test_adaptive_memory(loop):
with SGECluster(scheduler_port=0, cleanup_interval=100) as cluster:
adapt = Adaptive(cluster=cluster)
adapt = Adaptive(cluster)
with Client(cluster, loop=loop) as client:
future = client.submit(inc, 1, resources={'memory': 1e9})
assert future.result() == 2
Expand All @@ -34,7 +34,7 @@ def test_adaptive_memory(loop):

def test_adaptive_normal_tasks(loop):
with SGECluster(scheduler_port=0) as cluster:
adapt = Adaptive(cluster=cluster)
adapt = Adaptive(cluster)
with Client(cluster, loop=loop) as client:
future = client.submit(inc, 1)
assert future.result() == 2
Expand All @@ -43,7 +43,7 @@ def test_adaptive_normal_tasks(loop):
@pytest.mark.parametrize('interval', [50, 1000])
def test_dont_over_request(loop, interval):
with SGECluster(scheduler_port=0) as cluster:
adapt = Adaptive(cluster=cluster)
adapt = Adaptive(cluster)
with Client(cluster, loop=loop) as client:
future = client.submit(inc, 1)
assert future.result() == 2
Expand All @@ -56,7 +56,7 @@ def test_dont_over_request(loop, interval):

def test_request_more_than_one(loop):
with SGECluster(scheduler_port=0) as cluster:
adapt = Adaptive(cluster=cluster)
adapt = Adaptive(cluster)
with Client(cluster, loop=loop) as client:
futures = client.map(slowinc, range(1000), delay=0.2)
while len(cluster.scheduler.workers) < 3:
Expand All @@ -71,7 +71,7 @@ def test_dont_request_if_idle(loop):
sleep(0.1)
futures = client.map(slowinc, range(1000), delay=0.2,
workers=first(cluster.scheduler.workers))
adapt = Adaptive(cluster=cluster, interval=2000)
adapt = Adaptive(cluster, interval=2000)

for i in range(60):
sleep(0.1)
Expand All @@ -80,7 +80,7 @@ def test_dont_request_if_idle(loop):

def test_dont_request_if_not_enough_tasks(loop):
with SGECluster(scheduler_port=0) as cluster:
adapt = Adaptive(cluster=cluster)
adapt = Adaptive(cluster)
with Client(cluster, loop=loop) as client:
cluster.scheduler.task_duration['slowinc'] = 1000
future = client.submit(slowinc, 1, delay=1000)
Expand All @@ -92,7 +92,7 @@ def test_dont_request_if_not_enough_tasks(loop):

def test_dont_request_on_many_short_tasks(loop):
with SGECluster(scheduler_port=0) as cluster:
adapt = Adaptive(cluster=cluster, interval=50, startup_cost=10)
adapt = Adaptive(cluster, interval=50, startup_cost=10)
with Client(cluster, loop=loop) as client:
cluster.scheduler.task_duration['slowinc'] = 0.001
futures = client.map(slowinc, range(1000), delay=0.001)
Expand Down
34 changes: 34 additions & 0 deletions dask_drmaa/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,37 @@ def test_multiple_overlapping_clusters(loop):

assert future_1.result() == 2
assert future_2.result() == 3


def test_stop_single_worker(loop):
with DRMAACluster(scheduler_port=0) as cluster:
with Client(cluster, loop=loop) as client:
cluster.start_workers(2)
future = client.submit(lambda x: x + 1, 1)
assert future.result() == 2

a, b = cluster.workers
cluster.stop_workers(a)

start = time()
while len(client.ncores()) != 1:
sleep(0.2)
assert time() < start + 60


@pytest.mark.xfail(reason="Need mapping from worker addresses to job ids")
def test_stop_workers_politely(loop):
with DRMAACluster(scheduler_port=0) as cluster:
with Client(cluster, loop=loop) as client:
cluster.start_workers(2)

while len(client.ncores()) < 2:
sleep(0.1)

futures = client.scatter(list(range(10)))

a, b = cluster.workers
cluster.stop_workers(a)

data = client.gather(futures)
assert data == list(range(10))