Skip to content

Commit

Permalink
s2a: Cleanups to IntegrationTest
Browse files Browse the repository at this point in the history
Move unused and unimportant fields to local variables. pickUnusedPort()
is inherently racy, so avoid using it when unnecessary. The channel's
default executor is fine to use, but if you don't like it
directExecutor() would be an option too. But blocking stub doesn't even
use the executor for unary RPCs. Thread.join() does not propagate
exceptions from the Thread; it just waits for the thread to exit.
  • Loading branch information
ejona86 committed Sep 18, 2024
1 parent bdc0530 commit 9b0c19e
Showing 1 changed file with 27 additions and 59 deletions.
86 changes: 27 additions & 59 deletions s2a/src/test/java/io/grpc/s2a/handshaker/IntegrationTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@
import io.netty.handler.ssl.SslProvider;
import java.io.ByteArrayInputStream;
import java.io.File;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.logging.Level;
import java.util.concurrent.FutureTask;
import java.util.logging.Logger;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLSessionContext;
Expand Down Expand Up @@ -127,28 +125,21 @@ public final class IntegrationTest {
+ "-----END PRIVATE KEY-----";

private String s2aAddress;
private int s2aPort;
private Server s2aServer;
private String s2aDelayAddress;
private int s2aDelayPort;
private Server s2aDelayServer;
private String mtlsS2AAddress;
private int mtlsS2APort;
private Server mtlsS2AServer;
private int serverPort;
private String serverAddress;
private Server server;

@Before
public void setUp() throws Exception {
s2aPort = Utils.pickUnusedPort();
s2aServer = ServerBuilder.forPort(0).addService(new FakeS2AServer()).build().start();
int s2aPort = s2aServer.getPort();
s2aAddress = "localhost:" + s2aPort;
s2aServer = ServerBuilder.forPort(s2aPort).addService(new FakeS2AServer()).build();
logger.info("S2A service listening on localhost:" + s2aPort);
s2aServer.start();

mtlsS2APort = Utils.pickUnusedPort();
mtlsS2AAddress = "localhost:" + mtlsS2APort;
File s2aCert = new File("src/test/resources/server_cert.pem");
File s2aKey = new File("src/test/resources/server_key.pem");
File rootCert = new File("src/test/resources/root_cert.pem");
Expand All @@ -158,24 +149,25 @@ public void setUp() throws Exception {
.trustManager(rootCert)
.clientAuth(TlsServerCredentials.ClientAuth.REQUIRE)
.build();
mtlsS2AServer =
NettyServerBuilder.forPort(mtlsS2APort, s2aCreds).addService(new FakeS2AServer()).build();
logger.info("mTLS S2A service listening on localhost:" + mtlsS2APort);
mtlsS2AServer = NettyServerBuilder.forPort(0, s2aCreds).addService(new FakeS2AServer()).build();
mtlsS2AServer.start();
int mtlsS2APort = mtlsS2AServer.getPort();
mtlsS2AAddress = "localhost:" + mtlsS2APort;
logger.info("mTLS S2A service listening on localhost:" + mtlsS2APort);

s2aDelayPort = Utils.pickUnusedPort();
int s2aDelayPort = Utils.pickUnusedPort();
s2aDelayAddress = "localhost:" + s2aDelayPort;
s2aDelayServer = ServerBuilder.forPort(s2aDelayPort).addService(new FakeS2AServer()).build();

serverPort = Utils.pickUnusedPort();
serverAddress = "localhost:" + serverPort;
server =
NettyServerBuilder.forPort(serverPort)
NettyServerBuilder.forPort(0)
.addService(new SimpleServiceImpl())
.sslContext(buildSslContext())
.build();
.build()
.start();
int serverPort = server.getPort();
serverAddress = "localhost:" + serverPort;
logger.info("Simple Service listening on localhost:" + serverPort);
server.start();
}

@After
Expand All @@ -193,28 +185,23 @@ public void tearDown() throws Exception {

@Test
public void clientCommunicateUsingS2ACredentials_succeeds() throws Exception {
ExecutorService executor = Executors.newSingleThreadExecutor();
ChannelCredentials credentials =
S2AChannelCredentials.createBuilder(s2aAddress).setLocalSpiffeId("test-spiffe-id").build();
ManagedChannel channel =
Grpc.newChannelBuilder(serverAddress, credentials).executor(executor).build();
ManagedChannel channel = Grpc.newChannelBuilder(serverAddress, credentials).build();

assertThat(doUnaryRpc(executor, channel)).isTrue();
assertThat(doUnaryRpc(channel)).isTrue();
}

@Test
public void clientCommunicateUsingS2ACredentialsNoLocalIdentity_succeeds() throws Exception {
ExecutorService executor = Executors.newSingleThreadExecutor();
ChannelCredentials credentials = S2AChannelCredentials.createBuilder(s2aAddress).build();
ManagedChannel channel =
Grpc.newChannelBuilder(serverAddress, credentials).executor(executor).build();
ManagedChannel channel = Grpc.newChannelBuilder(serverAddress, credentials).build();

assertThat(doUnaryRpc(executor, channel)).isTrue();
assertThat(doUnaryRpc(channel)).isTrue();
}

@Test
public void clientCommunicateUsingMtlsToS2ACredentials_succeeds() throws Exception {
ExecutorService executor = Executors.newSingleThreadExecutor();
ChannelCredentials credentials =
MtlsToS2AChannelCredentials.createBuilder(
/* s2aAddress= */ mtlsS2AAddress,
Expand All @@ -224,41 +211,24 @@ public void clientCommunicateUsingMtlsToS2ACredentials_succeeds() throws Excepti
.build()
.setLocalSpiffeId("test-spiffe-id")
.build();
ManagedChannel channel =
Grpc.newChannelBuilder(serverAddress, credentials).executor(executor).build();
ManagedChannel channel = Grpc.newChannelBuilder(serverAddress, credentials).build();

assertThat(doUnaryRpc(executor, channel)).isTrue();
assertThat(doUnaryRpc(channel)).isTrue();
}

@Test
public void clientCommunicateUsingS2ACredentials_s2AdelayStart_succeeds() throws Exception {
DoUnaryRpc doUnaryRpc = new DoUnaryRpc();
doUnaryRpc.start();
ChannelCredentials credentials = S2AChannelCredentials.createBuilder(s2aDelayAddress).build();
ManagedChannel channel = Grpc.newChannelBuilder(serverAddress, credentials).build();

FutureTask<Boolean> rpc = new FutureTask<>(() -> doUnaryRpc(channel));
new Thread(rpc).start();
Thread.sleep(2000);
s2aDelayServer.start();
doUnaryRpc.join();
}

private class DoUnaryRpc extends Thread {
@Override
public void run() {
ExecutorService executor = Executors.newSingleThreadExecutor();
ChannelCredentials credentials = S2AChannelCredentials.createBuilder(s2aDelayAddress).build();
ManagedChannel channel =
Grpc.newChannelBuilder(serverAddress, credentials).executor(executor).build();
boolean result = false;
try {
result = doUnaryRpc(executor, channel);
} catch (InterruptedException e) {
logger.log(Level.SEVERE, "Failed to do unary rpc", e);
result = false;
}
assertThat(result).isTrue();
}
assertThat(rpc.get()).isTrue();
}

public static boolean doUnaryRpc(ExecutorService executor, ManagedChannel channel)
throws InterruptedException {
public static boolean doUnaryRpc(ManagedChannel channel) throws InterruptedException {
try {
SimpleServiceGrpc.SimpleServiceBlockingStub stub =
SimpleServiceGrpc.newBlockingStub(channel);
Expand All @@ -277,8 +247,6 @@ public static boolean doUnaryRpc(ExecutorService executor, ManagedChannel channe
} finally {
channel.shutdown();
channel.awaitTermination(1, SECONDS);
executor.shutdown();
executor.awaitTermination(1, SECONDS);
}
}

Expand Down Expand Up @@ -318,4 +286,4 @@ public void unaryRpc(SimpleRequest request, StreamObserver<SimpleResponse> obser
observer.onCompleted();
}
}
}
}

0 comments on commit 9b0c19e

Please sign in to comment.