Skip to content

Commit

Permalink
intermediate commit, implementing gradient wrt node height for smooth…
Browse files Browse the repository at this point in the history
… skygrid
  • Loading branch information
xji3 committed Jul 17, 2023
1 parent 385ca76 commit 8540f36
Showing 1 changed file with 142 additions and 0 deletions.
142 changes: 142 additions & 0 deletions src/dr/evomodel/coalescent/smooth/SmoothSkygridLikelihood.java
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ public SmoothSkygridLikelihood(String name,
this.tmpA = new double[trees.get(0).getNodeCount()];
this.tmpB = new double[trees.get(0).getNodeCount()];
this.tmpC = new double[trees.get(0).getNodeCount()];
this.tmpADerivOverS = new double[trees.get(0).getNodeCount()];
this.tmpBDerivOverS = new double[trees.get(0).getNodeCount()];
this.tmpCDerivOverS = new double[trees.get(0).getNodeCount()];
this.tmpD = new double[gridPointParameter.getDimension()];
this.tmpE = new double[gridPointParameter.getDimension()];
this.tmpF = new double[gridPointParameter.getDimension()];
Expand Down Expand Up @@ -270,8 +273,11 @@ private double getLineageCountDifference(int intervalIndex, BigFastTreeIntervals
}

private double[] tmpA;
private double[] tmpADerivOverS;
private double[] tmpB;
private double[] tmpBDerivOverS;
private double[] tmpC;
private double[] tmpCDerivOverS;
private double[] tmpD;
private double[] tmpE;
private double[] tmpF;
Expand Down Expand Up @@ -372,6 +378,49 @@ private void calculateTmpSums() {
}
}

private void calculateTmpSumDerivatives() {
if (!tmpSumsKnown) {
calculateTmpSums();
}

TreeModel tree = trees.get(0);
final double startTime = 0;
final double endTime = tree.getNodeHeight(tree.getRoot());
final int maxGridIndex = getMaxGridIndex(gridPointParameter, endTime);

for (int i = 0; i < uniqueTimes; i++) {
final double timeI = tmpTimes[i];
double sum = 0;
for (int j = 0; j < uniqueTimes; j++) {
if (j != i) {
final double timeJ = tmpTimes[j];
final double lineageCountEffect = tmpLineageEffect[j];
final double thisInverse = smoothFunction.getInverseOneMinusExponential(timeJ - timeI, smoothRate.getParameterValue(0));
sum += lineageCountEffect * thisInverse * (1 - thisInverse);
}
}
tmpADerivOverS[i] = - sum;
}

for (int i = 0; i < uniqueTimes; i++) {
final double timeI = tmpTimes[i];
double sum = 0;
for (int k = 0; k < maxGridIndex; k++) {
final double currentPopSizeInverse = Math.exp(-logPopSizeParameter.getParameterValue(k));
final double nextPopSizeInverse = Math.exp(-logPopSizeParameter.getParameterValue(k + 1));
final double gridTime = gridPointParameter.getParameterValue(k);
final double thisInverse = smoothFunction.getInverseOneMinusExponential(gridTime - timeI, smoothRate.getParameterValue(0));
sum += (nextPopSizeInverse - currentPopSizeInverse) * thisInverse * (1 - thisInverse);
}
tmpBDerivOverS[i] = -sum;
}

for (int i = 0; i < uniqueTimes; i++) {
final double timeI = tmpTimes[i];
tmpCDerivOverS[i] = smoothFunction.getSingleIntegrationDerivative(startTime, endTime, timeI, smoothRate.getParameterValue(0));
}
}

