diff --git a/client-it/src/test/java/io/streamnative/oxia/client/it/ClientReconnectIT.java b/client-it/src/test/java/io/streamnative/oxia/client/it/ClientReconnectIT.java new file mode 100644 index 00000000..6c25533f --- /dev/null +++ b/client-it/src/test/java/io/streamnative/oxia/client/it/ClientReconnectIT.java @@ -0,0 +1,102 @@ +/* + * Copyright © 2022-2024 StreamNative Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.streamnative.oxia.client.it; + +import io.streamnative.oxia.client.api.AsyncOxiaClient; +import io.streamnative.oxia.client.api.GetResult; +import io.streamnative.oxia.client.api.OxiaClientBuilder; +import io.streamnative.oxia.testcontainers.OxiaContainer; +import java.nio.charset.StandardCharsets; +import java.util.concurrent.TimeUnit; +import lombok.extern.slf4j.Slf4j; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.testcontainers.containers.output.Slf4jLogConsumer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.shaded.org.awaitility.Awaitility; + +@Testcontainers +@Slf4j +public class ClientReconnectIT { + + @Container + private static final OxiaContainer oxia = + new OxiaContainer(OxiaContainer.DEFAULT_IMAGE_NAME, 4, true) + .withLogConsumer(new Slf4jLogConsumer(log)); + + @Test + public void testReconnection() { + final AsyncOxiaClient client = + OxiaClientBuilder.create(oxia.getServiceAddress()).asyncClient().join(); + final String key = "1"; + final byte[] value = "1".getBytes(StandardCharsets.UTF_8); + + final long startTime = System.currentTimeMillis(); + final long elapse = 3000L; + while (true) { + try { + Thread.sleep(500); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + + try { + client.put(key, value).get(1, TimeUnit.SECONDS); + } catch (Throwable ex) { + Assertions.fail("unexpected behaviour", ex); + } + + try { + final GetResult getResult = client.get("1").get(1, TimeUnit.SECONDS); + Assertions.assertArrayEquals(getResult.getValue(), value); + } catch (Throwable ex) { + Assertions.fail("unexpected behaviour", ex); + } + + if (System.currentTimeMillis() - startTime >= elapse) { + oxia.stop(); + + try { + Thread.sleep(3000); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + + oxia.start(); + + Awaitility.await() + .atMost(15, TimeUnit.SECONDS) + .untilAsserted( + () -> { + try { + client.put(key, value).get(1, TimeUnit.SECONDS); + } catch (Throwable ex) { + Assertions.fail("unexpected behaviour", ex); + } + + try { + final GetResult getResult = client.get("1").get(1, TimeUnit.SECONDS); + Assertions.assertArrayEquals(getResult.getValue(), value); + } catch (Throwable ex) { + Assertions.fail("unexpected behaviour", ex); + } + }); + break; + } + } + } +} diff --git a/client/src/main/java/io/streamnative/oxia/client/batch/WriteStreamWrapper.java b/client/src/main/java/io/streamnative/oxia/client/batch/WriteStreamWrapper.java index 8edea549..ee4eca2e 100644 --- a/client/src/main/java/io/streamnative/oxia/client/batch/WriteStreamWrapper.java +++ b/client/src/main/java/io/streamnative/oxia/client/batch/WriteStreamWrapper.java @@ -21,70 +21,66 @@ import io.streamnative.oxia.proto.WriteRequest; import io.streamnative.oxia.proto.WriteResponse; import java.util.ArrayDeque; +import java.util.Deque; import java.util.concurrent.CompletableFuture; import lombok.extern.slf4j.Slf4j; @Slf4j -public class WriteStreamWrapper { +public final class WriteStreamWrapper implements StreamObserver { private final StreamObserver clientStream; - - private final ArrayDeque> pendingWrites = new ArrayDeque<>(); - + private final Deque> pendingWrites = new ArrayDeque<>(); private volatile Throwable failed = null; public WriteStreamWrapper(OxiaClientGrpc.OxiaClientStub stub) { - this.clientStream = - stub.writeStream( - new StreamObserver<>() { - @Override - public void onNext(WriteResponse value) { - synchronized (WriteStreamWrapper.this) { - var future = pendingWrites.poll(); - if (future != null) { - future.complete(value); - } - } - } - - @Override - public void onError(Throwable t) { - synchronized (WriteStreamWrapper.this) { - if (!pendingWrites.isEmpty()) { - log.warn("Got Error", t); - } - pendingWrites.forEach(f -> f.completeExceptionally(t)); - pendingWrites.clear(); - failed = t; - } - } + this.clientStream = stub.writeStream(this); + } - @Override - public void onCompleted() {} - }); + public boolean isValid() { + return failed == null; } - public synchronized CompletableFuture send(WriteRequest request) { - if (failed != null) { - return CompletableFuture.failedFuture(failed); + @Override + public void onNext(WriteResponse value) { + synchronized (WriteStreamWrapper.this) { + final var future = pendingWrites.poll(); + if (future != null) { + future.complete(value); + } } + } - CompletableFuture future = new CompletableFuture<>(); - - try { - if (log.isDebugEnabled()) { - log.debug("Sending request {}", request); + @Override + public void onError(Throwable t) { + synchronized (WriteStreamWrapper.this) { + if (!pendingWrites.isEmpty()) { + log.warn("Got Error", t); } - clientStream.onNext(request); - pendingWrites.add(future); - } catch (Exception e) { - future.completeExceptionally(e); + pendingWrites.forEach(f -> f.completeExceptionally(t)); + pendingWrites.clear(); + failed = t; } - - return future; } - public boolean isValid() { - return failed == null; + @Override + public void onCompleted() {} + + public CompletableFuture send(WriteRequest request) { + synchronized (WriteStreamWrapper.this) { + if (failed != null) { + return CompletableFuture.failedFuture(failed); + } + final CompletableFuture future = new CompletableFuture<>(); + try { + if (log.isDebugEnabled()) { + log.debug("Sending request {}", request); + } + clientStream.onNext(request); + pendingWrites.add(future); + } catch (Exception e) { + future.completeExceptionally(e); + } + return future; + } } } diff --git a/client/src/main/java/io/streamnative/oxia/client/grpc/OxiaStub.java b/client/src/main/java/io/streamnative/oxia/client/grpc/OxiaStub.java index 5a207581..772b4786 100644 --- a/client/src/main/java/io/streamnative/oxia/client/grpc/OxiaStub.java +++ b/client/src/main/java/io/streamnative/oxia/client/grpc/OxiaStub.java @@ -98,16 +98,19 @@ public OxiaClientGrpc.OxiaClientStub async() { Metadata.Key.of("shard-id", Metadata.ASCII_STRING_MARSHALLER); public WriteStreamWrapper writeStream(long streamId) { - return writeStreams.computeIfAbsent( + return writeStreams.compute( streamId, - key -> { - Metadata headers = new Metadata(); - headers.put(NAMESPACE_KEY, namespace); - headers.put(SHARD_ID_KEY, String.format("%d", streamId)); + (key, stream) -> { + if (stream == null || !stream.isValid()) { + Metadata headers = new Metadata(); + headers.put(NAMESPACE_KEY, namespace); + headers.put(SHARD_ID_KEY, String.format("%d", streamId)); - OxiaClientGrpc.OxiaClientStub stub = - asyncStub.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(headers)); - return new WriteStreamWrapper(stub); + OxiaClientGrpc.OxiaClientStub stub = + asyncStub.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(headers)); + return new WriteStreamWrapper(stub); + } + return stream; }); } diff --git a/testcontainers/src/main/java/io/streamnative/oxia/testcontainers/OxiaContainer.java b/testcontainers/src/main/java/io/streamnative/oxia/testcontainers/OxiaContainer.java index a00c71f6..1eb4a8b7 100644 --- a/testcontainers/src/main/java/io/streamnative/oxia/testcontainers/OxiaContainer.java +++ b/testcontainers/src/main/java/io/streamnative/oxia/testcontainers/OxiaContainer.java @@ -17,8 +17,11 @@ import static lombok.AccessLevel.PRIVATE; +import java.io.IOException; +import java.net.ServerSocket; import java.time.Duration; import lombok.NonNull; +import lombok.SneakyThrows; import lombok.With; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; @@ -38,18 +41,29 @@ public class OxiaContainer extends GenericContainer { DockerImageName.parse("streamnative/oxia:main"); public OxiaContainer(@NonNull DockerImageName imageName) { - this(imageName, DEFAULT_SHARDS); + this(imageName, DEFAULT_SHARDS, false); } + public OxiaContainer(@NonNull DockerImageName imageName, int shards) { + this(imageName, shards, false); + } + + @SneakyThrows @SuppressWarnings("resource") - OxiaContainer(@NonNull DockerImageName imageName, int shards) { + public OxiaContainer(@NonNull DockerImageName imageName, int shards, boolean fixedServicePort) { super(imageName); this.imageName = imageName; this.shards = shards; if (shards <= 0) { throw new IllegalArgumentException("shards must be greater than zero"); } - addExposedPorts(OXIA_PORT, METRICS_PORT); + if (fixedServicePort) { + int freePort = findFreePort(); + addFixedExposedPort(freePort, OXIA_PORT); + addExposedPorts(METRICS_PORT); + } else { + addExposedPorts(OXIA_PORT, METRICS_PORT); + } setCommand("oxia", "standalone", "--shards=" + shards); waitingFor( Wait.forHttp("/metrics") @@ -58,6 +72,17 @@ public OxiaContainer(@NonNull DockerImageName imageName) { .withStartupTimeout(Duration.ofSeconds(30))); } + private static int findFreePort() throws IOException { + for (int i = 10000; i <= 20000; i++) { + try (ServerSocket socket = new ServerSocket(i)) { + return i; + } catch (Throwable ignore) { + + } + } + throw new IOException("No free port found in the specified range"); + } + public String getServiceAddress() { return getHost() + ":" + getMappedPort(OXIA_PORT); }