Skip to content

Commit

Permalink
Shard end v2 (#624)
Browse files Browse the repository at this point in the history
* Revalidate if current shard is closed before shutting down the ShardConsumer

* Renaming Method

* Force Lease to be lost before shutting down with Zombi state

* Adding comments for ShardEnd related unit tests
  • Loading branch information
ychunxue authored and Cory-Bradshaw committed Oct 23, 2019
1 parent 6fa9bab commit 1f68648
Show file tree
Hide file tree
Showing 8 changed files with 312 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ protected ShardConsumer buildConsumer(@NonNull final ShardInfo shardInfo,
checkpoint);
ShardConsumerArgument argument = new ShardConsumerArgument(shardInfo,
streamName,
leaseRefresher,
leaseCoordinator,
executorService,
cache,
shardRecordProcessorFactory.shardRecordProcessor(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,21 @@ public synchronized void checkAndCreateLeaseForNewShards(@NonNull final ShardDet
final boolean cleanupLeasesOfCompletedShards, final boolean ignoreUnexpectedChildShards,
final MetricsScope scope) throws DependencyException, InvalidStateException,
ProvisionedThroughputException, KinesisClientLibIOException {
final List<Shard> shards = getShardList(shardDetector);
log.debug("Num shards: {}", shards.size());
final List<Shard> latestShards = getShardList(shardDetector);
checkAndCreateLeaseForNewShards(shardDetector, leaseRefresher, initialPosition, cleanupLeasesOfCompletedShards,
ignoreUnexpectedChildShards, scope, latestShards);
}

final Map<String, Shard> shardIdToShardMap = constructShardIdToShardMap(shards);
//Provide a pre-collcted list of shards to avoid calling ListShards API
public synchronized void checkAndCreateLeaseForNewShards(@NonNull final ShardDetector shardDetector,
final LeaseRefresher leaseRefresher, final InitialPositionInStreamExtended initialPosition, final boolean cleanupLeasesOfCompletedShards,
final boolean ignoreUnexpectedChildShards, final MetricsScope scope, List<Shard> latestShards)throws DependencyException, InvalidStateException,
ProvisionedThroughputException, KinesisClientLibIOException {
if (!CollectionUtils.isNullOrEmpty(latestShards)) {
log.debug("Num shards: {}", latestShards.size());
}

final Map<String, Shard> shardIdToShardMap = constructShardIdToShardMap(latestShards);
final Map<String, Set<String>> shardIdToChildShardIdsMap = constructShardIdToChildShardIdsMap(
shardIdToShardMap);
final Set<String> inconsistentShardIds = findInconsistentShardIds(shardIdToChildShardIdsMap, shardIdToShardMap);
Expand All @@ -91,8 +102,7 @@ public synchronized void checkAndCreateLeaseForNewShards(@NonNull final ShardDet

final List<Lease> currentLeases = leaseRefresher.listLeases();

final List<Lease> newLeasesToCreate = determineNewLeasesToCreate(shards, currentLeases, initialPosition,
inconsistentShardIds);
final List<Lease> newLeasesToCreate = determineNewLeasesToCreate(latestShards, currentLeases, initialPosition, inconsistentShardIds);
log.debug("Num new leases to create: {}", newLeasesToCreate.size());
for (Lease lease : newLeasesToCreate) {
long startTime = System.currentTimeMillis();
Expand All @@ -104,14 +114,13 @@ public synchronized void checkAndCreateLeaseForNewShards(@NonNull final ShardDet
MetricsUtil.addSuccessAndLatency(scope, "CreateLease", success, startTime, MetricsLevel.DETAILED);
}
}

final List<Lease> trackedLeases = new ArrayList<>(currentLeases);
trackedLeases.addAll(newLeasesToCreate);
cleanupGarbageLeases(shardDetector, shards, trackedLeases, leaseRefresher);
cleanupGarbageLeases(shardDetector, latestShards, trackedLeases, leaseRefresher);
if (cleanupLeasesOfCompletedShards) {
cleanupLeasesOfFinishedShards(currentLeases, shardIdToShardMap, shardIdToChildShardIdsMap, trackedLeases,
leaseRefresher);
cleanupLeasesOfFinishedShards(currentLeases, shardIdToShardMap, shardIdToChildShardIdsMap, trackedLeases, leaseRefresher);
}

}
// CHECKSTYLE:ON CyclomaticComplexity

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ static class BlockedOnParentState implements ConsumerState {
@Override
public ConsumerTask createTask(ShardConsumerArgument consumerArgument, ShardConsumer consumer, ProcessRecordsInput input) {
return new BlockOnParentShardTask(consumerArgument.shardInfo(),
consumerArgument.leaseRefresher(),
consumerArgument.leaseCoordinator().leaseRefresher(),
consumerArgument.parentShardPollIntervalMillis());
}

Expand Down Expand Up @@ -492,7 +492,7 @@ public ConsumerTask createTask(ShardConsumerArgument argument, ShardConsumer con
argument.initialPositionInStream(),
argument.cleanupLeasesOfCompletedShards(),
argument.ignoreUnexpectedChildShards(),
argument.leaseRefresher(),
argument.leaseCoordinator(),
argument.taskBackoffTimeMillis(),
argument.recordsPublisher(),
argument.hierarchicalShardSyncer(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import software.amazon.kinesis.annotations.KinesisClientInternalApi;
import software.amazon.kinesis.checkpoint.ShardRecordProcessorCheckpointer;
import software.amazon.kinesis.common.InitialPositionInStreamExtended;
import software.amazon.kinesis.leases.LeaseCoordinator;
import software.amazon.kinesis.leases.LeaseRefresher;
import software.amazon.kinesis.leases.ShardDetector;
import software.amazon.kinesis.leases.ShardInfo;
Expand All @@ -42,7 +43,7 @@ public class ShardConsumerArgument {
@NonNull
private final String streamName;
@NonNull
private final LeaseRefresher leaseRefresher;
private final LeaseCoordinator leaseCoordinator;
@NonNull
private final ExecutorService executorService;
@NonNull
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,18 @@

import com.google.common.annotations.VisibleForTesting;

import com.sun.org.apache.bcel.internal.generic.LUSHR;
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import software.amazon.awssdk.services.kinesis.model.Shard;
import software.amazon.awssdk.utils.CollectionUtils;
import software.amazon.kinesis.annotations.KinesisClientInternalApi;
import software.amazon.kinesis.checkpoint.ShardRecordProcessorCheckpointer;
import software.amazon.kinesis.common.InitialPositionInStreamExtended;
import software.amazon.kinesis.leases.Lease;
import software.amazon.kinesis.leases.LeaseCoordinator;
import software.amazon.kinesis.leases.LeaseRefresher;
import software.amazon.kinesis.leases.ShardDetector;
import software.amazon.kinesis.leases.ShardInfo;
Expand All @@ -36,6 +42,10 @@
import software.amazon.kinesis.retrieval.RecordsPublisher;
import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;

/**
* Task for invoking the ShardRecordProcessor shutdown() callback.
*/
Expand All @@ -61,7 +71,7 @@ public class ShutdownTask implements ConsumerTask {
private final boolean cleanupLeasesOfCompletedShards;
private final boolean ignoreUnexpectedChildShards;
@NonNull
private final LeaseRefresher leaseRefresher;
private final LeaseCoordinator leaseCoordinator;
private final long backoffTimeMillis;
@NonNull
private final RecordsPublisher recordsPublisher;
Expand All @@ -88,20 +98,38 @@ public TaskResult call() {

try {
try {
ShutdownReason localReason = reason;
List<Shard> latestShards = null;
/*
* Revalidate if the current shard is closed before shutting down the shard consumer with reason SHARD_END
* If current shard is not closed, shut down the shard consumer with reason LEASE_LOST that allows
* active workers to contend for the lease of this shard.
*/
if (localReason == ShutdownReason.SHARD_END) {
latestShards = shardDetector.listShards();

//If latestShards is empty, should also shutdown the ShardConsumer without checkpoint with SHARD_END
if (CollectionUtils.isNullOrEmpty(latestShards) || !isShardInContextParentOfAny(latestShards)) {
localReason = ShutdownReason.LEASE_LOST;
dropLease();
log.info("Forcing the lease to be lost before shutting down the consumer for Shard: " + shardInfo.shardId());
}
}

// If we reached end of the shard, set sequence number to SHARD_END.
if (reason == ShutdownReason.SHARD_END) {
if (localReason == ShutdownReason.SHARD_END) {
recordProcessorCheckpointer
.sequenceNumberAtShardEnd(recordProcessorCheckpointer.largestPermittedCheckpointValue());
recordProcessorCheckpointer.largestPermittedCheckpointValue(ExtendedSequenceNumber.SHARD_END);
}

log.debug("Invoking shutdown() for shard {}, concurrencyToken {}. Shutdown reason: {}",
shardInfo.shardId(), shardInfo.concurrencyToken(), reason);
final ShutdownInput shutdownInput = ShutdownInput.builder().shutdownReason(reason)
shardInfo.shardId(), shardInfo.concurrencyToken(), localReason);
final ShutdownInput shutdownInput = ShutdownInput.builder().shutdownReason(localReason)
.checkpointer(recordProcessorCheckpointer).build();
final long startTime = System.currentTimeMillis();
try {
if (reason == ShutdownReason.SHARD_END) {
if (localReason == ShutdownReason.SHARD_END) {
shardRecordProcessor.shardEnded(ShardEndedInput.builder().checkpointer(recordProcessorCheckpointer).build());
ExtendedSequenceNumber lastCheckpointValue = recordProcessorCheckpointer.lastCheckpointValue();
if (lastCheckpointValue == null
Expand All @@ -123,11 +151,11 @@ public TaskResult call() {
MetricsUtil.addLatency(scope, RECORD_PROCESSOR_SHUTDOWN_METRIC, startTime, MetricsLevel.SUMMARY);
}

if (reason == ShutdownReason.SHARD_END) {
if (localReason == ShutdownReason.SHARD_END) {
log.debug("Looking for child shards of shard {}", shardInfo.shardId());
// create leases for the child shards
hierarchicalShardSyncer.checkAndCreateLeaseForNewShards(shardDetector, leaseRefresher,
initialPositionInStream, cleanupLeasesOfCompletedShards, ignoreUnexpectedChildShards, scope);
hierarchicalShardSyncer.checkAndCreateLeaseForNewShards(shardDetector, leaseCoordinator.leaseRefresher(),
initialPositionInStream, cleanupLeasesOfCompletedShards, ignoreUnexpectedChildShards, scope, latestShards);
log.debug("Finished checking for child shards of shard {}", shardInfo.shardId());
}

Expand Down Expand Up @@ -169,4 +197,26 @@ public ShutdownReason getReason() {
return reason;
}

private boolean isShardInContextParentOfAny(List<Shard> shards) {
for(Shard shard : shards) {
if (isChildShardOfShardInContext(shard)) {
return true;
}
}
return false;
}

private boolean isChildShardOfShardInContext(Shard shard) {
return (StringUtils.equals(shard.parentShardId(), shardInfo.shardId())
|| StringUtils.equals(shard.adjacentParentShardId(), shardInfo.shardId()));
}

private void dropLease() {
Lease currentLease = leaseCoordinator.getCurrentlyHeldLease(shardInfo.shardId());
leaseCoordinator.dropLease(currentLease);
if(currentLease != null) {
log.warn("Dropped lease for shutting down ShardConsumer: " + currentLease.leaseKey());
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,9 @@ public void testBootstrapShardLeasesAtLatest() throws Exception {
testCheckAndCreateLeasesForShardsIfMissing(INITIAL_POSITION_LATEST);
}

/**
* Test checkAndCreateLeaseForNewShards while not providing a pre-fetched list of shards
*/
@Test
public void testCheckAndCreateLeasesForShardsIfMissingAtLatest() throws Exception {
final List<Shard> shards = constructShardListForGraphA();
Expand Down Expand Up @@ -205,6 +208,74 @@ public void testCheckAndCreateLeasesForShardsIfMissingAtLatest() throws Exceptio

}

/**
* Test checkAndCreateLeaseForNewShards with a pre-fetched list of shards. In this scenario, shardDetector.listShards()
* should never be called.
*/
@Test
public void testCheckAndCreateLeasesForShardsWithShardList() throws Exception {
final List<Shard> latestShards = constructShardListForGraphA();

final ArgumentCaptor<Lease> leaseCaptor = ArgumentCaptor.forClass(Lease.class);
when(shardDetector.listShards()).thenReturn(latestShards);
when(dynamoDBLeaseRefresher.listLeases()).thenReturn(Collections.emptyList());
when(dynamoDBLeaseRefresher.createLeaseIfNotExists(leaseCaptor.capture())).thenReturn(true);

hierarchicalShardSyncer
.checkAndCreateLeaseForNewShards(shardDetector, dynamoDBLeaseRefresher, INITIAL_POSITION_LATEST,
cleanupLeasesOfCompletedShards, false, SCOPE, latestShards);

final Set<String> expectedShardIds = new HashSet<>(
Arrays.asList("shardId-4", "shardId-8", "shardId-9", "shardId-10"));

final List<Lease> requestLeases = leaseCaptor.getAllValues();
final Set<String> requestLeaseKeys = requestLeases.stream().map(Lease::leaseKey).collect(Collectors.toSet());
final Set<ExtendedSequenceNumber> extendedSequenceNumbers = requestLeases.stream().map(Lease::checkpoint)
.collect(Collectors.toSet());

assertThat(requestLeases.size(), equalTo(expectedShardIds.size()));
assertThat(requestLeaseKeys, equalTo(expectedShardIds));
assertThat(extendedSequenceNumbers.size(), equalTo(1));

extendedSequenceNumbers.forEach(seq -> assertThat(seq, equalTo(ExtendedSequenceNumber.LATEST)));

verify(shardDetector, never()).listShards();
verify(dynamoDBLeaseRefresher, times(expectedShardIds.size())).createLeaseIfNotExists(any(Lease.class));
verify(dynamoDBLeaseRefresher, never()).deleteLease(any(Lease.class));
}

/**
* Test checkAndCreateLeaseForNewShards with an empty list of shards. In this scenario, shardDetector.listShards()
* should never be called.
*/
@Test
public void testCheckAndCreateLeasesForShardsWithEmptyShardList() throws Exception {
final List<Shard> shards = constructShardListForGraphA();

final ArgumentCaptor<Lease> leaseCaptor = ArgumentCaptor.forClass(Lease.class);
when(shardDetector.listShards()).thenReturn(shards);
when(dynamoDBLeaseRefresher.listLeases()).thenReturn(Collections.emptyList());
when(dynamoDBLeaseRefresher.createLeaseIfNotExists(leaseCaptor.capture())).thenReturn(true);

hierarchicalShardSyncer
.checkAndCreateLeaseForNewShards(shardDetector, dynamoDBLeaseRefresher, INITIAL_POSITION_LATEST,
cleanupLeasesOfCompletedShards, false, SCOPE, new ArrayList<Shard>());

final Set<String> expectedShardIds = new HashSet<>();

final List<Lease> requestLeases = leaseCaptor.getAllValues();
final Set<String> requestLeaseKeys = requestLeases.stream().map(Lease::leaseKey).collect(Collectors.toSet());
final Set<ExtendedSequenceNumber> extendedSequenceNumbers = requestLeases.stream().map(Lease::checkpoint)
.collect(Collectors.toSet());

assertThat(requestLeases.size(), equalTo(expectedShardIds.size()));
assertThat(extendedSequenceNumbers.size(), equalTo(0));

verify(shardDetector, never()).listShards();
verify(dynamoDBLeaseRefresher, never()).createLeaseIfNotExists(any(Lease.class));
verify(dynamoDBLeaseRefresher, never()).deleteLease(any(Lease.class));
}

@Test
public void testCheckAndCreateLeasesForNewShardsAtTrimHorizon() throws Exception {
testCheckAndCreateLeaseForShardsIfMissing(constructShardListForGraphA(), INITIAL_POSITION_TRIM_HORIZON);
Expand Down Expand Up @@ -1035,7 +1106,11 @@ public void testDetermineNewLeasesToCreateGraphBNoInitialLeasesAtTimestamp() {

/*
* Helper method to construct a shard list for graph A. Graph A is defined below. Shard structure (y-axis is
* epochs): 0 1 2 3 4 5- shards till epoch 102 \ / \ / | | 6 7 4 5- shards from epoch 103 - 205 \ / | /\ 8 4 9 10 -
* epochs): 0 1 2 3 4 5- shards till
* \ / \ / | |
* 6 7 4 5- shards from epoch 103 - 205
* \ / | /\
* 8 4 9 10 -
* shards from epoch 206 (open - no ending sequenceNumber)
*/
private List<Shard> constructShardListForGraphA() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import software.amazon.kinesis.checkpoint.ShardRecordProcessorCheckpointer;
import software.amazon.kinesis.common.InitialPositionInStream;
import software.amazon.kinesis.common.InitialPositionInStreamExtended;
import software.amazon.kinesis.leases.LeaseCoordinator;
import software.amazon.kinesis.leases.LeaseRefresher;
import software.amazon.kinesis.leases.ShardDetector;
import software.amazon.kinesis.leases.ShardInfo;
Expand All @@ -55,6 +56,8 @@
import software.amazon.kinesis.retrieval.AggregatorUtil;
import software.amazon.kinesis.retrieval.RecordsPublisher;

import javax.swing.*;

@RunWith(MockitoJUnitRunner.class)
public class ConsumerStatesTest {
private static final String STREAM_NAME = "TestStream";
Expand All @@ -73,6 +76,8 @@ public class ConsumerStatesTest {
@Mock
private ShardInfo shardInfo;
@Mock
private LeaseCoordinator leaseCoordinator;
@Mock
private LeaseRefresher leaseRefresher;
@Mock
private Checkpointer checkpointer;
Expand Down Expand Up @@ -109,7 +114,7 @@ public class ConsumerStatesTest {

@Before
public void setup() {
argument = new ShardConsumerArgument(shardInfo, STREAM_NAME, leaseRefresher, executorService, recordsPublisher,
argument = new ShardConsumerArgument(shardInfo, STREAM_NAME, leaseCoordinator, executorService, recordsPublisher,
shardRecordProcessor, checkpointer, recordProcessorCheckpointer, parentShardPollIntervalMillis,
taskBackoffTimeMillis, skipShardSyncAtWorkerInitializationIfLeasesExist, listShardsBackoffTimeInMillis,
maxListShardsRetryAttempts, shouldCallProcessRecordsEvenForEmptyRecordList, idleTimeInMillis,
Expand All @@ -127,6 +132,7 @@ public void setup() {
@Test
public void blockOnParentStateTest() {
ConsumerState state = ShardConsumerState.WAITING_ON_PARENT_SHARDS.consumerState();
when(leaseCoordinator.leaseRefresher()).thenReturn(leaseRefresher);

ConsumerTask task = state.createTask(argument, consumer, null);

Expand Down Expand Up @@ -309,7 +315,7 @@ public void shuttingDownStateTest() {
assertThat(task, shutdownTask(ShardRecordProcessorCheckpointer.class, "recordProcessorCheckpointer",
equalTo(recordProcessorCheckpointer)));
assertThat(task, shutdownTask(ShutdownReason.class, "reason", equalTo(reason)));
assertThat(task, shutdownTask(LEASE_REFRESHER_CLASS, "leaseRefresher", equalTo(leaseRefresher)));
assertThat(task, shutdownTask(LeaseCoordinator.class, "leaseCoordinator", equalTo(leaseCoordinator)));
assertThat(task, shutdownTask(InitialPositionInStreamExtended.class, "initialPositionInStream",
equalTo(initialPositionInStream)));
assertThat(task,
Expand Down
Loading

0 comments on commit 1f68648

Please sign in to comment.