Skip to content

Commit

Permalink
Fix Confidence Adjustment for Larger Shingle Sizes
Browse files Browse the repository at this point in the history
This PR addresses further adjustments to the confidence calculation issue discussed in PR 405. While PR 405 successfully resolved the issue for a shingle size of 4, it did not achieve the same results for larger shingle sizes like 8.

Key Changes
1. Refinement of seenValues Calculation:
* Previously, the formula increased confidence even as numImputed (number of imputations seen) increased because seenValues (all values seen) also increased.
* This PR fixes the issue by counting only non-imputed values as seenValues.
2. Upper Bound for numImputed:
* The numImputed is now upper bounded to the shingle size.
* The impute fraction calculation, which uses numberOfImputed * 1.0 / shingleSize, now ensures the fraction does not exceed 1.
3. Decrementing numberOfImputed:
* The numberOfImputed is decremented when there is no imputation.
* Previously, numberOfImputed remained unchanged when there is an imputation as there was both an increment and a decrement, keeping the imputation fraction constant. This PR ensures the imputation fraction accurately reflects the current state. This adjustment ensures that the forest update decision, which relies on the imputation fraction, functions correctly. The forest is updated only when the imputation fraction is below the threshold of 0.5.

Testing
* Added test scenarios with various shingle sizes to verify the changes.

Signed-off-by: Kaituo Li <kaituo@amazon.com>
  • Loading branch information
kaituo committed Jul 31, 2024
1 parent 07aab4a commit 691071c
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -80,44 +80,6 @@ public float[] getScaledShingledInput(double[] inputPoint, long timestamp, int[]
return point;
}

/**
* the timestamps are now used to calculate the number of imputed tuples in the
* shingle
*
* @param timestamp the timestamp of the current input
*/
@Override
protected void updateTimestamps(long timestamp) {
/*
* For imputations done on timestamps other than the current one (specified by
* the timestamp parameter), the timestamp of the imputed tuple matches that of
* the input tuple, and we increment numberOfImputed. For imputations done at
* the current timestamp (if all input values are missing), the timestamp of the
* imputed tuple is the current timestamp, and we increment numberOfImputed.
*
* To check if imputed values are still present in the shingle, we use the first
* condition (previousTimeStamps[0] == previousTimeStamps[1]). This works
* because previousTimeStamps has a size equal to the shingle size and is filled
* with the current timestamp. However, there are scenarios where we might miss
* decrementing numberOfImputed:
*
* 1. Not all values in the shingle are imputed. 2. We accumulated
* numberOfImputed when the current timestamp had missing values.
*
* As a result, this could cause the data quality measure to decrease
* continuously since we are always counting missing values that should
* eventually be reset to zero. The second condition <pre> timestamp >
* previousTimeStamps[previousTimeStamps.length-1] && numberOfImputed > 0 </pre>
* will decrement numberOfImputed when we move to a new timestamp, provided
* numberOfImputed is greater than zero.
*/
if (previousTimeStamps[0] == previousTimeStamps[1]
|| (timestamp > previousTimeStamps[previousTimeStamps.length - 1] && numberOfImputed > 0)) {
numberOfImputed = numberOfImputed - 1;
}
super.updateTimestamps(timestamp);
}

