Skip to content
This repository has been archived by the owner on Jul 1, 2022. It is now read-only.

Concurrency improvements to RemoteControlledSampler #609

Merged
merged 6 commits into from
Apr 29, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,15 @@ public class RemoteControlledSampler implements Sampler {
private final int maxOperations = 2000;
private final SamplingManager manager;

@Getter(AccessLevel.PACKAGE)
private Sampler sampler;
// initialized in constructor and updated from a single (poll timer) thread
// volatile to guarantee immediate visibility of the updated sampler to other threads (remove if not a requirement)
@Getter(AccessLevel.PACKAGE) // visible for testing
private volatile Sampler sampler;

// most of the time, toString here is called from the JaegerTracer, which holds this as well
@ToString.Exclude private final String serviceName;

@ToString.Exclude private final Timer pollTimer;
@ToString.Exclude private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
@ToString.Exclude private final Metrics metrics;

private RemoteControlledSampler(Builder builder) {
Expand All @@ -67,22 +68,21 @@ private RemoteControlledSampler(Builder builder) {
new TimerTask() {
@Override
public void run() {
updateSampler();
try {
updateSampler();
} catch (Exception e) { // keep the timer thread alive
log.error("Failed to update sampler", e);
}
}
},
0,
builder.poolingIntervalMs);
return;
}

public ReentrantReadWriteLock getLock() {
return lock;
builder.pollingIntervalMs);
}

