Skip to content

Commit

Permalink
[GH-82] Enable concurrent stage in
Browse files Browse the repository at this point in the history
When `scratch` is enabled, use `nxf_parallel` to pull the input
files.
  • Loading branch information
jealous committed Sep 3, 2024
1 parent 4ea7d44 commit 627e00d
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import groovy.transform.CompileStatic
import groovy.util.logging.Slf4j
import nextflow.executor.SimpleFileCopyStrategy
import nextflow.file.FileSystemPathFactory
import nextflow.processor.TaskBean
import nextflow.util.Escape

import java.nio.file.Path
Expand All @@ -29,11 +30,27 @@ import java.nio.file.Path
class FloatFileCopyStrategy extends SimpleFileCopyStrategy {
private FloatConf conf

FloatFileCopyStrategy(FloatConf conf) {
super()
FloatFileCopyStrategy(FloatConf conf, TaskBean bean) {
super(bean)
this.conf = conf
}

@Override
String getStageInputFilesScript(Map<String,Path> inputFiles) {
def result = 'downloads=(true)\n'
result += super.getStageInputFilesScript(inputFiles) + '\n'
result += 'nxf_parallel "${downloads[@]}"\n'
return result
}

/**
* {@inheritDoc}
*/
@Override
String stageInputFile( Path path, String targetName ) {
return """downloads+=("${super.stageInputFile(path, targetName)}")"""
}

@Override
String getBeforeStartScript() {
def script = FloatBashLib.script(conf)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class FloatGridExecutor extends AbstractGridExecutor {

protected BashWrapperBuilder createBashWrapperBuilder(TaskRun task) {
final bean = new TaskBean(task)
final strategy = new FloatFileCopyStrategy(floatConf)
final strategy = new FloatFileCopyStrategy(floatConf, bean)
// creates the wrapper script
final builder = new BashWrapperBuilder(bean, strategy)
// job directives headers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package com.memverge.nextflow

import nextflow.Session
import nextflow.processor.TaskBean
import nextflow.processor.TaskConfig
import nextflow.processor.TaskId
import nextflow.processor.TaskProcessor
Expand Down Expand Up @@ -85,6 +86,7 @@ class FloatBaseTest extends BaseTest {
task.processor = Mock(TaskProcessor)
task.processor.getSession() >> Mock(Session)
task.processor.getExecutor() >> exec
task.processor.getProcessEnvironment() >> [:]
task.config = conf
task.id = new TaskId(id)
task.index = taskSerial.incrementAndGet()
Expand All @@ -93,6 +95,13 @@ class FloatBaseTest extends BaseTest {
return task
}

def newTaskBean(FloatTestExecutor exec, int id, TaskConfig conf = null) {
def task = newTask(exec, id, conf)
def bean = new TaskBean(task)
bean.stageInMode = 'copy'
return bean
}

def jobID(TaskId id) {
return "${FloatConf.NF_JOB_ID}:$tJob-$id"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,36 +17,54 @@ package com.memverge.nextflow

import java.nio.file.Paths

class FloatFileCopyStrategyTest extends BaseTest {
class FloatFileCopyStrategyTest extends FloatBaseTest {

def "get before script with conf"() {
given:
def conf = [
float: [maxParallelTransfers: 10]
]
def conf = [float: [maxParallelTransfers: 10]]
def exec = newTestExecutor()

when:
def fConf = FloatConf.getConf(conf)
def strategy = new FloatFileCopyStrategy(fConf)
def strategy = new FloatFileCopyStrategy(fConf, newTaskBean(exec, 1))
final script = strategy.beforeStartScript

then:
script.contains('cpus>10')
!script.contains('\nnull')
}

def "get stage input file script"() {
given:
def conf = [float:[]]
def exec = newTestExecutor()

when:
def fConf = FloatConf.getConf(conf)
def strategy = new FloatFileCopyStrategy(fConf, newTaskBean(exec, 1))
final script = strategy.getStageInputFilesScript(
['a': Paths.get('/target/A')])

then:
script.contains('downloads+=("cp -fRL /target/A a")')
script.contains('nxf_parallel')
!script.contains('\nnull')
}

def "get unstage output file script"() {
given:
def conf = [float:[]]
def exec = newTestExecutor()

when:
def fConf = FloatConf.getConf(conf)
def strategy = new FloatFileCopyStrategy(fConf)
def strategy = new FloatFileCopyStrategy(fConf, newTaskBean(exec, 1))
final script = strategy.getUnstageOutputFilesScript(
['a',], Paths.get('/target/A'))

then:
script.contains('eval "ls -1d a"')
script.contains('nxf_parallel')
script.contains('uploads+=("nxf_fs_move "$name" /target/A")')
}
}

0 comments on commit 627e00d

Please sign in to comment.