Skip to content

Commit

Permalink
Merge pull request #405 from kaituo/missing
Browse files Browse the repository at this point in the history
Fix confidence adjustment when all input values are missing
  • Loading branch information
kaituo committed Jul 12, 2024
2 parents 7158799 + 7eadb49 commit 0859252
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,31 @@ public float[] getScaledShingledInput(double[] inputPoint, long timestamp, int[]
*/
@Override
protected void updateTimestamps(long timestamp) {
if (previousTimeStamps[0] == previousTimeStamps[1]) {
/*
* For imputations done on timestamps other than the current one (specified by
* the timestamp parameter), the timestamp of the imputed tuple matches that of
* the input tuple, and we increment numberOfImputed. For imputations done at
* the current timestamp (if all input values are missing), the timestamp of the
* imputed tuple is the current timestamp, and we increment numberOfImputed.
*
* To check if imputed values are still present in the shingle, we use the first
* condition (previousTimeStamps[0] == previousTimeStamps[1]). This works
* because previousTimeStamps has a size equal to the shingle size and is filled
* with the current timestamp. However, there are scenarios where we might miss
* decrementing numberOfImputed:
*
* 1. Not all values in the shingle are imputed. 2. We accumulated
* numberOfImputed when the current timestamp had missing values.
*
* As a result, this could cause the data quality measure to decrease
* continuously since we are always counting missing values that should
* eventually be reset to zero. The second condition <pre> timestamp >
* previousTimeStamps[previousTimeStamps.length-1] && numberOfImputed > 0 </pre>
* will decrement numberOfImputed when we move to a new timestamp, provided
* numberOfImputed is greater than zero.
*/
if (previousTimeStamps[0] == previousTimeStamps[1]
|| (timestamp > previousTimeStamps[previousTimeStamps.length - 1] && numberOfImputed > 0)) {
numberOfImputed = numberOfImputed - 1;
}
super.updateTimestamps(timestamp);
Expand Down Expand Up @@ -333,7 +357,10 @@ protected float[] generateShingle(double[] inputTuple, long timestamp, int[] mis
}
}

updateForest(changeForest, input, timestamp, forest, false);
// last parameter isFullyImputed = if we miss everything in inputTuple?
// This would ensure dataQuality is decreasing if we impute whenever
updateForest(changeForest, input, timestamp, forest,
missingValues != null ? missingValues.length == inputTuple.length : false);
if (changeForest) {
updateTimeStampDeviations(timestamp, lastInputTimeStamp);
transformer.updateDeviation(input, savedInput, missingValues);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ public Preprocessor toModel(PreprocessorState state, long seed) {
preprocessor.setPreviousTimeStamps(state.getPreviousTimeStamps());
preprocessor.setNormalizeTime(state.isNormalizeTime());
preprocessor.setFastForward(state.isFastForward());
preprocessor.setNumberOfImputed(state.getNumberOfImputed());
return preprocessor;
}

Expand Down Expand Up @@ -94,6 +95,7 @@ public PreprocessorState toState(Preprocessor model) {
state.setTimeStampDeviationStates(getStates(model.getTimeStampDeviations(), deviationMapper));
state.setDataQualityStates(getStates(model.getDataQuality(), deviationMapper));
state.setFastForward(model.isFastForward());
state.setNumberOfImputed(model.getNumberOfImputed());
return state;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,5 @@ public class PreprocessorState implements Serializable {
private DeviationState[] dataQualityStates;
private DeviationState[] timeStampDeviationStates;
private boolean fastForward;
private int numberOfImputed;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
/*
* Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package com.amazon.randomcutforest.parkservices;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;

import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;

import com.amazon.randomcutforest.config.ForestMode;
import com.amazon.randomcutforest.config.ImputationMethod;
import com.amazon.randomcutforest.config.Precision;
import com.amazon.randomcutforest.config.TransformMethod;

public class MissingValueTest {
@ParameterizedTest
@EnumSource(ImputationMethod.class)
public void testConfidence(ImputationMethod method) {
// Create and populate a random cut forest

int shingleSize = 4;
int numberOfTrees = 50;
int sampleSize = 256;
Precision precision = Precision.FLOAT_32;
int baseDimensions = 1;

long count = 0;

int dimensions = baseDimensions * shingleSize;
ThresholdedRandomCutForest forest = new ThresholdedRandomCutForest.Builder<>().compact(true)
.dimensions(dimensions).randomSeed(0).numberOfTrees(numberOfTrees).shingleSize(shingleSize)
.sampleSize(sampleSize).precision(precision).anomalyRate(0.01).imputationMethod(method)
.fillValues(new double[] { 3 }).forestMode(ForestMode.STREAMING_IMPUTE)
.transformMethod(TransformMethod.NORMALIZE).autoAdjust(true).build();

// Define the size and range
int size = 400;
double min = 200.0;
double max = 240.0;

// Generate the list of doubles
List<Double> randomDoubles = generateUniformRandomDoubles(size, min, max);

double lastConfidence = 0;
for (double val : randomDoubles) {
double[] point = new double[] { val };
long newStamp = 100 * count;
if (count >= 300 && count < 325) {
// drop observations
AnomalyDescriptor result = forest.process(new double[] { Double.NaN }, newStamp,
generateIntArray(point.length));
if (count > 300) {
// confidence start decreasing after 1 missing point
assertTrue(result.getDataConfidence() < lastConfidence, "count " + count);
}
lastConfidence = result.getDataConfidence();
float[] rcfPoint = result.getRCFPoint();
double scale = result.getScale()[0];
double shift = result.getShift()[0];
double[] actual = new double[] { (rcfPoint[3] * scale) + shift };
if (method == ImputationMethod.ZERO) {
assertEquals(0, actual[0], 0.001d);
} else if (method == ImputationMethod.FIXED_VALUES) {
assertEquals(3.0d, actual[0], 0.001d);
}
} else {
AnomalyDescriptor result = forest.process(point, newStamp);
if ((count > 100 && count < 300) || count >= 326) {
// The first 65+ observations gives 0 confidence.
// Confidence start increasing after 1 observed point
assertTrue(result.getDataConfidence() > lastConfidence);
}
lastConfidence = result.getDataConfidence();
}
++count;
}
}

public static int[] generateIntArray(int size) {
int[] intArray = new int[size];
for (int i = 0; i < size; i++) {
intArray[i] = i;
}
return intArray;
}

public static List<Double> generateUniformRandomDoubles(int size, double min, double max) {
List<Double> randomDoubles = new ArrayList<>(size);
Random random = new Random(0);

for (int i = 0; i < size; i++) {
double randomValue = min + (max - min) * random.nextDouble();
randomDoubles.add(randomValue);
}

return randomDoubles;
}
}

0 comments on commit 0859252

Please sign in to comment.