/**
* Updates {@link #sampler} to a new sampler when it is different.
*/
void updateSampler() {
void updateSampler() { // visible for testing
SamplingStrategyResponse response;
try {
response = manager.getSamplingStrategy(serviceName);
Expand Down Expand Up @@ -117,29 +117,27 @@ private void updateRateLimitingOrProbabilisticSampler(SamplingStrategyResponse r
return;
}

synchronized (this) {
if (!this.sampler.equals(sampler)) {
this.sampler = sampler;
metrics.samplerUpdated.inc(1);
}
if (!this.sampler.equals(sampler)) {
this.sampler = sampler;
metrics.samplerUpdated.inc(1);
}
}

private synchronized void updatePerOperationSampler(OperationSamplingParameters samplingParameters) {
if (sampler instanceof PerOperationSampler) {
if (((PerOperationSampler) sampler).update(samplingParameters)) {
private void updatePerOperationSampler(OperationSamplingParameters samplingParameters) {
Sampler currentSampler = sampler;
if (currentSampler instanceof PerOperationSampler) {
if (((PerOperationSampler) currentSampler).update(samplingParameters)) {
metrics.samplerUpdated.inc(1);
}
} else {
sampler = new PerOperationSampler(maxOperations, samplingParameters);
metrics.samplerUpdated.inc(1);
}
}

@Override
public SamplingStatus sample(String operation, long id) {
synchronized (this) {
return sampler.sample(operation, id);
}
return sampler.sample(operation, id);
}

@Override
Expand All @@ -149,32 +147,22 @@ public boolean equals(Object sampler) {
}
if (sampler instanceof RemoteControlledSampler) {
RemoteControlledSampler remoteSampler = ((RemoteControlledSampler) sampler);
synchronized (this) {
ReentrantReadWriteLock.ReadLock readLock = remoteSampler.getLock().readLock();
readLock.lock();
try {
return this.sampler.equals(remoteSampler.sampler);
} finally {
readLock.unlock();
}
}
return this.sampler.equals(remoteSampler.sampler);
}
return false;
}

@Override
public void close() {
synchronized (this) {
pollTimer.cancel();
}
pollTimer.cancel();
}

public static class Builder {
private final String serviceName;
private SamplingManager samplingManager;
private Sampler initialSampler;
private Metrics metrics;
private int poolingIntervalMs = DEFAULT_POLLING_INTERVAL_MS;
private int pollingIntervalMs = DEFAULT_POLLING_INTERVAL_MS;

public Builder(String serviceName) {
this.serviceName = serviceName;
Expand All @@ -196,7 +184,7 @@ public Builder withMetrics(Metrics metrics) {
}

public Builder withPollingInterval(int pollingIntervalMs) {
this.poolingIntervalMs = pollingIntervalMs;
this.pollingIntervalMs = pollingIntervalMs;
return this;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,36 +17,42 @@
import io.jaegertracing.internal.clock.Clock;
import io.jaegertracing.internal.clock.SystemClock;

import java.util.concurrent.atomic.AtomicLong;

public class RateLimiter {
private final double creditsPerNanosecond;
private final Clock clock;
private double balance;
private double maxBalance;
private long lastTick;
private final double creditsPerNanosecond;
private final long maxBalance; // max balance in nano ticks
private final AtomicLong debit; // last op nano time less remaining balance

public RateLimiter(double creditsPerSecond, double maxBalance) {
this(creditsPerSecond, maxBalance, new SystemClock());
}

public RateLimiter(double creditsPerSecond, double maxBalance, Clock clock) {
this.clock = clock;
this.balance = maxBalance;
this.maxBalance = maxBalance;
this.creditsPerNanosecond = creditsPerSecond / 1.0e9;
this.maxBalance = (long) (maxBalance / creditsPerNanosecond);
this.debit = new AtomicLong(clock.currentNanoTicks() - this.maxBalance);
}

public boolean checkCredit(double itemCost) {
long currentTime = clock.currentNanoTicks();
double elapsedTime = currentTime - lastTick;
lastTick = currentTime;
balance += elapsedTime * creditsPerNanosecond;
if (balance > maxBalance) {
balance = maxBalance;
}
if (balance >= itemCost) {
balance -= itemCost;
return true;
}
return false;
long cost = (long) (itemCost / creditsPerNanosecond);
long credit;
long currentDebit;
long balance;
do {
currentDebit = debit.get();
credit = clock.currentNanoTicks();
balance = credit - currentDebit;
if (balance > maxBalance) {
balance = maxBalance;
}
balance -= cost;
if (balance < 0) {
return false;
}
} while (!debit.compareAndSet(currentDebit, credit - balance));
return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
Expand All @@ -33,6 +34,9 @@
import io.jaegertracing.spi.SamplingManager;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

import org.junit.After;
import org.junit.Before;
import org.junit.Test;
Expand All @@ -52,6 +56,7 @@ public class RemoteControlledSamplerTest {
@Before
public void setUp() throws Exception {
metrics = new Metrics(new InMemoryMetricsFactory());
// TODO this starts the timer with mocks not yet configured, causing NPEs; refactor to .build() from tests
undertest = new RemoteControlledSampler.Builder(SERVICE_NAME)
.withSamplingManager(samplingManager)
.withInitialSampler(initialSampler)
Expand Down Expand Up @@ -107,6 +112,7 @@ public void testUpdateToPerOperationSamplerReplacesProbabilisticSampler() throws

@Test
public void testUpdatePerOperationSamplerUpdatesExistingPerOperationSampler() throws Exception {
undertest.close();
PerOperationSampler perOperationSampler = mock(PerOperationSampler.class);
OperationSamplingParameters parameters = mock(OperationSamplingParameters.class);
when(samplingManager.getSamplingStrategy(SERVICE_NAME)).thenReturn(
Expand Down Expand Up @@ -138,6 +144,23 @@ public void testUnparseableResponse() throws Exception {
assertEquals(initialSampler, undertest.getSampler());
}

@Test
public void testUpdateFailureKeepsTimerRunning() throws InterruptedException {
undertest.close();
CountDownLatch latch = new CountDownLatch(3);
SamplingManager failingManager = serviceName -> {
latch.countDown();
throw new RuntimeException("test update failure");
};
undertest = new RemoteControlledSampler.Builder(SERVICE_NAME)
.withSamplingManager(failingManager)
.withInitialSampler(initialSampler)
.withMetrics(metrics)
.withPollingInterval(1)
.build();
assertTrue(latch.await(1, TimeUnit.SECONDS));
}

@Test
public void testSample() throws Exception {
undertest.sample("op", 1L);
Expand All @@ -160,6 +183,7 @@ public void testEquals() {

@Test
public void testDefaultProbabilisticSampler() {
undertest.close();
undertest = new RemoteControlledSampler.Builder(SERVICE_NAME)
.withSamplingManager(samplingManager)
.withInitialSampler(null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,19 @@

package io.jaegertracing.internal.utils;

import static junit.framework.TestCase.assertFalse;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;

import io.jaegertracing.internal.clock.Clock;

import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

import org.junit.Test;

public class RateLimiterTest {
RateLimiter limiter;

private static class MockClock implements Clock {

Expand Down Expand Up @@ -127,4 +131,59 @@ public void testRateLimiterMaxBalance() {
assertTrue(limiter.checkCredit(1.0));
assertFalse(limiter.checkCredit(1.0));
}

/**
* Validates rate limiter behavior with {@link System#nanoTime()}-like (non-zero) initial nano ticks.
*/
@Test
public void testRateLimiterInitial() {
MockClock clock = new MockClock();
clock.timeNanos = TimeUnit.MILLISECONDS.toNanos(-1_000_000);
RateLimiter limiter = new RateLimiter(1000, 100, clock);

assertTrue(limiter.checkCredit(100)); // consume initial (max) balance
assertFalse(limiter.checkCredit(1));

clock.timeNanos += TimeUnit.MILLISECONDS.toNanos(49); // add 49 credits
assertFalse(limiter.checkCredit(50));

clock.timeNanos += TimeUnit.MILLISECONDS.toNanos(1); // add one credit
assertTrue(limiter.checkCredit(50)); // consume accrued balance
assertFalse(limiter.checkCredit(1));

clock.timeNanos += TimeUnit.MILLISECONDS.toNanos(1_000_000); // add a lot of credits (max out balance)
assertTrue(limiter.checkCredit(1)); // take one credit

clock.timeNanos += TimeUnit.MILLISECONDS.toNanos(1_000_000); // add a lot of credits (max out balance)
assertFalse(limiter.checkCredit(101)); // can't consume more than max balance
assertTrue(limiter.checkCredit(100)); // consume max balance
assertFalse(limiter.checkCredit(1));
}

/**
* Validates concurrent credit check correctness.
*/
@Test
public void testRateLimiterConcurrency() {
int numWorkers = ForkJoinPool.getCommonPoolParallelism();
int creditsPerWorker = 1000;
MockClock clock = new MockClock();
RateLimiter limiter = new RateLimiter(1, numWorkers * creditsPerWorker, clock);

AtomicInteger count = new AtomicInteger();
for (int w = 0; w < numWorkers; ++w) {
ForkJoinPool.commonPool().execute(() -> {
for (int i = 0; i < creditsPerWorker * 2; ++i) {
if (limiter.checkCredit(1)) {
count.getAndIncrement(); // count allowed operations
}
}
});
}
ForkJoinPool.commonPool().awaitQuiescence(1, TimeUnit.SECONDS);

assertEquals("Exactly the allocated number of credits must be consumed", numWorkers * creditsPerWorker,count.get());
assertFalse(limiter.checkCredit(1));
}

}