Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix to prevent invalid ShardConsumer state transitions due to rejected executions #560

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.RejectedExecutionException;
import java.util.function.Function;

import org.reactivestreams.Subscription;
Expand Down Expand Up @@ -164,6 +165,9 @@ public void executeLifecycle() {
} else if (needsInitialization) {
if (stateChangeFuture != null) {
if (stateChangeFuture.get()) {
// Task rejection during the subscribe() call will not be propagated back as it not executed
// in the context of the Scheduler thread. Hence we should not assume the subscription will
// always be successful.
subscribe();
needsInitialization = false;
}
Expand All @@ -177,6 +181,11 @@ public void executeLifecycle() {
//
} catch (ExecutionException e) {
throw new RuntimeException(e);
} catch (RejectedExecutionException e) {
// It is possible the tasks submitted to the executor service by the Scheduler thread might get rejected
// due to various reasons. Such failed executions must be captured and marked as failure to prevent
// the state transitions.
taskOutcome = TaskOutcome.FAILURE;
}

if (ConsumerStates.ShardConsumerState.PROCESSING.equals(currentState.state())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,17 @@
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.nullValue;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.fail;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyLong;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
Expand Down Expand Up @@ -58,6 +61,7 @@
import org.junit.rules.TestName;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.runners.MockitoJUnitRunner;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
Expand All @@ -69,6 +73,7 @@
import software.amazon.kinesis.leases.ShardInfo;
import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput;
import software.amazon.kinesis.lifecycle.events.TaskExecutionListenerInput;
import software.amazon.kinesis.lifecycle.ConsumerStates.ShardConsumerState;
import software.amazon.kinesis.retrieval.RecordsPublisher;
import software.amazon.kinesis.retrieval.RecordsRetrieved;
import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber;
Expand All @@ -95,6 +100,10 @@ public class ShardConsumerTest {
@Mock
private ShutdownNotification shutdownNotification;
@Mock
private ConsumerState blockedOnParentsState;
@Mock
private ConsumerTask blockedOnParentsTask;
@Mock
private ConsumerState initialState;
@Mock
private ConsumerTask initializeTask;
Expand All @@ -111,6 +120,8 @@ public class ShardConsumerTest {
@Mock
private TaskResult processingTaskResult;
@Mock
private TaskResult blockOnParentsTaskResult;
@Mock
private ConsumerState shutdownCompleteState;
@Mock
private ShardConsumerArgument shardConsumerArgument;
Expand Down Expand Up @@ -441,6 +452,144 @@ public final void testInitializationStateUponFailure() throws Exception {
verify(initialState, never()).shutdownTransition(any());
}

/**
* Test method to verify consumer undergoes the transition WAITING_ON_PARENT_SHARDS -> INITIALIZING -> PROCESSING
*/
@SuppressWarnings("unchecked")
@Test
public final void testSuccessfulConsumerStateTransition() throws Exception {
ExecutorService directExecutorService = spy(executorService);

doAnswer(invocation -> directlyExecuteRunnable(invocation))
.when(directExecutorService).execute(any());

ShardConsumer consumer = new ShardConsumer(recordsPublisher, directExecutorService, shardInfo,
logWarningForTaskAfterMillis, shardConsumerArgument, blockedOnParentsState,
t -> t, 1, taskExecutionListener, 0);

mockSuccessfulUnblockOnParents();
mockSuccessfulInitializeWithFailureTransition();
mockSuccessfulProcessing(null);

int arbitraryExecutionCount = 3;
do {
Copy link
Contributor

@micah-jaffe micah-jaffe Jun 25, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider writing as regular while loop for readability

try {
consumer.executeLifecycle();
} catch (Exception e) {
// Suppress any exception like the scheduler.
fail("Unexpected exception while executing consumer lifecycle");
}
} while (--arbitraryExecutionCount > 0);

assertEquals(ShardConsumerState.PROCESSING.consumerState().state(), consumer.currentState().state());
verify(directExecutorService, times(2)).execute(any());
}

/**
* Test method to verify consumer does not transition to PROCESSING from WAITING_ON_PARENT_SHARDS when
* INITIALIZING tasks gets rejected.
*/
@SuppressWarnings("unchecked")
@Test
public final void testConsumerNotTransitionsToProcessingWhenInitializationFails() {
ExecutorService failingService = spy(executorService);
ShardConsumer consumer = new ShardConsumer(recordsPublisher, failingService, shardInfo,
logWarningForTaskAfterMillis, shardConsumerArgument, blockedOnParentsState,
t -> t, 1, taskExecutionListener, 0);

mockSuccessfulUnblockOnParents();
mockSuccessfulInitializeWithFailureTransition();
mockSuccessfulProcessing(null);

// Failing the initialization task and all other attempts after that.
doAnswer(invocation -> directlyExecuteRunnable(invocation))
.doThrow(new RejectedExecutionException())
.when(failingService).execute(any());

int arbitraryExecutionCount = 5;
do {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to change: Out of curiosity, specific reason for do..while instead of for loop?

try {
consumer.executeLifecycle();
} catch (Exception e) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the test not fail if an exception is thrown?

// Suppress any exception like the scheduler.
fail("Unexpected exception while executing consumer lifecycle");
}
} while (--arbitraryExecutionCount > 0);

assertEquals(ShardConsumerState.INITIALIZING.consumerState().state(), consumer.currentState().state());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may want to validate the number of calls to the executorservice.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the expectation on number of calls

verify(failingService, times(5)).execute(any());
}

/**
* Test method to verify consumer transition to PROCESSING from WAITING_ON_PARENT_SHARDS with
* intermittent INITIALIZING task rejections.
*/
@SuppressWarnings("unchecked")
@Test
public final void testConsumerTransitionsToProcessingWithIntermittentInitializationFailures() {
ExecutorService failingService = spy(executorService);
ShardConsumer consumer = new ShardConsumer(recordsPublisher, failingService, shardInfo,
logWarningForTaskAfterMillis, shardConsumerArgument, blockedOnParentsState,
t -> t, 1, taskExecutionListener, 0);

mockSuccessfulUnblockOnParents();
mockSuccessfulInitializeWithFailureTransition();
mockSuccessfulProcessing(null);

// Failing the initialization task and few other attempts after that.
doAnswer(invocation -> directlyExecuteRunnable(invocation))
.doThrow(new RejectedExecutionException())
.doThrow(new RejectedExecutionException())
.doThrow(new RejectedExecutionException())
.doAnswer(invocation -> directlyExecuteRunnable(invocation))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it not possible to use ArgumentCaptor here?

.when(failingService).execute(any());

int arbitraryExecutionCount = 6;
do {
try {
consumer.executeLifecycle();
} catch (Exception e) {
// Suppress any exception like the scheduler.
fail("Unexpected exception while executing consumer lifecycle");
}
} while (--arbitraryExecutionCount > 0);

assertEquals(ShardConsumerState.PROCESSING.consumerState().state(), consumer.currentState().state());
verify(failingService, times(5)).execute(any());
}

/**
* Test method to verify consumer does not transition to INITIALIZING when WAITING_ON_PARENT_SHARDS task rejected.
*/
@SuppressWarnings("unchecked")
@Test
public final void testConsumerNotTransitionsToInitializingWhenWaitingOnParentsFails() {
ExecutorService failingService = spy(executorService);
ShardConsumer consumer = new ShardConsumer(recordsPublisher, failingService, shardInfo,
logWarningForTaskAfterMillis, shardConsumerArgument, blockedOnParentsState,
t -> t, 1, taskExecutionListener, 0);

mockSuccessfulUnblockOnParentsWithFailureTransition();
mockSuccessfulInitializeWithFailureTransition();

// Failing the waiting_on_parents task and few other attempts after that.
doThrow(new RejectedExecutionException())
.when(failingService).execute(any());

int arbitraryExecutionCount = 5;
do {
try {
consumer.executeLifecycle();
} catch (Exception e) {
// Suppress any exception like the scheduler.
fail("Unexpected exception while executing consumer lifecycle");
}
} while (--arbitraryExecutionCount > 0);

assertEquals(ShardConsumerState.WAITING_ON_PARENT_SHARDS.consumerState().state(), consumer.currentState().state());
verify(failingService, times(5)).execute(any());
}

/**
* Test method to verify consumer stays in INITIALIZING state when InitializationTask fails.
*/
Expand Down Expand Up @@ -742,6 +891,11 @@ private void mockSuccessfulProcessing(CyclicBarrier taskCallBarrier, CyclicBarri
when(processingState.state()).thenReturn(ConsumerStates.ShardConsumerState.PROCESSING);
}

private void mockSuccessfulInitializeWithFailureTransition() {
mockSuccessfulInitialize(null, null);
when(initialState.failureTransition()).thenReturn(initialState);
}

private void mockSuccessfulInitialize(CyclicBarrier taskCallBarrier) {
mockSuccessfulInitialize(taskCallBarrier, null);
}
Expand All @@ -763,6 +917,22 @@ private void mockSuccessfulInitialize(CyclicBarrier taskCallBarrier, CyclicBarri

}

private void mockSuccessfulUnblockOnParentsWithFailureTransition() {
mockSuccessfulUnblockOnParents();
when(blockedOnParentsState.failureTransition()).thenReturn(blockedOnParentsState);
}

private void mockSuccessfulUnblockOnParents() {
when(blockedOnParentsState.createTask(eq(shardConsumerArgument), any(), any())).thenReturn(blockedOnParentsTask);
when(blockedOnParentsState.taskType()).thenReturn(TaskType.BLOCK_ON_PARENT_SHARDS);
when(blockedOnParentsTask.taskType()).thenReturn(TaskType.BLOCK_ON_PARENT_SHARDS);
when(blockedOnParentsTask.call()).thenAnswer(i -> blockOnParentsTaskResult);
when(blockOnParentsTaskResult.getException()).thenReturn(null);
when(blockedOnParentsState.requiresDataAvailability()).thenReturn(false);
when(blockedOnParentsState.successTransition()).thenReturn(initialState);
when(blockedOnParentsState.state()).thenReturn(ShardConsumerState.WAITING_ON_PARENT_SHARDS);
}

private void awaitBarrier(CyclicBarrier barrier) throws Exception {
if (barrier != null) {
barrier.await();
Expand All @@ -773,4 +943,12 @@ private void awaitAndResetBarrier(CyclicBarrier barrier) throws Exception {
barrier.await();
barrier.reset();
}

private Object directlyExecuteRunnable(InvocationOnMock invocation) {
Object[] args = invocation.getArguments();
Runnable runnable = (Runnable) args[0];
runnable.run();
return null;
}

}