Skip to content

Commit

Permalink
fix: ensure cert refresh recovers from sleep
Browse files Browse the repository at this point in the history
  • Loading branch information
ttosta-google committed Jan 20, 2024
1 parent ffccf0c commit c1263ad
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 11 deletions.
12 changes: 11 additions & 1 deletion core/src/main/java/com/google/cloud/sql/core/Connector.java
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,17 @@ Socket connect(ConnectionConfig config, long timeoutMs) throws IOException {
}

DefaultConnectionInfoCache getConnection(ConnectionConfig config) {
return instances.computeIfAbsent(config, k -> createConnectionInfo(config));
DefaultConnectionInfoCache instance =
instances.computeIfAbsent(config, k -> createConnectionInfo(config));

// If the client certificate has expired (as when the computer goes to
// sleep, and the refresh cycle cannot run), force a refresh immediately.
// The TLS handshake will not fail on an expired client certificate. It's
// not until the first read where the client cert error will be surfaced.
// So check that the certificate is valid before proceeding.
instance.refreshIfExpired();

return instance;
}

private DefaultConnectionInfoCache createConnectionInfo(ConnectionConfig config) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@ void forceRefresh() {
this.refresher.forceRefresh();
}

void refreshIfExpired() {
this.refresher.refreshIfExpired();
}

ListenableFuture<ConnectionInfo> getNext() {
return refresher.getNext();
}
Expand Down
55 changes: 45 additions & 10 deletions core/src/main/java/com/google/cloud/sql/core/Refresher.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
/** Handles periodic refresh operations for an instance. */
class Refresher {
private static final Logger logger = LoggerFactory.getLogger(Refresher.class);
private static final long DEFAULT_CONNECT_TIMEOUT_MS = 45000;

private final ListeningScheduledExecutorService executor;

Expand All @@ -58,6 +59,9 @@ class Refresher {
@GuardedBy("connectionInfoGuard")
private boolean closed;

@GuardedBy("connectionInfoGuard")
private boolean triggerNextRefresh = true;

/**
* Create a new refresher.
*
Expand All @@ -71,10 +75,29 @@ class Refresher {
ListeningScheduledExecutorService executor,
Supplier<ListenableFuture<ConnectionInfo>> refreshOperation,
AsyncRateLimiter rateLimiter) {
this(name, executor, refreshOperation, rateLimiter, true);
}

/**
* Create a new refresher.
*
* @param name the name of what is being refreshed, for logging.
* @param executor the executor to schedule refresh tasks.
* @param refreshOperation The supplier that refreshes the data.
* @param rateLimiter The rate limiter.
* @param triggerNextRefresh The next refresh operation should be triggered.
*/
Refresher(
String name,
ListeningScheduledExecutorService executor,
Supplier<ListenableFuture<ConnectionInfo>> refreshOperation,
AsyncRateLimiter rateLimiter,
boolean triggerNextRefresh) {
this.name = name;
this.executor = executor;
this.refreshOperation = refreshOperation;
this.rateLimiter = rateLimiter;
this.triggerNextRefresh = triggerNextRefresh;
synchronized (connectionInfoGuard) {
forceRefresh();
this.current = this.next;
Expand Down Expand Up @@ -156,6 +179,18 @@ void forceRefresh() {
}
}

/** Force a new refresh of the instance data if the client certificate has expired. */
void refreshIfExpired() {
ConnectionInfo info = getConnectionInfo(DEFAULT_CONNECT_TIMEOUT_MS);
if (Instant.now().isAfter(info.getExpiration())) {
logger.debug(
String.format(
"[%s] Client certificate has expired. Starting next refresh operation immediately.",
name));
forceRefresh();
}
}

/**
* Triggers an update of internal information obtained from the Cloud SQL Admin API, returning a
* future that resolves once a valid T has been acquired. This sets up a chain of futures that
Expand Down Expand Up @@ -202,15 +237,6 @@ private ListenableFuture<ConnectionInfo> handleRefreshResult(
long secondsToRefresh =
refreshCalculator.calculateSecondsUntilNextRefresh(Instant.now(), info.getExpiration());

logger.debug(
String.format(
"[%s] Refresh Operation: Next operation scheduled at %s.",
name,
Instant.now()
.plus(secondsToRefresh, ChronoUnit.SECONDS)
.truncatedTo(ChronoUnit.SECONDS)
.toString()));

synchronized (connectionInfoGuard) {
// Refresh completed successfully, reset forceRefreshRunning.
refreshRunning = false;
Expand All @@ -219,7 +245,16 @@ private ListenableFuture<ConnectionInfo> handleRefreshResult(

// Now update nextInstanceData to perform a refresh after the
// scheduled delay
if (!closed) {
if (!closed && triggerNextRefresh) {
logger.debug(
String.format(
"[%s] Refresh Operation: Next operation scheduled at %s.",
name,
Instant.now()
.plus(secondsToRefresh, ChronoUnit.SECONDS)
.truncatedTo(ChronoUnit.SECONDS)
.toString()));

next =
Futures.scheduleAsync(
this::startRefreshAttempt, secondsToRefresh, TimeUnit.SECONDS, executor);
Expand Down
39 changes: 39 additions & 0 deletions core/src/test/java/com/google/cloud/sql/core/RefresherTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,45 @@ public void testClosedCloudSqlInstanceDataStopsRefreshTasks() throws Exception {
assertThat(r.getNext().isCancelled()).isTrue();
}

@Test
public void testCloudSqlRefreshesTokenIfExpired() throws Exception {
ExampleData initialData = new ExampleData(Instant.now().minus(2, ChronoUnit.SECONDS));
ExampleData data = new ExampleData(Instant.now().plus(1, ChronoUnit.HOURS));

AtomicInteger refreshCount = new AtomicInteger();
final PauseCondition refresh1 = new PauseCondition();

Refresher r =
new Refresher(
"RefresherTest.testCloudSqlRefreshesTokenIfExpired",
executorService,
() -> {
int c = refreshCount.get();
ExampleData refreshResult = data;
switch (c) {
case 0:
// refresh 0 should return initialData immediately
refreshResult = initialData;
break;
}
// refresh 2 and on should return data immediately
refreshCount.incrementAndGet();
return Futures.immediateFuture(refreshResult);
},
rateLimiter,
false);

// Get the first data that is about to expire
refresh1.waitForCondition(() -> r.getConnectionInfo(TEST_TIMEOUT_MS) == initialData, 1000L);
assertThat(refreshCount.get()).isEqualTo(1);

r.refreshIfExpired();

// getConnectionInfo again, and assert the refresh operation completed.
refresh1.waitForCondition(() -> r.getConnectionInfo(TEST_TIMEOUT_MS) == data, 1000L);
assertThat(refreshCount.get()).isEqualTo(2);
}

private static class ExampleData extends ConnectionInfo {

ExampleData(Instant expiration) {
Expand Down

0 comments on commit c1263ad

Please sign in to comment.