Skip to content

Use Okio's Pipe for stream piping #28

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions engine/src/main/java/de/gesellix/docker/engine/AttachConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -56,27 +56,39 @@ public void setOnResponse(Closure<?> onResponse) {
setOnResponse(onResponse::call);
}

/**
* @deprecated Internal use only. Will eventually be removed.
*/
@Deprecated
public Object onSinkClosed(Response r) {
return callbacks.onSinkClosed.apply(r);
}

/**
* @deprecated Internal use only. Will eventually be removed.
*/
@Deprecated
public void setOnSinkClosed(Function<Response, ?> onSinkClosed) {
callbacks.onSinkClosed = onSinkClosed;
}

/**
* @see #setOnSinkClosed(Function)
* @deprecated Will be removed after migration from Groovy to plain Java
* @deprecated Internal use only. Will eventually be removed.
*/
@Deprecated
public void setOnSinkClosed(Closure<?> onSinkClosed) {
setOnSinkClosed(onSinkClosed::call);
}

public Object onSinkWritten(Response r) {
public Object onStdInConsumed(Response r) {
return callbacks.onSinkWritten.apply(r);
}

public Object onSinkWritten(Response r) {
return onStdInConsumed(r);
}

public void setOnSinkWritten(Function<Response, ?> onSinkWritten) {
callbacks.onSinkWritten = onSinkWritten;
}
Expand All @@ -90,10 +102,14 @@ public void setOnSinkWritten(Closure<?> onSinkWritten) {
setOnSinkWritten(onSinkWritten::call);
}

public Object onSourceConsumed() {
public Object onStdOutConsumed() {
return callbacks.onSourceConsumed.get();
}

public Object onSourceConsumed() {
return onStdOutConsumed();
}

public void setOnSourceConsumed(Supplier<?> onSourceConsumed) {
callbacks.onSourceConsumed = onSourceConsumed;
}
Expand Down Expand Up @@ -151,6 +167,10 @@ public static class Callbacks {

private Function<Exception, ?> onFailure = (Exception e) -> null;
private Function<Response, ?> onResponse = (Response r) -> null;
/**
* @deprecated Internal use only. Will eventually be removed.
*/
@Deprecated
private Function<Response, ?> onSinkClosed = (Response r) -> null;
private Function<Response, ?> onSinkWritten = (Response r) -> null;
private Supplier<?> onSourceConsumed = () -> null;
Expand Down
147 changes: 65 additions & 82 deletions engine/src/main/java/de/gesellix/docker/engine/OkResponseCallback.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,20 @@
import okhttp3.Call;
import okhttp3.Callback;
import okhttp3.Response;
import okio.BufferedSink;
import okio.Buffer;
import okio.Okio;
import okio.Pipe;
import okio.Sink;
import okio.Source;
import org.jetbrains.annotations.NotNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.util.Timer;
import java.util.TimerTask;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.function.BiConsumer;

public class OkResponseCallback implements Callback {

Expand All @@ -40,44 +41,61 @@ public void onFailure(Exception e) {
attachConfig.onFailure(e);
}

/** Reads all bytes from {@code source} and writes them to {@code sink}. */
private Long readAll(Source source, Sink sink) throws IOException {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: document credits, code has been taken from Okio's tests.

long result = 0L;
// Okio.buffer(sink).writeAll(source);
Buffer buffer = new Buffer();
for (long count; (count = source.read(buffer, 1024)) != -1L; result += count) {
sink.write(buffer, count);
}
return result;
}

/** Calls {@link #readAll} on a background thread. */
private Future<Long> readAllAsync(final Source source, final Sink sink) {
ExecutorService executor = Executors.newSingleThreadExecutor();
try {
return executor.submit(() -> readAll(source, sink));
}
finally {
executor.shutdown();
}
}

private Thread transfer(Source source, Sink sink, BiConsumer<Long, Long> onFinish) {
return new Thread(() -> {
Pipe p = new Pipe(1024);
try {
Future<Long> futureSink = readAllAsync(p.source(), sink);
Future<Long> futureSource = readAllAsync(source, p.sink());
Long read = futureSource.get();
p.sink().flush();
p.sink().close();
Long written = futureSink.get();
onFinish.accept(read, written);
}
catch (Exception e) {
log.warn("error", e);
onFailure(e);
}
});
}

@Override
public void onResponse(@NotNull final Call call, @NotNull final Response response) throws IOException {
TcpUpgradeVerificator.ensureTcpUpgrade(response);

if (attachConfig.getStreams().getStdin() != null) {
// pass input from the client via stdin and pass it to the output stream
// running it in an own thread allows the client to gain back control
final Source stdinSource = Okio.source(attachConfig.getStreams().getStdin());
Thread writer = new Thread(() -> {
try {
final BufferedSink bufferedSink = Okio.buffer(getConnectionProvider().getSink());
bufferedSink.writeAll(stdinSource);
bufferedSink.flush();
attachConfig.onSinkWritten(response);
CountDownLatch done = new CountDownLatch(1);
delayed(100, "writer", () -> {
try {
bufferedSink.close();
attachConfig.onSinkClosed(response);
}
catch (Exception e) {
log.warn("error", e);
}
return null;
}, done);
done.await(5, TimeUnit.SECONDS);
}
catch (InterruptedException e) {
log.debug("stdin->sink interrupted", e);
Thread.currentThread().interrupt();
}
catch (Exception e) {
onFailure(e);
}
finally {
log.trace("writer finished");
}
});
// client's stdin -> socket
Thread writer = transfer(
Okio.source(attachConfig.getStreams().getStdin()),
connectionProvider.getSink(),
(read, written) -> {
log.warn("read {}, written {}", read, written);
attachConfig.onStdInConsumed(response);
attachConfig.onSinkClosed(response);
});
writer.setName("stdin-writer " + call.request().url().encodedPath());
writer.start();
}
Expand All @@ -86,29 +104,14 @@ public void onResponse(@NotNull final Call call, @NotNull final Response respons
}

if (attachConfig.getStreams().getStdout() != null) {
final BufferedSink bufferedStdout = Okio.buffer(Okio.sink(attachConfig.getStreams().getStdout()));
Thread reader = new Thread(() -> {
try {
bufferedStdout.writeAll(getConnectionProvider().getSource());
bufferedStdout.flush();
CountDownLatch done = new CountDownLatch(1);
delayed(100, "reader", () -> {
attachConfig.onSourceConsumed();
return null;
}, done);
done.await(5, TimeUnit.SECONDS);
}
catch (InterruptedException e) {
log.debug("source->stdout interrupted", e);
Thread.currentThread().interrupt();
}
catch (Exception e) {
onFailure(e);
}
finally {
log.trace("reader finished");
}
});
// client's stdout <- socket
Thread reader = transfer(
connectionProvider.getSource(),
Okio.sink(attachConfig.getStreams().getStdout()),
(read, written) -> {
log.warn("read {}, written {}", read, written);
attachConfig.onStdOutConsumed();
});
reader.setName("stdout-reader " + call.request().url().encodedPath());
reader.start();
}
Expand All @@ -118,24 +121,4 @@ public void onResponse(@NotNull final Call call, @NotNull final Response respons

attachConfig.onResponse(response);
}

public static void delayed(long delay, String name, final Supplier<?> action, final CountDownLatch done) {
new Timer(true).schedule(new TimerTask() {
@Override
public void run() {
Thread.currentThread().setName("Delayed " + name + " action (" + Thread.currentThread().getName() + ")");
try {
action.get();
}
finally {
done.countDown();
cancel();
}
}
}, delay);
}

public ConnectionProvider getConnectionProvider() {
return connectionProvider;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,13 @@ class OkDockerClientIntegrationSpec extends Specification {
def onSinkWritten = new CountDownLatch(1)
def onSourceConsumed = new CountDownLatch(1)

Response okResponse = null
def attachConfig = new AttachConfig()
attachConfig.streams.stdin = new PipedInputStream(stdin)
attachConfig.streams.stdout = stdout
attachConfig.onResponse = { Response res ->
okResponse = res
}
attachConfig.onSinkClosed = { Response response ->
log.info("[attach (interactive)] sink closed \n${stdout.toString()}")
onSinkClosed.countDown()
Expand All @@ -178,6 +182,7 @@ class OkDockerClientIntegrationSpec extends Specification {
stdin.write("$content\n".bytes)
stdin.flush()
stdin.close()
// okResponse?.close()
def sourceConsumed = onSourceConsumed.await(5, SECONDS)
def sinkWritten = onSinkWritten.await(5, SECONDS)
def sinkClosed = onSinkClosed.await(5, SECONDS)
Expand Down