/**
* decides if the forest should be updated, this is needed for imputation on the
* fly. The main goal of this function is to avoid runaway sequences where a
Expand All @@ -128,7 +90,10 @@ protected void updateTimestamps(long timestamp) {
*/
protected boolean updateAllowed() {
double fraction = numberOfImputed * 1.0 / (shingleSize);
if (numberOfImputed == shingleSize - 1 && previousTimeStamps[0] != previousTimeStamps[1]
if (fraction > 1) {
fraction = 1;
}
if (numberOfImputed >= shingleSize - 1 && previousTimeStamps[0] != previousTimeStamps[1]
&& (transformMethod == DIFFERENCE || transformMethod == NORMALIZE_DIFFERENCE)) {
// this shingle is disconnected from the previously seen values
// these transformations will have little meaning
Expand All @@ -144,6 +109,7 @@ protected boolean updateAllowed() {
// two different points).
return false;
}

dataQuality[0].update(1 - fraction);
return (fraction < useImputedFraction && internalTimeStamp >= shingleSize);
}
Expand All @@ -168,7 +134,9 @@ void updateForest(boolean changeForest, double[] input, long timestamp, RandomCu
updateShingle(input, scaledInput);
updateTimestamps(timestamp);
if (isFullyImputed) {
numberOfImputed = numberOfImputed + 1;
numberOfImputed = Math.min(numberOfImputed + 1, shingleSize);
} else if (numberOfImputed > 0) {
numberOfImputed = numberOfImputed - 1;
}
if (changeForest) {
if (forest.isInternalShinglingEnabled()) {
Expand All @@ -190,7 +158,9 @@ public void update(double[] point, float[] rcfPoint, long timestamp, int[] missi
return;
}
generateShingle(point, timestamp, missing, getTimeFactor(timeStampDeviations[1]), true, forest);
++valuesSeen;
if (missing == null || missing.length != point.length) {
++valuesSeen;
}
}

protected double getTimeFactor(Deviation deviation) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -961,7 +961,8 @@ public void setLastScore(double[] score) {
}

void validateIgnore(double[] shift, int length) {
checkArgument(shift.length == length, () -> String.format(Locale.ROOT, "has to be of length %d but is %d", length, shift.length));
checkArgument(shift.length == length,
() -> String.format(Locale.ROOT, "has to be of length %d but is %d", length, shift.length));
for (double element : shift) {
checkArgument(element >= 0, "has to be non-negative");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,38 @@
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Random;
import java.util.stream.Stream;

import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.ArgumentsProvider;
import org.junit.jupiter.params.provider.ArgumentsSource;

import com.amazon.randomcutforest.config.ForestMode;
import com.amazon.randomcutforest.config.ImputationMethod;
import com.amazon.randomcutforest.config.Precision;
import com.amazon.randomcutforest.config.TransformMethod;

public class MissingValueTest {
private static class EnumAndValueProvider implements ArgumentsProvider {
@Override
public Stream<? extends Arguments> provideArguments(ExtensionContext context) {
return Stream.of(ImputationMethod.PREVIOUS, ImputationMethod.ZERO, ImputationMethod.FIXED_VALUES)
.flatMap(method -> Stream.of(4, 8, 16) // Example shingle sizes
.map(shingleSize -> Arguments.of(method, shingleSize)));
}
}

@ParameterizedTest
@EnumSource(ImputationMethod.class)
public void testConfidence(ImputationMethod method) {
@ArgumentsSource(EnumAndValueProvider.class)
public void testConfidence(ImputationMethod method, int shingleSize) {
// Create and populate a random cut forest

int shingleSize = 4;
int numberOfTrees = 50;
int sampleSize = 256;
Precision precision = Precision.FLOAT_32;
Expand All @@ -45,11 +59,19 @@ public void testConfidence(ImputationMethod method) {
long count = 0;

int dimensions = baseDimensions * shingleSize;
ThresholdedRandomCutForest forest = new ThresholdedRandomCutForest.Builder<>().compact(true)
ThresholdedRandomCutForest.Builder forestBuilder = new ThresholdedRandomCutForest.Builder<>().compact(true)
.dimensions(dimensions).randomSeed(0).numberOfTrees(numberOfTrees).shingleSize(shingleSize)
.sampleSize(sampleSize).precision(precision).anomalyRate(0.01).imputationMethod(method)
.fillValues(new double[] { 3 }).forestMode(ForestMode.STREAMING_IMPUTE)
.transformMethod(TransformMethod.NORMALIZE).autoAdjust(true).build();
.forestMode(ForestMode.STREAMING_IMPUTE).transformMethod(TransformMethod.NORMALIZE).autoAdjust(true);

if (method == ImputationMethod.FIXED_VALUES) {
// we cannot pass fillValues when the method is not fixed values. Otherwise, we
// will impute
// filled in values irregardless of imputation method
forestBuilder.fillValues(new double[] { 3 });
}

ThresholdedRandomCutForest forest = forestBuilder.build();

// Define the size and range
int size = 400;
Expand All @@ -75,18 +97,36 @@ public void testConfidence(ImputationMethod method) {
float[] rcfPoint = result.getRCFPoint();
double scale = result.getScale()[0];
double shift = result.getShift()[0];
double[] actual = new double[] { (rcfPoint[3] * scale) + shift };
double[] actual = new double[] { (rcfPoint[shingleSize - 1] * scale) + shift };
if (method == ImputationMethod.ZERO) {
assertEquals(0, actual[0], 0.001d);
if (count == 300) {
assertTrue(result.getAnomalyGrade() > 0);
}
} else if (method == ImputationMethod.FIXED_VALUES) {
assertEquals(3.0d, actual[0], 0.001d);
if (count == 300) {
assertTrue(result.getAnomalyGrade() > 0);
}
} else if (method == ImputationMethod.PREVIOUS) {
assertEquals(0, result.getAnomalyGrade(), 0.001d,
"count: " + count + " actual: " + Arrays.toString(actual));
}
} else {
AnomalyDescriptor result = forest.process(point, newStamp);
if ((count > 100 && count < 300) || count >= 326) {
// after 325, we have a period of confidence decreasing. After that, confidence
// starts increasing again.
int backupPoint = 325 + shingleSize * 3 / 4;
if ((count > 100 && count < 300) || count >= backupPoint) {
// The first 65+ observations gives 0 confidence.
// Confidence start increasing after 1 observed point
assertTrue(result.getDataConfidence() > lastConfidence);
assertTrue(result.getDataConfidence() > lastConfidence,
String.format(Locale.ROOT, "count: %d, confidence: %f, last confidence: %f", count,
result.getDataConfidence(), lastConfidence));
} else if (count < backupPoint && count > 300) {
assertTrue(result.getDataConfidence() < lastConfidence,
String.format(Locale.ROOT, "count: %d, confidence: %f, last confidence: %f", count,
result.getDataConfidence(), lastConfidence));
}
lastConfidence = result.getDataConfidence();
}
Expand Down

0 comments on commit 691071c

Please sign in to comment.