From 97f09dd60dbebff06a1624b069c3dd062e3566b3 Mon Sep 17 00:00:00 2001 From: Xun Zhang Date: Fri, 5 Aug 2022 13:15:49 -0700 Subject: [PATCH 1/2] increase the default epochs to 1000 for linear regression Signed-off-by: Xun Zhang --- .../ml/engine/algorithms/regression/LinearRegression.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/regression/LinearRegression.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/regression/LinearRegression.java index 4a2b07aa90..42256355ff 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/regression/LinearRegression.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/regression/LinearRegression.java @@ -58,7 +58,7 @@ public class LinearRegression implements Trainable, Predictable { //RMSProp private static final double DEFAULT_DECAY_RATE = 0.9; - private static final int DEFAULT_EPOCHS = 10; + private static final int DEFAULT_EPOCHS = 1000; private static final int DEFAULT_INTERVAL = -1; private static final int DEFAULT_BATCH_SIZE = 1; From 6a6974796a68dd942e4d933b90c2b6f71b46492e Mon Sep 17 00:00:00 2001 From: Xun Zhang Date: Mon, 8 Aug 2022 09:47:42 -0700 Subject: [PATCH 2/2] make more parameters configuration in linear regression Signed-off-by: Xun Zhang --- .../regression/LinearRegressionParams.java | 16 ++++++++++++++-- .../algorithms/regression/LinearRegression.java | 16 ++++++++++++++-- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/input/parameter/regression/LinearRegressionParams.java b/common/src/main/java/org/opensearch/ml/common/input/parameter/regression/LinearRegressionParams.java index b97cd3449b..a8962fc8ed 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/parameter/regression/LinearRegressionParams.java +++ b/common/src/main/java/org/opensearch/ml/common/input/parameter/regression/LinearRegressionParams.java @@ -44,6 +44,7 @@ public class LinearRegressionParams implements MLAlgoParams { public static final String DECAY_RATE_FIELD = "decay_rate"; public static final String EPOCHS_FIELD = "epochs"; public static final String BATCH_SIZE_FIELD = "batch_size"; + public static final String LOGGING_INTERVAL_FIELD = "logging_interval"; public static final String SEED_FIELD = "seed"; public static final String TARGET_FIELD = "target"; @@ -58,11 +59,12 @@ public class LinearRegressionParams implements MLAlgoParams { private Double decayRate; private Integer epochs; private Integer batchSize; + private Integer loggingInterval; private Long seed; private String target; @Builder(toBuilder = true) - public LinearRegressionParams(ObjectiveType objectiveType, OptimizerType optimizerType, Double learningRate, MomentumType momentumType, Double momentumFactor, Double epsilon, Double beta1, Double beta2, Double decayRate, Integer epochs, Integer batchSize, Long seed, String target) { + public LinearRegressionParams(ObjectiveType objectiveType, OptimizerType optimizerType, Double learningRate, MomentumType momentumType, Double momentumFactor, Double epsilon, Double beta1, Double beta2, Double decayRate, Integer epochs, Integer batchSize, Integer loggingInterval, Long seed, String target) { this.objectiveType = objectiveType; this.optimizerType = optimizerType; this.learningRate = learningRate; @@ -74,6 +76,7 @@ public LinearRegressionParams(ObjectiveType objectiveType, OptimizerType optimiz this.decayRate = decayRate; this.epochs = epochs; this.batchSize = batchSize; + this.loggingInterval = loggingInterval; this.seed = seed; this.target = target; } @@ -96,6 +99,7 @@ public LinearRegressionParams(StreamInput in) throws IOException { this.decayRate = in.readOptionalDouble(); this.epochs = in.readOptionalInt(); this.batchSize = in.readOptionalInt(); + this.loggingInterval = in.readOptionalInt(); this.seed = in.readOptionalLong(); this.target = in.readOptionalString(); } @@ -112,6 +116,7 @@ public static MLAlgoParams parse(XContentParser parser) throws IOException { Double decayRate = null; Integer epochs = null; Integer batchSize = null; + Integer loggingInterval = null; Long seed = null; String target = null; @@ -154,6 +159,9 @@ public static MLAlgoParams parse(XContentParser parser) throws IOException { case BATCH_SIZE_FIELD: batchSize = parser.intValue(false); break; + case LOGGING_INTERVAL_FIELD: + loggingInterval = parser.intValue(false); + break; case SEED_FIELD: seed = parser.longValue(false); break; @@ -165,7 +173,7 @@ public static MLAlgoParams parse(XContentParser parser) throws IOException { break; } } - return new LinearRegressionParams(objective, optimizerType, learningRate, momentumType, momentumFactor, epsilon, beta1, beta2,decayRate, epochs, batchSize, seed, target); + return new LinearRegressionParams(objective, optimizerType, learningRate, momentumType, momentumFactor, epsilon, beta1, beta2,decayRate, epochs, batchSize, loggingInterval, seed, target); } @Override @@ -201,6 +209,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalDouble(decayRate); out.writeOptionalInt(epochs); out.writeOptionalInt(batchSize); + out.writeOptionalInt(loggingInterval); out.writeOptionalLong(seed); out.writeOptionalString(target); } @@ -241,6 +250,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (batchSize != null) { builder.field(BATCH_SIZE_FIELD, batchSize); } + if (loggingInterval != null) { + builder.field(LOGGING_INTERVAL_FIELD, loggingInterval); + } if (seed != null) { builder.field(SEED_FIELD, seed); } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/regression/LinearRegression.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/regression/LinearRegression.java index 42256355ff..db524e5b63 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/regression/LinearRegression.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/regression/LinearRegression.java @@ -21,6 +21,7 @@ import org.opensearch.ml.engine.utils.TribuoUtil; import org.tribuo.MutableDataset; import org.tribuo.Prediction; +import org.tribuo.Trainer; import org.tribuo.math.StochasticGradientOptimiser; import org.tribuo.math.optimisers.AdaDelta; import org.tribuo.math.optimisers.AdaGrad; @@ -61,12 +62,15 @@ public class LinearRegression implements Trainable, Predictable { private static final int DEFAULT_EPOCHS = 1000; private static final int DEFAULT_INTERVAL = -1; private static final int DEFAULT_BATCH_SIZE = 1; + private static final Long DEFAULT_SEED = Trainer.DEFAULT_SEED; private LinearRegressionParams parameters; private StochasticGradientOptimiser optimiser; private RegressionObjective objective; - private long seed = System.currentTimeMillis(); + private int loggingInterval; + private int minibatchSize; + private long seed; public LinearRegression() {} @@ -177,6 +181,14 @@ private void validateParameters() { if (parameters.getBatchSize() != null && parameters.getBatchSize() < 0) { throw new IllegalArgumentException("MiniBatchSize should not be negative."); } + + if (parameters.getLoggingInterval() != null && parameters.getLoggingInterval() < -1) { + throw new IllegalArgumentException("Invalid Logging intervals"); + } + + loggingInterval = Optional.ofNullable(parameters.getLoggingInterval()).orElse(DEFAULT_INTERVAL); + minibatchSize = Optional.ofNullable(parameters.getBatchSize()).orElse(DEFAULT_BATCH_SIZE); + seed = Optional.ofNullable(parameters.getSeed()).orElse(DEFAULT_SEED); } @Override @@ -200,7 +212,7 @@ public Model train(DataFrame dataFrame) { MutableDataset trainDataset = TribuoUtil.generateDatasetWithTarget(dataFrame, new RegressionFactory(), "Linear regression training data from opensearch", TribuoOutputType.REGRESSOR, parameters.getTarget()); Integer epochs = Optional.ofNullable(parameters.getEpochs()).orElse(DEFAULT_EPOCHS); - LinearSGDTrainer linearSGDTrainer = new LinearSGDTrainer(objective, optimiser, epochs, DEFAULT_INTERVAL, DEFAULT_BATCH_SIZE, seed); + LinearSGDTrainer linearSGDTrainer = new LinearSGDTrainer(objective, optimiser, epochs, loggingInterval, minibatchSize, seed); org.tribuo.Model regressionModel = linearSGDTrainer.train(trainDataset); Model model = new Model(); model.setName(FunctionName.LINEAR_REGRESSION.name());