protected double calculateLogLikelihood() {
assert(trees.size() == 1);
if (!likelihoodKnown) {
Expand Down Expand Up @@ -406,6 +455,26 @@ protected double calculateLogLikelihood() {
return logLikelihood;
}

private double[] getGradientWrtNodeHeightNew() {
if (!likelihoodKnown) {
calculateLogLikelihood();
}
TreeModel tree = trees.get(0);
final double startTime = 0;
final double endTime = tree.getNodeHeight(tree.getRoot());
double[] gradient = new double[tree.getInternalNodeCount()];
getGradientWrtNodeHeightFromSingleIntegration(startTime, endTime, gradient);

double lineageEffectSquaredSum = 0;
for (int i = 0; i < uniqueTimes; i++) {
lineageEffectSquaredSum += tmpLineageEffect[i] * tmpLineageEffect[i];
}
getGradientWrtNodeHeightFromDoubleIntegration(startTime, endTime, getMaxGridIndex(gridPointParameter, endTime), gradient);

getGradientWrtNodeHeightFromTripleIntegration(startTime, endTime, getMaxGridIndex(gridPointParameter, endTime), gradient);
return gradient;
}

double getTripleIntegration(double startTime, double endTime, int maxGridIndex, double lineageEffectSquaredSum) {
double tripleIntegrationSum = 0;
for (int i = 0; i < uniqueTimes; i++) {
Expand Down Expand Up @@ -449,6 +518,42 @@ protected double calculateLogLikelihood() {
return tripleIntegrationSum + tripleWithQuadraticIntegrationSum;
}

private void getGradientWrtNodeHeightFromTripleIntegration(double startTime, double endTime, int maxGridIndex,
double[] gradient) {
for (int i = 0; i < uniqueTimes; i++) {
final double lineageCountEffect = tmpLineageEffect[i];
final double timeI = tmpTimes[i];
gradient[i] += lineageCountEffect * (tmpADerivOverS[i] * tmpB[i] * tmpC[i] + tmpA[i] * tmpBDerivOverS[i] * tmpC[i] + tmpA[i] * tmpB[i] * tmpCDerivOverS[i]);
for (int k = 0; k < maxGridIndex; k++) {
final double currentPopSizeInverse = Math.exp(-logPopSizeParameter.getParameterValue(k));
final double nextPopSizeInverse = Math.exp(-logPopSizeParameter.getParameterValue(k + 1));
final double gridTime = gridPointParameter.getParameterValue(k);
final double tmpEInverse = smoothFunction.getInverseOneMinusExponential(timeI - gridTime, smoothRate.getParameterValue(0));

gradient[i] += (nextPopSizeInverse - currentPopSizeInverse) * tmpD[k] * (tmpE[k] - lineageCountEffect * tmpEInverse ) * tmpEInverse * (1 - tmpEInverse) * lineageCountEffect;
}


final double startTimeInverse = smoothFunction.getInverseOnePlusExponential(timeI - startTime, smoothRate.getParameterValue(0));
final double endTimeInverse = smoothFunction.getInverseOnePlusExponential(timeI - startTime, smoothRate.getParameterValue(0));
final double commonSecondTermMultiplier = startTimeInverse - endTimeInverse;
final double commonSecondTermMultiplierDerivativeOverS = - startTimeInverse * (1 - startTimeInverse) + endTimeInverse * (1 - endTimeInverse);

for (int k = 0; k < maxGridIndex; k++) {
final double currentPopSizeInverse = Math.exp(-logPopSizeParameter.getParameterValue(k));
final double nextPopSizeInverse = Math.exp(-logPopSizeParameter.getParameterValue(k + 1));
final double gridTime = gridPointParameter.getParameterValue(k);
final double inverse = smoothFunction.getInverseOneMinusExponential(gridTime - timeI, smoothRate.getParameterValue(0));
final double inverseDerivativeOverS = -inverse * (1 - inverse);
gradient[i] += (nextPopSizeInverse - currentPopSizeInverse)
* (inverseDerivativeOverS * commonSecondTermMultiplier + inverse * commonSecondTermMultiplierDerivativeOverS +
2 * (1 - inverse) * inverseDerivativeOverS * tmpC[i] + (2.0 - inverse) * inverse * tmpCDerivOverS[i] +
2 * (1 - inverse) * (-inverseDerivativeOverS) * tmpD[k]);
}
}

}

double getDoubleIntegration(double startTime, double endTime, int maxGridIndex, double lineageEffectSquaredSum) {
double firstDoubleIntegrationOffDiagonalSum = 0;
double firstDoubleIntegrationDiagonalSum = 0;
Expand Down Expand Up @@ -482,6 +587,34 @@ protected double calculateLogLikelihood() {

}

private void getGradientWrtNodeHeightFromDoubleIntegration(double startTime, double endTime, int maxGridIndex,
double[] gradient) {
final double firstPopSize = Math.exp(-logPopSizeParameter.getParameterValue(0));
for (int i = 0; i < uniqueTimes; i++) {
final double lineageCountEffect = tmpLineageEffect[i];
final double timeI = tmpTimes[i];
//firstDoubleIntegrationOffDiagonalSum += lineageCountEffect * tmpA[i] * tmpC[i];
gradient[i] += -lineageCountEffect * (tmpA[i] * tmpCDerivOverS[i] + tmpADerivOverS[i] * tmpC[i]) * firstPopSize;

//firstDoubleIntegrationDiagonalSum
gradient[i] += lineageCountEffect * lineageCountEffect
* (smoothFunction.getSingleIntegrationDerivative(startTime, endTime, timeI, smoothRate.getParameterValue(0))
+ (smoothFunction.getDerivative(timeI, endTime, 0, 1, smoothRate.getParameterValue(0))
- smoothFunction.getDerivative(timeI, startTime, 0, 1, smoothRate.getParameterValue(0)) / smoothRate.getParameterValue(0))
) * -0.5 * firstPopSize;

gradient[i] += 0.5 * tmpLineageEffect[i] * (tmpB[i] * tmpCDerivOverS[i] + tmpBDerivOverS[i] * tmpC[i]);

for (int k = 0; k < maxGridIndex; k++) {
final double currentPopSizeInverse = Math.exp(-logPopSizeParameter.getParameterValue(k));
final double nextPopSizeInverse = Math.exp(-logPopSizeParameter.getParameterValue(k + 1));
final double gridTime = gridPointParameter.getParameterValue(k);
final double tmpEInverse = smoothFunction.getInverseOneMinusExponential(timeI - gridTime, smoothRate.getParameterValue(0));
gradient[i] += 0.5 * tmpD[k] * (nextPopSizeInverse - currentPopSizeInverse) * tmpEInverse * (1 - tmpEInverse) * lineageCountEffect;
}
}
}

private double getSingleIntegration(double startTime, double endTime) {
double singleIntegration = 0;
for (int i = 0; i < uniqueTimes; i++) {
Expand All @@ -493,6 +626,15 @@ private double getSingleIntegration(double startTime, double endTime) {
return singleIntegration;
}

private void getGradientWrtNodeHeightFromSingleIntegration(double startTime, double endTime, double[] gradient) {
for (int i = 0; i < uniqueTimes; i++) {
final double timeI = tmpTimes[i];
final double lineageCountEffectI = tmpLineageEffect[i];
gradient[i] += lineageCountEffectI * smoothFunction.getSingleIntegrationDerivative(startTime, endTime, timeI, smoothRate.getParameterValue(0))
* 0.5 * Math.exp(-logPopSizeParameter.getParameterValue(0));
}
}

protected void handleModelChangedEvent(Model model, Object object, int index) {
super.handleModelChangedEvent(model, object, index);
tmpSumsKnown = false;
Expand Down

0 comments on commit 8540f36

Please sign in to comment.