Skip to content

Commit

Permalink
increase the default epochs to 1000 for linear regression (#394)
Browse files Browse the repository at this point in the history
* increase the default epochs to 1000 for linear regression

Signed-off-by: Xun Zhang <xunzh@amazon.com>

* make more parameters configuration in linear regression

Signed-off-by: Xun Zhang <xunzh@amazon.com>
  • Loading branch information
Zhangxunmt authored Aug 8, 2022
1 parent 23cbbb4 commit b760180
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ public class LinearRegressionParams implements MLAlgoParams {
public static final String DECAY_RATE_FIELD = "decay_rate";
public static final String EPOCHS_FIELD = "epochs";
public static final String BATCH_SIZE_FIELD = "batch_size";
public static final String LOGGING_INTERVAL_FIELD = "logging_interval";
public static final String SEED_FIELD = "seed";
public static final String TARGET_FIELD = "target";

Expand All @@ -58,11 +59,12 @@ public class LinearRegressionParams implements MLAlgoParams {
private Double decayRate;
private Integer epochs;
private Integer batchSize;
private Integer loggingInterval;
private Long seed;
private String target;

@Builder(toBuilder = true)
public LinearRegressionParams(ObjectiveType objectiveType, OptimizerType optimizerType, Double learningRate, MomentumType momentumType, Double momentumFactor, Double epsilon, Double beta1, Double beta2, Double decayRate, Integer epochs, Integer batchSize, Long seed, String target) {
public LinearRegressionParams(ObjectiveType objectiveType, OptimizerType optimizerType, Double learningRate, MomentumType momentumType, Double momentumFactor, Double epsilon, Double beta1, Double beta2, Double decayRate, Integer epochs, Integer batchSize, Integer loggingInterval, Long seed, String target) {
this.objectiveType = objectiveType;
this.optimizerType = optimizerType;
this.learningRate = learningRate;
Expand All @@ -74,6 +76,7 @@ public LinearRegressionParams(ObjectiveType objectiveType, OptimizerType optimiz
this.decayRate = decayRate;
this.epochs = epochs;
this.batchSize = batchSize;
this.loggingInterval = loggingInterval;
this.seed = seed;
this.target = target;
}
Expand All @@ -96,6 +99,7 @@ public LinearRegressionParams(StreamInput in) throws IOException {
this.decayRate = in.readOptionalDouble();
this.epochs = in.readOptionalInt();
this.batchSize = in.readOptionalInt();
this.loggingInterval = in.readOptionalInt();
this.seed = in.readOptionalLong();
this.target = in.readOptionalString();
}
Expand All @@ -112,6 +116,7 @@ public static MLAlgoParams parse(XContentParser parser) throws IOException {
Double decayRate = null;
Integer epochs = null;
Integer batchSize = null;
Integer loggingInterval = null;
Long seed = null;
String target = null;

Expand Down Expand Up @@ -154,6 +159,9 @@ public static MLAlgoParams parse(XContentParser parser) throws IOException {
case BATCH_SIZE_FIELD:
batchSize = parser.intValue(false);
break;
case LOGGING_INTERVAL_FIELD:
loggingInterval = parser.intValue(false);
break;
case SEED_FIELD:
seed = parser.longValue(false);
break;
Expand All @@ -165,7 +173,7 @@ public static MLAlgoParams parse(XContentParser parser) throws IOException {
break;
}
}
return new LinearRegressionParams(objective, optimizerType, learningRate, momentumType, momentumFactor, epsilon, beta1, beta2,decayRate, epochs, batchSize, seed, target);
return new LinearRegressionParams(objective, optimizerType, learningRate, momentumType, momentumFactor, epsilon, beta1, beta2,decayRate, epochs, batchSize, loggingInterval, seed, target);
}

@Override
Expand Down Expand Up @@ -201,6 +209,7 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalDouble(decayRate);
out.writeOptionalInt(epochs);
out.writeOptionalInt(batchSize);
out.writeOptionalInt(loggingInterval);
out.writeOptionalLong(seed);
out.writeOptionalString(target);
}
Expand Down Expand Up @@ -241,6 +250,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (batchSize != null) {
builder.field(BATCH_SIZE_FIELD, batchSize);
}
if (loggingInterval != null) {
builder.field(LOGGING_INTERVAL_FIELD, loggingInterval);
}
if (seed != null) {
builder.field(SEED_FIELD, seed);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.opensearch.ml.engine.utils.TribuoUtil;
import org.tribuo.MutableDataset;
import org.tribuo.Prediction;
import org.tribuo.Trainer;
import org.tribuo.math.StochasticGradientOptimiser;
import org.tribuo.math.optimisers.AdaDelta;
import org.tribuo.math.optimisers.AdaGrad;
Expand Down Expand Up @@ -58,15 +59,18 @@ public class LinearRegression implements Trainable, Predictable {
//RMSProp
private static final double DEFAULT_DECAY_RATE = 0.9;

private static final int DEFAULT_EPOCHS = 10;
private static final int DEFAULT_EPOCHS = 1000;
private static final int DEFAULT_INTERVAL = -1;
private static final int DEFAULT_BATCH_SIZE = 1;
private static final Long DEFAULT_SEED = Trainer.DEFAULT_SEED;

private LinearRegressionParams parameters;
private StochasticGradientOptimiser optimiser;
private RegressionObjective objective;

private long seed = System.currentTimeMillis();
private int loggingInterval;
private int minibatchSize;
private long seed;

public LinearRegression() {}

Expand Down Expand Up @@ -177,6 +181,14 @@ private void validateParameters() {
if (parameters.getBatchSize() != null && parameters.getBatchSize() < 0) {
throw new IllegalArgumentException("MiniBatchSize should not be negative.");
}

if (parameters.getLoggingInterval() != null && parameters.getLoggingInterval() < -1) {
throw new IllegalArgumentException("Invalid Logging intervals");
}

loggingInterval = Optional.ofNullable(parameters.getLoggingInterval()).orElse(DEFAULT_INTERVAL);
minibatchSize = Optional.ofNullable(parameters.getBatchSize()).orElse(DEFAULT_BATCH_SIZE);
seed = Optional.ofNullable(parameters.getSeed()).orElse(DEFAULT_SEED);
}

@Override
Expand All @@ -200,7 +212,7 @@ public Model train(DataFrame dataFrame) {
MutableDataset<Regressor> trainDataset = TribuoUtil.generateDatasetWithTarget(dataFrame, new RegressionFactory(),
"Linear regression training data from opensearch", TribuoOutputType.REGRESSOR, parameters.getTarget());
Integer epochs = Optional.ofNullable(parameters.getEpochs()).orElse(DEFAULT_EPOCHS);
LinearSGDTrainer linearSGDTrainer = new LinearSGDTrainer(objective, optimiser, epochs, DEFAULT_INTERVAL, DEFAULT_BATCH_SIZE, seed);
LinearSGDTrainer linearSGDTrainer = new LinearSGDTrainer(objective, optimiser, epochs, loggingInterval, minibatchSize, seed);
org.tribuo.Model<Regressor> regressionModel = linearSGDTrainer.train(trainDataset);
Model model = new Model();
model.setName(FunctionName.LINEAR_REGRESSION.name());
Expand Down

0 comments on commit b760180

Please sign in to comment.