Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Energy forces #278

Merged
merged 37 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
7b56afd
Energy and Force Prediction changes (loss function in base, and optio…
RylieWeaver Sep 6, 2024
a277ee5
comments and renamings
RylieWeaver Sep 6, 2024
fd248b5
Black formatting and mak computational graph fixes
RylieWeaver Sep 6, 2024
c3eaedc
fix loss weighting
RylieWeaver Sep 6, 2024
5136b89
black formatting
RylieWeaver Sep 6, 2024
65f6f1d
black formatting
RylieWeaver Sep 6, 2024
3395484
black formatting
RylieWeaver Sep 6, 2024
c5d0c7b
Fix DIMEStack testing issues, and compute_grad_energy default in conf…
RylieWeaver Sep 6, 2024
fc5a554
LJ example added first draft
RylieWeaver Sep 16, 2024
0bce3d5
formatting
RylieWeaver Sep 16, 2024
271e36a
formatting
RylieWeaver Sep 16, 2024
9de9677
formatting
RylieWeaver Sep 16, 2024
a904def
formatting
RylieWeaver Sep 16, 2024
12c042e
formatting
RylieWeaver Sep 16, 2024
0f3fe54
formatting
RylieWeaver Sep 16, 2024
4df57c7
take out images and adjust unecessary import
RylieWeaver Sep 16, 2024
4779309
revert to SiLU in DimeNet
RylieWeaver Sep 16, 2024
456ebf7
don't create plots by default
RylieWeaver Sep 16, 2024
20f2e41
file cleanup and dataset test
RylieWeaver Sep 17, 2024
2da742e
unnecessary import
RylieWeaver Sep 17, 2024
70206dc
file cleanup
RylieWeaver Sep 17, 2024
7032acd
add some tests back in
RylieWeaver Sep 17, 2024
f98fb3e
Restore all tests
RylieWeaver Sep 17, 2024
d66de16
formatting
RylieWeaver Sep 17, 2024
206b52f
check dataset things
RylieWeaver Sep 17, 2024
9980834
formatting
RylieWeaver Sep 17, 2024
fa10c20
Revise paths to be more succinct and Use radius from config
RylieWeaver Sep 17, 2024
42551b8
file restructuring and using hydra radius graph function
RylieWeaver Sep 17, 2024
ee20cf7
formatting
RylieWeaver Sep 17, 2024
1fa7766
renaming
RylieWeaver Sep 17, 2024
e8a4a4d
remove qm9 test
RylieWeaver Sep 17, 2024
5d358a3
remove qm9 test
RylieWeaver Sep 17, 2024
06dad51
smaller number samples for test
RylieWeaver Sep 17, 2024
393976c
Update examples/LennardJones/LJ_inference_plots.py
allaffa Sep 17, 2024
85f44ab
use info function
RylieWeaver Sep 17, 2024
81992af
Unnecessary __init__ file
RylieWeaver Sep 17, 2024
bc93b9b
Unecessary json args
RylieWeaver Sep 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions hydragnn/models/Base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torch_geometric.nn import global_mean_pool, BatchNorm
from torch.nn import GaussianNLLLoss
from torch.utils.checkpoint import checkpoint
import torch_scatter
from hydragnn.utils.model import activation_function_selection, loss_function_selection
import sys
from hydragnn.utils.distributed import get_device
Expand Down Expand Up @@ -355,6 +356,60 @@ def loss(self, pred, value, head_index):
elif self.ilossweights_hyperp == 1:
return self.loss_hpweighted(pred, value, head_index, var=var)

def energy_force_loss(self, pred, data):
# Asserts
assert (
data.pos is not None and data.energy is not None and data.forces is not None
), "data.pos, data.energy, data.forces must be provided for energy-force loss. Check your dataset creation and naming."
assert (
data.pos.requires_grad
), "data.pos does not have grad, so force predictions cannot be computed. Check that data.pos has grad set to true before prediction."
assert (
self.num_heads == 1 and self.head_type[0] == "node"
), "Force predictions are only supported for models with one head that predict nodal energy. Check your num_heads and head_types."
# Initialize loss
tot_loss = 0
tasks_loss = []
# Energies
node_energy_pred = pred[0]
graph_energy_pred = torch_scatter.scatter_add(
node_energy_pred, data.batch, dim=0
).float()
graph_energy_true = data.energy
energy_loss_weight = self.loss_weights[
0
] # There should only be one loss-weight for energy
tot_loss += (
self.loss_function(graph_energy_pred, graph_energy_true)
* energy_loss_weight
)
tasks_loss.append(self.loss_function(graph_energy_pred, graph_energy_true))
# Forces
forces_true = data.forces.float()
forces_pred = torch.autograd.grad(
graph_energy_pred,
data.pos,
grad_outputs=torch.ones_like(graph_energy_pred),
retain_graph=graph_energy_pred.requires_grad, # Retain graph only if needed (it will be needed during training, but not during validation/testing)
create_graph=True,
)[0].float()
Comment on lines +389 to +395
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is grad_outputs used here? This would lead to sum of the gradients, right?

Copy link
Collaborator Author

@RylieWeaver RylieWeaver Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so.

Here's my understanding of it:
The grad_outputs argument essentially serves as a weight to multiply the gradients. This is necessary since graph_energy_pred is a higher dimension than just a scalar. The torch.ones_like() here will assign a weight of 1 to each gradient, which I believe is what we want. Otherwise, there would be a force prediction which is c*grad_E where c is a constant not equal to 1, which is nonphysical.

Specifically to your question:
No, I don't believe that it results in a sum. The output shape of forces_pred will be in the same shape as data.pos. The grad_outputs specifically is for weighting the gradients before a sum, but that extra step of summing is outside of autograd calculation I believe.

Does this answer your question?

@pzhanggit

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so.

Here's my understanding of it: The grad_outputs argument essentially serves as a weight to multiply the gradients. This is necessary since graph_energy_pred is a higher dimension than just a scalar. The torch.ones_like() here will assign a weight of 1 to each gradient, which I believe is what we want. Otherwise, there would be a force prediction which is c*grad_E where c is a constant not equal to 1, which is nonphysical.

Specifically to your question: No, I don't believe that it results in a sum. The output shape of forces_pred will be in the same shape as data.pos. The grad_outputs specifically is for weighting the gradients before a sum, but that extra step of summing is outside of autograd calculation I believe.

Does this answer your question?

@pzhanggit

Thanks. So just to verify if my understanding is correct: grad_outputs is needed because we're calculating gradients of a batch of samples here. And since grad_outputs should be a sequence of length matching output containing the “vector” in vector-Jacobian product, it provides a way to aggregate/sum the gradients across all the samples in the batch. This is equivalent to calculate the gradients iteratively for each sample, since the cross-gradients between samples would be zero.

Copy link
Collaborator Author

@RylieWeaver RylieWeaver Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pzhanggit

grad_outputs is needed because we're calculating gradients of a batch of samples here

Yep, that's right.

And since grad_outputs should be a sequence of length matching output containing the “vector” in vector-Jacobian product

I think so. Yes, grad_outputs should be a sequence, and it should be the same shape as whatever you're predicting. It then multiplies the matrix of gradients (row-wise, not normal matrix multiplication), which I think is the vector-Jacobian product you're referring to.

it provides a way to aggregate/sum the gradients across all the samples in the batch.

I think yes. It's scaling those gradients, which would be relevant in an aggregation/sum. Although, it does not do that aggregation/sum.

This is equivalent to calculate the gradients iteratively for each sample, since the cross-gradients between samples would be zero.

Yep.

Are there any parts still unclear?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pzhanggit

grad_outputs is needed because we're calculating gradients of a batch of samples here

Yep, that's right.

And since grad_outputs should be a sequence of length matching output containing the “vector” in vector-Jacobian product

I think so. Yes, grad_outputs should be a sequence, and it should be the same shape as whatever you're predicting. It then multiplies the matrix of gradients (row-wise, not normal matrix multiplication), which I think is the vector-Jacobian product you're referring to.

it provides a way to aggregate/sum the gradients across all the samples in the batch.

I think yes. It's scaling those gradients, which would be relevant in an aggregation/sum. Although, it does not do that aggregation/sum.

This is equivalent to calculate the gradients iteratively for each sample, since the cross-gradients between samples would be zero.

Yep.

Are there any parts still unclear?

Scaling those gradients?

Copy link
Collaborator

@allaffa allaffa Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pzhanggit
The scaling is done with torch.ones_like(graph_energy_pred)
In our case, we use vectors of ones because we do NOT want to scale. However, you could apply any multiplying factor (or even provide a customized vector with different values for each entry).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pzhanggit @allaffa ^ on what Max said. Also, feel free to @ me so In reply faster :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pzhanggit @RylieWeaver

I personally found the explanation given by ChatGPT very useful. I attach it here hoping that you will find it useful too.

Screenshot 2024-09-19 at 5 46 13 PM Screenshot 2024-09-19 at 5 46 24 PM Screenshot 2024-09-19 at 5 46 35 PM Screenshot 2024-09-19 at 5 46 44 PM Screenshot 2024-09-19 at 5 46 58 PM Screenshot 2024-09-19 at 5 47 21 PM

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep. Wow, the o1-preview is great

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks.

assert (
forces_pred is not None
), "No gradients were found for data.pos. Does your model use positions for prediction?"
forces_pred = -forces_pred
force_loss_weight = (
energy_loss_weight
* torch.mean(torch.abs(graph_energy_true))
/ (torch.mean(torch.abs(forces_true)) + 1e-8)
) # Weight force loss and graph energy equally
tot_loss += (
self.loss_function(forces_pred, forces_true) * force_loss_weight
) # Have force-weight be the complement to energy-weight
## FixMe: current loss functions require the number of heads to be the number of things being predicted
## so, we need to do loss calculation manually without calling the other functions.

return tot_loss, tasks_loss

def loss_nll(self, pred, value, head_index, var=None):
# negative log likelihood loss
# uncertainty to weigh losses in https://openaccess.thecvf.com/content_cvpr_2018/papers/Kendall_Multi-Task_Learning_Using_CVPR_2018_paper.pdf
Expand Down
4 changes: 2 additions & 2 deletions hydragnn/models/DIMEStack.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import torch
from torch import Tensor
from torch.nn import Identity, SiLU
from torch.nn import Identity, SiLU, Sigmoid

from torch_geometric.nn import Linear, Sequential
from torch_geometric.nn.models.dimenet import (
Expand Down Expand Up @@ -101,7 +101,7 @@ def get_conv(self, input_dim, output_dim):
out_emb_channels=self.out_emb_size,
out_channels=output_dim,
num_layers=1,
act=SiLU(),
act=Sigmoid(), # Sigmoid instead of SiLU here promotes stability when we have a linear decoder at the start of convolution (especially for random data examples in test_graphs.py)
allaffa marked this conversation as resolved.
Show resolved Hide resolved
output_initializer="glorot_orthogonal",
)
return Sequential(
Expand Down
1 change: 1 addition & 0 deletions hydragnn/run_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def _(config: dict, use_deepspeed=False):
plot_hist_solution,
create_plots,
use_deepspeed=use_deepspeed,
compute_grad_energy=config["NeuralNetwork"]["Training"]["compute_grad_energy"],
)

save_model(model, optimizer, log_name, use_deepspeed=use_deepspeed)
Expand Down
59 changes: 49 additions & 10 deletions hydragnn/train/train_validate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def train_validate_test(
plot_hist_solution=False,
create_plots=False,
use_deepspeed=False,
compute_grad_energy=False,
):
num_epoch = config["Training"]["num_epoch"]
EarlyStop = (
Expand Down Expand Up @@ -162,6 +163,7 @@ def train_validate_test(
verbosity,
profiler=prof,
use_deepspeed=use_deepspeed,
compute_grad_energy=compute_grad_energy,
)
tr.stop("train")
tr.disable()
Expand All @@ -172,14 +174,19 @@ def train_validate_test(
continue

val_loss, val_taskserr = validate(
val_loader, model, verbosity, reduce_ranks=True
val_loader,
model,
verbosity,
reduce_ranks=True,
compute_grad_energy=compute_grad_energy,
)
test_loss, test_taskserr, true_values, predicted_values = test(
test_loader,
model,
verbosity,
reduce_ranks=True,
return_samples=plot_hist_solution,
compute_grad_energy=compute_grad_energy,
)
scheduler.step(val_loss)
if writer is not None:
Expand Down Expand Up @@ -434,7 +441,15 @@ def gather_tensor_ranks(head_values):
return head_values


def train(loader, model, opt, verbosity, profiler=None, use_deepspeed=False):
def train(
loader,
model,
opt,
verbosity,
profiler=None,
use_deepspeed=False,
compute_grad_energy=False,
):
if profiler is None:
profiler = Profiler()

Expand Down Expand Up @@ -492,8 +507,13 @@ def train(loader, model, opt, verbosity, profiler=None, use_deepspeed=False):
data = data.to(get_device())
if trace_level > 0:
tr.stop("h2d", **syncopt)
pred = model(data)
loss, tasks_loss = model.module.loss(pred, data.y, head_index)
if compute_grad_energy: # for force and energy prediction
data.pos.requires_grad = True
pred = model(data)
loss, tasks_loss = model.module.energy_force_loss(pred, data)
else:
pred = model(data)
loss, tasks_loss = model.module.loss(pred, data.y, head_index)
if trace_level > 0:
tr.start("forward_sync", **syncopt)
MPI.COMM_WORLD.Barrier()
Expand Down Expand Up @@ -541,7 +561,7 @@ def train(loader, model, opt, verbosity, profiler=None, use_deepspeed=False):


@torch.no_grad()
def validate(loader, model, verbosity, reduce_ranks=True):
def validate(loader, model, verbosity, reduce_ranks=True, compute_grad_energy=False):

total_error = torch.tensor(0.0, device=get_device())
tasks_error = torch.zeros(model.module.num_heads, device=get_device())
Expand All @@ -565,8 +585,14 @@ def validate(loader, model, verbosity, reduce_ranks=True):
loader.dataset.ddstore.epoch_end()
head_index = get_head_indices(model, data)
data = data.to(get_device())
pred = model(data)
error, tasks_loss = model.module.loss(pred, data.y, head_index)
if compute_grad_energy: # for force and energy prediction
with torch.enable_grad():
data.pos.requires_grad = True
pred = model(data)
error, tasks_loss = model.module.energy_force_loss(pred, data)
else:
pred = model(data)
error, tasks_loss = model.module.loss(pred, data.y, head_index)
total_error += error * data.num_graphs
num_samples_local += data.num_graphs
for itask in range(len(tasks_loss)):
Expand All @@ -585,7 +611,14 @@ def validate(loader, model, verbosity, reduce_ranks=True):


@torch.no_grad()
def test(loader, model, verbosity, reduce_ranks=True, return_samples=True):
def test(
loader,
model,
verbosity,
reduce_ranks=True,
return_samples=True,
compute_grad_energy=False,
):

total_error = torch.tensor(0.0, device=get_device())
tasks_error = torch.zeros(model.module.num_heads, device=get_device())
Expand All @@ -612,8 +645,14 @@ def test(loader, model, verbosity, reduce_ranks=True, return_samples=True):
loader.dataset.ddstore.epoch_end()
head_index = get_head_indices(model, data)
data = data.to(get_device())
pred = model(data)
error, tasks_loss = model.module.loss(pred, data.y, head_index)
if compute_grad_energy: # for force and energy prediction
with torch.enable_grad():
data.pos.requires_grad = True
pred = model(data)
error, tasks_loss = model.module.energy_force_loss(pred, data)
else:
pred = model(data)
error, tasks_loss = model.module.loss(pred, data.y, head_index)
## FIXME: temporary
if int(os.getenv("HYDRAGNN_DUMP_TESTDATA", "0")) == 1:
if model.module.var_output:
Expand Down
11 changes: 8 additions & 3 deletions hydragnn/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ def update_config(config, train_loader, val_loader, test_loader):

if "conv_checkpointing" not in config["NeuralNetwork"]["Training"]:
config["NeuralNetwork"]["Training"]["conv_checkpointing"] = False

if "compute_grad_energy" not in config["NeuralNetwork"]["Training"]:
config["NeuralNetwork"]["Training"]["compute_grad_energy"] = False
return config


Expand Down Expand Up @@ -260,9 +263,11 @@ def get_log_name_config(config):
+ str(config["NeuralNetwork"]["Training"]["batch_size"])
+ "-data-"
+ config["Dataset"]["name"][
: config["Dataset"]["name"].rfind("_")
if config["Dataset"]["name"].rfind("_") > 0
else None
: (
config["Dataset"]["name"].rfind("_")
if config["Dataset"]["name"].rfind("_") > 0
else None
)
]
+ "-node_ft-"
+ "".join(
Expand Down
Loading